bev-project/EVAL_FROM_PKL.py

99 lines
3.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
"""
从已有的pkl结果文件进行评估跳过推理步骤
"""
import os
import pickle
import sys
import torch
# 设置环境
os.environ['PATH'] = '/opt/conda/bin:' + os.environ.get('PATH', '')
os.environ['LD_LIBRARY_PATH'] = '/opt/conda/lib/python3.8/site-packages/torch/lib:/opt/conda/lib:/usr/local/cuda/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')
os.environ['PYTHONPATH'] = '/workspace/bevfusion:' + os.environ.get('PYTHONPATH', '')
def main():
import mmcv
from mmcv import Config
from mmdet3d.datasets import NuScenesDataset
# 加载配置
config_file = 'configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/multitask_BEV2X_phase4b_rmtppad_segmentation.yaml'
cfg = Config.fromfile(config_file)
# 设置为测试模式
cfg.data.val.test_mode = True
# 创建数据集
print("正在创建数据集...")
from mmdet3d.datasets import build_dataset
# 构建完整的val配置 (参考nuscenes/default.yaml)
val_config = {
'type': 'NuScenesDataset',
'dataset_root': cfg.dataset_root,
'ann_file': cfg.dataset_root + 'nuscenes_infos_val.pkl',
'pipeline': cfg.evaluation.pipeline,
'object_classes': cfg.object_classes,
'map_classes': cfg.map_classes,
'modality': {
'use_lidar': True,
'use_camera': True,
'use_radar': False,
'use_map': False,
'use_external': False
},
'test_mode': True,
'use_valid_flag': False,
'box_type_3d': 'LiDAR',
'load_interval': cfg.data.val.load_interval if hasattr(cfg.data.val, 'load_interval') else 1
}
dataset = build_dataset(val_config)
print(f"数据集大小: {len(dataset)}")
# 加载已有的结果
results_file = '/data/eval_fast/epoch1_fast_20251119_133104/fast_results.pkl'
print(f"正在加载结果文件: {results_file}")
with open(results_file, 'rb') as f:
results = pickle.load(f)
print(f"结果文件包含 {len(results)} 个样本")
# 检查结果格式
if results:
sample = results[0]
print(f"样本键: {list(sample.keys())}")
if 'boxes_3d' in sample:
print(f"检测框数量: {len(sample['boxes_3d'])}")
if 'masks_bev' in sample:
print(f"BEV分割mask形状: {sample['masks_bev'].shape}")
# 进行评估
print("开始评估...")
eval_kwargs = dict(
metric=['bbox', 'map'],
save_best=None,
rule=None,
logger=None
)
try:
eval_results = dataset.evaluate(results, **eval_kwargs)
print("\n" + "="*50)
print("评估结果:")
print("="*50)
for key, value in eval_results.items():
if isinstance(value, float):
print("30")
else:
print(f"{key}: {value}")
except Exception as e:
print(f"评估失败: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()