bev-project/mmdet3d/datasets/pipelines/loading_vector.py

128 lines
4.2 KiB
Python
Raw Normal View History

"""
矢量地图数据加载Pipeline
"""
import pickle
import numpy as np
from mmdet.datasets import PIPELINES
@PIPELINES.register_module()
class LoadVectorMap:
"""
加载矢量地图标注
从预处理的pkl文件中加载矢量地图数据
Args:
vector_map_file: 矢量地图数据文件路径
num_vec_classes: 矢量类别数
max_num_vectors: 每个样本最多的矢量数
num_points_per_vec: 每个矢量的点数
"""
def __init__(
self,
vector_map_file='data/nuscenes/vector_maps_bevfusion.pkl',
num_vec_classes=3,
max_num_vectors=50,
num_points_per_vec=20,
):
self.num_vec_classes = num_vec_classes
self.max_num_vectors = max_num_vectors
self.num_points_per_vec = num_points_per_vec
# 加载矢量地图数据
print(f"加载矢量地图数据: {vector_map_file}")
with open(vector_map_file, 'rb') as f:
self.vector_maps = pickle.load(f)
print(f"✅ 加载了 {len(self.vector_maps)} 个样本的矢量地图")
# 类别映射
self.class_map = {
'divider': 0,
'boundary': 1,
'ped_crossing': 2,
}
def __call__(self, results):
"""
加载当前样本的矢量地图
Args:
results: dict, 包含 'token' 'sample_idx'
Returns:
results: 添加 'gt_vectors_labels' 'gt_vectors_points'
"""
# 获取样本标识
sample_token = results.get('token', results.get('sample_idx', None))
if sample_token is None:
print("警告: 无法获取sample_token使用空矢量地图")
vectors_raw = {'divider': [], 'boundary': [], 'ped_crossing': []}
else:
vectors_raw = self.vector_maps.get(sample_token,
{'divider': [], 'boundary': [], 'ped_crossing': []})
# 转换为训练格式
gt_vectors = []
# Divider (class 0)
for vec in vectors_raw.get('divider', []):
points = np.array(vec['points'], dtype=np.float32)
if len(points) == self.num_points_per_vec:
gt_vectors.append({
'class': 0,
'points': points,
})
# Boundary (class 1)
for vec in vectors_raw.get('boundary', []):
points = np.array(vec['points'], dtype=np.float32)
if len(points) == self.num_points_per_vec:
gt_vectors.append({
'class': 1,
'points': points,
})
# Ped crossing (class 2)
for vec in vectors_raw.get('ped_crossing', []):
points = np.array(vec['points'], dtype=np.float32)
if len(points) == self.num_points_per_vec:
gt_vectors.append({
'class': 2,
'points': points,
})
# 限制数量(截断)
if len(gt_vectors) > self.max_num_vectors:
# 随机采样或按顺序截断
gt_vectors = gt_vectors[:self.max_num_vectors]
# Padding到固定数量
while len(gt_vectors) < self.max_num_vectors:
gt_vectors.append({
'class': -1, # 无效类别padding
'points': np.zeros((self.num_points_per_vec, 2), dtype=np.float32),
})
# 转为numpy数组
gt_labels = np.array([v['class'] for v in gt_vectors], dtype=np.int64)
gt_points = np.array([v['points'] for v in gt_vectors], dtype=np.float32)
# 添加到results
results['gt_vectors_labels'] = gt_labels
results['gt_vectors_points'] = gt_points
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(num_vec_classes={self.num_vec_classes}, '
repr_str += f'max_num_vectors={self.max_num_vectors}, '
repr_str += f'num_points_per_vec={self.num_points_per_vec})'
return repr_str