275 lines
10 KiB
Python
275 lines
10 KiB
Python
#!/usr/bin/env python
|
||
"""
|
||
可视化单batch推理结果 - BEV分割和3D检测
|
||
"""
|
||
import os
|
||
import pickle
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import torch
|
||
from PIL import Image
|
||
import argparse
|
||
try:
|
||
import cv2
|
||
except ImportError:
|
||
cv2 = None
|
||
print("Warning: cv2 not available, some visualization features may be limited")
|
||
|
||
# NuScenes类别颜色映射
|
||
NUSCENES_COLORS = {
|
||
'drivable_area': '#00FF00', # 绿色 - 可行驶区域
|
||
'ped_crossing': '#FFFF00', # 黄色 - 人行横道
|
||
'walkway': '#FFA500', # 橙色 - 人行道
|
||
'stop_line': '#FF0000', # 红色 - 停止线
|
||
'carpark_area': '#800080', # 紫色 - 停车区
|
||
'divider': '#000000', # 黑色 - 分隔线
|
||
}
|
||
|
||
# BEV分割类别索引
|
||
MAP_CLASSES = [
|
||
'drivable_area', # 0
|
||
'ped_crossing', # 1
|
||
'walkway', # 2
|
||
'stop_line', # 3
|
||
'carpark_area', # 4
|
||
'divider', # 5
|
||
]
|
||
|
||
def visualize_bev_segmentation(pred_masks, gt_masks=None, sample_idx=0, save_path=None):
|
||
"""
|
||
可视化BEV分割结果
|
||
Args:
|
||
pred_masks: 预测分割mask, shape [6, H, W]
|
||
gt_masks: ground truth分割mask, shape [6, H, W]
|
||
sample_idx: 样本索引
|
||
save_path: 保存路径
|
||
"""
|
||
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
|
||
fig.suptitle(f'BEV分割结果 - 样本 {sample_idx}', fontsize=16)
|
||
|
||
# 预测结果
|
||
for i in range(6):
|
||
ax = axes[0, i]
|
||
mask = pred_masks[i].cpu().numpy() if torch.is_tensor(pred_masks[i]) else pred_masks[i]
|
||
|
||
# 创建彩色mask
|
||
colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
|
||
color = NUSCENES_COLORS[MAP_CLASSES[i]]
|
||
# 将hex颜色转换为RGB
|
||
r = int(color[1:3], 16)
|
||
g = int(color[3:5], 16)
|
||
b = int(color[5:7], 16)
|
||
|
||
colored_mask[mask > 0.5] = [r, g, b]
|
||
|
||
ax.imshow(colored_mask)
|
||
ax.set_title(f'{MAP_CLASSES[i]} (Pred)')
|
||
ax.axis('off')
|
||
|
||
# Ground Truth (如果提供)
|
||
if gt_masks is not None:
|
||
for i in range(6):
|
||
if i < 4: # 只显示前4个类别
|
||
ax = axes[1, i]
|
||
mask = gt_masks[i].cpu().numpy() if torch.is_tensor(gt_masks[i]) else gt_masks[i]
|
||
|
||
colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
|
||
color = NUSCENES_COLORS[MAP_CLASSES[i]]
|
||
r = int(color[1:3], 16)
|
||
g = int(color[3:5], 16)
|
||
b = int(color[5:7], 16)
|
||
|
||
colored_mask[mask > 0.5] = [r, g, b]
|
||
|
||
ax.imshow(colored_mask)
|
||
ax.set_title(f'{MAP_CLASSES[i]} (GT)')
|
||
ax.axis('off')
|
||
|
||
# 第4个位置显示GT叠加结果
|
||
ax = axes[1, 3]
|
||
combined_gt = np.zeros((gt_masks.shape[1], gt_masks.shape[2], 3), dtype=np.uint8)
|
||
for i in range(6):
|
||
mask = gt_masks[i].cpu().numpy() if torch.is_tensor(gt_masks[i]) else gt_masks[i]
|
||
if mask.max() > 0.5:
|
||
color = NUSCENES_COLORS[MAP_CLASSES[i]]
|
||
r = int(color[1:3], 16)
|
||
g = int(color[3:5], 16)
|
||
b = int(color[5:7], 16)
|
||
combined_gt[mask > 0.5] = [r, g, b]
|
||
|
||
ax.imshow(combined_gt)
|
||
ax.set_title('Ground Truth 叠加')
|
||
ax.axis('off')
|
||
else:
|
||
# 如果没有GT,在第二行显示预测的统计信息
|
||
axes[1, 0].text(0.5, 0.5, 'Ground Truth\n不可用',
|
||
transform=axes[1, 0].transAxes, ha='center', va='center',
|
||
fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
|
||
axes[1, 0].set_title('Ground Truth (N/A)')
|
||
axes[1, 0].axis('off')
|
||
|
||
# 隐藏其他子图
|
||
for i in range(1, 4):
|
||
axes[1, i].axis('off')
|
||
|
||
plt.tight_layout()
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f"BEV分割可视化已保存: {save_path}")
|
||
plt.show()
|
||
|
||
def visualize_3d_detection(boxes_3d, scores_3d, labels_3d, sample_idx=0, save_path=None):
|
||
"""
|
||
可视化3D检测结果
|
||
Args:
|
||
boxes_3d: 3D检测框
|
||
scores_3d: 检测置信度
|
||
labels_3d: 检测类别标签
|
||
sample_idx: 样本索引
|
||
save_path: 保存路径
|
||
"""
|
||
fig, ax = plt.subplots(1, 1, figsize=(12, 10))
|
||
|
||
# 创建空的BEV图像 (598x598)
|
||
bev_image = np.ones((598, 598, 3), dtype=np.uint8) * 240 # 浅灰色背景
|
||
|
||
# BEV坐标范围: [-50, 50] x [-50, 50] 米
|
||
# 分辨率: 0.167米/像素 (从之前的配置)
|
||
bev_range = 50.0
|
||
bev_resolution = 0.167
|
||
bev_size = int(2 * bev_range / bev_resolution) # 598
|
||
|
||
detection_count = 0
|
||
if len(boxes_3d) > 0:
|
||
print(f"正在可视化 {len(boxes_3d)} 个3D检测框...")
|
||
|
||
# 将3D框投影到BEV平面
|
||
for i, (box, score, label) in enumerate(zip(boxes_3d, scores_3d, labels_3d)):
|
||
if torch.is_tensor(box):
|
||
box = box.cpu().numpy()
|
||
if torch.is_tensor(score):
|
||
score = score.item() if score.numel() == 1 else score.cpu().numpy()
|
||
if torch.is_tensor(label):
|
||
label = label.item() if label.numel() == 1 else label.cpu().numpy()
|
||
|
||
try:
|
||
# LiDAR坐标系: x(前), y(左), z(上)
|
||
center_x, center_y = box[0], box[1]
|
||
|
||
# 转换为BEV像素坐标
|
||
# BEV原点在图像中心,x轴向右,y轴向下
|
||
bev_x = int((center_x + bev_range) / bev_resolution)
|
||
bev_y = int((bev_range - center_y) / bev_resolution) # y轴翻转
|
||
|
||
if 0 <= bev_x < bev_size and 0 <= bev_y < bev_size:
|
||
# 根据置信度设置颜色深浅
|
||
alpha = min(1.0, score) # 置信度越高颜色越深
|
||
color_intensity = int(255 * alpha)
|
||
|
||
# 使用红色系表示检测框
|
||
color = (255, 255 - color_intensity, 255 - color_intensity) # 红到粉色渐变
|
||
|
||
# 在BEV图像上绘制检测点
|
||
if bev_x < bev_image.shape[1] and bev_y < bev_image.shape[0]:
|
||
cv2.circle(bev_image, (bev_x, bev_y), 4, color, -1) if cv2 else None
|
||
|
||
# 添加文本标签
|
||
conf_text = f'{score:.2f}'
|
||
ax.scatter(bev_x, bev_y, c=[color], s=50, alpha=alpha, edgecolors='black', linewidth=1)
|
||
ax.text(bev_x + 5, bev_y - 5, conf_text, fontsize=8,
|
||
bbox=dict(boxstyle="round,pad=0.1", facecolor="white", alpha=0.8))
|
||
|
||
detection_count += 1
|
||
|
||
except Exception as e:
|
||
print(f"处理检测框 {i} 时出错: {e}")
|
||
continue
|
||
|
||
ax.imshow(bev_image)
|
||
ax.set_title(f'3D检测结果 - 样本 {sample_idx} ({detection_count}/{len(boxes_3d)} 个检测框显示)', fontsize=14)
|
||
else:
|
||
ax.imshow(bev_image)
|
||
ax.set_title(f'3D检测结果 - 样本 {sample_idx} (0 个检测框)', fontsize=14)
|
||
|
||
# 设置坐标轴
|
||
ax.set_xlabel('BEV X (pixels)')
|
||
ax.set_ylabel('BEV Y (pixels)')
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
# 添加坐标范围信息
|
||
ax.text(10, 20, f'范围: [-{bev_range}m, +{bev_range}m]', fontsize=10,
|
||
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
|
||
|
||
plt.tight_layout()
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f"3D检测可视化已保存: {save_path}")
|
||
plt.show()
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='可视化单batch推理结果')
|
||
parser.add_argument('--result_file', type=str,
|
||
default='/data/infer_test/20251120_124755/one_batch_results.pkl',
|
||
help='推理结果文件路径')
|
||
parser.add_argument('--sample_idx', type=int, default=0,
|
||
help='要可视化的样本索引')
|
||
parser.add_argument('--output_dir', type=str, default='./visualization_output',
|
||
help='可视化结果保存目录')
|
||
args = parser.parse_args()
|
||
|
||
# 创建输出目录
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
# 加载推理结果
|
||
print(f"加载推理结果: {args.result_file}")
|
||
with open(args.result_file, 'rb') as f:
|
||
results = pickle.load(f)
|
||
|
||
print(f"结果包含 {len(results)} 个样本")
|
||
print(f"样本 {args.sample_idx} 的keys: {list(results[args.sample_idx].keys())}")
|
||
|
||
sample = results[args.sample_idx]
|
||
|
||
# 1. 可视化BEV分割
|
||
if 'masks_bev' in sample:
|
||
print(f"BEV分割形状: {sample['masks_bev'].shape}")
|
||
pred_masks = sample['masks_bev']
|
||
gt_masks = sample.get('gt_masks_bev', None)
|
||
|
||
seg_save_path = os.path.join(args.output_dir, f'bev_segmentation_sample_{args.sample_idx}.png')
|
||
visualize_bev_segmentation(pred_masks, gt_masks, args.sample_idx, seg_save_path)
|
||
|
||
# 2. 可视化3D检测
|
||
if 'boxes_3d' in sample and 'scores_3d' in sample and 'labels_3d' in sample:
|
||
boxes_3d = sample['boxes_3d']
|
||
scores_3d = sample['scores_3d']
|
||
labels_3d = sample['labels_3d']
|
||
|
||
print(f"3D检测框数量: {len(boxes_3d)}")
|
||
if len(scores_3d) > 0:
|
||
print(f"检测置信度范围: {scores_3d.min():.3f} - {scores_3d.max():.3f}")
|
||
|
||
det_save_path = os.path.join(args.output_dir, f'3d_detection_sample_{args.sample_idx}.png')
|
||
visualize_3d_detection(boxes_3d, scores_3d, labels_3d, args.sample_idx, det_save_path)
|
||
|
||
# 3. 打印统计信息
|
||
print("\n" + "="*50)
|
||
print(f"样本 {args.sample_idx} 统计信息:")
|
||
print("="*50)
|
||
|
||
if 'masks_bev' in sample:
|
||
pred_masks = sample['masks_bev']
|
||
for i, class_name in enumerate(MAP_CLASSES):
|
||
mask = pred_masks[i]
|
||
pixel_count = (mask > 0.5).sum().item()
|
||
percentage = pixel_count / (mask.shape[0] * mask.shape[1]) * 100
|
||
print(".2f")
|
||
|
||
if 'boxes_3d' in sample:
|
||
print(f"3D检测框数量: {len(sample['boxes_3d'])}")
|
||
|
||
print(f"\n可视化结果已保存到: {args.output_dir}")
|
||
|
||
if __name__ == '__main__':
|
||
main()
|