bev-project/visualize_bev_correct.py

241 lines
8.1 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
BEVFusion 正确的可视化脚本
使用正确的数据键名: masks_bev, boxes_3d
"""
import os
import pickle
import numpy as np
import cv2
from pathlib import Path
import argparse
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import torch
# 分割类别名称nuScenes定义
SEGMENTATION_CLASSES = [
'drivable_area', # 0
'ped_crossing', # 1
'walkway', # 2
'stop_line', # 3
'carpark_area', # 4
'divider', # 5
]
# 分割类别颜色
SEGMENTATION_COLORS = [
[255, 195, 128], # drivable_area - 橙色
[255, 99, 71], # ped_crossing - 番茄红
[255, 140, 0], # walkway - 深橙色
[180, 62, 39], # stop_line - 棕色
[233, 150, 70], # carpark_area - 棕橙色
[160, 82, 45], # divider - 褐色
]
class BEVVisualizer:
def __init__(self, result_file, frame_interval=1):
"""初始化可视化器"""
print(f"加载推理结果: {result_file}")
with open(result_file, 'rb') as f:
self.results = pickle.load(f)
print(f"成功加载 {len(self.results)} 个样本")
self.frame_interval = frame_interval
def visualize_single_bev(self, sample_idx, save_path=None, show_info=True):
"""可视化单个样本的BEV结果"""
result = self.results[sample_idx]
# 创建图像
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
fig.suptitle(f'Sample {sample_idx}', fontsize=16, fontweight='bold')
# 1. 可视化分割BEV
if 'masks_bev' in result:
masks_bev = result['masks_bev']
# 转换为numpy
if hasattr(masks_bev, 'cpu'):
masks_bev = masks_bev.cpu().numpy()
# (C, H, W) -> (H, W),取最大概率的类别
seg_map = np.argmax(masks_bev, axis=0)
# 创建彩色可视化
h, w = seg_map.shape
vis_img = np.zeros((h, w, 3), dtype=np.uint8)
for class_id in range(len(SEGMENTATION_CLASSES)):
mask = seg_map == class_id
if np.any(mask):
vis_img[mask] = SEGMENTATION_COLORS[class_id]
axes[0].imshow(vis_img)
axes[0].set_title('BEV Segmentation Map', fontsize=14)
axes[0].axis('off')
# 添加图例
patches = []
for i, cls_name in enumerate(SEGMENTATION_CLASSES):
color = np.array(SEGMENTATION_COLORS[i]) / 255.0
mask = seg_map == i
count = np.sum(mask)
if count > 0:
patches.append(mpatches.Patch(color=color,
label=f'{cls_name}: {count} px'))
if patches:
axes[0].legend(handles=patches, loc='upper right',
fontsize=10, framealpha=0.8)
else:
axes[0].text(0.5, 0.5, 'No Segmentation',
ha='center', va='center', fontsize=20)
axes[0].axis('off')
# 2. 可视化GT分割
if 'gt_masks_bev' in result:
gt_masks = result['gt_masks_bev']
# 转换为numpy
if hasattr(gt_masks, 'cpu'):
gt_masks = gt_masks.cpu().numpy()
# 创建彩色可视化
h, w = gt_masks.shape[1], gt_masks.shape[2] if len(gt_masks.shape) == 3 else gt_masks.shape
if len(gt_masks.shape) == 3:
# (C, H, W)
gt_map = np.argmax(gt_masks, axis=0)
else:
gt_map = gt_masks
vis_img_gt = np.zeros((h, w, 3), dtype=np.uint8)
for class_id in range(min(len(SEGMENTATION_CLASSES), gt_map.max() + 1)):
mask = gt_map == class_id
if np.any(mask):
vis_img_gt[mask] = SEGMENTATION_COLORS[class_id]
axes[1].imshow(vis_img_gt)
axes[1].set_title('Ground Truth Segmentation', fontsize=14)
axes[1].axis('off')
else:
axes[1].text(0.5, 0.5, 'No GT Data',
ha='center', va='center', fontsize=20)
axes[1].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=100, bbox_inches='tight')
plt.close()
if show_info:
print(f"保存: {save_path}")
else:
plt.show()
def generate_video(self, output_path, fps=10):
"""生成视频"""
print(f"开始生成视频: {output_path}")
# 计算采样索引
sample_indices = list(range(0, len(self.results), self.frame_interval))
print(f"总样本数: {len(self.results)}")
print(f"帧间隔: {self.frame_interval}")
print(f"将生成: {len(sample_indices)}")
print()
# 创建临时目录
temp_dir = Path(output_path).parent / 'temp_frames'
temp_dir.mkdir(exist_ok=True)
# 生成所有帧
frames = []
for idx in tqdm(sample_indices, desc="生成帧"):
frame_path = temp_dir / f'frame_{idx:05d}.png'
self.visualize_single_bev(idx, save_path=frame_path, show_info=False)
frames.append(str(frame_path))
print()
print(f"生成了 {len(frames)} 帧,开始合成视频...")
# 读取第一帧确定尺寸
first_frame = cv2.imread(frames[0])
h, w = first_frame.shape[:2]
# 创建视频写入器
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
# 写入所有帧
for frame_path in tqdm(frames, desc="写入视频"):
frame = cv2.imread(frame_path)
video_writer.write(frame)
video_writer.release()
print(f"✅ 视频已保存: {output_path}")
# 清理临时文件
print("清理临时文件...")
for frame_path in frames:
os.remove(frame_path)
temp_dir.rmdir()
print("✅ 完成!")
def main():
parser = argparse.ArgumentParser(description='BEVFusion结果可视化修复版')
parser.add_argument('--result-file', type=str,
default='results_epoch19.pkl',
help='推理结果文件')
parser.add_argument('--output-dir', type=str,
default='visualizations',
help='输出目录')
parser.add_argument('--mode', type=str,
choices=['quick', 'video'],
default='quick',
help='可视化模式')
parser.add_argument('--num-samples', type=int, default=10,
help='快速可视化的样本数')
parser.add_argument('--video-fps', type=int, default=10,
help='视频帧率')
parser.add_argument('--frame-interval', type=int, default=6,
help='帧间隔(下采样)')
args = parser.parse_args()
# 初始化可视化器
visualizer = BEVVisualizer(
result_file=args.result_file,
frame_interval=args.frame_interval
)
os.makedirs(args.output_dir, exist_ok=True)
# 执行可视化
if args.mode == 'quick':
print(f"快速可视化前 {args.num_samples} 个样本...")
for idx in tqdm(range(args.num_samples)):
save_path = os.path.join(args.output_dir, f'sample_{idx:04d}_bev.png')
visualizer.visualize_single_bev(idx, save_path=save_path)
print(f"✅ 完成!图像保存在: {args.output_dir}/")
elif args.mode == 'video':
video_path = os.path.join(args.output_dir, 'bevfusion_results.mp4')
visualizer.generate_video(
output_path=video_path,
fps=args.video_fps
)
if __name__ == '__main__':
main()