241 lines
8.1 KiB
Python
241 lines
8.1 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
BEVFusion 正确的可视化脚本
|
|||
|
|
使用正确的数据键名: masks_bev, boxes_3d
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import pickle
|
|||
|
|
import numpy as np
|
|||
|
|
import cv2
|
|||
|
|
from pathlib import Path
|
|||
|
|
import argparse
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
import matplotlib.patches as mpatches
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 分割类别名称(nuScenes定义)
|
|||
|
|
SEGMENTATION_CLASSES = [
|
|||
|
|
'drivable_area', # 0
|
|||
|
|
'ped_crossing', # 1
|
|||
|
|
'walkway', # 2
|
|||
|
|
'stop_line', # 3
|
|||
|
|
'carpark_area', # 4
|
|||
|
|
'divider', # 5
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 分割类别颜色
|
|||
|
|
SEGMENTATION_COLORS = [
|
|||
|
|
[255, 195, 128], # drivable_area - 橙色
|
|||
|
|
[255, 99, 71], # ped_crossing - 番茄红
|
|||
|
|
[255, 140, 0], # walkway - 深橙色
|
|||
|
|
[180, 62, 39], # stop_line - 棕色
|
|||
|
|
[233, 150, 70], # carpark_area - 棕橙色
|
|||
|
|
[160, 82, 45], # divider - 褐色
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BEVVisualizer:
|
|||
|
|
def __init__(self, result_file, frame_interval=1):
|
|||
|
|
"""初始化可视化器"""
|
|||
|
|
print(f"加载推理结果: {result_file}")
|
|||
|
|
with open(result_file, 'rb') as f:
|
|||
|
|
self.results = pickle.load(f)
|
|||
|
|
print(f"成功加载 {len(self.results)} 个样本")
|
|||
|
|
self.frame_interval = frame_interval
|
|||
|
|
|
|||
|
|
def visualize_single_bev(self, sample_idx, save_path=None, show_info=True):
|
|||
|
|
"""可视化单个样本的BEV结果"""
|
|||
|
|
result = self.results[sample_idx]
|
|||
|
|
|
|||
|
|
# 创建图像
|
|||
|
|
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
|
|||
|
|
fig.suptitle(f'Sample {sample_idx}', fontsize=16, fontweight='bold')
|
|||
|
|
|
|||
|
|
# 1. 可视化分割(BEV)
|
|||
|
|
if 'masks_bev' in result:
|
|||
|
|
masks_bev = result['masks_bev']
|
|||
|
|
|
|||
|
|
# 转换为numpy
|
|||
|
|
if hasattr(masks_bev, 'cpu'):
|
|||
|
|
masks_bev = masks_bev.cpu().numpy()
|
|||
|
|
|
|||
|
|
# (C, H, W) -> (H, W),取最大概率的类别
|
|||
|
|
seg_map = np.argmax(masks_bev, axis=0)
|
|||
|
|
|
|||
|
|
# 创建彩色可视化
|
|||
|
|
h, w = seg_map.shape
|
|||
|
|
vis_img = np.zeros((h, w, 3), dtype=np.uint8)
|
|||
|
|
|
|||
|
|
for class_id in range(len(SEGMENTATION_CLASSES)):
|
|||
|
|
mask = seg_map == class_id
|
|||
|
|
if np.any(mask):
|
|||
|
|
vis_img[mask] = SEGMENTATION_COLORS[class_id]
|
|||
|
|
|
|||
|
|
axes[0].imshow(vis_img)
|
|||
|
|
axes[0].set_title('BEV Segmentation Map', fontsize=14)
|
|||
|
|
axes[0].axis('off')
|
|||
|
|
|
|||
|
|
# 添加图例
|
|||
|
|
patches = []
|
|||
|
|
for i, cls_name in enumerate(SEGMENTATION_CLASSES):
|
|||
|
|
color = np.array(SEGMENTATION_COLORS[i]) / 255.0
|
|||
|
|
mask = seg_map == i
|
|||
|
|
count = np.sum(mask)
|
|||
|
|
if count > 0:
|
|||
|
|
patches.append(mpatches.Patch(color=color,
|
|||
|
|
label=f'{cls_name}: {count} px'))
|
|||
|
|
|
|||
|
|
if patches:
|
|||
|
|
axes[0].legend(handles=patches, loc='upper right',
|
|||
|
|
fontsize=10, framealpha=0.8)
|
|||
|
|
else:
|
|||
|
|
axes[0].text(0.5, 0.5, 'No Segmentation',
|
|||
|
|
ha='center', va='center', fontsize=20)
|
|||
|
|
axes[0].axis('off')
|
|||
|
|
|
|||
|
|
# 2. 可视化GT分割
|
|||
|
|
if 'gt_masks_bev' in result:
|
|||
|
|
gt_masks = result['gt_masks_bev']
|
|||
|
|
|
|||
|
|
# 转换为numpy
|
|||
|
|
if hasattr(gt_masks, 'cpu'):
|
|||
|
|
gt_masks = gt_masks.cpu().numpy()
|
|||
|
|
|
|||
|
|
# 创建彩色可视化
|
|||
|
|
h, w = gt_masks.shape[1], gt_masks.shape[2] if len(gt_masks.shape) == 3 else gt_masks.shape
|
|||
|
|
if len(gt_masks.shape) == 3:
|
|||
|
|
# (C, H, W)
|
|||
|
|
gt_map = np.argmax(gt_masks, axis=0)
|
|||
|
|
else:
|
|||
|
|
gt_map = gt_masks
|
|||
|
|
|
|||
|
|
vis_img_gt = np.zeros((h, w, 3), dtype=np.uint8)
|
|||
|
|
for class_id in range(min(len(SEGMENTATION_CLASSES), gt_map.max() + 1)):
|
|||
|
|
mask = gt_map == class_id
|
|||
|
|
if np.any(mask):
|
|||
|
|
vis_img_gt[mask] = SEGMENTATION_COLORS[class_id]
|
|||
|
|
|
|||
|
|
axes[1].imshow(vis_img_gt)
|
|||
|
|
axes[1].set_title('Ground Truth Segmentation', fontsize=14)
|
|||
|
|
axes[1].axis('off')
|
|||
|
|
else:
|
|||
|
|
axes[1].text(0.5, 0.5, 'No GT Data',
|
|||
|
|
ha='center', va='center', fontsize=20)
|
|||
|
|
axes[1].axis('off')
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
|
|||
|
|
if save_path:
|
|||
|
|
plt.savefig(save_path, dpi=100, bbox_inches='tight')
|
|||
|
|
plt.close()
|
|||
|
|
if show_info:
|
|||
|
|
print(f"保存: {save_path}")
|
|||
|
|
else:
|
|||
|
|
plt.show()
|
|||
|
|
|
|||
|
|
def generate_video(self, output_path, fps=10):
|
|||
|
|
"""生成视频"""
|
|||
|
|
print(f"开始生成视频: {output_path}")
|
|||
|
|
|
|||
|
|
# 计算采样索引
|
|||
|
|
sample_indices = list(range(0, len(self.results), self.frame_interval))
|
|||
|
|
print(f"总样本数: {len(self.results)}")
|
|||
|
|
print(f"帧间隔: {self.frame_interval}")
|
|||
|
|
print(f"将生成: {len(sample_indices)} 帧")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 创建临时目录
|
|||
|
|
temp_dir = Path(output_path).parent / 'temp_frames'
|
|||
|
|
temp_dir.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
# 生成所有帧
|
|||
|
|
frames = []
|
|||
|
|
for idx in tqdm(sample_indices, desc="生成帧"):
|
|||
|
|
frame_path = temp_dir / f'frame_{idx:05d}.png'
|
|||
|
|
self.visualize_single_bev(idx, save_path=frame_path, show_info=False)
|
|||
|
|
frames.append(str(frame_path))
|
|||
|
|
|
|||
|
|
print()
|
|||
|
|
print(f"生成了 {len(frames)} 帧,开始合成视频...")
|
|||
|
|
|
|||
|
|
# 读取第一帧确定尺寸
|
|||
|
|
first_frame = cv2.imread(frames[0])
|
|||
|
|
h, w = first_frame.shape[:2]
|
|||
|
|
|
|||
|
|
# 创建视频写入器
|
|||
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|||
|
|
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
|||
|
|
|
|||
|
|
# 写入所有帧
|
|||
|
|
for frame_path in tqdm(frames, desc="写入视频"):
|
|||
|
|
frame = cv2.imread(frame_path)
|
|||
|
|
video_writer.write(frame)
|
|||
|
|
|
|||
|
|
video_writer.release()
|
|||
|
|
print(f"✅ 视频已保存: {output_path}")
|
|||
|
|
|
|||
|
|
# 清理临时文件
|
|||
|
|
print("清理临时文件...")
|
|||
|
|
for frame_path in frames:
|
|||
|
|
os.remove(frame_path)
|
|||
|
|
temp_dir.rmdir()
|
|||
|
|
print("✅ 完成!")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(description='BEVFusion结果可视化(修复版)')
|
|||
|
|
parser.add_argument('--result-file', type=str,
|
|||
|
|
default='results_epoch19.pkl',
|
|||
|
|
help='推理结果文件')
|
|||
|
|
parser.add_argument('--output-dir', type=str,
|
|||
|
|
default='visualizations',
|
|||
|
|
help='输出目录')
|
|||
|
|
parser.add_argument('--mode', type=str,
|
|||
|
|
choices=['quick', 'video'],
|
|||
|
|
default='quick',
|
|||
|
|
help='可视化模式')
|
|||
|
|
parser.add_argument('--num-samples', type=int, default=10,
|
|||
|
|
help='快速可视化的样本数')
|
|||
|
|
parser.add_argument('--video-fps', type=int, default=10,
|
|||
|
|
help='视频帧率')
|
|||
|
|
parser.add_argument('--frame-interval', type=int, default=6,
|
|||
|
|
help='帧间隔(下采样)')
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 初始化可视化器
|
|||
|
|
visualizer = BEVVisualizer(
|
|||
|
|
result_file=args.result_file,
|
|||
|
|
frame_interval=args.frame_interval
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 执行可视化
|
|||
|
|
if args.mode == 'quick':
|
|||
|
|
print(f"快速可视化前 {args.num_samples} 个样本...")
|
|||
|
|
for idx in tqdm(range(args.num_samples)):
|
|||
|
|
save_path = os.path.join(args.output_dir, f'sample_{idx:04d}_bev.png')
|
|||
|
|
visualizer.visualize_single_bev(idx, save_path=save_path)
|
|||
|
|
print(f"✅ 完成!图像保存在: {args.output_dir}/")
|
|||
|
|
|
|||
|
|
elif args.mode == 'video':
|
|||
|
|
video_path = os.path.join(args.output_dir, 'bevfusion_results.mp4')
|
|||
|
|
visualizer.generate_video(
|
|||
|
|
output_path=video_path,
|
|||
|
|
fps=args.video_fps
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|