bev-project/visualize_results.py

176 lines
5.3 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
可视化nuScenes推理结果的简化脚本
直接使用pickle文件生成可视化
"""
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import argparse
def visualize_detection_bev(results, idx=0, output_path='detection_bev.png'):
"""可视化3D检测结果BEV视图"""
fig, ax = plt.subplots(figsize=(12, 12))
result = results[idx]
if 'boxes_3d' in result:
boxes = result['boxes_3d']
scores = result['scores_3d']
labels = result['labels_3d']
# 类别名称
class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
colors = plt.cm.tab10(np.linspace(0, 1, 10))
for box, score, label in zip(boxes, scores, labels):
if score < 0.3:
continue
# 提取中心和尺寸
center = box[:2]
size = box[3:5]
yaw = box[6]
# 绘制框
l, w = size[0], size[1]
corners = np.array([
[-l/2, -w/2], [l/2, -w/2],
[l/2, w/2], [-l/2, w/2], [-l/2, -w/2]
])
rot = np.array([
[np.cos(yaw), -np.sin(yaw)],
[np.sin(yaw), np.cos(yaw)]
])
corners = corners @ rot.T + center
color = colors[label % 10]
ax.plot(corners[:, 0], corners[:, 1], color=color, linewidth=2)
ax.scatter(center[0], center[1], color=color, s=50)
# 标签
class_name = class_names[label] if label < len(class_names) else f'cls_{label}'
ax.text(center[0], center[1], f'{class_name}\n{score:.2f}',
fontsize=8, color='white',
bbox=dict(boxstyle='round', facecolor=color, alpha=0.7))
ax.set_xlim(-50, 50)
ax.set_ylim(-50, 50)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.set_xlabel('X (meters)', fontsize=14)
ax.set_ylabel('Y (meters)', fontsize=14)
ax.set_title('3D Detection Results (BEV View)', fontsize=16)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
print(f"✅ 已保存: {output_path}")
def visualize_segmentation(results, idx=0, output_path='segmentation.png'):
"""可视化BEV分割结果"""
fig, ax = plt.subplots(figsize=(10, 10))
result = results[idx]
if 'seg_pred' in result:
seg_pred = result['seg_pred'] # (C, H, W)
# 创建彩色图
h, w = seg_pred.shape[1:]
color_map = np.zeros((h, w, 3), dtype=np.uint8)
colors = {
0: [128, 64, 128], # drivable_area
1: [244, 35, 232], # ped_crossing
2: [70, 70, 70], # walkway
3: [220, 20, 60], # stop_line
4: [157, 234, 50], # carpark_area
5: [255, 255, 0], # divider
}
class_names = [
'drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider'
]
for idx, name in enumerate(class_names):
if idx < seg_pred.shape[0]:
mask = seg_pred[idx] > 0.5
color_map[mask] = colors[idx]
ax.imshow(color_map)
ax.set_title('BEV Segmentation', fontsize=16)
ax.axis('off')
# 图例
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor=np.array(colors[i])/255, label=class_names[i])
for i in range(len(class_names))
]
ax.legend(handles=legend_elements, loc='upper right')
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
print(f"✅ 已保存: {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--results', default='results.pkl', help='结果pickle文件')
parser.add_argument('--output-dir', default='visualizations', help='输出目录')
parser.add_argument('--samples', type=int, default=5, help='可视化样本数')
args = parser.parse_args()
print(f"\n{'='*80}")
print("可视化nuScenes推理结果")
print(f"{'='*80}\n")
# 创建输出目录
Path(args.output_dir).mkdir(exist_ok=True)
# 加载结果
print(f"加载结果: {args.results}")
with open(args.results, 'rb') as f:
results = pickle.load(f)
print(f"结果数量: {len(results)}")
print(f"可视化前 {min(args.samples, len(results))} 个样本\n")
# 可视化
for i in range(min(args.samples, len(results))):
print(f"处理样本 {i+1}/{min(args.samples, len(results))}...")
det_path = Path(args.output_dir) / f"sample_{i:04d}_detection.png"
seg_path = Path(args.output_dir) / f"sample_{i:04d}_segmentation.png"
visualize_detection_bev(results, i, str(det_path))
visualize_segmentation(results, i, str(seg_path))
print(f"\n{'='*80}")
print(f"✅ 完成!结果保存在: {args.output_dir}/")
print(f"{'='*80}\n")
if __name__ == '__main__':
main()