#!/usr/bin/env python3 """ BEVFusion 正确的可视化脚本 使用正确的数据键名: masks_bev, boxes_3d """ import os import pickle import numpy as np import cv2 from pathlib import Path import argparse from tqdm import tqdm import matplotlib.pyplot as plt import matplotlib.patches as mpatches import torch # 分割类别名称(nuScenes定义) SEGMENTATION_CLASSES = [ 'drivable_area', # 0 'ped_crossing', # 1 'walkway', # 2 'stop_line', # 3 'carpark_area', # 4 'divider', # 5 ] # 分割类别颜色 SEGMENTATION_COLORS = [ [255, 195, 128], # drivable_area - 橙色 [255, 99, 71], # ped_crossing - 番茄红 [255, 140, 0], # walkway - 深橙色 [180, 62, 39], # stop_line - 棕色 [233, 150, 70], # carpark_area - 棕橙色 [160, 82, 45], # divider - 褐色 ] class BEVVisualizer: def __init__(self, result_file, frame_interval=1): """初始化可视化器""" print(f"加载推理结果: {result_file}") with open(result_file, 'rb') as f: self.results = pickle.load(f) print(f"成功加载 {len(self.results)} 个样本") self.frame_interval = frame_interval def visualize_single_bev(self, sample_idx, save_path=None, show_info=True): """可视化单个样本的BEV结果""" result = self.results[sample_idx] # 创建图像 fig, axes = plt.subplots(1, 2, figsize=(20, 10)) fig.suptitle(f'Sample {sample_idx}', fontsize=16, fontweight='bold') # 1. 可视化分割(BEV) if 'masks_bev' in result: masks_bev = result['masks_bev'] # 转换为numpy if hasattr(masks_bev, 'cpu'): masks_bev = masks_bev.cpu().numpy() # (C, H, W) -> (H, W),取最大概率的类别 seg_map = np.argmax(masks_bev, axis=0) # 创建彩色可视化 h, w = seg_map.shape vis_img = np.zeros((h, w, 3), dtype=np.uint8) for class_id in range(len(SEGMENTATION_CLASSES)): mask = seg_map == class_id if np.any(mask): vis_img[mask] = SEGMENTATION_COLORS[class_id] axes[0].imshow(vis_img) axes[0].set_title('BEV Segmentation Map', fontsize=14) axes[0].axis('off') # 添加图例 patches = [] for i, cls_name in enumerate(SEGMENTATION_CLASSES): color = np.array(SEGMENTATION_COLORS[i]) / 255.0 mask = seg_map == i count = np.sum(mask) if count > 0: patches.append(mpatches.Patch(color=color, label=f'{cls_name}: {count} px')) if patches: axes[0].legend(handles=patches, loc='upper right', fontsize=10, framealpha=0.8) else: axes[0].text(0.5, 0.5, 'No Segmentation', ha='center', va='center', fontsize=20) axes[0].axis('off') # 2. 可视化GT分割 if 'gt_masks_bev' in result: gt_masks = result['gt_masks_bev'] # 转换为numpy if hasattr(gt_masks, 'cpu'): gt_masks = gt_masks.cpu().numpy() # 创建彩色可视化 h, w = gt_masks.shape[1], gt_masks.shape[2] if len(gt_masks.shape) == 3 else gt_masks.shape if len(gt_masks.shape) == 3: # (C, H, W) gt_map = np.argmax(gt_masks, axis=0) else: gt_map = gt_masks vis_img_gt = np.zeros((h, w, 3), dtype=np.uint8) for class_id in range(min(len(SEGMENTATION_CLASSES), gt_map.max() + 1)): mask = gt_map == class_id if np.any(mask): vis_img_gt[mask] = SEGMENTATION_COLORS[class_id] axes[1].imshow(vis_img_gt) axes[1].set_title('Ground Truth Segmentation', fontsize=14) axes[1].axis('off') else: axes[1].text(0.5, 0.5, 'No GT Data', ha='center', va='center', fontsize=20) axes[1].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=100, bbox_inches='tight') plt.close() if show_info: print(f"保存: {save_path}") else: plt.show() def generate_video(self, output_path, fps=10): """生成视频""" print(f"开始生成视频: {output_path}") # 计算采样索引 sample_indices = list(range(0, len(self.results), self.frame_interval)) print(f"总样本数: {len(self.results)}") print(f"帧间隔: {self.frame_interval}") print(f"将生成: {len(sample_indices)} 帧") print() # 创建临时目录 temp_dir = Path(output_path).parent / 'temp_frames' temp_dir.mkdir(exist_ok=True) # 生成所有帧 frames = [] for idx in tqdm(sample_indices, desc="生成帧"): frame_path = temp_dir / f'frame_{idx:05d}.png' self.visualize_single_bev(idx, save_path=frame_path, show_info=False) frames.append(str(frame_path)) print() print(f"生成了 {len(frames)} 帧,开始合成视频...") # 读取第一帧确定尺寸 first_frame = cv2.imread(frames[0]) h, w = first_frame.shape[:2] # 创建视频写入器 fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) # 写入所有帧 for frame_path in tqdm(frames, desc="写入视频"): frame = cv2.imread(frame_path) video_writer.write(frame) video_writer.release() print(f"✅ 视频已保存: {output_path}") # 清理临时文件 print("清理临时文件...") for frame_path in frames: os.remove(frame_path) temp_dir.rmdir() print("✅ 完成!") def main(): parser = argparse.ArgumentParser(description='BEVFusion结果可视化(修复版)') parser.add_argument('--result-file', type=str, default='results_epoch19.pkl', help='推理结果文件') parser.add_argument('--output-dir', type=str, default='visualizations', help='输出目录') parser.add_argument('--mode', type=str, choices=['quick', 'video'], default='quick', help='可视化模式') parser.add_argument('--num-samples', type=int, default=10, help='快速可视化的样本数') parser.add_argument('--video-fps', type=int, default=10, help='视频帧率') parser.add_argument('--frame-interval', type=int, default=6, help='帧间隔(下采样)') args = parser.parse_args() # 初始化可视化器 visualizer = BEVVisualizer( result_file=args.result_file, frame_interval=args.frame_interval ) os.makedirs(args.output_dir, exist_ok=True) # 执行可视化 if args.mode == 'quick': print(f"快速可视化前 {args.num_samples} 个样本...") for idx in tqdm(range(args.num_samples)): save_path = os.path.join(args.output_dir, f'sample_{idx:04d}_bev.png') visualizer.visualize_single_bev(idx, save_path=save_path) print(f"✅ 完成!图像保存在: {args.output_dir}/") elif args.mode == 'video': video_path = os.path.join(args.output_dir, 'bevfusion_results.mp4') visualizer.generate_video( output_path=video_path, fps=args.video_fps ) if __name__ == '__main__': main()