bev-project/inference_and_visualize.py

387 lines
12 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
BEVFusion推理和可视化脚本
基于epoch_19.pth在nuScenes验证集上进行推理并生成可视化结果
"""
import argparse
import os
import sys
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
# 添加项目路径
sys.path.insert(0, '/workspace/bevfusion')
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model
from mmcv.runner import load_checkpoint
def parse_args():
parser = argparse.ArgumentParser(description='BEVFusion Inference and Visualization')
parser.add_argument('--config',
default='configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/multitask.yaml',
help='配置文件路径')
parser.add_argument('--checkpoint',
default='runs/run-326653dc-74184412/epoch_19.pth',
help='checkpoint文件路径')
parser.add_argument('--samples', type=int, default=10,
help='推理样本数量')
parser.add_argument('--output-dir', default='inference_results',
help='输出目录')
parser.add_argument('--show-score-thr', type=float, default=0.3,
help='检测框显示的置信度阈值')
parser.add_argument('--device', default='cuda:0',
help='推理设备')
return parser.parse_args()
def setup_model(config_path, checkpoint_path, device='cuda:0'):
"""加载模型和权重"""
print(f"\n{'='*80}")
print("加载模型配置和权重")
print(f"{'='*80}")
# 加载配置
cfg = Config.fromfile(config_path)
# 构建模型 - 使用完整的model配置
print("\n1. 构建模型...")
model = build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')
)
# 加载checkpoint
print(f"\n2. 加载checkpoint: {checkpoint_path}")
checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu')
# 设置为评估模式
model.eval()
model = model.to(device)
print(f"\n✅ 模型加载完成")
print(f" - Device: {device}")
print(f" - Checkpoint epoch: {checkpoint.get('meta', {}).get('epoch', 'N/A')}")
return model, cfg
def build_val_dataloader(cfg, samples=10):
"""构建验证数据加载器"""
print(f"\n{'='*80}")
print("构建数据加载器")
print(f"{'='*80}")
# 构建验证数据集
dataset = build_dataset(cfg.data.val)
# 限制样本数量
if samples < len(dataset):
print(f"\n⚠️ 限制样本数量: {len(dataset)}{samples}")
# 均匀采样
indices = np.linspace(0, len(dataset)-1, samples, dtype=int)
dataset = torch.utils.data.Subset(dataset, indices)
# 构建数据加载器
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=0,
dist=False,
shuffle=False
)
print(f"\n✅ 数据加载器构建完成")
print(f" - 样本数量: {len(dataset)}")
print(f" - 数据集: nuScenes validation set")
return data_loader, dataset
def visualize_bev_segmentation(seg_pred, classes, save_path):
"""可视化BEV分割结果"""
# seg_pred: (C, H, W) - C个类别的预测
seg_pred = torch.sigmoid(seg_pred)
# 创建彩色分割图
h, w = seg_pred.shape[1:]
color_map = np.zeros((h, w, 3), dtype=np.uint8)
# 定义每个类别的颜色
colors = {
'drivable_area': [128, 64, 128], # 紫色
'ped_crossing': [244, 35, 232], # 粉色
'walkway': [70, 70, 70], # 灰色
'stop_line': [220, 20, 60], # 红色
'carpark_area': [157, 234, 50], # 绿色
'divider': [255, 255, 0], # 黄色
}
# 为每个类别着色
for idx, class_name in enumerate(classes):
mask = seg_pred[idx].cpu().numpy() > 0.5
color = colors.get(class_name, [255, 255, 255])
color_map[mask] = color
# 保存可视化结果
plt.figure(figsize=(10, 10))
plt.imshow(color_map)
plt.title('BEV Segmentation')
plt.axis('off')
# 添加图例
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor=np.array(colors[name])/255, label=name)
for name in classes if name in colors
]
plt.legend(handles=legend_elements, loc='upper right')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
return color_map
def visualize_3d_boxes(img, boxes_3d, labels, scores, save_path, score_thr=0.3):
"""可视化3D检测框投影到图像上"""
# 简化版在BEV视图上显示框
fig, ax = plt.subplots(figsize=(12, 12))
# 绘制检测框
if len(boxes_3d) > 0:
# 提取中心点和尺寸
centers = boxes_3d[:, :2] # x, y
sizes = boxes_3d[:, 3:5] # length, width
yaws = boxes_3d[:, 6] # yaw angle
# 定义类别颜色
class_colors = plt.cm.tab10(np.linspace(0, 1, 10))
for i, (center, size, yaw, label, score) in enumerate(zip(centers, sizes, yaws, labels, scores)):
if score < score_thr:
continue
# 绘制框
color = class_colors[label % 10]
# 计算框的四个角点
l, w = size[0], size[1]
corners = np.array([
[-l/2, -w/2],
[l/2, -w/2],
[l/2, w/2],
[-l/2, w/2],
[-l/2, -w/2] # 闭合
])
# 旋转
rot_mat = np.array([
[np.cos(yaw), -np.sin(yaw)],
[np.sin(yaw), np.cos(yaw)]
])
corners = corners @ rot_mat.T
corners += center
# 绘制
ax.plot(corners[:, 0], corners[:, 1], color=color, linewidth=2)
ax.scatter(center[0], center[1], color=color, s=50, zorder=10)
# 添加标签
ax.text(center[0], center[1], f'{score:.2f}',
fontsize=8, color='white',
bbox=dict(boxstyle='round', facecolor=color, alpha=0.7))
ax.set_xlim(-50, 50)
ax.set_ylim(-50, 50)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.set_xlabel('X (meters)')
ax.set_ylabel('Y (meters)')
ax.set_title('3D Detection Results (BEV)')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
def visualize_combined(img, seg_pred, boxes_3d, labels, scores, classes, save_path, score_thr=0.3):
"""综合可视化:分割+检测"""
fig = plt.figure(figsize=(20, 10))
# 1. BEV分割
ax1 = plt.subplot(1, 2, 1)
seg_pred = torch.sigmoid(seg_pred)
h, w = seg_pred.shape[1:]
color_map = np.zeros((h, w, 3), dtype=np.uint8)
colors = {
'drivable_area': [128, 64, 128],
'ped_crossing': [244, 35, 232],
'walkway': [70, 70, 70],
'stop_line': [220, 20, 60],
'carpark_area': [157, 234, 50],
'divider': [255, 255, 0],
}
for idx, class_name in enumerate(classes):
mask = seg_pred[idx].cpu().numpy() > 0.5
color = colors.get(class_name, [255, 255, 255])
color_map[mask] = color
ax1.imshow(color_map)
ax1.set_title('BEV Segmentation', fontsize=16)
ax1.axis('off')
# 2. 3D检测
ax2 = plt.subplot(1, 2, 2)
if len(boxes_3d) > 0:
centers = boxes_3d[:, :2]
sizes = boxes_3d[:, 3:5]
yaws = boxes_3d[:, 6]
class_colors = plt.cm.tab10(np.linspace(0, 1, 10))
for i, (center, size, yaw, label, score) in enumerate(zip(centers, sizes, yaws, labels, scores)):
if score < score_thr:
continue
color = class_colors[label % 10]
l, w = size[0], size[1]
corners = np.array([
[-l/2, -w/2], [l/2, -w/2], [l/2, w/2], [-l/2, w/2], [-l/2, -w/2]
])
rot_mat = np.array([
[np.cos(yaw), -np.sin(yaw)],
[np.sin(yaw), np.cos(yaw)]
])
corners = corners @ rot_mat.T + center
ax2.plot(corners[:, 0], corners[:, 1], color=color, linewidth=2)
ax2.scatter(center[0], center[1], color=color, s=50, zorder=10)
ax2.set_xlim(-50, 50)
ax2.set_ylim(-50, 50)
ax2.set_aspect('equal')
ax2.grid(True, alpha=0.3)
ax2.set_xlabel('X (meters)', fontsize=12)
ax2.set_ylabel('Y (meters)', fontsize=12)
ax2.set_title('3D Object Detection (BEV)', fontsize=16)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f" ✅ 已保存: {save_path}")
def inference(model, data_loader, output_dir, cfg, score_thr=0.3):
"""执行推理并可视化"""
print(f"\n{'='*80}")
print("开始推理和可视化")
print(f"{'='*80}\n")
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 获取类别信息
map_classes = cfg.get('map_classes', [
'drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider'
])
object_classes = cfg.get('object_classes', [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
])
model.eval()
with torch.no_grad():
for idx, data in enumerate(tqdm(data_loader, desc="推理中")):
# 前向传播
result = model(return_loss=False, rescale=True, **data)
# 提取结果
if isinstance(result, dict):
# 分割结果
seg_pred = result.get('seg_pred', None)
# 检测结果
bbox_results = result.get('bbox_results', None)
else:
# 兼容性处理
bbox_results = result[0] if len(result) > 0 else None
seg_pred = None
# 可视化
sample_name = f"sample_{idx:04d}"
# 如果有分割结果
if seg_pred is not None:
seg_path = os.path.join(output_dir, f"{sample_name}_segmentation.png")
visualize_bev_segmentation(seg_pred[0], map_classes, seg_path)
# 如果有检测结果
if bbox_results is not None and 'boxes_3d' in bbox_results:
boxes_3d = bbox_results['boxes_3d'].tensor.cpu().numpy()
labels = bbox_results['labels_3d'].cpu().numpy()
scores = bbox_results['scores_3d'].cpu().numpy()
det_path = os.path.join(output_dir, f"{sample_name}_detection.png")
visualize_3d_boxes(None, boxes_3d, labels, scores, det_path, score_thr)
# 综合可视化
if seg_pred is not None and bbox_results is not None:
combined_path = os.path.join(output_dir, f"{sample_name}_combined.png")
visualize_combined(
None, seg_pred[0],
boxes_3d, labels, scores,
map_classes, combined_path, score_thr
)
print(f"\n{'='*80}")
print(f"✅ 推理完成!")
print(f" - 结果保存在: {output_dir}")
print(f" - 共处理样本: {len(data_loader)}")
print(f"{'='*80}\n")
def main():
args = parse_args()
print(f"\n{'='*80}")
print("BEVFusion推理和可视化")
print(f"{'='*80}")
print(f"配置文件: {args.config}")
print(f"Checkpoint: {args.checkpoint}")
print(f"样本数量: {args.samples}")
print(f"输出目录: {args.output_dir}")
print(f"置信度阈值: {args.show_score_thr}")
print(f"{'='*80}\n")
# 1. 加载模型
model, cfg = setup_model(args.config, args.checkpoint, args.device)
# 2. 构建数据加载器
data_loader, dataset = build_val_dataloader(cfg, args.samples)
# 3. 执行推理和可视化
inference(model, data_loader, args.output_dir, cfg, args.show_score_thr)
if __name__ == '__main__':
main()