375 lines
12 KiB
Python
375 lines
12 KiB
Python
"""
|
||
从nuScenes提取矢量地图标注
|
||
适配BEVFusion的数据格式
|
||
|
||
使用方法:
|
||
python tools/data_converter/extract_vector_map_bevfusion.py \
|
||
--dataroot data/nuscenes \
|
||
--output data/nuscenes/vector_maps_bevfusion.pkl
|
||
"""
|
||
|
||
import pickle
|
||
import argparse
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
from pyquaternion import Quaternion
|
||
|
||
try:
|
||
from nuscenes.nuscenes import NuScenes
|
||
from nuscenes.map_expansion.map_api import NuScenesMap
|
||
except:
|
||
print("请先安装nuscenes-devkit: pip install nuscenes-devkit")
|
||
exit(1)
|
||
|
||
|
||
def extract_vector_maps_for_bevfusion(
|
||
nusc_root='data/nuscenes',
|
||
version='v1.0-trainval',
|
||
output_file='data/nuscenes/vector_maps_bevfusion.pkl',
|
||
x_range=[-50, 50],
|
||
y_range=[-50, 50],
|
||
num_points_per_vec=20,
|
||
):
|
||
"""
|
||
提取矢量地图并保存
|
||
|
||
Args:
|
||
nusc_root: nuScenes数据根目录
|
||
version: 数据版本
|
||
output_file: 输出文件路径
|
||
x_range: BEV x轴范围(米)
|
||
y_range: BEV y轴范围(米)
|
||
num_points_per_vec: 每个矢量重采样的点数
|
||
"""
|
||
|
||
print(f"加载nuScenes数据集: {version}")
|
||
nusc = NuScenes(version=version, dataroot=nusc_root, verbose=True)
|
||
|
||
# 加载BEVFusion的info文件以获取sample列表
|
||
info_file = f'{nusc_root}/nuscenes_infos_{"train" if "train" in version else "val"}.pkl'
|
||
print(f"加载BEVFusion info文件: {info_file}")
|
||
|
||
with open(info_file, 'rb') as f:
|
||
infos = pickle.load(f)
|
||
|
||
print(f"共有 {len(infos['infos'] if isinstance(infos, dict) else infos)} 个样本")
|
||
|
||
vector_maps = {}
|
||
map_apis = {} # 缓存map API
|
||
|
||
# 遍历所有样本
|
||
samples_list = infos['infos'] if isinstance(infos, dict) else infos
|
||
|
||
for info in tqdm(samples_list, desc="提取矢量地图"):
|
||
sample_token = info['token']
|
||
|
||
try:
|
||
sample = nusc.get('sample', sample_token)
|
||
|
||
# 获取lidar数据和ego pose
|
||
lidar_token = sample['data']['LIDAR_TOP']
|
||
lidar_data = nusc.get('sample_data', lidar_token)
|
||
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
|
||
|
||
# 获取地图API(通过scene获取log)
|
||
scene = nusc.get('scene', sample['scene_token'])
|
||
log = nusc.get('log', scene['log_token'])
|
||
map_name = log['location']
|
||
|
||
if map_name not in map_apis:
|
||
map_apis[map_name] = NuScenesMap(dataroot=nusc_root, map_name=map_name)
|
||
nusc_map = map_apis[map_name]
|
||
|
||
# 提取该位置的矢量
|
||
vectors = extract_vectors_in_range(
|
||
nusc_map,
|
||
ego_pose,
|
||
x_range=x_range,
|
||
y_range=y_range,
|
||
num_points=num_points_per_vec
|
||
)
|
||
|
||
vector_maps[sample_token] = vectors
|
||
|
||
except Exception as e:
|
||
print(f"警告: 样本 {sample_token} 提取失败: {e}")
|
||
vector_maps[sample_token] = {'divider': [], 'boundary': [], 'ped_crossing': []}
|
||
|
||
# 保存
|
||
print(f"\n保存到: {output_file}")
|
||
with open(output_file, 'wb') as f:
|
||
pickle.dump(vector_maps, f)
|
||
|
||
# 统计
|
||
print_statistics(vector_maps)
|
||
|
||
print("✅ 矢量地图提取完成!")
|
||
|
||
|
||
def extract_vectors_in_range(nusc_map, ego_pose, x_range, y_range, num_points=20):
|
||
"""
|
||
在ego附近的范围内提取矢量元素
|
||
|
||
Returns:
|
||
dict: {
|
||
'divider': [{'points': [[x,y], ...]}, ...],
|
||
'boundary': [...],
|
||
'ped_crossing': [...]
|
||
}
|
||
"""
|
||
from shapely.geometry import box as ShapelyBox, LineString, Polygon
|
||
|
||
# 定义查询区域(全局坐标)
|
||
ego_x, ego_y = ego_pose['translation'][:2]
|
||
patch_box = (
|
||
ego_x + x_range[0],
|
||
ego_y + y_range[0],
|
||
ego_x + x_range[1],
|
||
ego_y + y_range[1],
|
||
)
|
||
|
||
vectors = {
|
||
'divider': [],
|
||
'boundary': [],
|
||
'ped_crossing': [],
|
||
}
|
||
|
||
# 1. 提取车道分隔线和道路分隔线
|
||
for layer_name in ['lane_divider', 'road_divider']:
|
||
try:
|
||
records = nusc_map.get_records_in_patch(
|
||
patch_box,
|
||
layer_names=[layer_name],
|
||
mode='intersect'
|
||
)
|
||
|
||
for record_token in records:
|
||
record = nusc_map.get(layer_name, record_token)
|
||
line = nusc_map.extract_line(record['line_token'])
|
||
|
||
if line is None or line.is_empty:
|
||
continue
|
||
|
||
# 转换到ego坐标系
|
||
points_global = np.array(line.coords)
|
||
points_ego = transform_to_ego(points_global, ego_pose)
|
||
|
||
# 过滤范围外的点
|
||
in_range = (
|
||
(points_ego[:, 0] >= x_range[0]) & (points_ego[:, 0] <= x_range[1]) &
|
||
(points_ego[:, 1] >= y_range[0]) & (points_ego[:, 1] <= y_range[1])
|
||
)
|
||
|
||
if in_range.sum() < 2:
|
||
continue
|
||
|
||
points_filtered = points_ego[in_range]
|
||
|
||
# 重采样到固定点数
|
||
points_resampled = resample_polyline(points_filtered, num_points)
|
||
|
||
# 归一化到[0, 1]
|
||
points_norm = normalize_points(points_resampled, x_range, y_range)
|
||
|
||
vectors['divider'].append({
|
||
'points': points_norm.tolist(),
|
||
'type': layer_name,
|
||
})
|
||
except Exception as e:
|
||
pass # 某些地图层可能不存在
|
||
|
||
# 2. 提取道路边界
|
||
try:
|
||
for layer_name in ['road_segment']:
|
||
records = nusc_map.get_records_in_patch(
|
||
patch_box,
|
||
layer_names=[layer_name],
|
||
mode='intersect'
|
||
)
|
||
|
||
for record_token in records:
|
||
record = nusc_map.get(layer_name, record_token)
|
||
polygon = nusc_map.extract_polygon(record['polygon_token'])
|
||
|
||
if polygon is None or polygon.is_empty:
|
||
continue
|
||
|
||
# 提取外边界
|
||
boundary_points = np.array(polygon.exterior.coords)
|
||
boundary_ego = transform_to_ego(boundary_points, ego_pose)
|
||
|
||
# 简化和重采样
|
||
simplified = simplify_polyline(boundary_ego, tolerance=1.0)
|
||
if len(simplified) >= 2:
|
||
resampled = resample_polyline(simplified, num_points)
|
||
normalized = normalize_points(resampled, x_range, y_range)
|
||
|
||
vectors['boundary'].append({
|
||
'points': normalized.tolist()
|
||
})
|
||
except Exception as e:
|
||
pass
|
||
|
||
# 3. 提取人行横道
|
||
try:
|
||
for layer_name in ['ped_crossing']:
|
||
records = nusc_map.get_records_in_patch(
|
||
patch_box,
|
||
layer_names=[layer_name],
|
||
mode='intersect'
|
||
)
|
||
|
||
for record_token in records:
|
||
record = nusc_map.get(layer_name, record_token)
|
||
polygon = nusc_map.extract_polygon(record['polygon_token'])
|
||
|
||
if polygon is None or polygon.is_empty:
|
||
continue
|
||
|
||
# 人行横道用多边形表示
|
||
crossing_points = np.array(polygon.exterior.coords)
|
||
crossing_ego = transform_to_ego(crossing_points, ego_pose)
|
||
|
||
# 重采样
|
||
resampled = resample_polyline(crossing_ego, num_points)
|
||
normalized = normalize_points(resampled, x_range, y_range)
|
||
|
||
vectors['ped_crossing'].append({
|
||
'points': normalized.tolist()
|
||
})
|
||
except Exception as e:
|
||
pass
|
||
|
||
return vectors
|
||
|
||
|
||
def transform_to_ego(points_global, ego_pose):
|
||
"""将全局坐标转换到ego车辆坐标系"""
|
||
# 平移
|
||
translation = np.array(ego_pose['translation'][:2])
|
||
points_centered = points_global - translation
|
||
|
||
# 旋转(逆旋转)
|
||
quat = Quaternion(ego_pose['rotation'])
|
||
rot_matrix = quat.rotation_matrix[:2, :2]
|
||
points_ego = points_centered @ rot_matrix.T
|
||
|
||
return points_ego
|
||
|
||
|
||
def normalize_points(points, x_range, y_range):
|
||
"""将真实坐标归一化到[0, 1]"""
|
||
points_norm = points.copy()
|
||
points_norm[:, 0] = (points[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
|
||
points_norm[:, 1] = (points[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
|
||
return points_norm
|
||
|
||
|
||
def resample_polyline(points, num_points):
|
||
"""
|
||
将多段线重采样到固定点数
|
||
|
||
使用线性插值在原始点之间均匀采样
|
||
"""
|
||
if len(points) < 2:
|
||
# 点太少,padding
|
||
padded = np.zeros((num_points, 2))
|
||
padded[:len(points)] = points
|
||
return padded
|
||
|
||
try:
|
||
from scipy.interpolate import interp1d
|
||
|
||
# 计算累积距离
|
||
dists = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
|
||
cum_dists = np.concatenate([[0], np.cumsum(dists)])
|
||
|
||
if cum_dists[-1] < 1e-6:
|
||
# 所有点重合
|
||
return np.tile(points[0], (num_points, 1))
|
||
|
||
# 均匀采样距离
|
||
sample_dists = np.linspace(0, cum_dists[-1], num_points)
|
||
|
||
# 插值
|
||
interp_x = interp1d(cum_dists, points[:, 0], kind='linear')
|
||
interp_y = interp1d(cum_dists, points[:, 1], kind='linear')
|
||
|
||
resampled = np.stack([
|
||
interp_x(sample_dists),
|
||
interp_y(sample_dists)
|
||
], axis=-1)
|
||
|
||
return resampled
|
||
except:
|
||
# 插值失败,返回原始点(padding或截断)
|
||
if len(points) >= num_points:
|
||
return points[:num_points]
|
||
else:
|
||
padded = np.zeros((num_points, 2))
|
||
padded[:len(points)] = points
|
||
return padded
|
||
|
||
|
||
def simplify_polyline(points, tolerance=1.0):
|
||
"""简化多段线(减少点数)"""
|
||
try:
|
||
from shapely.geometry import LineString
|
||
|
||
if len(points) < 2:
|
||
return points
|
||
|
||
line = LineString(points)
|
||
simplified = line.simplify(tolerance, preserve_topology=False)
|
||
return np.array(simplified.coords)
|
||
except:
|
||
return points
|
||
|
||
|
||
def print_statistics(vector_maps):
|
||
"""打印统计信息"""
|
||
total_dividers = sum(len(v['divider']) for v in vector_maps.values())
|
||
total_boundaries = sum(len(v['boundary']) for v in vector_maps.values())
|
||
total_crossings = sum(len(v['ped_crossing']) for v in vector_maps.values())
|
||
|
||
print("\n========== 矢量地图统计 ==========")
|
||
print(f"样本数: {len(vector_maps)}")
|
||
print(f"总divider数: {total_dividers} (平均: {total_dividers/len(vector_maps):.1f}/样本)")
|
||
print(f"总boundary数: {total_boundaries} (平均: {total_boundaries/len(vector_maps):.1f}/样本)")
|
||
print(f"总crossing数: {total_crossings} (平均: {total_crossings/len(vector_maps):.1f}/样本)")
|
||
print("================================")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='提取nuScenes矢量地图')
|
||
parser.add_argument('--dataroot', type=str, default='data/nuscenes',
|
||
help='nuScenes数据根目录')
|
||
parser.add_argument('--version', type=str, default='v1.0-trainval',
|
||
help='数据版本')
|
||
parser.add_argument('--output', type=str,
|
||
default='data/nuscenes/vector_maps_bevfusion.pkl',
|
||
help='输出文件路径')
|
||
parser.add_argument('--x-range', type=float, nargs=2, default=[-50, 50],
|
||
help='BEV x轴范围')
|
||
parser.add_argument('--y-range', type=float, nargs=2, default=[-50, 50],
|
||
help='BEV y轴范围')
|
||
parser.add_argument('--num-points', type=int, default=20,
|
||
help='每个矢量的点数')
|
||
|
||
args = parser.parse_args()
|
||
|
||
extract_vector_maps_for_bevfusion(
|
||
nusc_root=args.dataroot,
|
||
version=args.version,
|
||
output_file=args.output,
|
||
x_range=args.x_range,
|
||
y_range=args.y_range,
|
||
num_points_per_vec=args.num_points,
|
||
)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|