bev-project/VISUALIZE_INFERENCE.py

275 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
"""
可视化单batch推理结果 - BEV分割和3D检测
"""
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image
import argparse
try:
import cv2
except ImportError:
cv2 = None
print("Warning: cv2 not available, some visualization features may be limited")
# NuScenes类别颜色映射
NUSCENES_COLORS = {
'drivable_area': '#00FF00', # 绿色 - 可行驶区域
'ped_crossing': '#FFFF00', # 黄色 - 人行横道
'walkway': '#FFA500', # 橙色 - 人行道
'stop_line': '#FF0000', # 红色 - 停止线
'carpark_area': '#800080', # 紫色 - 停车区
'divider': '#000000', # 黑色 - 分隔线
}
# BEV分割类别索引
MAP_CLASSES = [
'drivable_area', # 0
'ped_crossing', # 1
'walkway', # 2
'stop_line', # 3
'carpark_area', # 4
'divider', # 5
]
def visualize_bev_segmentation(pred_masks, gt_masks=None, sample_idx=0, save_path=None):
"""
可视化BEV分割结果
Args:
pred_masks: 预测分割mask, shape [6, H, W]
gt_masks: ground truth分割mask, shape [6, H, W]
sample_idx: 样本索引
save_path: 保存路径
"""
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
fig.suptitle(f'BEV分割结果 - 样本 {sample_idx}', fontsize=16)
# 预测结果
for i in range(6):
ax = axes[0, i]
mask = pred_masks[i].cpu().numpy() if torch.is_tensor(pred_masks[i]) else pred_masks[i]
# 创建彩色mask
colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
color = NUSCENES_COLORS[MAP_CLASSES[i]]
# 将hex颜色转换为RGB
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
colored_mask[mask > 0.5] = [r, g, b]
ax.imshow(colored_mask)
ax.set_title(f'{MAP_CLASSES[i]} (Pred)')
ax.axis('off')
# Ground Truth (如果提供)
if gt_masks is not None:
for i in range(6):
if i < 4: # 只显示前4个类别
ax = axes[1, i]
mask = gt_masks[i].cpu().numpy() if torch.is_tensor(gt_masks[i]) else gt_masks[i]
colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
color = NUSCENES_COLORS[MAP_CLASSES[i]]
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
colored_mask[mask > 0.5] = [r, g, b]
ax.imshow(colored_mask)
ax.set_title(f'{MAP_CLASSES[i]} (GT)')
ax.axis('off')
# 第4个位置显示GT叠加结果
ax = axes[1, 3]
combined_gt = np.zeros((gt_masks.shape[1], gt_masks.shape[2], 3), dtype=np.uint8)
for i in range(6):
mask = gt_masks[i].cpu().numpy() if torch.is_tensor(gt_masks[i]) else gt_masks[i]
if mask.max() > 0.5:
color = NUSCENES_COLORS[MAP_CLASSES[i]]
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
combined_gt[mask > 0.5] = [r, g, b]
ax.imshow(combined_gt)
ax.set_title('Ground Truth 叠加')
ax.axis('off')
else:
# 如果没有GT在第二行显示预测的统计信息
axes[1, 0].text(0.5, 0.5, 'Ground Truth\n不可用',
transform=axes[1, 0].transAxes, ha='center', va='center',
fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
axes[1, 0].set_title('Ground Truth (N/A)')
axes[1, 0].axis('off')
# 隐藏其他子图
for i in range(1, 4):
axes[1, i].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"BEV分割可视化已保存: {save_path}")
plt.show()
def visualize_3d_detection(boxes_3d, scores_3d, labels_3d, sample_idx=0, save_path=None):
"""
可视化3D检测结果
Args:
boxes_3d: 3D检测框
scores_3d: 检测置信度
labels_3d: 检测类别标签
sample_idx: 样本索引
save_path: 保存路径
"""
fig, ax = plt.subplots(1, 1, figsize=(12, 10))
# 创建空的BEV图像 (598x598)
bev_image = np.ones((598, 598, 3), dtype=np.uint8) * 240 # 浅灰色背景
# BEV坐标范围: [-50, 50] x [-50, 50] 米
# 分辨率: 0.167米/像素 (从之前的配置)
bev_range = 50.0
bev_resolution = 0.167
bev_size = int(2 * bev_range / bev_resolution) # 598
detection_count = 0
if len(boxes_3d) > 0:
print(f"正在可视化 {len(boxes_3d)} 个3D检测框...")
# 将3D框投影到BEV平面
for i, (box, score, label) in enumerate(zip(boxes_3d, scores_3d, labels_3d)):
if torch.is_tensor(box):
box = box.cpu().numpy()
if torch.is_tensor(score):
score = score.item() if score.numel() == 1 else score.cpu().numpy()
if torch.is_tensor(label):
label = label.item() if label.numel() == 1 else label.cpu().numpy()
try:
# LiDAR坐标系: x(前), y(左), z(上)
center_x, center_y = box[0], box[1]
# 转换为BEV像素坐标
# BEV原点在图像中心x轴向右y轴向下
bev_x = int((center_x + bev_range) / bev_resolution)
bev_y = int((bev_range - center_y) / bev_resolution) # y轴翻转
if 0 <= bev_x < bev_size and 0 <= bev_y < bev_size:
# 根据置信度设置颜色深浅
alpha = min(1.0, score) # 置信度越高颜色越深
color_intensity = int(255 * alpha)
# 使用红色系表示检测框
color = (255, 255 - color_intensity, 255 - color_intensity) # 红到粉色渐变
# 在BEV图像上绘制检测点
if bev_x < bev_image.shape[1] and bev_y < bev_image.shape[0]:
cv2.circle(bev_image, (bev_x, bev_y), 4, color, -1) if cv2 else None
# 添加文本标签
conf_text = f'{score:.2f}'
ax.scatter(bev_x, bev_y, c=[color], s=50, alpha=alpha, edgecolors='black', linewidth=1)
ax.text(bev_x + 5, bev_y - 5, conf_text, fontsize=8,
bbox=dict(boxstyle="round,pad=0.1", facecolor="white", alpha=0.8))
detection_count += 1
except Exception as e:
print(f"处理检测框 {i} 时出错: {e}")
continue
ax.imshow(bev_image)
ax.set_title(f'3D检测结果 - 样本 {sample_idx} ({detection_count}/{len(boxes_3d)} 个检测框显示)', fontsize=14)
else:
ax.imshow(bev_image)
ax.set_title(f'3D检测结果 - 样本 {sample_idx} (0 个检测框)', fontsize=14)
# 设置坐标轴
ax.set_xlabel('BEV X (pixels)')
ax.set_ylabel('BEV Y (pixels)')
ax.grid(True, alpha=0.3)
# 添加坐标范围信息
ax.text(10, 20, f'范围: [-{bev_range}m, +{bev_range}m]', fontsize=10,
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"3D检测可视化已保存: {save_path}")
plt.show()
def main():
parser = argparse.ArgumentParser(description='可视化单batch推理结果')
parser.add_argument('--result_file', type=str,
default='/data/infer_test/20251120_124755/one_batch_results.pkl',
help='推理结果文件路径')
parser.add_argument('--sample_idx', type=int, default=0,
help='要可视化的样本索引')
parser.add_argument('--output_dir', type=str, default='./visualization_output',
help='可视化结果保存目录')
args = parser.parse_args()
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# 加载推理结果
print(f"加载推理结果: {args.result_file}")
with open(args.result_file, 'rb') as f:
results = pickle.load(f)
print(f"结果包含 {len(results)} 个样本")
print(f"样本 {args.sample_idx} 的keys: {list(results[args.sample_idx].keys())}")
sample = results[args.sample_idx]
# 1. 可视化BEV分割
if 'masks_bev' in sample:
print(f"BEV分割形状: {sample['masks_bev'].shape}")
pred_masks = sample['masks_bev']
gt_masks = sample.get('gt_masks_bev', None)
seg_save_path = os.path.join(args.output_dir, f'bev_segmentation_sample_{args.sample_idx}.png')
visualize_bev_segmentation(pred_masks, gt_masks, args.sample_idx, seg_save_path)
# 2. 可视化3D检测
if 'boxes_3d' in sample and 'scores_3d' in sample and 'labels_3d' in sample:
boxes_3d = sample['boxes_3d']
scores_3d = sample['scores_3d']
labels_3d = sample['labels_3d']
print(f"3D检测框数量: {len(boxes_3d)}")
if len(scores_3d) > 0:
print(f"检测置信度范围: {scores_3d.min():.3f} - {scores_3d.max():.3f}")
det_save_path = os.path.join(args.output_dir, f'3d_detection_sample_{args.sample_idx}.png')
visualize_3d_detection(boxes_3d, scores_3d, labels_3d, args.sample_idx, det_save_path)
# 3. 打印统计信息
print("\n" + "="*50)
print(f"样本 {args.sample_idx} 统计信息:")
print("="*50)
if 'masks_bev' in sample:
pred_masks = sample['masks_bev']
for i, class_name in enumerate(MAP_CLASSES):
mask = pred_masks[i]
pixel_count = (mask > 0.5).sum().item()
percentage = pixel_count / (mask.shape[0] * mask.shape[1]) * 100
print(".2f")
if 'boxes_3d' in sample:
print(f"3D检测框数量: {len(sample['boxes_3d'])}")
print(f"\n可视化结果已保存到: {args.output_dir}")
if __name__ == '__main__':
main()