bev-project/mmdet3d/datasets/pipelines/loading_vector.py

128 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
矢量地图数据加载Pipeline
"""
import pickle
import numpy as np
from mmdet.datasets import PIPELINES
@PIPELINES.register_module()
class LoadVectorMap:
"""
加载矢量地图标注
从预处理的pkl文件中加载矢量地图数据
Args:
vector_map_file: 矢量地图数据文件路径
num_vec_classes: 矢量类别数
max_num_vectors: 每个样本最多的矢量数
num_points_per_vec: 每个矢量的点数
"""
def __init__(
self,
vector_map_file='data/nuscenes/vector_maps_bevfusion.pkl',
num_vec_classes=3,
max_num_vectors=50,
num_points_per_vec=20,
):
self.num_vec_classes = num_vec_classes
self.max_num_vectors = max_num_vectors
self.num_points_per_vec = num_points_per_vec
# 加载矢量地图数据
print(f"加载矢量地图数据: {vector_map_file}")
with open(vector_map_file, 'rb') as f:
self.vector_maps = pickle.load(f)
print(f"✅ 加载了 {len(self.vector_maps)} 个样本的矢量地图")
# 类别映射
self.class_map = {
'divider': 0,
'boundary': 1,
'ped_crossing': 2,
}
def __call__(self, results):
"""
加载当前样本的矢量地图
Args:
results: dict, 包含 'token''sample_idx'
Returns:
results: 添加 'gt_vectors_labels''gt_vectors_points'
"""
# 获取样本标识
sample_token = results.get('token', results.get('sample_idx', None))
if sample_token is None:
print("警告: 无法获取sample_token使用空矢量地图")
vectors_raw = {'divider': [], 'boundary': [], 'ped_crossing': []}
else:
vectors_raw = self.vector_maps.get(sample_token,
{'divider': [], 'boundary': [], 'ped_crossing': []})
# 转换为训练格式
gt_vectors = []
# Divider (class 0)
for vec in vectors_raw.get('divider', []):
points = np.array(vec['points'], dtype=np.float32)
if len(points) == self.num_points_per_vec:
gt_vectors.append({
'class': 0,
'points': points,
})
# Boundary (class 1)
for vec in vectors_raw.get('boundary', []):
points = np.array(vec['points'], dtype=np.float32)
if len(points) == self.num_points_per_vec:
gt_vectors.append({
'class': 1,
'points': points,
})
# Ped crossing (class 2)
for vec in vectors_raw.get('ped_crossing', []):
points = np.array(vec['points'], dtype=np.float32)
if len(points) == self.num_points_per_vec:
gt_vectors.append({
'class': 2,
'points': points,
})
# 限制数量(截断)
if len(gt_vectors) > self.max_num_vectors:
# 随机采样或按顺序截断
gt_vectors = gt_vectors[:self.max_num_vectors]
# Padding到固定数量
while len(gt_vectors) < self.max_num_vectors:
gt_vectors.append({
'class': -1, # 无效类别padding
'points': np.zeros((self.num_points_per_vec, 2), dtype=np.float32),
})
# 转为numpy数组
gt_labels = np.array([v['class'] for v in gt_vectors], dtype=np.int64)
gt_points = np.array([v['points'] for v in gt_vectors], dtype=np.float32)
# 添加到results
results['gt_vectors_labels'] = gt_labels
results['gt_vectors_points'] = gt_points
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(num_vec_classes={self.num_vec_classes}, '
repr_str += f'max_num_vectors={self.max_num_vectors}, '
repr_str += f'num_points_per_vec={self.num_points_per_vec})'
return repr_str