387 lines
12 KiB
Python
Executable File
387 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
BEVFusion推理和可视化脚本
|
||
基于epoch_19.pth在nuScenes验证集上进行推理并生成可视化结果
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
import torch
|
||
import numpy as np
|
||
import cv2
|
||
import matplotlib.pyplot as plt
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
|
||
# 添加项目路径
|
||
sys.path.insert(0, '/workspace/bevfusion')
|
||
|
||
from mmcv import Config
|
||
from mmcv.parallel import MMDataParallel
|
||
from mmdet3d.datasets import build_dataloader, build_dataset
|
||
from mmdet3d.models import build_model
|
||
from mmcv.runner import load_checkpoint
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description='BEVFusion Inference and Visualization')
|
||
parser.add_argument('--config',
|
||
default='configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/multitask.yaml',
|
||
help='配置文件路径')
|
||
parser.add_argument('--checkpoint',
|
||
default='runs/run-326653dc-74184412/epoch_19.pth',
|
||
help='checkpoint文件路径')
|
||
parser.add_argument('--samples', type=int, default=10,
|
||
help='推理样本数量')
|
||
parser.add_argument('--output-dir', default='inference_results',
|
||
help='输出目录')
|
||
parser.add_argument('--show-score-thr', type=float, default=0.3,
|
||
help='检测框显示的置信度阈值')
|
||
parser.add_argument('--device', default='cuda:0',
|
||
help='推理设备')
|
||
return parser.parse_args()
|
||
|
||
|
||
def setup_model(config_path, checkpoint_path, device='cuda:0'):
|
||
"""加载模型和权重"""
|
||
print(f"\n{'='*80}")
|
||
print("加载模型配置和权重")
|
||
print(f"{'='*80}")
|
||
|
||
# 加载配置
|
||
cfg = Config.fromfile(config_path)
|
||
|
||
# 构建模型 - 使用完整的model配置
|
||
print("\n1. 构建模型...")
|
||
model = build_model(
|
||
cfg.model,
|
||
train_cfg=cfg.get('train_cfg'),
|
||
test_cfg=cfg.get('test_cfg')
|
||
)
|
||
|
||
# 加载checkpoint
|
||
print(f"\n2. 加载checkpoint: {checkpoint_path}")
|
||
checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||
|
||
# 设置为评估模式
|
||
model.eval()
|
||
model = model.to(device)
|
||
|
||
print(f"\n✅ 模型加载完成")
|
||
print(f" - Device: {device}")
|
||
print(f" - Checkpoint epoch: {checkpoint.get('meta', {}).get('epoch', 'N/A')}")
|
||
|
||
return model, cfg
|
||
|
||
|
||
def build_val_dataloader(cfg, samples=10):
|
||
"""构建验证数据加载器"""
|
||
print(f"\n{'='*80}")
|
||
print("构建数据加载器")
|
||
print(f"{'='*80}")
|
||
|
||
# 构建验证数据集
|
||
dataset = build_dataset(cfg.data.val)
|
||
|
||
# 限制样本数量
|
||
if samples < len(dataset):
|
||
print(f"\n⚠️ 限制样本数量: {len(dataset)} → {samples}")
|
||
# 均匀采样
|
||
indices = np.linspace(0, len(dataset)-1, samples, dtype=int)
|
||
dataset = torch.utils.data.Subset(dataset, indices)
|
||
|
||
# 构建数据加载器
|
||
data_loader = build_dataloader(
|
||
dataset,
|
||
samples_per_gpu=1,
|
||
workers_per_gpu=0,
|
||
dist=False,
|
||
shuffle=False
|
||
)
|
||
|
||
print(f"\n✅ 数据加载器构建完成")
|
||
print(f" - 样本数量: {len(dataset)}")
|
||
print(f" - 数据集: nuScenes validation set")
|
||
|
||
return data_loader, dataset
|
||
|
||
|
||
def visualize_bev_segmentation(seg_pred, classes, save_path):
|
||
"""可视化BEV分割结果"""
|
||
# seg_pred: (C, H, W) - C个类别的预测
|
||
seg_pred = torch.sigmoid(seg_pred)
|
||
|
||
# 创建彩色分割图
|
||
h, w = seg_pred.shape[1:]
|
||
color_map = np.zeros((h, w, 3), dtype=np.uint8)
|
||
|
||
# 定义每个类别的颜色
|
||
colors = {
|
||
'drivable_area': [128, 64, 128], # 紫色
|
||
'ped_crossing': [244, 35, 232], # 粉色
|
||
'walkway': [70, 70, 70], # 灰色
|
||
'stop_line': [220, 20, 60], # 红色
|
||
'carpark_area': [157, 234, 50], # 绿色
|
||
'divider': [255, 255, 0], # 黄色
|
||
}
|
||
|
||
# 为每个类别着色
|
||
for idx, class_name in enumerate(classes):
|
||
mask = seg_pred[idx].cpu().numpy() > 0.5
|
||
color = colors.get(class_name, [255, 255, 255])
|
||
color_map[mask] = color
|
||
|
||
# 保存可视化结果
|
||
plt.figure(figsize=(10, 10))
|
||
plt.imshow(color_map)
|
||
plt.title('BEV Segmentation')
|
||
plt.axis('off')
|
||
|
||
# 添加图例
|
||
from matplotlib.patches import Patch
|
||
legend_elements = [
|
||
Patch(facecolor=np.array(colors[name])/255, label=name)
|
||
for name in classes if name in colors
|
||
]
|
||
plt.legend(handles=legend_elements, loc='upper right')
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
return color_map
|
||
|
||
|
||
def visualize_3d_boxes(img, boxes_3d, labels, scores, save_path, score_thr=0.3):
|
||
"""可视化3D检测框(投影到图像上)"""
|
||
# 简化版:在BEV视图上显示框
|
||
fig, ax = plt.subplots(figsize=(12, 12))
|
||
|
||
# 绘制检测框
|
||
if len(boxes_3d) > 0:
|
||
# 提取中心点和尺寸
|
||
centers = boxes_3d[:, :2] # x, y
|
||
sizes = boxes_3d[:, 3:5] # length, width
|
||
yaws = boxes_3d[:, 6] # yaw angle
|
||
|
||
# 定义类别颜色
|
||
class_colors = plt.cm.tab10(np.linspace(0, 1, 10))
|
||
|
||
for i, (center, size, yaw, label, score) in enumerate(zip(centers, sizes, yaws, labels, scores)):
|
||
if score < score_thr:
|
||
continue
|
||
|
||
# 绘制框
|
||
color = class_colors[label % 10]
|
||
|
||
# 计算框的四个角点
|
||
l, w = size[0], size[1]
|
||
corners = np.array([
|
||
[-l/2, -w/2],
|
||
[l/2, -w/2],
|
||
[l/2, w/2],
|
||
[-l/2, w/2],
|
||
[-l/2, -w/2] # 闭合
|
||
])
|
||
|
||
# 旋转
|
||
rot_mat = np.array([
|
||
[np.cos(yaw), -np.sin(yaw)],
|
||
[np.sin(yaw), np.cos(yaw)]
|
||
])
|
||
corners = corners @ rot_mat.T
|
||
corners += center
|
||
|
||
# 绘制
|
||
ax.plot(corners[:, 0], corners[:, 1], color=color, linewidth=2)
|
||
ax.scatter(center[0], center[1], color=color, s=50, zorder=10)
|
||
|
||
# 添加标签
|
||
ax.text(center[0], center[1], f'{score:.2f}',
|
||
fontsize=8, color='white',
|
||
bbox=dict(boxstyle='round', facecolor=color, alpha=0.7))
|
||
|
||
ax.set_xlim(-50, 50)
|
||
ax.set_ylim(-50, 50)
|
||
ax.set_aspect('equal')
|
||
ax.grid(True, alpha=0.3)
|
||
ax.set_xlabel('X (meters)')
|
||
ax.set_ylabel('Y (meters)')
|
||
ax.set_title('3D Detection Results (BEV)')
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
|
||
def visualize_combined(img, seg_pred, boxes_3d, labels, scores, classes, save_path, score_thr=0.3):
|
||
"""综合可视化:分割+检测"""
|
||
fig = plt.figure(figsize=(20, 10))
|
||
|
||
# 1. BEV分割
|
||
ax1 = plt.subplot(1, 2, 1)
|
||
seg_pred = torch.sigmoid(seg_pred)
|
||
h, w = seg_pred.shape[1:]
|
||
color_map = np.zeros((h, w, 3), dtype=np.uint8)
|
||
|
||
colors = {
|
||
'drivable_area': [128, 64, 128],
|
||
'ped_crossing': [244, 35, 232],
|
||
'walkway': [70, 70, 70],
|
||
'stop_line': [220, 20, 60],
|
||
'carpark_area': [157, 234, 50],
|
||
'divider': [255, 255, 0],
|
||
}
|
||
|
||
for idx, class_name in enumerate(classes):
|
||
mask = seg_pred[idx].cpu().numpy() > 0.5
|
||
color = colors.get(class_name, [255, 255, 255])
|
||
color_map[mask] = color
|
||
|
||
ax1.imshow(color_map)
|
||
ax1.set_title('BEV Segmentation', fontsize=16)
|
||
ax1.axis('off')
|
||
|
||
# 2. 3D检测
|
||
ax2 = plt.subplot(1, 2, 2)
|
||
|
||
if len(boxes_3d) > 0:
|
||
centers = boxes_3d[:, :2]
|
||
sizes = boxes_3d[:, 3:5]
|
||
yaws = boxes_3d[:, 6]
|
||
|
||
class_colors = plt.cm.tab10(np.linspace(0, 1, 10))
|
||
|
||
for i, (center, size, yaw, label, score) in enumerate(zip(centers, sizes, yaws, labels, scores)):
|
||
if score < score_thr:
|
||
continue
|
||
|
||
color = class_colors[label % 10]
|
||
l, w = size[0], size[1]
|
||
corners = np.array([
|
||
[-l/2, -w/2], [l/2, -w/2], [l/2, w/2], [-l/2, w/2], [-l/2, -w/2]
|
||
])
|
||
|
||
rot_mat = np.array([
|
||
[np.cos(yaw), -np.sin(yaw)],
|
||
[np.sin(yaw), np.cos(yaw)]
|
||
])
|
||
corners = corners @ rot_mat.T + center
|
||
|
||
ax2.plot(corners[:, 0], corners[:, 1], color=color, linewidth=2)
|
||
ax2.scatter(center[0], center[1], color=color, s=50, zorder=10)
|
||
|
||
ax2.set_xlim(-50, 50)
|
||
ax2.set_ylim(-50, 50)
|
||
ax2.set_aspect('equal')
|
||
ax2.grid(True, alpha=0.3)
|
||
ax2.set_xlabel('X (meters)', fontsize=12)
|
||
ax2.set_ylabel('Y (meters)', fontsize=12)
|
||
ax2.set_title('3D Object Detection (BEV)', fontsize=16)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f" ✅ 已保存: {save_path}")
|
||
|
||
|
||
def inference(model, data_loader, output_dir, cfg, score_thr=0.3):
|
||
"""执行推理并可视化"""
|
||
print(f"\n{'='*80}")
|
||
print("开始推理和可视化")
|
||
print(f"{'='*80}\n")
|
||
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 获取类别信息
|
||
map_classes = cfg.get('map_classes', [
|
||
'drivable_area', 'ped_crossing', 'walkway',
|
||
'stop_line', 'carpark_area', 'divider'
|
||
])
|
||
|
||
object_classes = cfg.get('object_classes', [
|
||
'car', 'truck', 'construction_vehicle', 'bus', 'trailer',
|
||
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
|
||
])
|
||
|
||
model.eval()
|
||
|
||
with torch.no_grad():
|
||
for idx, data in enumerate(tqdm(data_loader, desc="推理中")):
|
||
# 前向传播
|
||
result = model(return_loss=False, rescale=True, **data)
|
||
|
||
# 提取结果
|
||
if isinstance(result, dict):
|
||
# 分割结果
|
||
seg_pred = result.get('seg_pred', None)
|
||
# 检测结果
|
||
bbox_results = result.get('bbox_results', None)
|
||
else:
|
||
# 兼容性处理
|
||
bbox_results = result[0] if len(result) > 0 else None
|
||
seg_pred = None
|
||
|
||
# 可视化
|
||
sample_name = f"sample_{idx:04d}"
|
||
|
||
# 如果有分割结果
|
||
if seg_pred is not None:
|
||
seg_path = os.path.join(output_dir, f"{sample_name}_segmentation.png")
|
||
visualize_bev_segmentation(seg_pred[0], map_classes, seg_path)
|
||
|
||
# 如果有检测结果
|
||
if bbox_results is not None and 'boxes_3d' in bbox_results:
|
||
boxes_3d = bbox_results['boxes_3d'].tensor.cpu().numpy()
|
||
labels = bbox_results['labels_3d'].cpu().numpy()
|
||
scores = bbox_results['scores_3d'].cpu().numpy()
|
||
|
||
det_path = os.path.join(output_dir, f"{sample_name}_detection.png")
|
||
visualize_3d_boxes(None, boxes_3d, labels, scores, det_path, score_thr)
|
||
|
||
# 综合可视化
|
||
if seg_pred is not None and bbox_results is not None:
|
||
combined_path = os.path.join(output_dir, f"{sample_name}_combined.png")
|
||
visualize_combined(
|
||
None, seg_pred[0],
|
||
boxes_3d, labels, scores,
|
||
map_classes, combined_path, score_thr
|
||
)
|
||
|
||
print(f"\n{'='*80}")
|
||
print(f"✅ 推理完成!")
|
||
print(f" - 结果保存在: {output_dir}")
|
||
print(f" - 共处理样本: {len(data_loader)}")
|
||
print(f"{'='*80}\n")
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
|
||
print(f"\n{'='*80}")
|
||
print("BEVFusion推理和可视化")
|
||
print(f"{'='*80}")
|
||
print(f"配置文件: {args.config}")
|
||
print(f"Checkpoint: {args.checkpoint}")
|
||
print(f"样本数量: {args.samples}")
|
||
print(f"输出目录: {args.output_dir}")
|
||
print(f"置信度阈值: {args.show_score_thr}")
|
||
print(f"{'='*80}\n")
|
||
|
||
# 1. 加载模型
|
||
model, cfg = setup_model(args.config, args.checkpoint, args.device)
|
||
|
||
# 2. 构建数据加载器
|
||
data_loader, dataset = build_val_dataloader(cfg, args.samples)
|
||
|
||
# 3. 执行推理和可视化
|
||
inference(model, data_loader, args.output_dir, cfg, args.show_score_thr)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|