bev-project/VISUALIZE_INFERENCE.py

275 lines
10 KiB
Python
Raw Normal View History

2025-11-21 10:50:51 +08:00
#!/usr/bin/env python
"""
可视化单batch推理结果 - BEV分割和3D检测
"""
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image
import argparse
try:
import cv2
except ImportError:
cv2 = None
print("Warning: cv2 not available, some visualization features may be limited")
# NuScenes类别颜色映射
NUSCENES_COLORS = {
'drivable_area': '#00FF00', # 绿色 - 可行驶区域
'ped_crossing': '#FFFF00', # 黄色 - 人行横道
'walkway': '#FFA500', # 橙色 - 人行道
'stop_line': '#FF0000', # 红色 - 停止线
'carpark_area': '#800080', # 紫色 - 停车区
'divider': '#000000', # 黑色 - 分隔线
}
# BEV分割类别索引
MAP_CLASSES = [
'drivable_area', # 0
'ped_crossing', # 1
'walkway', # 2
'stop_line', # 3
'carpark_area', # 4
'divider', # 5
]
def visualize_bev_segmentation(pred_masks, gt_masks=None, sample_idx=0, save_path=None):
"""
可视化BEV分割结果
Args:
pred_masks: 预测分割mask, shape [6, H, W]
gt_masks: ground truth分割mask, shape [6, H, W]
sample_idx: 样本索引
save_path: 保存路径
"""
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
fig.suptitle(f'BEV分割结果 - 样本 {sample_idx}', fontsize=16)
# 预测结果
for i in range(6):
ax = axes[0, i]
mask = pred_masks[i].cpu().numpy() if torch.is_tensor(pred_masks[i]) else pred_masks[i]
# 创建彩色mask
colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
color = NUSCENES_COLORS[MAP_CLASSES[i]]
# 将hex颜色转换为RGB
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
colored_mask[mask > 0.5] = [r, g, b]
ax.imshow(colored_mask)
ax.set_title(f'{MAP_CLASSES[i]} (Pred)')
ax.axis('off')
# Ground Truth (如果提供)
if gt_masks is not None:
for i in range(6):
if i < 4: # 只显示前4个类别
ax = axes[1, i]
mask = gt_masks[i].cpu().numpy() if torch.is_tensor(gt_masks[i]) else gt_masks[i]
colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
color = NUSCENES_COLORS[MAP_CLASSES[i]]
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
colored_mask[mask > 0.5] = [r, g, b]
ax.imshow(colored_mask)
ax.set_title(f'{MAP_CLASSES[i]} (GT)')
ax.axis('off')
# 第4个位置显示GT叠加结果
ax = axes[1, 3]
combined_gt = np.zeros((gt_masks.shape[1], gt_masks.shape[2], 3), dtype=np.uint8)
for i in range(6):
mask = gt_masks[i].cpu().numpy() if torch.is_tensor(gt_masks[i]) else gt_masks[i]
if mask.max() > 0.5:
color = NUSCENES_COLORS[MAP_CLASSES[i]]
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
combined_gt[mask > 0.5] = [r, g, b]
ax.imshow(combined_gt)
ax.set_title('Ground Truth 叠加')
ax.axis('off')
else:
# 如果没有GT在第二行显示预测的统计信息
axes[1, 0].text(0.5, 0.5, 'Ground Truth\n不可用',
transform=axes[1, 0].transAxes, ha='center', va='center',
fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
axes[1, 0].set_title('Ground Truth (N/A)')
axes[1, 0].axis('off')
# 隐藏其他子图
for i in range(1, 4):
axes[1, i].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"BEV分割可视化已保存: {save_path}")
plt.show()
def visualize_3d_detection(boxes_3d, scores_3d, labels_3d, sample_idx=0, save_path=None):
"""
可视化3D检测结果
Args:
boxes_3d: 3D检测框
scores_3d: 检测置信度
labels_3d: 检测类别标签
sample_idx: 样本索引
save_path: 保存路径
"""
fig, ax = plt.subplots(1, 1, figsize=(12, 10))
# 创建空的BEV图像 (598x598)
bev_image = np.ones((598, 598, 3), dtype=np.uint8) * 240 # 浅灰色背景
# BEV坐标范围: [-50, 50] x [-50, 50] 米
# 分辨率: 0.167米/像素 (从之前的配置)
bev_range = 50.0
bev_resolution = 0.167
bev_size = int(2 * bev_range / bev_resolution) # 598
detection_count = 0
if len(boxes_3d) > 0:
print(f"正在可视化 {len(boxes_3d)} 个3D检测框...")
# 将3D框投影到BEV平面
for i, (box, score, label) in enumerate(zip(boxes_3d, scores_3d, labels_3d)):
if torch.is_tensor(box):
box = box.cpu().numpy()
if torch.is_tensor(score):
score = score.item() if score.numel() == 1 else score.cpu().numpy()
if torch.is_tensor(label):
label = label.item() if label.numel() == 1 else label.cpu().numpy()
try:
# LiDAR坐标系: x(前), y(左), z(上)
center_x, center_y = box[0], box[1]
# 转换为BEV像素坐标
# BEV原点在图像中心x轴向右y轴向下
bev_x = int((center_x + bev_range) / bev_resolution)
bev_y = int((bev_range - center_y) / bev_resolution) # y轴翻转
if 0 <= bev_x < bev_size and 0 <= bev_y < bev_size:
# 根据置信度设置颜色深浅
alpha = min(1.0, score) # 置信度越高颜色越深
color_intensity = int(255 * alpha)
# 使用红色系表示检测框
color = (255, 255 - color_intensity, 255 - color_intensity) # 红到粉色渐变
# 在BEV图像上绘制检测点
if bev_x < bev_image.shape[1] and bev_y < bev_image.shape[0]:
cv2.circle(bev_image, (bev_x, bev_y), 4, color, -1) if cv2 else None
# 添加文本标签
conf_text = f'{score:.2f}'
ax.scatter(bev_x, bev_y, c=[color], s=50, alpha=alpha, edgecolors='black', linewidth=1)
ax.text(bev_x + 5, bev_y - 5, conf_text, fontsize=8,
bbox=dict(boxstyle="round,pad=0.1", facecolor="white", alpha=0.8))
detection_count += 1
except Exception as e:
print(f"处理检测框 {i} 时出错: {e}")
continue
ax.imshow(bev_image)
ax.set_title(f'3D检测结果 - 样本 {sample_idx} ({detection_count}/{len(boxes_3d)} 个检测框显示)', fontsize=14)
else:
ax.imshow(bev_image)
ax.set_title(f'3D检测结果 - 样本 {sample_idx} (0 个检测框)', fontsize=14)
# 设置坐标轴
ax.set_xlabel('BEV X (pixels)')
ax.set_ylabel('BEV Y (pixels)')
ax.grid(True, alpha=0.3)
# 添加坐标范围信息
ax.text(10, 20, f'范围: [-{bev_range}m, +{bev_range}m]', fontsize=10,
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"3D检测可视化已保存: {save_path}")
plt.show()
def main():
parser = argparse.ArgumentParser(description='可视化单batch推理结果')
parser.add_argument('--result_file', type=str,
default='/data/infer_test/20251120_124755/one_batch_results.pkl',
help='推理结果文件路径')
parser.add_argument('--sample_idx', type=int, default=0,
help='要可视化的样本索引')
parser.add_argument('--output_dir', type=str, default='./visualization_output',
help='可视化结果保存目录')
args = parser.parse_args()
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# 加载推理结果
print(f"加载推理结果: {args.result_file}")
with open(args.result_file, 'rb') as f:
results = pickle.load(f)
print(f"结果包含 {len(results)} 个样本")
print(f"样本 {args.sample_idx} 的keys: {list(results[args.sample_idx].keys())}")
sample = results[args.sample_idx]
# 1. 可视化BEV分割
if 'masks_bev' in sample:
print(f"BEV分割形状: {sample['masks_bev'].shape}")
pred_masks = sample['masks_bev']
gt_masks = sample.get('gt_masks_bev', None)
seg_save_path = os.path.join(args.output_dir, f'bev_segmentation_sample_{args.sample_idx}.png')
visualize_bev_segmentation(pred_masks, gt_masks, args.sample_idx, seg_save_path)
# 2. 可视化3D检测
if 'boxes_3d' in sample and 'scores_3d' in sample and 'labels_3d' in sample:
boxes_3d = sample['boxes_3d']
scores_3d = sample['scores_3d']
labels_3d = sample['labels_3d']
print(f"3D检测框数量: {len(boxes_3d)}")
if len(scores_3d) > 0:
print(f"检测置信度范围: {scores_3d.min():.3f} - {scores_3d.max():.3f}")
det_save_path = os.path.join(args.output_dir, f'3d_detection_sample_{args.sample_idx}.png')
visualize_3d_detection(boxes_3d, scores_3d, labels_3d, args.sample_idx, det_save_path)
# 3. 打印统计信息
print("\n" + "="*50)
print(f"样本 {args.sample_idx} 统计信息:")
print("="*50)
if 'masks_bev' in sample:
pred_masks = sample['masks_bev']
for i, class_name in enumerate(MAP_CLASSES):
mask = pred_masks[i]
pixel_count = (mask > 0.5).sum().item()
percentage = pixel_count / (mask.shape[0] * mask.shape[1]) * 100
print(".2f")
if 'boxes_3d' in sample:
print(f"3D检测框数量: {len(sample['boxes_3d'])}")
print(f"\n可视化结果已保存到: {args.output_dir}")
if __name__ == '__main__':
main()