207 lines
7.1 KiB
Python
207 lines
7.1 KiB
Python
#!/usr/bin/env python
|
||
"""
|
||
基于推理结果进行可视化 - 参考visualize_single.py
|
||
"""
|
||
import argparse
|
||
import copy
|
||
import os
|
||
import pickle
|
||
|
||
import mmcv
|
||
import numpy as np
|
||
import torch
|
||
from mmcv import Config
|
||
from mmcv.runner import load_checkpoint
|
||
from torchpack.utils.config import configs
|
||
|
||
from mmdet3d.core import LiDARInstance3DBoxes
|
||
from mmdet3d.core.utils import visualize_camera, visualize_lidar, visualize_map
|
||
from mmdet3d.datasets import build_dataloader, build_dataset
|
||
|
||
|
||
def recursive_eval(obj, globals=None):
|
||
if globals is None:
|
||
globals = copy.deepcopy(obj)
|
||
|
||
if isinstance(obj, dict):
|
||
for key in obj:
|
||
obj[key] = recursive_eval(obj[key], globals)
|
||
elif isinstance(obj, list):
|
||
for k, val in enumerate(obj):
|
||
obj[k] = recursive_eval(val, globals)
|
||
elif isinstance(obj, str) and obj.startswith("${") and obj.endswith("}"):
|
||
obj = eval(obj[2:-1], globals)
|
||
obj = recursive_eval(obj, globals)
|
||
|
||
return obj
|
||
|
||
|
||
def load_inference_results(result_file):
|
||
"""加载推理结果"""
|
||
print(f"加载推理结果: {result_file}")
|
||
with open(result_file, 'rb') as f:
|
||
results = pickle.load(f)
|
||
print(f"结果包含 {len(results)} 个样本")
|
||
return results
|
||
|
||
|
||
def visualize_sample(sample_idx, sample, cfg, metas=None, out_dir="viz_inference"):
|
||
"""可视化单个样本"""
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
|
||
print(f"\n--- 可视化样本 {sample_idx} ---")
|
||
|
||
# 准备检测结果
|
||
bboxes = None
|
||
labels = None
|
||
scores = None
|
||
|
||
if 'boxes_3d' in sample and len(sample['boxes_3d']) > 0:
|
||
# 3D检测框
|
||
boxes_3d = sample['boxes_3d']
|
||
scores_3d = sample['scores_3d']
|
||
labels_3d = sample['labels_3d']
|
||
|
||
print(f"3D检测框数量: {len(boxes_3d)}")
|
||
if len(scores_3d) > 0:
|
||
print(".3f")
|
||
|
||
# 转换为LiDARInstance3DBoxes格式
|
||
if torch.is_tensor(boxes_3d):
|
||
boxes_3d = boxes_3d.cpu().numpy()
|
||
if torch.is_tensor(scores_3d):
|
||
scores_3d = scores_3d.cpu().numpy()
|
||
if torch.is_tensor(labels_3d):
|
||
labels_3d = labels_3d.cpu().numpy()
|
||
|
||
# 过滤低置信度检测
|
||
if len(scores_3d) > 0:
|
||
keep = scores_3d >= 0.1 # 置信度阈值
|
||
boxes_3d = boxes_3d[keep]
|
||
scores_3d = scores_3d[keep]
|
||
labels_3d = labels_3d[keep]
|
||
|
||
if len(boxes_3d) > 0:
|
||
# 调整z坐标 (底面到中心)
|
||
boxes_3d = boxes_3d.copy()
|
||
boxes_3d[:, 2] -= boxes_3d[:, 5] / 2
|
||
bboxes = LiDARInstance3DBoxes(boxes_3d, box_dim=9)
|
||
scores = scores_3d
|
||
labels = labels_3d.astype(np.int64)
|
||
print(f"过滤后检测框数量: {len(boxes_3d)}")
|
||
|
||
# 准备分割结果
|
||
masks = None
|
||
if 'masks_bev' in sample:
|
||
masks_bev = sample['masks_bev']
|
||
print(f"BEV分割形状: {masks_bev.shape}")
|
||
|
||
if torch.is_tensor(masks_bev):
|
||
masks_bev = masks_bev.cpu().numpy()
|
||
|
||
# 二值化分割结果 (阈值0.5)
|
||
masks = masks_bev >= 0.5
|
||
masks = masks.astype(np.bool)
|
||
|
||
# 创建虚拟的metas信息(用于可视化)
|
||
if metas is None:
|
||
metas = {
|
||
"timestamp": f"sample_{sample_idx}",
|
||
"token": f"token_{sample_idx}",
|
||
"lidar2image": [np.eye(4)], # 单位矩阵
|
||
"filename": ["dummy_image.jpg"], # 虚拟图像路径
|
||
}
|
||
|
||
# 可视化分割结果 (BEV地图)
|
||
if masks is not None:
|
||
map_out_dir = os.path.join(out_dir, "map")
|
||
os.makedirs(map_out_dir, exist_ok=True)
|
||
|
||
map_path = os.path.join(map_out_dir, f"sample_{sample_idx:03d}.png")
|
||
visualize_map(
|
||
map_path,
|
||
masks,
|
||
classes=cfg.map_classes if hasattr(cfg, 'map_classes') else
|
||
['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider']
|
||
)
|
||
print(f"BEV分割可视化已保存: {map_path}")
|
||
|
||
# 打印分割统计
|
||
class_names = ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider']
|
||
print("BEV分割统计:")
|
||
for i, name in enumerate(class_names[:len(masks)]):
|
||
pixel_count = masks[i].sum()
|
||
percentage = pixel_count / masks[i].size * 100
|
||
print("15s")
|
||
|
||
# 创建虚拟点云数据用于LiDAR可视化
|
||
if bboxes is not None:
|
||
# 创建虚拟的LiDAR点云 (只用于可视化)
|
||
# BEV范围: [-50, 50] x [-50, 50] x [-5, 3]
|
||
num_points = 10000
|
||
points = np.random.rand(num_points, 5)
|
||
points[:, 0] = (points[:, 0] - 0.5) * 100 # x: -50 to 50
|
||
points[:, 1] = (points[:, 1] - 0.5) * 100 # y: -50 to 50
|
||
points[:, 2] = points[:, 2] * 8 - 5 # z: -5 to 3
|
||
points[:, 3] = np.ones(num_points) # intensity
|
||
points[:, 4] = np.zeros(num_points) # ring (unused)
|
||
|
||
lidar_out_dir = os.path.join(out_dir, "lidar")
|
||
os.makedirs(lidar_out_dir, exist_ok=True)
|
||
|
||
lidar_path = os.path.join(lidar_out_dir, f"sample_{sample_idx:03d}.png")
|
||
visualize_lidar(
|
||
lidar_path,
|
||
points,
|
||
bboxes=bboxes,
|
||
labels=labels,
|
||
xlim=[-50, 50], # BEV范围
|
||
ylim=[-50, 50],
|
||
classes=cfg.object_classes if hasattr(cfg, 'object_classes') else
|
||
['car', 'truck', 'construction_vehicle', 'bus', 'trailer',
|
||
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
|
||
)
|
||
print(f"LiDAR检测可视化已保存: {lidar_path}")
|
||
|
||
print(f"样本 {sample_idx} 可视化完成")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='可视化推理结果')
|
||
parser.add_argument('--result-file', type=str,
|
||
default='/data/infer_test/20251120_124755/one_batch_results.pkl',
|
||
help='推理结果文件路径')
|
||
parser.add_argument('--config', type=str,
|
||
default='configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/multitask_BEV2X_phase4b_rmtppad_segmentation.yaml',
|
||
help='配置文件路径')
|
||
parser.add_argument('--sample-idx', type=int, default=0,
|
||
help='要可视化的样本索引')
|
||
parser.add_argument('--out-dir', type=str, default='viz_inference_results',
|
||
help='可视化结果保存目录')
|
||
args = parser.parse_args()
|
||
|
||
# 加载配置
|
||
configs.load(args.config, recursive=True)
|
||
cfg = Config(recursive_eval(configs), filename=args.config)
|
||
|
||
# 加载推理结果
|
||
results = load_inference_results(args.result_file)
|
||
|
||
if args.sample_idx >= len(results):
|
||
print(f"样本索引 {args.sample_idx} 超出范围 (总共 {len(results)} 个样本)")
|
||
return
|
||
|
||
# 可视化指定样本
|
||
sample = results[args.sample_idx]
|
||
visualize_sample(args.sample_idx, sample, cfg, out_dir=args.out_dir)
|
||
|
||
print("\n✅ 可视化完成!")
|
||
print(f"结果保存在: {args.out_dir}")
|
||
print("\n包含的文件:")
|
||
print("- map/sample_XXX.png: BEV分割结果")
|
||
print("- lidar/sample_XXX.png: LiDAR检测结果")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|