""" 从nuScenes提取矢量地图标注 适配BEVFusion的数据格式 使用方法: python tools/data_converter/extract_vector_map_bevfusion.py \ --dataroot data/nuscenes \ --output data/nuscenes/vector_maps_bevfusion.pkl """ import pickle import argparse import numpy as np from pathlib import Path from tqdm import tqdm from pyquaternion import Quaternion try: from nuscenes.nuscenes import NuScenes from nuscenes.map_expansion.map_api import NuScenesMap except: print("请先安装nuscenes-devkit: pip install nuscenes-devkit") exit(1) def extract_vector_maps_for_bevfusion( nusc_root='data/nuscenes', version='v1.0-trainval', output_file='data/nuscenes/vector_maps_bevfusion.pkl', x_range=[-50, 50], y_range=[-50, 50], num_points_per_vec=20, ): """ 提取矢量地图并保存 Args: nusc_root: nuScenes数据根目录 version: 数据版本 output_file: 输出文件路径 x_range: BEV x轴范围(米) y_range: BEV y轴范围(米) num_points_per_vec: 每个矢量重采样的点数 """ print(f"加载nuScenes数据集: {version}") nusc = NuScenes(version=version, dataroot=nusc_root, verbose=True) # 加载BEVFusion的info文件以获取sample列表 info_file = f'{nusc_root}/nuscenes_infos_{"train" if "train" in version else "val"}.pkl' print(f"加载BEVFusion info文件: {info_file}") with open(info_file, 'rb') as f: infos = pickle.load(f) print(f"共有 {len(infos['infos'] if isinstance(infos, dict) else infos)} 个样本") vector_maps = {} map_apis = {} # 缓存map API # 遍历所有样本 samples_list = infos['infos'] if isinstance(infos, dict) else infos for info in tqdm(samples_list, desc="提取矢量地图"): sample_token = info['token'] try: sample = nusc.get('sample', sample_token) # 获取lidar数据和ego pose lidar_token = sample['data']['LIDAR_TOP'] lidar_data = nusc.get('sample_data', lidar_token) ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token']) # 获取地图API(通过scene获取log) scene = nusc.get('scene', sample['scene_token']) log = nusc.get('log', scene['log_token']) map_name = log['location'] if map_name not in map_apis: map_apis[map_name] = NuScenesMap(dataroot=nusc_root, map_name=map_name) nusc_map = map_apis[map_name] # 提取该位置的矢量 vectors = extract_vectors_in_range( nusc_map, ego_pose, x_range=x_range, y_range=y_range, num_points=num_points_per_vec ) vector_maps[sample_token] = vectors except Exception as e: print(f"警告: 样本 {sample_token} 提取失败: {e}") vector_maps[sample_token] = {'divider': [], 'boundary': [], 'ped_crossing': []} # 保存 print(f"\n保存到: {output_file}") with open(output_file, 'wb') as f: pickle.dump(vector_maps, f) # 统计 print_statistics(vector_maps) print("✅ 矢量地图提取完成!") def extract_vectors_in_range(nusc_map, ego_pose, x_range, y_range, num_points=20): """ 在ego附近的范围内提取矢量元素 Returns: dict: { 'divider': [{'points': [[x,y], ...]}, ...], 'boundary': [...], 'ped_crossing': [...] } """ from shapely.geometry import box as ShapelyBox, LineString, Polygon # 定义查询区域(全局坐标) ego_x, ego_y = ego_pose['translation'][:2] patch_box = ( ego_x + x_range[0], ego_y + y_range[0], ego_x + x_range[1], ego_y + y_range[1], ) vectors = { 'divider': [], 'boundary': [], 'ped_crossing': [], } # 1. 提取车道分隔线和道路分隔线 for layer_name in ['lane_divider', 'road_divider']: try: records = nusc_map.get_records_in_patch( patch_box, layer_names=[layer_name], mode='intersect' ) for record_token in records: record = nusc_map.get(layer_name, record_token) line = nusc_map.extract_line(record['line_token']) if line is None or line.is_empty: continue # 转换到ego坐标系 points_global = np.array(line.coords) points_ego = transform_to_ego(points_global, ego_pose) # 过滤范围外的点 in_range = ( (points_ego[:, 0] >= x_range[0]) & (points_ego[:, 0] <= x_range[1]) & (points_ego[:, 1] >= y_range[0]) & (points_ego[:, 1] <= y_range[1]) ) if in_range.sum() < 2: continue points_filtered = points_ego[in_range] # 重采样到固定点数 points_resampled = resample_polyline(points_filtered, num_points) # 归一化到[0, 1] points_norm = normalize_points(points_resampled, x_range, y_range) vectors['divider'].append({ 'points': points_norm.tolist(), 'type': layer_name, }) except Exception as e: pass # 某些地图层可能不存在 # 2. 提取道路边界 try: for layer_name in ['road_segment']: records = nusc_map.get_records_in_patch( patch_box, layer_names=[layer_name], mode='intersect' ) for record_token in records: record = nusc_map.get(layer_name, record_token) polygon = nusc_map.extract_polygon(record['polygon_token']) if polygon is None or polygon.is_empty: continue # 提取外边界 boundary_points = np.array(polygon.exterior.coords) boundary_ego = transform_to_ego(boundary_points, ego_pose) # 简化和重采样 simplified = simplify_polyline(boundary_ego, tolerance=1.0) if len(simplified) >= 2: resampled = resample_polyline(simplified, num_points) normalized = normalize_points(resampled, x_range, y_range) vectors['boundary'].append({ 'points': normalized.tolist() }) except Exception as e: pass # 3. 提取人行横道 try: for layer_name in ['ped_crossing']: records = nusc_map.get_records_in_patch( patch_box, layer_names=[layer_name], mode='intersect' ) for record_token in records: record = nusc_map.get(layer_name, record_token) polygon = nusc_map.extract_polygon(record['polygon_token']) if polygon is None or polygon.is_empty: continue # 人行横道用多边形表示 crossing_points = np.array(polygon.exterior.coords) crossing_ego = transform_to_ego(crossing_points, ego_pose) # 重采样 resampled = resample_polyline(crossing_ego, num_points) normalized = normalize_points(resampled, x_range, y_range) vectors['ped_crossing'].append({ 'points': normalized.tolist() }) except Exception as e: pass return vectors def transform_to_ego(points_global, ego_pose): """将全局坐标转换到ego车辆坐标系""" # 平移 translation = np.array(ego_pose['translation'][:2]) points_centered = points_global - translation # 旋转(逆旋转) quat = Quaternion(ego_pose['rotation']) rot_matrix = quat.rotation_matrix[:2, :2] points_ego = points_centered @ rot_matrix.T return points_ego def normalize_points(points, x_range, y_range): """将真实坐标归一化到[0, 1]""" points_norm = points.copy() points_norm[:, 0] = (points[:, 0] - x_range[0]) / (x_range[1] - x_range[0]) points_norm[:, 1] = (points[:, 1] - y_range[0]) / (y_range[1] - y_range[0]) return points_norm def resample_polyline(points, num_points): """ 将多段线重采样到固定点数 使用线性插值在原始点之间均匀采样 """ if len(points) < 2: # 点太少,padding padded = np.zeros((num_points, 2)) padded[:len(points)] = points return padded try: from scipy.interpolate import interp1d # 计算累积距离 dists = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1)) cum_dists = np.concatenate([[0], np.cumsum(dists)]) if cum_dists[-1] < 1e-6: # 所有点重合 return np.tile(points[0], (num_points, 1)) # 均匀采样距离 sample_dists = np.linspace(0, cum_dists[-1], num_points) # 插值 interp_x = interp1d(cum_dists, points[:, 0], kind='linear') interp_y = interp1d(cum_dists, points[:, 1], kind='linear') resampled = np.stack([ interp_x(sample_dists), interp_y(sample_dists) ], axis=-1) return resampled except: # 插值失败,返回原始点(padding或截断) if len(points) >= num_points: return points[:num_points] else: padded = np.zeros((num_points, 2)) padded[:len(points)] = points return padded def simplify_polyline(points, tolerance=1.0): """简化多段线(减少点数)""" try: from shapely.geometry import LineString if len(points) < 2: return points line = LineString(points) simplified = line.simplify(tolerance, preserve_topology=False) return np.array(simplified.coords) except: return points def print_statistics(vector_maps): """打印统计信息""" total_dividers = sum(len(v['divider']) for v in vector_maps.values()) total_boundaries = sum(len(v['boundary']) for v in vector_maps.values()) total_crossings = sum(len(v['ped_crossing']) for v in vector_maps.values()) print("\n========== 矢量地图统计 ==========") print(f"样本数: {len(vector_maps)}") print(f"总divider数: {total_dividers} (平均: {total_dividers/len(vector_maps):.1f}/样本)") print(f"总boundary数: {total_boundaries} (平均: {total_boundaries/len(vector_maps):.1f}/样本)") print(f"总crossing数: {total_crossings} (平均: {total_crossings/len(vector_maps):.1f}/样本)") print("================================") def main(): parser = argparse.ArgumentParser(description='提取nuScenes矢量地图') parser.add_argument('--dataroot', type=str, default='data/nuscenes', help='nuScenes数据根目录') parser.add_argument('--version', type=str, default='v1.0-trainval', help='数据版本') parser.add_argument('--output', type=str, default='data/nuscenes/vector_maps_bevfusion.pkl', help='输出文件路径') parser.add_argument('--x-range', type=float, nargs=2, default=[-50, 50], help='BEV x轴范围') parser.add_argument('--y-range', type=float, nargs=2, default=[-50, 50], help='BEV y轴范围') parser.add_argument('--num-points', type=int, default=20, help='每个矢量的点数') args = parser.parse_args() extract_vector_maps_for_bevfusion( nusc_root=args.dataroot, version=args.version, output_file=args.output, x_range=args.x_range, y_range=args.y_range, num_points_per_vec=args.num_points, ) if __name__ == '__main__': main()