429 lines
16 KiB
Python
429 lines
16 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
BEVFusion 推理结果可视化工具
|
|||
|
|
支持生成图像和视频
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import pickle
|
|||
|
|
import numpy as np
|
|||
|
|
import cv2
|
|||
|
|
from pathlib import Path
|
|||
|
|
import argparse
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
import mmcv
|
|||
|
|
from nuscenes.nuscenes import NuScenes
|
|||
|
|
from nuscenes.utils.data_classes import LidarPointCloud, Box
|
|||
|
|
from nuscenes.utils.geometry_utils import view_points, box_in_image, BoxVisibility
|
|||
|
|
from pyquaternion import Quaternion
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
import matplotlib.patches as mpatches
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 分割类别颜色映射(与官方保持一致)
|
|||
|
|
SEGMENTATION_COLORS = {
|
|||
|
|
'drivable_area': (255, 195, 128), # 橙色
|
|||
|
|
'ped_crossing': (255, 99, 71), # 番茄红
|
|||
|
|
'walkway': (255, 140, 0), # 深橙色
|
|||
|
|
'carpark': (233, 150, 70), # 棕橙色
|
|||
|
|
'car': (220, 20, 60), # 深红色
|
|||
|
|
'truck': (255, 61, 99), # 粉红色
|
|||
|
|
'bus': (255, 140, 0), # 深橙色
|
|||
|
|
'trailer': (255, 127, 80), # 珊瑚色
|
|||
|
|
'construction_vehicle': (233, 150, 70), # 棕橙色
|
|||
|
|
'pedestrian': (0, 0, 230), # 蓝色
|
|||
|
|
'motorcycle': (255, 61, 99), # 粉红色
|
|||
|
|
'bicycle': (112, 128, 144), # 板岩灰
|
|||
|
|
'traffic_cone': (184, 134, 11), # 深金黄
|
|||
|
|
'barrier': (160, 82, 45), # 褐色
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BEVFusionVisualizer:
|
|||
|
|
def __init__(self, nusc_root, nusc_version='v1.0-trainval', max_samples=None):
|
|||
|
|
"""
|
|||
|
|
初始化可视化器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
nusc_root: nuScenes 数据集根目录
|
|||
|
|
nusc_version: nuScenes 版本
|
|||
|
|
max_samples: 最大可视化样本数(用于快速测试)
|
|||
|
|
"""
|
|||
|
|
self.nusc = NuScenes(version=nusc_version, dataroot=nusc_root, verbose=True)
|
|||
|
|
self.max_samples = max_samples
|
|||
|
|
|
|||
|
|
def load_results(self, result_path):
|
|||
|
|
"""加载推理结果"""
|
|||
|
|
print(f"加载推理结果: {result_path}")
|
|||
|
|
with open(result_path, 'rb') as f:
|
|||
|
|
self.results = pickle.load(f)
|
|||
|
|
print(f"成功加载 {len(self.results)} 个样本的结果")
|
|||
|
|
|
|||
|
|
if self.max_samples:
|
|||
|
|
self.results = self.results[:self.max_samples]
|
|||
|
|
print(f"限制为前 {self.max_samples} 个样本")
|
|||
|
|
|
|||
|
|
def visualize_segmentation_on_camera(self, sample_idx, camera_name='CAM_FRONT',
|
|||
|
|
alpha=0.5, save_path=None):
|
|||
|
|
"""
|
|||
|
|
在相机图像上可视化分割结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sample_idx: 样本索引
|
|||
|
|
camera_name: 相机名称
|
|||
|
|
alpha: 叠加透明度
|
|||
|
|
save_path: 保存路径
|
|||
|
|
"""
|
|||
|
|
result = self.results[sample_idx]
|
|||
|
|
|
|||
|
|
# 获取分割结果(BEV视角,200x200)
|
|||
|
|
if 'pts_seg' in result:
|
|||
|
|
seg_mask = result['pts_seg'] # (H, W) 或 (C, H, W)
|
|||
|
|
else:
|
|||
|
|
print(f"样本 {sample_idx} 没有分割结果")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 获取对应的 nuScenes sample
|
|||
|
|
# 这里需要根据实际的数据结构来获取
|
|||
|
|
# 暂时使用示例方法
|
|||
|
|
sample_token = self.nusc.sample[sample_idx]['token']
|
|||
|
|
sample = self.nusc.get('sample', sample_token)
|
|||
|
|
|
|||
|
|
# 获取相机数据
|
|||
|
|
cam_data = self.nusc.get('sample_data', sample['data'][camera_name])
|
|||
|
|
img_path = os.path.join(self.nusc.dataroot, cam_data['filename'])
|
|||
|
|
img = cv2.imread(img_path)
|
|||
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|||
|
|
|
|||
|
|
# 创建分割掩码的可视化
|
|||
|
|
seg_vis = self.create_segmentation_overlay(seg_mask, img.shape[:2])
|
|||
|
|
|
|||
|
|
# 叠加到图像上
|
|||
|
|
overlay = cv2.addWeighted(img, 1 - alpha, seg_vis, alpha, 0)
|
|||
|
|
|
|||
|
|
if save_path:
|
|||
|
|
plt.figure(figsize=(16, 9))
|
|||
|
|
plt.imshow(overlay)
|
|||
|
|
plt.title(f'Sample {sample_idx} - {camera_name} - Segmentation')
|
|||
|
|
plt.axis('off')
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|||
|
|
plt.close()
|
|||
|
|
print(f"保存到: {save_path}")
|
|||
|
|
|
|||
|
|
return overlay
|
|||
|
|
|
|||
|
|
def create_segmentation_overlay(self, seg_mask, target_shape):
|
|||
|
|
"""
|
|||
|
|
创建分割掩码的彩色可视化
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
seg_mask: 分割掩码 (H, W) 或 (C, H, W)
|
|||
|
|
target_shape: 目标图像尺寸 (H, W)
|
|||
|
|
"""
|
|||
|
|
if len(seg_mask.shape) == 3:
|
|||
|
|
# (C, H, W) -> (H, W),取最大概率的类别
|
|||
|
|
seg_mask = np.argmax(seg_mask, axis=0)
|
|||
|
|
|
|||
|
|
# 创建RGB可视化
|
|||
|
|
h, w = target_shape
|
|||
|
|
vis = np.zeros((h, w, 3), dtype=np.uint8)
|
|||
|
|
|
|||
|
|
# 这里需要根据实际的类别映射来上色
|
|||
|
|
# 简化版本:使用不同颜色表示不同类别
|
|||
|
|
unique_labels = np.unique(seg_mask)
|
|||
|
|
colors = plt.cm.get_cmap('tab10')(np.linspace(0, 1, len(unique_labels)))
|
|||
|
|
|
|||
|
|
for idx, label in enumerate(unique_labels):
|
|||
|
|
if label == 0: # 背景
|
|||
|
|
continue
|
|||
|
|
mask = seg_mask == label
|
|||
|
|
# 将BEV mask投影到相机视角(这里需要实际的投影矩阵)
|
|||
|
|
# 简化版本:直接resize
|
|||
|
|
mask_resized = cv2.resize(mask.astype(np.uint8), (w, h),
|
|||
|
|
interpolation=cv2.INTER_NEAREST)
|
|||
|
|
color = (colors[idx][:3] * 255).astype(np.uint8)
|
|||
|
|
vis[mask_resized > 0] = color
|
|||
|
|
|
|||
|
|
return vis
|
|||
|
|
|
|||
|
|
def visualize_3d_detection(self, sample_idx, camera_name='CAM_FRONT',
|
|||
|
|
save_path=None):
|
|||
|
|
"""
|
|||
|
|
可视化3D检测框
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sample_idx: 样本索引
|
|||
|
|
camera_name: 相机名称
|
|||
|
|
save_path: 保存路径
|
|||
|
|
"""
|
|||
|
|
result = self.results[sample_idx]
|
|||
|
|
|
|||
|
|
# 获取检测结果
|
|||
|
|
if 'pts_bbox' not in result:
|
|||
|
|
print(f"样本 {sample_idx} 没有检测结果")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
boxes_3d = result['pts_bbox']['boxes_3d']
|
|||
|
|
scores_3d = result['pts_bbox']['scores_3d']
|
|||
|
|
labels_3d = result['pts_bbox']['labels_3d']
|
|||
|
|
|
|||
|
|
# 获取相机图像
|
|||
|
|
sample_token = self.nusc.sample[sample_idx]['token']
|
|||
|
|
sample = self.nusc.get('sample', sample_token)
|
|||
|
|
cam_data = self.nusc.get('sample_data', sample['data'][camera_name])
|
|||
|
|
img_path = os.path.join(self.nusc.dataroot, cam_data['filename'])
|
|||
|
|
img = cv2.imread(img_path)
|
|||
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|||
|
|
|
|||
|
|
# 绘制3D框(投影到2D)
|
|||
|
|
# 这里需要获取相机内外参进行投影
|
|||
|
|
# 简化版本:直接在BEV视角可视化
|
|||
|
|
|
|||
|
|
if save_path:
|
|||
|
|
plt.figure(figsize=(16, 9))
|
|||
|
|
plt.imshow(img)
|
|||
|
|
plt.title(f'Sample {sample_idx} - {camera_name} - 3D Detection')
|
|||
|
|
plt.axis('off')
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|||
|
|
plt.close()
|
|||
|
|
print(f"保存到: {save_path}")
|
|||
|
|
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
def visualize_bev(self, sample_idx, save_path=None):
|
|||
|
|
"""
|
|||
|
|
在BEV视角可视化分割和检测结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sample_idx: 样本索引
|
|||
|
|
save_path: 保存路径
|
|||
|
|
"""
|
|||
|
|
result = self.results[sample_idx]
|
|||
|
|
|
|||
|
|
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
|
|||
|
|
|
|||
|
|
# 1. 可视化分割(BEV)
|
|||
|
|
# 兼容不同的数据格式
|
|||
|
|
seg_mask = None
|
|||
|
|
if 'masks_bev' in result:
|
|||
|
|
seg_mask = result['masks_bev']
|
|||
|
|
elif 'pts_seg' in result:
|
|||
|
|
seg_mask = result['pts_seg']
|
|||
|
|
|
|||
|
|
if seg_mask is not None:
|
|||
|
|
# 转换为numpy
|
|||
|
|
if hasattr(seg_mask, 'cpu'):
|
|||
|
|
seg_mask = seg_mask.cpu().numpy()
|
|||
|
|
|
|||
|
|
if len(seg_mask.shape) == 3:
|
|||
|
|
# (C, H, W) -> (H, W),取最大概率的类别
|
|||
|
|
seg_mask = np.argmax(seg_mask, axis=0)
|
|||
|
|
|
|||
|
|
axes[0].imshow(seg_mask, cmap='tab20')
|
|||
|
|
axes[0].set_title(f'BEV Segmentation - Sample {sample_idx}')
|
|||
|
|
axes[0].axis('off')
|
|||
|
|
else:
|
|||
|
|
axes[0].text(0.5, 0.5, 'No Segmentation Data',
|
|||
|
|
ha='center', va='center', fontsize=20)
|
|||
|
|
axes[0].set_title('BEV Segmentation')
|
|||
|
|
axes[0].axis('off')
|
|||
|
|
|
|||
|
|
# 2. 可视化检测(BEV)
|
|||
|
|
boxes_3d = None
|
|||
|
|
scores_3d = None
|
|||
|
|
if 'boxes_3d' in result:
|
|||
|
|
boxes_3d = result['boxes_3d']
|
|||
|
|
scores_3d = result.get('scores_3d', None)
|
|||
|
|
elif 'pts_bbox' in result:
|
|||
|
|
# 创建空白BEV图
|
|||
|
|
bev_range = 51.2 # 米
|
|||
|
|
bev_size = 200
|
|||
|
|
bev_img = np.ones((bev_size, bev_size, 3)) * 255
|
|||
|
|
|
|||
|
|
boxes_3d = result['pts_bbox']['boxes_3d']
|
|||
|
|
scores_3d = result['pts_bbox']['scores_3d']
|
|||
|
|
|
|||
|
|
# 绘制检测框
|
|||
|
|
# 这里需要将3D框投影到BEV
|
|||
|
|
# 简化版本:显示检测数量
|
|||
|
|
axes[1].imshow(bev_img.astype(np.uint8))
|
|||
|
|
axes[1].set_title(f'BEV Detection ({len(boxes_3d)} boxes)')
|
|||
|
|
axes[1].axis('off')
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
|
|||
|
|
if save_path:
|
|||
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|||
|
|
print(f"保存到: {save_path}")
|
|||
|
|
plt.close()
|
|||
|
|
else:
|
|||
|
|
plt.show()
|
|||
|
|
|
|||
|
|
def generate_video(self, output_path, fps=10, camera_name='CAM_FRONT',
|
|||
|
|
vis_type='both', frame_interval=1):
|
|||
|
|
"""
|
|||
|
|
生成可视化视频
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
output_path: 输出视频路径
|
|||
|
|
fps: 帧率
|
|||
|
|
camera_name: 相机名称
|
|||
|
|
vis_type: 可视化类型 ('seg', 'det', 'both')
|
|||
|
|
frame_interval: 帧间隔(1=所有帧,2=每隔1帧,6=每隔5帧)
|
|||
|
|
"""
|
|||
|
|
print(f"开始生成视频: {output_path}")
|
|||
|
|
print(f"总样本数: {len(self.results)}")
|
|||
|
|
print(f"帧间隔: {frame_interval} (将处理 {len(range(0, len(self.results), frame_interval))} 帧)")
|
|||
|
|
|
|||
|
|
# 创建临时目录保存帧
|
|||
|
|
temp_dir = Path(output_path).parent / 'temp_frames'
|
|||
|
|
temp_dir.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
frames = []
|
|||
|
|
sample_indices = range(0, len(self.results), frame_interval)
|
|||
|
|
for idx in tqdm(sample_indices, desc="生成帧"):
|
|||
|
|
frame_path = temp_dir / f'frame_{idx:04d}.png'
|
|||
|
|
|
|||
|
|
if vis_type == 'seg':
|
|||
|
|
self.visualize_segmentation_on_camera(idx, camera_name,
|
|||
|
|
save_path=frame_path)
|
|||
|
|
elif vis_type == 'det':
|
|||
|
|
self.visualize_3d_detection(idx, camera_name, save_path=frame_path)
|
|||
|
|
elif vis_type == 'bev':
|
|||
|
|
self.visualize_bev(idx, save_path=frame_path)
|
|||
|
|
else: # both
|
|||
|
|
# 组合可视化
|
|||
|
|
self.visualize_bev(idx, save_path=frame_path)
|
|||
|
|
|
|||
|
|
if frame_path.exists():
|
|||
|
|
frames.append(str(frame_path))
|
|||
|
|
|
|||
|
|
# 使用 ffmpeg 或 opencv 生成视频
|
|||
|
|
if frames:
|
|||
|
|
print(f"合成视频,共 {len(frames)} 帧...")
|
|||
|
|
first_frame = cv2.imread(frames[0])
|
|||
|
|
h, w = first_frame.shape[:2]
|
|||
|
|
|
|||
|
|
# 使用 opencv VideoWriter
|
|||
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|||
|
|
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
|||
|
|
|
|||
|
|
for frame_path in tqdm(frames, desc="写入视频"):
|
|||
|
|
frame = cv2.imread(frame_path)
|
|||
|
|
video_writer.write(frame)
|
|||
|
|
|
|||
|
|
video_writer.release()
|
|||
|
|
print(f"视频已保存到: {output_path}")
|
|||
|
|
|
|||
|
|
# 清理临时文件
|
|||
|
|
for frame_path in frames:
|
|||
|
|
os.remove(frame_path)
|
|||
|
|
temp_dir.rmdir()
|
|||
|
|
else:
|
|||
|
|
print("没有生成任何帧!")
|
|||
|
|
|
|||
|
|
def quick_visualize(self, num_samples=10, output_dir='visualizations'):
|
|||
|
|
"""
|
|||
|
|
快速可视化前N个样本
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
num_samples: 可视化样本数
|
|||
|
|
output_dir: 输出目录
|
|||
|
|
"""
|
|||
|
|
output_path = Path(output_dir)
|
|||
|
|
output_path.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
num_samples = min(num_samples, len(self.results))
|
|||
|
|
|
|||
|
|
print(f"快速可视化前 {num_samples} 个样本...")
|
|||
|
|
for idx in tqdm(range(num_samples)):
|
|||
|
|
# BEV 可视化
|
|||
|
|
bev_path = output_path / f'sample_{idx:04d}_bev.png'
|
|||
|
|
self.visualize_bev(idx, save_path=bev_path)
|
|||
|
|
|
|||
|
|
# 相机可视化
|
|||
|
|
cam_path = output_path / f'sample_{idx:04d}_camera.png'
|
|||
|
|
self.visualize_segmentation_on_camera(idx, save_path=cam_path)
|
|||
|
|
|
|||
|
|
print(f"可视化完成!保存在: {output_dir}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(description='BEVFusion结果可视化')
|
|||
|
|
parser.add_argument('--result-file', type=str,
|
|||
|
|
default='results_epoch19.pkl',
|
|||
|
|
help='推理结果文件')
|
|||
|
|
parser.add_argument('--nusc-root', type=str,
|
|||
|
|
default='data/nuscenes',
|
|||
|
|
help='nuScenes数据集根目录')
|
|||
|
|
parser.add_argument('--nusc-version', type=str,
|
|||
|
|
default='v1.0-trainval',
|
|||
|
|
help='nuScenes版本')
|
|||
|
|
parser.add_argument('--output-dir', type=str,
|
|||
|
|
default='visualizations',
|
|||
|
|
help='输出目录')
|
|||
|
|
parser.add_argument('--mode', type=str,
|
|||
|
|
choices=['quick', 'video', 'all'],
|
|||
|
|
default='quick',
|
|||
|
|
help='可视化模式')
|
|||
|
|
parser.add_argument('--num-samples', type=int, default=10,
|
|||
|
|
help='快速可视化的样本数')
|
|||
|
|
parser.add_argument('--video-fps', type=int, default=10,
|
|||
|
|
help='视频帧率')
|
|||
|
|
parser.add_argument('--frame-interval', type=int, default=1,
|
|||
|
|
help='帧间隔(1=所有帧,6=每隔5帧采样)')
|
|||
|
|
parser.add_argument('--camera', type=str, default='CAM_FRONT',
|
|||
|
|
help='相机名称')
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 初始化可视化器
|
|||
|
|
max_samples = args.num_samples if args.mode == 'quick' else None
|
|||
|
|
visualizer = BEVFusionVisualizer(
|
|||
|
|
nusc_root=args.nusc_root,
|
|||
|
|
nusc_version=args.nusc_version,
|
|||
|
|
max_samples=max_samples
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 加载结果
|
|||
|
|
visualizer.load_results(args.result_file)
|
|||
|
|
|
|||
|
|
# 执行可视化
|
|||
|
|
if args.mode == 'quick':
|
|||
|
|
visualizer.quick_visualize(
|
|||
|
|
num_samples=args.num_samples,
|
|||
|
|
output_dir=args.output_dir
|
|||
|
|
)
|
|||
|
|
elif args.mode == 'video':
|
|||
|
|
video_path = os.path.join(args.output_dir, 'bevfusion_results.mp4')
|
|||
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|||
|
|
visualizer.generate_video(
|
|||
|
|
output_path=video_path,
|
|||
|
|
fps=args.video_fps,
|
|||
|
|
camera_name=args.camera,
|
|||
|
|
vis_type='bev',
|
|||
|
|
frame_interval=args.frame_interval
|
|||
|
|
)
|
|||
|
|
elif args.mode == 'all':
|
|||
|
|
# 先快速可视化
|
|||
|
|
visualizer.quick_visualize(
|
|||
|
|
num_samples=args.num_samples,
|
|||
|
|
output_dir=args.output_dir
|
|||
|
|
)
|
|||
|
|
# 再生成视频
|
|||
|
|
video_path = os.path.join(args.output_dir, 'bevfusion_results.mp4')
|
|||
|
|
visualizer.generate_video(
|
|||
|
|
output_path=video_path,
|
|||
|
|
fps=args.video_fps,
|
|||
|
|
camera_name=args.camera,
|
|||
|
|
vis_type='bev',
|
|||
|
|
frame_interval=args.frame_interval
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|
|||
|
|
|