bev-project/inference_and_visualize.py

387 lines
12 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
BEVFusion推理和可视化脚本
基于epoch_19.pth在nuScenes验证集上进行推理并生成可视化结果
"""
import argparse
import os
import sys
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
# 添加项目路径
sys.path.insert(0, '/workspace/bevfusion')
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model
from mmcv.runner import load_checkpoint
def parse_args():
parser = argparse.ArgumentParser(description='BEVFusion Inference and Visualization')
parser.add_argument('--config',
default='configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/multitask.yaml',
help='配置文件路径')
parser.add_argument('--checkpoint',
default='runs/run-326653dc-74184412/epoch_19.pth',
help='checkpoint文件路径')
parser.add_argument('--samples', type=int, default=10,
help='推理样本数量')
parser.add_argument('--output-dir', default='inference_results',
help='输出目录')
parser.add_argument('--show-score-thr', type=float, default=0.3,
help='检测框显示的置信度阈值')
parser.add_argument('--device', default='cuda:0',
help='推理设备')
return parser.parse_args()
def setup_model(config_path, checkpoint_path, device='cuda:0'):
"""加载模型和权重"""
print(f"\n{'='*80}")
print("加载模型配置和权重")
print(f"{'='*80}")
# 加载配置
cfg = Config.fromfile(config_path)
# 构建模型 - 使用完整的model配置
print("\n1. 构建模型...")
model = build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')
)
# 加载checkpoint
print(f"\n2. 加载checkpoint: {checkpoint_path}")
checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu')
# 设置为评估模式
model.eval()
model = model.to(device)
print(f"\n✅ 模型加载完成")
print(f" - Device: {device}")
print(f" - Checkpoint epoch: {checkpoint.get('meta', {}).get('epoch', 'N/A')}")
return model, cfg
def build_val_dataloader(cfg, samples=10):
"""构建验证数据加载器"""
print(f"\n{'='*80}")
print("构建数据加载器")
print(f"{'='*80}")
# 构建验证数据集
dataset = build_dataset(cfg.data.val)
# 限制样本数量
if samples < len(dataset):
print(f"\n⚠️ 限制样本数量: {len(dataset)}{samples}")
# 均匀采样
indices = np.linspace(0, len(dataset)-1, samples, dtype=int)
dataset = torch.utils.data.Subset(dataset, indices)
# 构建数据加载器
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=0,
dist=False,
shuffle=False
)
print(f"\n✅ 数据加载器构建完成")
print(f" - 样本数量: {len(dataset)}")
print(f" - 数据集: nuScenes validation set")
return data_loader, dataset
def visualize_bev_segmentation(seg_pred, classes, save_path):
"""可视化BEV分割结果"""
# seg_pred: (C, H, W) - C个类别的预测
seg_pred = torch.sigmoid(seg_pred)
# 创建彩色分割图
h, w = seg_pred.shape[1:]
color_map = np.zeros((h, w, 3), dtype=np.uint8)
# 定义每个类别的颜色
colors = {
'drivable_area': [128, 64, 128], # 紫色
'ped_crossing': [244, 35, 232], # 粉色
'walkway': [70, 70, 70], # 灰色
'stop_line': [220, 20, 60], # 红色
'carpark_area': [157, 234, 50], # 绿色
'divider': [255, 255, 0], # 黄色
}
# 为每个类别着色
for idx, class_name in enumerate(classes):
mask = seg_pred[idx].cpu().numpy() > 0.5
color = colors.get(class_name, [255, 255, 255])
color_map[mask] = color
# 保存可视化结果
plt.figure(figsize=(10, 10))
plt.imshow(color_map)
plt.title('BEV Segmentation')
plt.axis('off')
# 添加图例
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor=np.array(colors[name])/255, label=name)
for name in classes if name in colors
]
plt.legend(handles=legend_elements, loc='upper right')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
return color_map
def visualize_3d_boxes(img, boxes_3d, labels, scores, save_path, score_thr=0.3):
"""可视化3D检测框投影到图像上"""
# 简化版在BEV视图上显示框
fig, ax = plt.subplots(figsize=(12, 12))
# 绘制检测框
if len(boxes_3d) > 0:
# 提取中心点和尺寸
centers = boxes_3d[:, :2] # x, y
sizes = boxes_3d[:, 3:5] # length, width
yaws = boxes_3d[:, 6] # yaw angle
# 定义类别颜色
class_colors = plt.cm.tab10(np.linspace(0, 1, 10))
for i, (center, size, yaw, label, score) in enumerate(zip(centers, sizes, yaws, labels, scores)):
if score < score_thr:
continue
# 绘制框
color = class_colors[label % 10]
# 计算框的四个角点
l, w = size[0], size[1]
corners = np.array([
[-l/2, -w/2],
[l/2, -w/2],
[l/2, w/2],
[-l/2, w/2],
[-l/2, -w/2] # 闭合
])
# 旋转
rot_mat = np.array([
[np.cos(yaw), -np.sin(yaw)],
[np.sin(yaw), np.cos(yaw)]
])
corners = corners @ rot_mat.T
corners += center
# 绘制
ax.plot(corners[:, 0], corners[:, 1], color=color, linewidth=2)
ax.scatter(center[0], center[1], color=color, s=50, zorder=10)
# 添加标签
ax.text(center[0], center[1], f'{score:.2f}',
fontsize=8, color='white',
bbox=dict(boxstyle='round', facecolor=color, alpha=0.7))
ax.set_xlim(-50, 50)
ax.set_ylim(-50, 50)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.set_xlabel('X (meters)')
ax.set_ylabel('Y (meters)')
ax.set_title('3D Detection Results (BEV)')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
def visualize_combined(img, seg_pred, boxes_3d, labels, scores, classes, save_path, score_thr=0.3):
"""综合可视化:分割+检测"""
fig = plt.figure(figsize=(20, 10))
# 1. BEV分割
ax1 = plt.subplot(1, 2, 1)
seg_pred = torch.sigmoid(seg_pred)
h, w = seg_pred.shape[1:]
color_map = np.zeros((h, w, 3), dtype=np.uint8)
colors = {
'drivable_area': [128, 64, 128],
'ped_crossing': [244, 35, 232],
'walkway': [70, 70, 70],
'stop_line': [220, 20, 60],
'carpark_area': [157, 234, 50],
'divider': [255, 255, 0],
}
for idx, class_name in enumerate(classes):
mask = seg_pred[idx].cpu().numpy() > 0.5
color = colors.get(class_name, [255, 255, 255])
color_map[mask] = color
ax1.imshow(color_map)
ax1.set_title('BEV Segmentation', fontsize=16)
ax1.axis('off')
# 2. 3D检测
ax2 = plt.subplot(1, 2, 2)
if len(boxes_3d) > 0:
centers = boxes_3d[:, :2]
sizes = boxes_3d[:, 3:5]
yaws = boxes_3d[:, 6]
class_colors = plt.cm.tab10(np.linspace(0, 1, 10))
for i, (center, size, yaw, label, score) in enumerate(zip(centers, sizes, yaws, labels, scores)):
if score < score_thr:
continue
color = class_colors[label % 10]
l, w = size[0], size[1]
corners = np.array([
[-l/2, -w/2], [l/2, -w/2], [l/2, w/2], [-l/2, w/2], [-l/2, -w/2]
])
rot_mat = np.array([
[np.cos(yaw), -np.sin(yaw)],
[np.sin(yaw), np.cos(yaw)]
])
corners = corners @ rot_mat.T + center
ax2.plot(corners[:, 0], corners[:, 1], color=color, linewidth=2)
ax2.scatter(center[0], center[1], color=color, s=50, zorder=10)
ax2.set_xlim(-50, 50)
ax2.set_ylim(-50, 50)
ax2.set_aspect('equal')
ax2.grid(True, alpha=0.3)
ax2.set_xlabel('X (meters)', fontsize=12)
ax2.set_ylabel('Y (meters)', fontsize=12)
ax2.set_title('3D Object Detection (BEV)', fontsize=16)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f" ✅ 已保存: {save_path}")
def inference(model, data_loader, output_dir, cfg, score_thr=0.3):
"""执行推理并可视化"""
print(f"\n{'='*80}")
print("开始推理和可视化")
print(f"{'='*80}\n")
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 获取类别信息
map_classes = cfg.get('map_classes', [
'drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider'
])
object_classes = cfg.get('object_classes', [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
])
model.eval()
with torch.no_grad():
for idx, data in enumerate(tqdm(data_loader, desc="推理中")):
# 前向传播
result = model(return_loss=False, rescale=True, **data)
# 提取结果
if isinstance(result, dict):
# 分割结果
seg_pred = result.get('seg_pred', None)
# 检测结果
bbox_results = result.get('bbox_results', None)
else:
# 兼容性处理
bbox_results = result[0] if len(result) > 0 else None
seg_pred = None
# 可视化
sample_name = f"sample_{idx:04d}"
# 如果有分割结果
if seg_pred is not None:
seg_path = os.path.join(output_dir, f"{sample_name}_segmentation.png")
visualize_bev_segmentation(seg_pred[0], map_classes, seg_path)
# 如果有检测结果
if bbox_results is not None and 'boxes_3d' in bbox_results:
boxes_3d = bbox_results['boxes_3d'].tensor.cpu().numpy()
labels = bbox_results['labels_3d'].cpu().numpy()
scores = bbox_results['scores_3d'].cpu().numpy()
det_path = os.path.join(output_dir, f"{sample_name}_detection.png")
visualize_3d_boxes(None, boxes_3d, labels, scores, det_path, score_thr)
# 综合可视化
if seg_pred is not None and bbox_results is not None:
combined_path = os.path.join(output_dir, f"{sample_name}_combined.png")
visualize_combined(
None, seg_pred[0],
boxes_3d, labels, scores,
map_classes, combined_path, score_thr
)
print(f"\n{'='*80}")
print(f"✅ 推理完成!")
print(f" - 结果保存在: {output_dir}")
print(f" - 共处理样本: {len(data_loader)}")
print(f"{'='*80}\n")
def main():
args = parse_args()
print(f"\n{'='*80}")
print("BEVFusion推理和可视化")
print(f"{'='*80}")
print(f"配置文件: {args.config}")
print(f"Checkpoint: {args.checkpoint}")
print(f"样本数量: {args.samples}")
print(f"输出目录: {args.output_dir}")
print(f"置信度阈值: {args.show_score_thr}")
print(f"{'='*80}\n")
# 1. 加载模型
model, cfg = setup_model(args.config, args.checkpoint, args.device)
# 2. 构建数据加载器
data_loader, dataset = build_val_dataloader(cfg, args.samples)
# 3. 执行推理和可视化
inference(model, data_loader, args.output_dir, cfg, args.show_score_thr)
if __name__ == '__main__':
main()