bev-project/VISUALIZE_INFERENCE_RESULTS.py

207 lines
7.1 KiB
Python
Raw Normal View History

2025-11-21 10:50:51 +08:00
#!/usr/bin/env python
"""
基于推理结果进行可视化 - 参考visualize_single.py
"""
import argparse
import copy
import os
import pickle
import mmcv
import numpy as np
import torch
from mmcv import Config
from mmcv.runner import load_checkpoint
from torchpack.utils.config import configs
from mmdet3d.core import LiDARInstance3DBoxes
from mmdet3d.core.utils import visualize_camera, visualize_lidar, visualize_map
from mmdet3d.datasets import build_dataloader, build_dataset
def recursive_eval(obj, globals=None):
if globals is None:
globals = copy.deepcopy(obj)
if isinstance(obj, dict):
for key in obj:
obj[key] = recursive_eval(obj[key], globals)
elif isinstance(obj, list):
for k, val in enumerate(obj):
obj[k] = recursive_eval(val, globals)
elif isinstance(obj, str) and obj.startswith("${") and obj.endswith("}"):
obj = eval(obj[2:-1], globals)
obj = recursive_eval(obj, globals)
return obj
def load_inference_results(result_file):
"""加载推理结果"""
print(f"加载推理结果: {result_file}")
with open(result_file, 'rb') as f:
results = pickle.load(f)
print(f"结果包含 {len(results)} 个样本")
return results
def visualize_sample(sample_idx, sample, cfg, metas=None, out_dir="viz_inference"):
"""可视化单个样本"""
os.makedirs(out_dir, exist_ok=True)
print(f"\n--- 可视化样本 {sample_idx} ---")
# 准备检测结果
bboxes = None
labels = None
scores = None
if 'boxes_3d' in sample and len(sample['boxes_3d']) > 0:
# 3D检测框
boxes_3d = sample['boxes_3d']
scores_3d = sample['scores_3d']
labels_3d = sample['labels_3d']
print(f"3D检测框数量: {len(boxes_3d)}")
if len(scores_3d) > 0:
print(".3f")
# 转换为LiDARInstance3DBoxes格式
if torch.is_tensor(boxes_3d):
boxes_3d = boxes_3d.cpu().numpy()
if torch.is_tensor(scores_3d):
scores_3d = scores_3d.cpu().numpy()
if torch.is_tensor(labels_3d):
labels_3d = labels_3d.cpu().numpy()
# 过滤低置信度检测
if len(scores_3d) > 0:
keep = scores_3d >= 0.1 # 置信度阈值
boxes_3d = boxes_3d[keep]
scores_3d = scores_3d[keep]
labels_3d = labels_3d[keep]
if len(boxes_3d) > 0:
# 调整z坐标 (底面到中心)
boxes_3d = boxes_3d.copy()
boxes_3d[:, 2] -= boxes_3d[:, 5] / 2
bboxes = LiDARInstance3DBoxes(boxes_3d, box_dim=9)
scores = scores_3d
labels = labels_3d.astype(np.int64)
print(f"过滤后检测框数量: {len(boxes_3d)}")
# 准备分割结果
masks = None
if 'masks_bev' in sample:
masks_bev = sample['masks_bev']
print(f"BEV分割形状: {masks_bev.shape}")
if torch.is_tensor(masks_bev):
masks_bev = masks_bev.cpu().numpy()
# 二值化分割结果 (阈值0.5)
masks = masks_bev >= 0.5
masks = masks.astype(np.bool)
# 创建虚拟的metas信息用于可视化
if metas is None:
metas = {
"timestamp": f"sample_{sample_idx}",
"token": f"token_{sample_idx}",
"lidar2image": [np.eye(4)], # 单位矩阵
"filename": ["dummy_image.jpg"], # 虚拟图像路径
}
# 可视化分割结果 (BEV地图)
if masks is not None:
map_out_dir = os.path.join(out_dir, "map")
os.makedirs(map_out_dir, exist_ok=True)
map_path = os.path.join(map_out_dir, f"sample_{sample_idx:03d}.png")
visualize_map(
map_path,
masks,
classes=cfg.map_classes if hasattr(cfg, 'map_classes') else
['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider']
)
print(f"BEV分割可视化已保存: {map_path}")
# 打印分割统计
class_names = ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider']
print("BEV分割统计:")
for i, name in enumerate(class_names[:len(masks)]):
pixel_count = masks[i].sum()
percentage = pixel_count / masks[i].size * 100
print("15s")
# 创建虚拟点云数据用于LiDAR可视化
if bboxes is not None:
# 创建虚拟的LiDAR点云 (只用于可视化)
# BEV范围: [-50, 50] x [-50, 50] x [-5, 3]
num_points = 10000
points = np.random.rand(num_points, 5)
points[:, 0] = (points[:, 0] - 0.5) * 100 # x: -50 to 50
points[:, 1] = (points[:, 1] - 0.5) * 100 # y: -50 to 50
points[:, 2] = points[:, 2] * 8 - 5 # z: -5 to 3
points[:, 3] = np.ones(num_points) # intensity
points[:, 4] = np.zeros(num_points) # ring (unused)
lidar_out_dir = os.path.join(out_dir, "lidar")
os.makedirs(lidar_out_dir, exist_ok=True)
lidar_path = os.path.join(lidar_out_dir, f"sample_{sample_idx:03d}.png")
visualize_lidar(
lidar_path,
points,
bboxes=bboxes,
labels=labels,
xlim=[-50, 50], # BEV范围
ylim=[-50, 50],
classes=cfg.object_classes if hasattr(cfg, 'object_classes') else
['car', 'truck', 'construction_vehicle', 'bus', 'trailer',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
)
print(f"LiDAR检测可视化已保存: {lidar_path}")
print(f"样本 {sample_idx} 可视化完成")
def main():
parser = argparse.ArgumentParser(description='可视化推理结果')
parser.add_argument('--result-file', type=str,
default='/data/infer_test/20251120_124755/one_batch_results.pkl',
help='推理结果文件路径')
parser.add_argument('--config', type=str,
default='configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/multitask_BEV2X_phase4b_rmtppad_segmentation.yaml',
help='配置文件路径')
parser.add_argument('--sample-idx', type=int, default=0,
help='要可视化的样本索引')
parser.add_argument('--out-dir', type=str, default='viz_inference_results',
help='可视化结果保存目录')
args = parser.parse_args()
# 加载配置
configs.load(args.config, recursive=True)
cfg = Config(recursive_eval(configs), filename=args.config)
# 加载推理结果
results = load_inference_results(args.result_file)
if args.sample_idx >= len(results):
print(f"样本索引 {args.sample_idx} 超出范围 (总共 {len(results)} 个样本)")
return
# 可视化指定样本
sample = results[args.sample_idx]
visualize_sample(args.sample_idx, sample, cfg, out_dir=args.out_dir)
print("\n✅ 可视化完成!")
print(f"结果保存在: {args.out_dir}")
print("\n包含的文件:")
print("- map/sample_XXX.png: BEV分割结果")
print("- lidar/sample_XXX.png: LiDAR检测结果")
if __name__ == '__main__':
main()