128 lines
4.2 KiB
Python
128 lines
4.2 KiB
Python
|
|
"""
|
|||
|
|
矢量地图数据加载Pipeline
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pickle
|
|||
|
|
import numpy as np
|
|||
|
|
from mmdet.datasets import PIPELINES
|
|||
|
|
|
|||
|
|
|
|||
|
|
@PIPELINES.register_module()
|
|||
|
|
class LoadVectorMap:
|
|||
|
|
"""
|
|||
|
|
加载矢量地图标注
|
|||
|
|
|
|||
|
|
从预处理的pkl文件中加载矢量地图数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
vector_map_file: 矢量地图数据文件路径
|
|||
|
|
num_vec_classes: 矢量类别数
|
|||
|
|
max_num_vectors: 每个样本最多的矢量数
|
|||
|
|
num_points_per_vec: 每个矢量的点数
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
vector_map_file='data/nuscenes/vector_maps_bevfusion.pkl',
|
|||
|
|
num_vec_classes=3,
|
|||
|
|
max_num_vectors=50,
|
|||
|
|
num_points_per_vec=20,
|
|||
|
|
):
|
|||
|
|
self.num_vec_classes = num_vec_classes
|
|||
|
|
self.max_num_vectors = max_num_vectors
|
|||
|
|
self.num_points_per_vec = num_points_per_vec
|
|||
|
|
|
|||
|
|
# 加载矢量地图数据
|
|||
|
|
print(f"加载矢量地图数据: {vector_map_file}")
|
|||
|
|
with open(vector_map_file, 'rb') as f:
|
|||
|
|
self.vector_maps = pickle.load(f)
|
|||
|
|
|
|||
|
|
print(f"✅ 加载了 {len(self.vector_maps)} 个样本的矢量地图")
|
|||
|
|
|
|||
|
|
# 类别映射
|
|||
|
|
self.class_map = {
|
|||
|
|
'divider': 0,
|
|||
|
|
'boundary': 1,
|
|||
|
|
'ped_crossing': 2,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def __call__(self, results):
|
|||
|
|
"""
|
|||
|
|
加载当前样本的矢量地图
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
results: dict, 包含 'token' 或 'sample_idx'
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
results: 添加 'gt_vectors_labels' 和 'gt_vectors_points'
|
|||
|
|
"""
|
|||
|
|
# 获取样本标识
|
|||
|
|
sample_token = results.get('token', results.get('sample_idx', None))
|
|||
|
|
|
|||
|
|
if sample_token is None:
|
|||
|
|
print("警告: 无法获取sample_token,使用空矢量地图")
|
|||
|
|
vectors_raw = {'divider': [], 'boundary': [], 'ped_crossing': []}
|
|||
|
|
else:
|
|||
|
|
vectors_raw = self.vector_maps.get(sample_token,
|
|||
|
|
{'divider': [], 'boundary': [], 'ped_crossing': []})
|
|||
|
|
|
|||
|
|
# 转换为训练格式
|
|||
|
|
gt_vectors = []
|
|||
|
|
|
|||
|
|
# Divider (class 0)
|
|||
|
|
for vec in vectors_raw.get('divider', []):
|
|||
|
|
points = np.array(vec['points'], dtype=np.float32)
|
|||
|
|
if len(points) == self.num_points_per_vec:
|
|||
|
|
gt_vectors.append({
|
|||
|
|
'class': 0,
|
|||
|
|
'points': points,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# Boundary (class 1)
|
|||
|
|
for vec in vectors_raw.get('boundary', []):
|
|||
|
|
points = np.array(vec['points'], dtype=np.float32)
|
|||
|
|
if len(points) == self.num_points_per_vec:
|
|||
|
|
gt_vectors.append({
|
|||
|
|
'class': 1,
|
|||
|
|
'points': points,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# Ped crossing (class 2)
|
|||
|
|
for vec in vectors_raw.get('ped_crossing', []):
|
|||
|
|
points = np.array(vec['points'], dtype=np.float32)
|
|||
|
|
if len(points) == self.num_points_per_vec:
|
|||
|
|
gt_vectors.append({
|
|||
|
|
'class': 2,
|
|||
|
|
'points': points,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 限制数量(截断)
|
|||
|
|
if len(gt_vectors) > self.max_num_vectors:
|
|||
|
|
# 随机采样或按顺序截断
|
|||
|
|
gt_vectors = gt_vectors[:self.max_num_vectors]
|
|||
|
|
|
|||
|
|
# Padding到固定数量
|
|||
|
|
while len(gt_vectors) < self.max_num_vectors:
|
|||
|
|
gt_vectors.append({
|
|||
|
|
'class': -1, # 无效类别(padding)
|
|||
|
|
'points': np.zeros((self.num_points_per_vec, 2), dtype=np.float32),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 转为numpy数组
|
|||
|
|
gt_labels = np.array([v['class'] for v in gt_vectors], dtype=np.int64)
|
|||
|
|
gt_points = np.array([v['points'] for v in gt_vectors], dtype=np.float32)
|
|||
|
|
|
|||
|
|
# 添加到results
|
|||
|
|
results['gt_vectors_labels'] = gt_labels
|
|||
|
|
results['gt_vectors_points'] = gt_points
|
|||
|
|
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
def __repr__(self):
|
|||
|
|
repr_str = self.__class__.__name__
|
|||
|
|
repr_str += f'(num_vec_classes={self.num_vec_classes}, '
|
|||
|
|
repr_str += f'max_num_vectors={self.max_num_vectors}, '
|
|||
|
|
repr_str += f'num_points_per_vec={self.num_points_per_vec})'
|
|||
|
|
return repr_str
|
|||
|
|
|