#!/usr/bin/env python """ 可视化单batch推理结果 - BEV分割和3D检测 """ import os import pickle import numpy as np import matplotlib.pyplot as plt import torch from PIL import Image import argparse try: import cv2 except ImportError: cv2 = None print("Warning: cv2 not available, some visualization features may be limited") # NuScenes类别颜色映射 NUSCENES_COLORS = { 'drivable_area': '#00FF00', # 绿色 - 可行驶区域 'ped_crossing': '#FFFF00', # 黄色 - 人行横道 'walkway': '#FFA500', # 橙色 - 人行道 'stop_line': '#FF0000', # 红色 - 停止线 'carpark_area': '#800080', # 紫色 - 停车区 'divider': '#000000', # 黑色 - 分隔线 } # BEV分割类别索引 MAP_CLASSES = [ 'drivable_area', # 0 'ped_crossing', # 1 'walkway', # 2 'stop_line', # 3 'carpark_area', # 4 'divider', # 5 ] def visualize_bev_segmentation(pred_masks, gt_masks=None, sample_idx=0, save_path=None): """ 可视化BEV分割结果 Args: pred_masks: 预测分割mask, shape [6, H, W] gt_masks: ground truth分割mask, shape [6, H, W] sample_idx: 样本索引 save_path: 保存路径 """ fig, axes = plt.subplots(2, 4, figsize=(20, 10)) fig.suptitle(f'BEV分割结果 - 样本 {sample_idx}', fontsize=16) # 预测结果 for i in range(6): ax = axes[0, i] mask = pred_masks[i].cpu().numpy() if torch.is_tensor(pred_masks[i]) else pred_masks[i] # 创建彩色mask colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8) color = NUSCENES_COLORS[MAP_CLASSES[i]] # 将hex颜色转换为RGB r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) colored_mask[mask > 0.5] = [r, g, b] ax.imshow(colored_mask) ax.set_title(f'{MAP_CLASSES[i]} (Pred)') ax.axis('off') # Ground Truth (如果提供) if gt_masks is not None: for i in range(6): if i < 4: # 只显示前4个类别 ax = axes[1, i] mask = gt_masks[i].cpu().numpy() if torch.is_tensor(gt_masks[i]) else gt_masks[i] colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8) color = NUSCENES_COLORS[MAP_CLASSES[i]] r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) colored_mask[mask > 0.5] = [r, g, b] ax.imshow(colored_mask) ax.set_title(f'{MAP_CLASSES[i]} (GT)') ax.axis('off') # 第4个位置显示GT叠加结果 ax = axes[1, 3] combined_gt = np.zeros((gt_masks.shape[1], gt_masks.shape[2], 3), dtype=np.uint8) for i in range(6): mask = gt_masks[i].cpu().numpy() if torch.is_tensor(gt_masks[i]) else gt_masks[i] if mask.max() > 0.5: color = NUSCENES_COLORS[MAP_CLASSES[i]] r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) combined_gt[mask > 0.5] = [r, g, b] ax.imshow(combined_gt) ax.set_title('Ground Truth 叠加') ax.axis('off') else: # 如果没有GT,在第二行显示预测的统计信息 axes[1, 0].text(0.5, 0.5, 'Ground Truth\n不可用', transform=axes[1, 0].transAxes, ha='center', va='center', fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) axes[1, 0].set_title('Ground Truth (N/A)') axes[1, 0].axis('off') # 隐藏其他子图 for i in range(1, 4): axes[1, i].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"BEV分割可视化已保存: {save_path}") plt.show() def visualize_3d_detection(boxes_3d, scores_3d, labels_3d, sample_idx=0, save_path=None): """ 可视化3D检测结果 Args: boxes_3d: 3D检测框 scores_3d: 检测置信度 labels_3d: 检测类别标签 sample_idx: 样本索引 save_path: 保存路径 """ fig, ax = plt.subplots(1, 1, figsize=(12, 10)) # 创建空的BEV图像 (598x598) bev_image = np.ones((598, 598, 3), dtype=np.uint8) * 240 # 浅灰色背景 # BEV坐标范围: [-50, 50] x [-50, 50] 米 # 分辨率: 0.167米/像素 (从之前的配置) bev_range = 50.0 bev_resolution = 0.167 bev_size = int(2 * bev_range / bev_resolution) # 598 detection_count = 0 if len(boxes_3d) > 0: print(f"正在可视化 {len(boxes_3d)} 个3D检测框...") # 将3D框投影到BEV平面 for i, (box, score, label) in enumerate(zip(boxes_3d, scores_3d, labels_3d)): if torch.is_tensor(box): box = box.cpu().numpy() if torch.is_tensor(score): score = score.item() if score.numel() == 1 else score.cpu().numpy() if torch.is_tensor(label): label = label.item() if label.numel() == 1 else label.cpu().numpy() try: # LiDAR坐标系: x(前), y(左), z(上) center_x, center_y = box[0], box[1] # 转换为BEV像素坐标 # BEV原点在图像中心,x轴向右,y轴向下 bev_x = int((center_x + bev_range) / bev_resolution) bev_y = int((bev_range - center_y) / bev_resolution) # y轴翻转 if 0 <= bev_x < bev_size and 0 <= bev_y < bev_size: # 根据置信度设置颜色深浅 alpha = min(1.0, score) # 置信度越高颜色越深 color_intensity = int(255 * alpha) # 使用红色系表示检测框 color = (255, 255 - color_intensity, 255 - color_intensity) # 红到粉色渐变 # 在BEV图像上绘制检测点 if bev_x < bev_image.shape[1] and bev_y < bev_image.shape[0]: cv2.circle(bev_image, (bev_x, bev_y), 4, color, -1) if cv2 else None # 添加文本标签 conf_text = f'{score:.2f}' ax.scatter(bev_x, bev_y, c=[color], s=50, alpha=alpha, edgecolors='black', linewidth=1) ax.text(bev_x + 5, bev_y - 5, conf_text, fontsize=8, bbox=dict(boxstyle="round,pad=0.1", facecolor="white", alpha=0.8)) detection_count += 1 except Exception as e: print(f"处理检测框 {i} 时出错: {e}") continue ax.imshow(bev_image) ax.set_title(f'3D检测结果 - 样本 {sample_idx} ({detection_count}/{len(boxes_3d)} 个检测框显示)', fontsize=14) else: ax.imshow(bev_image) ax.set_title(f'3D检测结果 - 样本 {sample_idx} (0 个检测框)', fontsize=14) # 设置坐标轴 ax.set_xlabel('BEV X (pixels)') ax.set_ylabel('BEV Y (pixels)') ax.grid(True, alpha=0.3) # 添加坐标范围信息 ax.text(10, 20, f'范围: [-{bev_range}m, +{bev_range}m]', fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"3D检测可视化已保存: {save_path}") plt.show() def main(): parser = argparse.ArgumentParser(description='可视化单batch推理结果') parser.add_argument('--result_file', type=str, default='/data/infer_test/20251120_124755/one_batch_results.pkl', help='推理结果文件路径') parser.add_argument('--sample_idx', type=int, default=0, help='要可视化的样本索引') parser.add_argument('--output_dir', type=str, default='./visualization_output', help='可视化结果保存目录') args = parser.parse_args() # 创建输出目录 os.makedirs(args.output_dir, exist_ok=True) # 加载推理结果 print(f"加载推理结果: {args.result_file}") with open(args.result_file, 'rb') as f: results = pickle.load(f) print(f"结果包含 {len(results)} 个样本") print(f"样本 {args.sample_idx} 的keys: {list(results[args.sample_idx].keys())}") sample = results[args.sample_idx] # 1. 可视化BEV分割 if 'masks_bev' in sample: print(f"BEV分割形状: {sample['masks_bev'].shape}") pred_masks = sample['masks_bev'] gt_masks = sample.get('gt_masks_bev', None) seg_save_path = os.path.join(args.output_dir, f'bev_segmentation_sample_{args.sample_idx}.png') visualize_bev_segmentation(pred_masks, gt_masks, args.sample_idx, seg_save_path) # 2. 可视化3D检测 if 'boxes_3d' in sample and 'scores_3d' in sample and 'labels_3d' in sample: boxes_3d = sample['boxes_3d'] scores_3d = sample['scores_3d'] labels_3d = sample['labels_3d'] print(f"3D检测框数量: {len(boxes_3d)}") if len(scores_3d) > 0: print(f"检测置信度范围: {scores_3d.min():.3f} - {scores_3d.max():.3f}") det_save_path = os.path.join(args.output_dir, f'3d_detection_sample_{args.sample_idx}.png') visualize_3d_detection(boxes_3d, scores_3d, labels_3d, args.sample_idx, det_save_path) # 3. 打印统计信息 print("\n" + "="*50) print(f"样本 {args.sample_idx} 统计信息:") print("="*50) if 'masks_bev' in sample: pred_masks = sample['masks_bev'] for i, class_name in enumerate(MAP_CLASSES): mask = pred_masks[i] pixel_count = (mask > 0.5).sum().item() percentage = pixel_count / (mask.shape[0] * mask.shape[1]) * 100 print(".2f") if 'boxes_3d' in sample: print(f"3D检测框数量: {len(sample['boxes_3d'])}") print(f"\n可视化结果已保存到: {args.output_dir}") if __name__ == '__main__': main()