241 lines
8.1 KiB
Python
241 lines
8.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
BEVFusion 正确的可视化脚本
|
||
使用正确的数据键名: masks_bev, boxes_3d
|
||
"""
|
||
|
||
import os
|
||
import pickle
|
||
import numpy as np
|
||
import cv2
|
||
from pathlib import Path
|
||
import argparse
|
||
from tqdm import tqdm
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as mpatches
|
||
import torch
|
||
|
||
|
||
# 分割类别名称(nuScenes定义)
|
||
SEGMENTATION_CLASSES = [
|
||
'drivable_area', # 0
|
||
'ped_crossing', # 1
|
||
'walkway', # 2
|
||
'stop_line', # 3
|
||
'carpark_area', # 4
|
||
'divider', # 5
|
||
]
|
||
|
||
# 分割类别颜色
|
||
SEGMENTATION_COLORS = [
|
||
[255, 195, 128], # drivable_area - 橙色
|
||
[255, 99, 71], # ped_crossing - 番茄红
|
||
[255, 140, 0], # walkway - 深橙色
|
||
[180, 62, 39], # stop_line - 棕色
|
||
[233, 150, 70], # carpark_area - 棕橙色
|
||
[160, 82, 45], # divider - 褐色
|
||
]
|
||
|
||
|
||
class BEVVisualizer:
|
||
def __init__(self, result_file, frame_interval=1):
|
||
"""初始化可视化器"""
|
||
print(f"加载推理结果: {result_file}")
|
||
with open(result_file, 'rb') as f:
|
||
self.results = pickle.load(f)
|
||
print(f"成功加载 {len(self.results)} 个样本")
|
||
self.frame_interval = frame_interval
|
||
|
||
def visualize_single_bev(self, sample_idx, save_path=None, show_info=True):
|
||
"""可视化单个样本的BEV结果"""
|
||
result = self.results[sample_idx]
|
||
|
||
# 创建图像
|
||
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
|
||
fig.suptitle(f'Sample {sample_idx}', fontsize=16, fontweight='bold')
|
||
|
||
# 1. 可视化分割(BEV)
|
||
if 'masks_bev' in result:
|
||
masks_bev = result['masks_bev']
|
||
|
||
# 转换为numpy
|
||
if hasattr(masks_bev, 'cpu'):
|
||
masks_bev = masks_bev.cpu().numpy()
|
||
|
||
# (C, H, W) -> (H, W),取最大概率的类别
|
||
seg_map = np.argmax(masks_bev, axis=0)
|
||
|
||
# 创建彩色可视化
|
||
h, w = seg_map.shape
|
||
vis_img = np.zeros((h, w, 3), dtype=np.uint8)
|
||
|
||
for class_id in range(len(SEGMENTATION_CLASSES)):
|
||
mask = seg_map == class_id
|
||
if np.any(mask):
|
||
vis_img[mask] = SEGMENTATION_COLORS[class_id]
|
||
|
||
axes[0].imshow(vis_img)
|
||
axes[0].set_title('BEV Segmentation Map', fontsize=14)
|
||
axes[0].axis('off')
|
||
|
||
# 添加图例
|
||
patches = []
|
||
for i, cls_name in enumerate(SEGMENTATION_CLASSES):
|
||
color = np.array(SEGMENTATION_COLORS[i]) / 255.0
|
||
mask = seg_map == i
|
||
count = np.sum(mask)
|
||
if count > 0:
|
||
patches.append(mpatches.Patch(color=color,
|
||
label=f'{cls_name}: {count} px'))
|
||
|
||
if patches:
|
||
axes[0].legend(handles=patches, loc='upper right',
|
||
fontsize=10, framealpha=0.8)
|
||
else:
|
||
axes[0].text(0.5, 0.5, 'No Segmentation',
|
||
ha='center', va='center', fontsize=20)
|
||
axes[0].axis('off')
|
||
|
||
# 2. 可视化GT分割
|
||
if 'gt_masks_bev' in result:
|
||
gt_masks = result['gt_masks_bev']
|
||
|
||
# 转换为numpy
|
||
if hasattr(gt_masks, 'cpu'):
|
||
gt_masks = gt_masks.cpu().numpy()
|
||
|
||
# 创建彩色可视化
|
||
h, w = gt_masks.shape[1], gt_masks.shape[2] if len(gt_masks.shape) == 3 else gt_masks.shape
|
||
if len(gt_masks.shape) == 3:
|
||
# (C, H, W)
|
||
gt_map = np.argmax(gt_masks, axis=0)
|
||
else:
|
||
gt_map = gt_masks
|
||
|
||
vis_img_gt = np.zeros((h, w, 3), dtype=np.uint8)
|
||
for class_id in range(min(len(SEGMENTATION_CLASSES), gt_map.max() + 1)):
|
||
mask = gt_map == class_id
|
||
if np.any(mask):
|
||
vis_img_gt[mask] = SEGMENTATION_COLORS[class_id]
|
||
|
||
axes[1].imshow(vis_img_gt)
|
||
axes[1].set_title('Ground Truth Segmentation', fontsize=14)
|
||
axes[1].axis('off')
|
||
else:
|
||
axes[1].text(0.5, 0.5, 'No GT Data',
|
||
ha='center', va='center', fontsize=20)
|
||
axes[1].axis('off')
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=100, bbox_inches='tight')
|
||
plt.close()
|
||
if show_info:
|
||
print(f"保存: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def generate_video(self, output_path, fps=10):
|
||
"""生成视频"""
|
||
print(f"开始生成视频: {output_path}")
|
||
|
||
# 计算采样索引
|
||
sample_indices = list(range(0, len(self.results), self.frame_interval))
|
||
print(f"总样本数: {len(self.results)}")
|
||
print(f"帧间隔: {self.frame_interval}")
|
||
print(f"将生成: {len(sample_indices)} 帧")
|
||
print()
|
||
|
||
# 创建临时目录
|
||
temp_dir = Path(output_path).parent / 'temp_frames'
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
# 生成所有帧
|
||
frames = []
|
||
for idx in tqdm(sample_indices, desc="生成帧"):
|
||
frame_path = temp_dir / f'frame_{idx:05d}.png'
|
||
self.visualize_single_bev(idx, save_path=frame_path, show_info=False)
|
||
frames.append(str(frame_path))
|
||
|
||
print()
|
||
print(f"生成了 {len(frames)} 帧,开始合成视频...")
|
||
|
||
# 读取第一帧确定尺寸
|
||
first_frame = cv2.imread(frames[0])
|
||
h, w = first_frame.shape[:2]
|
||
|
||
# 创建视频写入器
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||
|
||
# 写入所有帧
|
||
for frame_path in tqdm(frames, desc="写入视频"):
|
||
frame = cv2.imread(frame_path)
|
||
video_writer.write(frame)
|
||
|
||
video_writer.release()
|
||
print(f"✅ 视频已保存: {output_path}")
|
||
|
||
# 清理临时文件
|
||
print("清理临时文件...")
|
||
for frame_path in frames:
|
||
os.remove(frame_path)
|
||
temp_dir.rmdir()
|
||
print("✅ 完成!")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='BEVFusion结果可视化(修复版)')
|
||
parser.add_argument('--result-file', type=str,
|
||
default='results_epoch19.pkl',
|
||
help='推理结果文件')
|
||
parser.add_argument('--output-dir', type=str,
|
||
default='visualizations',
|
||
help='输出目录')
|
||
parser.add_argument('--mode', type=str,
|
||
choices=['quick', 'video'],
|
||
default='quick',
|
||
help='可视化模式')
|
||
parser.add_argument('--num-samples', type=int, default=10,
|
||
help='快速可视化的样本数')
|
||
parser.add_argument('--video-fps', type=int, default=10,
|
||
help='视频帧率')
|
||
parser.add_argument('--frame-interval', type=int, default=6,
|
||
help='帧间隔(下采样)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 初始化可视化器
|
||
visualizer = BEVVisualizer(
|
||
result_file=args.result_file,
|
||
frame_interval=args.frame_interval
|
||
)
|
||
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
# 执行可视化
|
||
if args.mode == 'quick':
|
||
print(f"快速可视化前 {args.num_samples} 个样本...")
|
||
for idx in tqdm(range(args.num_samples)):
|
||
save_path = os.path.join(args.output_dir, f'sample_{idx:04d}_bev.png')
|
||
visualizer.visualize_single_bev(idx, save_path=save_path)
|
||
print(f"✅ 完成!图像保存在: {args.output_dir}/")
|
||
|
||
elif args.mode == 'video':
|
||
video_path = os.path.join(args.output_dir, 'bevfusion_results.mp4')
|
||
visualizer.generate_video(
|
||
output_path=video_path,
|
||
fps=args.video_fps
|
||
)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|
||
|
||
|
||
|
||
|
||
|