""" 矢量地图数据加载Pipeline """ import pickle import numpy as np from mmdet.datasets import PIPELINES @PIPELINES.register_module() class LoadVectorMap: """ 加载矢量地图标注 从预处理的pkl文件中加载矢量地图数据 Args: vector_map_file: 矢量地图数据文件路径 num_vec_classes: 矢量类别数 max_num_vectors: 每个样本最多的矢量数 num_points_per_vec: 每个矢量的点数 """ def __init__( self, vector_map_file='data/nuscenes/vector_maps_bevfusion.pkl', num_vec_classes=3, max_num_vectors=50, num_points_per_vec=20, ): self.num_vec_classes = num_vec_classes self.max_num_vectors = max_num_vectors self.num_points_per_vec = num_points_per_vec # 加载矢量地图数据 print(f"加载矢量地图数据: {vector_map_file}") with open(vector_map_file, 'rb') as f: self.vector_maps = pickle.load(f) print(f"✅ 加载了 {len(self.vector_maps)} 个样本的矢量地图") # 类别映射 self.class_map = { 'divider': 0, 'boundary': 1, 'ped_crossing': 2, } def __call__(self, results): """ 加载当前样本的矢量地图 Args: results: dict, 包含 'token' 或 'sample_idx' Returns: results: 添加 'gt_vectors_labels' 和 'gt_vectors_points' """ # 获取样本标识 sample_token = results.get('token', results.get('sample_idx', None)) if sample_token is None: print("警告: 无法获取sample_token,使用空矢量地图") vectors_raw = {'divider': [], 'boundary': [], 'ped_crossing': []} else: vectors_raw = self.vector_maps.get(sample_token, {'divider': [], 'boundary': [], 'ped_crossing': []}) # 转换为训练格式 gt_vectors = [] # Divider (class 0) for vec in vectors_raw.get('divider', []): points = np.array(vec['points'], dtype=np.float32) if len(points) == self.num_points_per_vec: gt_vectors.append({ 'class': 0, 'points': points, }) # Boundary (class 1) for vec in vectors_raw.get('boundary', []): points = np.array(vec['points'], dtype=np.float32) if len(points) == self.num_points_per_vec: gt_vectors.append({ 'class': 1, 'points': points, }) # Ped crossing (class 2) for vec in vectors_raw.get('ped_crossing', []): points = np.array(vec['points'], dtype=np.float32) if len(points) == self.num_points_per_vec: gt_vectors.append({ 'class': 2, 'points': points, }) # 限制数量(截断) if len(gt_vectors) > self.max_num_vectors: # 随机采样或按顺序截断 gt_vectors = gt_vectors[:self.max_num_vectors] # Padding到固定数量 while len(gt_vectors) < self.max_num_vectors: gt_vectors.append({ 'class': -1, # 无效类别(padding) 'points': np.zeros((self.num_points_per_vec, 2), dtype=np.float32), }) # 转为numpy数组 gt_labels = np.array([v['class'] for v in gt_vectors], dtype=np.int64) gt_points = np.array([v['points'] for v in gt_vectors], dtype=np.float32) # 添加到results results['gt_vectors_labels'] = gt_labels results['gt_vectors_points'] = gt_points return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(num_vec_classes={self.num_vec_classes}, ' repr_str += f'max_num_vectors={self.max_num_vectors}, ' repr_str += f'num_points_per_vec={self.num_points_per_vec})' return repr_str