429 lines
16 KiB
Python
429 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
BEVFusion 推理结果可视化工具
|
||
支持生成图像和视频
|
||
"""
|
||
|
||
import os
|
||
import pickle
|
||
import numpy as np
|
||
import cv2
|
||
from pathlib import Path
|
||
import argparse
|
||
from tqdm import tqdm
|
||
import mmcv
|
||
from nuscenes.nuscenes import NuScenes
|
||
from nuscenes.utils.data_classes import LidarPointCloud, Box
|
||
from nuscenes.utils.geometry_utils import view_points, box_in_image, BoxVisibility
|
||
from pyquaternion import Quaternion
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as mpatches
|
||
|
||
|
||
# 分割类别颜色映射(与官方保持一致)
|
||
SEGMENTATION_COLORS = {
|
||
'drivable_area': (255, 195, 128), # 橙色
|
||
'ped_crossing': (255, 99, 71), # 番茄红
|
||
'walkway': (255, 140, 0), # 深橙色
|
||
'carpark': (233, 150, 70), # 棕橙色
|
||
'car': (220, 20, 60), # 深红色
|
||
'truck': (255, 61, 99), # 粉红色
|
||
'bus': (255, 140, 0), # 深橙色
|
||
'trailer': (255, 127, 80), # 珊瑚色
|
||
'construction_vehicle': (233, 150, 70), # 棕橙色
|
||
'pedestrian': (0, 0, 230), # 蓝色
|
||
'motorcycle': (255, 61, 99), # 粉红色
|
||
'bicycle': (112, 128, 144), # 板岩灰
|
||
'traffic_cone': (184, 134, 11), # 深金黄
|
||
'barrier': (160, 82, 45), # 褐色
|
||
}
|
||
|
||
|
||
class BEVFusionVisualizer:
|
||
def __init__(self, nusc_root, nusc_version='v1.0-trainval', max_samples=None):
|
||
"""
|
||
初始化可视化器
|
||
|
||
Args:
|
||
nusc_root: nuScenes 数据集根目录
|
||
nusc_version: nuScenes 版本
|
||
max_samples: 最大可视化样本数(用于快速测试)
|
||
"""
|
||
self.nusc = NuScenes(version=nusc_version, dataroot=nusc_root, verbose=True)
|
||
self.max_samples = max_samples
|
||
|
||
def load_results(self, result_path):
|
||
"""加载推理结果"""
|
||
print(f"加载推理结果: {result_path}")
|
||
with open(result_path, 'rb') as f:
|
||
self.results = pickle.load(f)
|
||
print(f"成功加载 {len(self.results)} 个样本的结果")
|
||
|
||
if self.max_samples:
|
||
self.results = self.results[:self.max_samples]
|
||
print(f"限制为前 {self.max_samples} 个样本")
|
||
|
||
def visualize_segmentation_on_camera(self, sample_idx, camera_name='CAM_FRONT',
|
||
alpha=0.5, save_path=None):
|
||
"""
|
||
在相机图像上可视化分割结果
|
||
|
||
Args:
|
||
sample_idx: 样本索引
|
||
camera_name: 相机名称
|
||
alpha: 叠加透明度
|
||
save_path: 保存路径
|
||
"""
|
||
result = self.results[sample_idx]
|
||
|
||
# 获取分割结果(BEV视角,200x200)
|
||
if 'pts_seg' in result:
|
||
seg_mask = result['pts_seg'] # (H, W) 或 (C, H, W)
|
||
else:
|
||
print(f"样本 {sample_idx} 没有分割结果")
|
||
return None
|
||
|
||
# 获取对应的 nuScenes sample
|
||
# 这里需要根据实际的数据结构来获取
|
||
# 暂时使用示例方法
|
||
sample_token = self.nusc.sample[sample_idx]['token']
|
||
sample = self.nusc.get('sample', sample_token)
|
||
|
||
# 获取相机数据
|
||
cam_data = self.nusc.get('sample_data', sample['data'][camera_name])
|
||
img_path = os.path.join(self.nusc.dataroot, cam_data['filename'])
|
||
img = cv2.imread(img_path)
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
|
||
# 创建分割掩码的可视化
|
||
seg_vis = self.create_segmentation_overlay(seg_mask, img.shape[:2])
|
||
|
||
# 叠加到图像上
|
||
overlay = cv2.addWeighted(img, 1 - alpha, seg_vis, alpha, 0)
|
||
|
||
if save_path:
|
||
plt.figure(figsize=(16, 9))
|
||
plt.imshow(overlay)
|
||
plt.title(f'Sample {sample_idx} - {camera_name} - Segmentation')
|
||
plt.axis('off')
|
||
plt.tight_layout()
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"保存到: {save_path}")
|
||
|
||
return overlay
|
||
|
||
def create_segmentation_overlay(self, seg_mask, target_shape):
|
||
"""
|
||
创建分割掩码的彩色可视化
|
||
|
||
Args:
|
||
seg_mask: 分割掩码 (H, W) 或 (C, H, W)
|
||
target_shape: 目标图像尺寸 (H, W)
|
||
"""
|
||
if len(seg_mask.shape) == 3:
|
||
# (C, H, W) -> (H, W),取最大概率的类别
|
||
seg_mask = np.argmax(seg_mask, axis=0)
|
||
|
||
# 创建RGB可视化
|
||
h, w = target_shape
|
||
vis = np.zeros((h, w, 3), dtype=np.uint8)
|
||
|
||
# 这里需要根据实际的类别映射来上色
|
||
# 简化版本:使用不同颜色表示不同类别
|
||
unique_labels = np.unique(seg_mask)
|
||
colors = plt.cm.get_cmap('tab10')(np.linspace(0, 1, len(unique_labels)))
|
||
|
||
for idx, label in enumerate(unique_labels):
|
||
if label == 0: # 背景
|
||
continue
|
||
mask = seg_mask == label
|
||
# 将BEV mask投影到相机视角(这里需要实际的投影矩阵)
|
||
# 简化版本:直接resize
|
||
mask_resized = cv2.resize(mask.astype(np.uint8), (w, h),
|
||
interpolation=cv2.INTER_NEAREST)
|
||
color = (colors[idx][:3] * 255).astype(np.uint8)
|
||
vis[mask_resized > 0] = color
|
||
|
||
return vis
|
||
|
||
def visualize_3d_detection(self, sample_idx, camera_name='CAM_FRONT',
|
||
save_path=None):
|
||
"""
|
||
可视化3D检测框
|
||
|
||
Args:
|
||
sample_idx: 样本索引
|
||
camera_name: 相机名称
|
||
save_path: 保存路径
|
||
"""
|
||
result = self.results[sample_idx]
|
||
|
||
# 获取检测结果
|
||
if 'pts_bbox' not in result:
|
||
print(f"样本 {sample_idx} 没有检测结果")
|
||
return None
|
||
|
||
boxes_3d = result['pts_bbox']['boxes_3d']
|
||
scores_3d = result['pts_bbox']['scores_3d']
|
||
labels_3d = result['pts_bbox']['labels_3d']
|
||
|
||
# 获取相机图像
|
||
sample_token = self.nusc.sample[sample_idx]['token']
|
||
sample = self.nusc.get('sample', sample_token)
|
||
cam_data = self.nusc.get('sample_data', sample['data'][camera_name])
|
||
img_path = os.path.join(self.nusc.dataroot, cam_data['filename'])
|
||
img = cv2.imread(img_path)
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
|
||
# 绘制3D框(投影到2D)
|
||
# 这里需要获取相机内外参进行投影
|
||
# 简化版本:直接在BEV视角可视化
|
||
|
||
if save_path:
|
||
plt.figure(figsize=(16, 9))
|
||
plt.imshow(img)
|
||
plt.title(f'Sample {sample_idx} - {camera_name} - 3D Detection')
|
||
plt.axis('off')
|
||
plt.tight_layout()
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"保存到: {save_path}")
|
||
|
||
return img
|
||
|
||
def visualize_bev(self, sample_idx, save_path=None):
|
||
"""
|
||
在BEV视角可视化分割和检测结果
|
||
|
||
Args:
|
||
sample_idx: 样本索引
|
||
save_path: 保存路径
|
||
"""
|
||
result = self.results[sample_idx]
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
|
||
|
||
# 1. 可视化分割(BEV)
|
||
# 兼容不同的数据格式
|
||
seg_mask = None
|
||
if 'masks_bev' in result:
|
||
seg_mask = result['masks_bev']
|
||
elif 'pts_seg' in result:
|
||
seg_mask = result['pts_seg']
|
||
|
||
if seg_mask is not None:
|
||
# 转换为numpy
|
||
if hasattr(seg_mask, 'cpu'):
|
||
seg_mask = seg_mask.cpu().numpy()
|
||
|
||
if len(seg_mask.shape) == 3:
|
||
# (C, H, W) -> (H, W),取最大概率的类别
|
||
seg_mask = np.argmax(seg_mask, axis=0)
|
||
|
||
axes[0].imshow(seg_mask, cmap='tab20')
|
||
axes[0].set_title(f'BEV Segmentation - Sample {sample_idx}')
|
||
axes[0].axis('off')
|
||
else:
|
||
axes[0].text(0.5, 0.5, 'No Segmentation Data',
|
||
ha='center', va='center', fontsize=20)
|
||
axes[0].set_title('BEV Segmentation')
|
||
axes[0].axis('off')
|
||
|
||
# 2. 可视化检测(BEV)
|
||
boxes_3d = None
|
||
scores_3d = None
|
||
if 'boxes_3d' in result:
|
||
boxes_3d = result['boxes_3d']
|
||
scores_3d = result.get('scores_3d', None)
|
||
elif 'pts_bbox' in result:
|
||
# 创建空白BEV图
|
||
bev_range = 51.2 # 米
|
||
bev_size = 200
|
||
bev_img = np.ones((bev_size, bev_size, 3)) * 255
|
||
|
||
boxes_3d = result['pts_bbox']['boxes_3d']
|
||
scores_3d = result['pts_bbox']['scores_3d']
|
||
|
||
# 绘制检测框
|
||
# 这里需要将3D框投影到BEV
|
||
# 简化版本:显示检测数量
|
||
axes[1].imshow(bev_img.astype(np.uint8))
|
||
axes[1].set_title(f'BEV Detection ({len(boxes_3d)} boxes)')
|
||
axes[1].axis('off')
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
print(f"保存到: {save_path}")
|
||
plt.close()
|
||
else:
|
||
plt.show()
|
||
|
||
def generate_video(self, output_path, fps=10, camera_name='CAM_FRONT',
|
||
vis_type='both', frame_interval=1):
|
||
"""
|
||
生成可视化视频
|
||
|
||
Args:
|
||
output_path: 输出视频路径
|
||
fps: 帧率
|
||
camera_name: 相机名称
|
||
vis_type: 可视化类型 ('seg', 'det', 'both')
|
||
frame_interval: 帧间隔(1=所有帧,2=每隔1帧,6=每隔5帧)
|
||
"""
|
||
print(f"开始生成视频: {output_path}")
|
||
print(f"总样本数: {len(self.results)}")
|
||
print(f"帧间隔: {frame_interval} (将处理 {len(range(0, len(self.results), frame_interval))} 帧)")
|
||
|
||
# 创建临时目录保存帧
|
||
temp_dir = Path(output_path).parent / 'temp_frames'
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
frames = []
|
||
sample_indices = range(0, len(self.results), frame_interval)
|
||
for idx in tqdm(sample_indices, desc="生成帧"):
|
||
frame_path = temp_dir / f'frame_{idx:04d}.png'
|
||
|
||
if vis_type == 'seg':
|
||
self.visualize_segmentation_on_camera(idx, camera_name,
|
||
save_path=frame_path)
|
||
elif vis_type == 'det':
|
||
self.visualize_3d_detection(idx, camera_name, save_path=frame_path)
|
||
elif vis_type == 'bev':
|
||
self.visualize_bev(idx, save_path=frame_path)
|
||
else: # both
|
||
# 组合可视化
|
||
self.visualize_bev(idx, save_path=frame_path)
|
||
|
||
if frame_path.exists():
|
||
frames.append(str(frame_path))
|
||
|
||
# 使用 ffmpeg 或 opencv 生成视频
|
||
if frames:
|
||
print(f"合成视频,共 {len(frames)} 帧...")
|
||
first_frame = cv2.imread(frames[0])
|
||
h, w = first_frame.shape[:2]
|
||
|
||
# 使用 opencv VideoWriter
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||
|
||
for frame_path in tqdm(frames, desc="写入视频"):
|
||
frame = cv2.imread(frame_path)
|
||
video_writer.write(frame)
|
||
|
||
video_writer.release()
|
||
print(f"视频已保存到: {output_path}")
|
||
|
||
# 清理临时文件
|
||
for frame_path in frames:
|
||
os.remove(frame_path)
|
||
temp_dir.rmdir()
|
||
else:
|
||
print("没有生成任何帧!")
|
||
|
||
def quick_visualize(self, num_samples=10, output_dir='visualizations'):
|
||
"""
|
||
快速可视化前N个样本
|
||
|
||
Args:
|
||
num_samples: 可视化样本数
|
||
output_dir: 输出目录
|
||
"""
|
||
output_path = Path(output_dir)
|
||
output_path.mkdir(exist_ok=True)
|
||
|
||
num_samples = min(num_samples, len(self.results))
|
||
|
||
print(f"快速可视化前 {num_samples} 个样本...")
|
||
for idx in tqdm(range(num_samples)):
|
||
# BEV 可视化
|
||
bev_path = output_path / f'sample_{idx:04d}_bev.png'
|
||
self.visualize_bev(idx, save_path=bev_path)
|
||
|
||
# 相机可视化
|
||
cam_path = output_path / f'sample_{idx:04d}_camera.png'
|
||
self.visualize_segmentation_on_camera(idx, save_path=cam_path)
|
||
|
||
print(f"可视化完成!保存在: {output_dir}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='BEVFusion结果可视化')
|
||
parser.add_argument('--result-file', type=str,
|
||
default='results_epoch19.pkl',
|
||
help='推理结果文件')
|
||
parser.add_argument('--nusc-root', type=str,
|
||
default='data/nuscenes',
|
||
help='nuScenes数据集根目录')
|
||
parser.add_argument('--nusc-version', type=str,
|
||
default='v1.0-trainval',
|
||
help='nuScenes版本')
|
||
parser.add_argument('--output-dir', type=str,
|
||
default='visualizations',
|
||
help='输出目录')
|
||
parser.add_argument('--mode', type=str,
|
||
choices=['quick', 'video', 'all'],
|
||
default='quick',
|
||
help='可视化模式')
|
||
parser.add_argument('--num-samples', type=int, default=10,
|
||
help='快速可视化的样本数')
|
||
parser.add_argument('--video-fps', type=int, default=10,
|
||
help='视频帧率')
|
||
parser.add_argument('--frame-interval', type=int, default=1,
|
||
help='帧间隔(1=所有帧,6=每隔5帧采样)')
|
||
parser.add_argument('--camera', type=str, default='CAM_FRONT',
|
||
help='相机名称')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 初始化可视化器
|
||
max_samples = args.num_samples if args.mode == 'quick' else None
|
||
visualizer = BEVFusionVisualizer(
|
||
nusc_root=args.nusc_root,
|
||
nusc_version=args.nusc_version,
|
||
max_samples=max_samples
|
||
)
|
||
|
||
# 加载结果
|
||
visualizer.load_results(args.result_file)
|
||
|
||
# 执行可视化
|
||
if args.mode == 'quick':
|
||
visualizer.quick_visualize(
|
||
num_samples=args.num_samples,
|
||
output_dir=args.output_dir
|
||
)
|
||
elif args.mode == 'video':
|
||
video_path = os.path.join(args.output_dir, 'bevfusion_results.mp4')
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
visualizer.generate_video(
|
||
output_path=video_path,
|
||
fps=args.video_fps,
|
||
camera_name=args.camera,
|
||
vis_type='bev',
|
||
frame_interval=args.frame_interval
|
||
)
|
||
elif args.mode == 'all':
|
||
# 先快速可视化
|
||
visualizer.quick_visualize(
|
||
num_samples=args.num_samples,
|
||
output_dir=args.output_dir
|
||
)
|
||
# 再生成视频
|
||
video_path = os.path.join(args.output_dir, 'bevfusion_results.mp4')
|
||
visualizer.generate_video(
|
||
output_path=video_path,
|
||
fps=args.video_fps,
|
||
camera_name=args.camera,
|
||
vis_type='bev',
|
||
frame_interval=args.frame_interval
|
||
)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|