bev-project/tools/data_converter/extract_vector_map_bevfusio...

375 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
从nuScenes提取矢量地图标注
适配BEVFusion的数据格式
使用方法:
python tools/data_converter/extract_vector_map_bevfusion.py \
--dataroot data/nuscenes \
--output data/nuscenes/vector_maps_bevfusion.pkl
"""
import pickle
import argparse
import numpy as np
from pathlib import Path
from tqdm import tqdm
from pyquaternion import Quaternion
try:
from nuscenes.nuscenes import NuScenes
from nuscenes.map_expansion.map_api import NuScenesMap
except:
print("请先安装nuscenes-devkit: pip install nuscenes-devkit")
exit(1)
def extract_vector_maps_for_bevfusion(
nusc_root='data/nuscenes',
version='v1.0-trainval',
output_file='data/nuscenes/vector_maps_bevfusion.pkl',
x_range=[-50, 50],
y_range=[-50, 50],
num_points_per_vec=20,
):
"""
提取矢量地图并保存
Args:
nusc_root: nuScenes数据根目录
version: 数据版本
output_file: 输出文件路径
x_range: BEV x轴范围
y_range: BEV y轴范围
num_points_per_vec: 每个矢量重采样的点数
"""
print(f"加载nuScenes数据集: {version}")
nusc = NuScenes(version=version, dataroot=nusc_root, verbose=True)
# 加载BEVFusion的info文件以获取sample列表
info_file = f'{nusc_root}/nuscenes_infos_{"train" if "train" in version else "val"}.pkl'
print(f"加载BEVFusion info文件: {info_file}")
with open(info_file, 'rb') as f:
infos = pickle.load(f)
print(f"共有 {len(infos['infos'] if isinstance(infos, dict) else infos)} 个样本")
vector_maps = {}
map_apis = {} # 缓存map API
# 遍历所有样本
samples_list = infos['infos'] if isinstance(infos, dict) else infos
for info in tqdm(samples_list, desc="提取矢量地图"):
sample_token = info['token']
try:
sample = nusc.get('sample', sample_token)
# 获取lidar数据和ego pose
lidar_token = sample['data']['LIDAR_TOP']
lidar_data = nusc.get('sample_data', lidar_token)
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
# 获取地图API通过scene获取log
scene = nusc.get('scene', sample['scene_token'])
log = nusc.get('log', scene['log_token'])
map_name = log['location']
if map_name not in map_apis:
map_apis[map_name] = NuScenesMap(dataroot=nusc_root, map_name=map_name)
nusc_map = map_apis[map_name]
# 提取该位置的矢量
vectors = extract_vectors_in_range(
nusc_map,
ego_pose,
x_range=x_range,
y_range=y_range,
num_points=num_points_per_vec
)
vector_maps[sample_token] = vectors
except Exception as e:
print(f"警告: 样本 {sample_token} 提取失败: {e}")
vector_maps[sample_token] = {'divider': [], 'boundary': [], 'ped_crossing': []}
# 保存
print(f"\n保存到: {output_file}")
with open(output_file, 'wb') as f:
pickle.dump(vector_maps, f)
# 统计
print_statistics(vector_maps)
print("✅ 矢量地图提取完成!")
def extract_vectors_in_range(nusc_map, ego_pose, x_range, y_range, num_points=20):
"""
在ego附近的范围内提取矢量元素
Returns:
dict: {
'divider': [{'points': [[x,y], ...]}, ...],
'boundary': [...],
'ped_crossing': [...]
}
"""
from shapely.geometry import box as ShapelyBox, LineString, Polygon
# 定义查询区域(全局坐标)
ego_x, ego_y = ego_pose['translation'][:2]
patch_box = (
ego_x + x_range[0],
ego_y + y_range[0],
ego_x + x_range[1],
ego_y + y_range[1],
)
vectors = {
'divider': [],
'boundary': [],
'ped_crossing': [],
}
# 1. 提取车道分隔线和道路分隔线
for layer_name in ['lane_divider', 'road_divider']:
try:
records = nusc_map.get_records_in_patch(
patch_box,
layer_names=[layer_name],
mode='intersect'
)
for record_token in records:
record = nusc_map.get(layer_name, record_token)
line = nusc_map.extract_line(record['line_token'])
if line is None or line.is_empty:
continue
# 转换到ego坐标系
points_global = np.array(line.coords)
points_ego = transform_to_ego(points_global, ego_pose)
# 过滤范围外的点
in_range = (
(points_ego[:, 0] >= x_range[0]) & (points_ego[:, 0] <= x_range[1]) &
(points_ego[:, 1] >= y_range[0]) & (points_ego[:, 1] <= y_range[1])
)
if in_range.sum() < 2:
continue
points_filtered = points_ego[in_range]
# 重采样到固定点数
points_resampled = resample_polyline(points_filtered, num_points)
# 归一化到[0, 1]
points_norm = normalize_points(points_resampled, x_range, y_range)
vectors['divider'].append({
'points': points_norm.tolist(),
'type': layer_name,
})
except Exception as e:
pass # 某些地图层可能不存在
# 2. 提取道路边界
try:
for layer_name in ['road_segment']:
records = nusc_map.get_records_in_patch(
patch_box,
layer_names=[layer_name],
mode='intersect'
)
for record_token in records:
record = nusc_map.get(layer_name, record_token)
polygon = nusc_map.extract_polygon(record['polygon_token'])
if polygon is None or polygon.is_empty:
continue
# 提取外边界
boundary_points = np.array(polygon.exterior.coords)
boundary_ego = transform_to_ego(boundary_points, ego_pose)
# 简化和重采样
simplified = simplify_polyline(boundary_ego, tolerance=1.0)
if len(simplified) >= 2:
resampled = resample_polyline(simplified, num_points)
normalized = normalize_points(resampled, x_range, y_range)
vectors['boundary'].append({
'points': normalized.tolist()
})
except Exception as e:
pass
# 3. 提取人行横道
try:
for layer_name in ['ped_crossing']:
records = nusc_map.get_records_in_patch(
patch_box,
layer_names=[layer_name],
mode='intersect'
)
for record_token in records:
record = nusc_map.get(layer_name, record_token)
polygon = nusc_map.extract_polygon(record['polygon_token'])
if polygon is None or polygon.is_empty:
continue
# 人行横道用多边形表示
crossing_points = np.array(polygon.exterior.coords)
crossing_ego = transform_to_ego(crossing_points, ego_pose)
# 重采样
resampled = resample_polyline(crossing_ego, num_points)
normalized = normalize_points(resampled, x_range, y_range)
vectors['ped_crossing'].append({
'points': normalized.tolist()
})
except Exception as e:
pass
return vectors
def transform_to_ego(points_global, ego_pose):
"""将全局坐标转换到ego车辆坐标系"""
# 平移
translation = np.array(ego_pose['translation'][:2])
points_centered = points_global - translation
# 旋转(逆旋转)
quat = Quaternion(ego_pose['rotation'])
rot_matrix = quat.rotation_matrix[:2, :2]
points_ego = points_centered @ rot_matrix.T
return points_ego
def normalize_points(points, x_range, y_range):
"""将真实坐标归一化到[0, 1]"""
points_norm = points.copy()
points_norm[:, 0] = (points[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
points_norm[:, 1] = (points[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
return points_norm
def resample_polyline(points, num_points):
"""
将多段线重采样到固定点数
使用线性插值在原始点之间均匀采样
"""
if len(points) < 2:
# 点太少padding
padded = np.zeros((num_points, 2))
padded[:len(points)] = points
return padded
try:
from scipy.interpolate import interp1d
# 计算累积距离
dists = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
cum_dists = np.concatenate([[0], np.cumsum(dists)])
if cum_dists[-1] < 1e-6:
# 所有点重合
return np.tile(points[0], (num_points, 1))
# 均匀采样距离
sample_dists = np.linspace(0, cum_dists[-1], num_points)
# 插值
interp_x = interp1d(cum_dists, points[:, 0], kind='linear')
interp_y = interp1d(cum_dists, points[:, 1], kind='linear')
resampled = np.stack([
interp_x(sample_dists),
interp_y(sample_dists)
], axis=-1)
return resampled
except:
# 插值失败返回原始点padding或截断
if len(points) >= num_points:
return points[:num_points]
else:
padded = np.zeros((num_points, 2))
padded[:len(points)] = points
return padded
def simplify_polyline(points, tolerance=1.0):
"""简化多段线(减少点数)"""
try:
from shapely.geometry import LineString
if len(points) < 2:
return points
line = LineString(points)
simplified = line.simplify(tolerance, preserve_topology=False)
return np.array(simplified.coords)
except:
return points
def print_statistics(vector_maps):
"""打印统计信息"""
total_dividers = sum(len(v['divider']) for v in vector_maps.values())
total_boundaries = sum(len(v['boundary']) for v in vector_maps.values())
total_crossings = sum(len(v['ped_crossing']) for v in vector_maps.values())
print("\n========== 矢量地图统计 ==========")
print(f"样本数: {len(vector_maps)}")
print(f"总divider数: {total_dividers} (平均: {total_dividers/len(vector_maps):.1f}/样本)")
print(f"总boundary数: {total_boundaries} (平均: {total_boundaries/len(vector_maps):.1f}/样本)")
print(f"总crossing数: {total_crossings} (平均: {total_crossings/len(vector_maps):.1f}/样本)")
print("================================")
def main():
parser = argparse.ArgumentParser(description='提取nuScenes矢量地图')
parser.add_argument('--dataroot', type=str, default='data/nuscenes',
help='nuScenes数据根目录')
parser.add_argument('--version', type=str, default='v1.0-trainval',
help='数据版本')
parser.add_argument('--output', type=str,
default='data/nuscenes/vector_maps_bevfusion.pkl',
help='输出文件路径')
parser.add_argument('--x-range', type=float, nargs=2, default=[-50, 50],
help='BEV x轴范围')
parser.add_argument('--y-range', type=float, nargs=2, default=[-50, 50],
help='BEV y轴范围')
parser.add_argument('--num-points', type=int, default=20,
help='每个矢量的点数')
args = parser.parse_args()
extract_vector_maps_for_bevfusion(
nusc_root=args.dataroot,
version=args.version,
output_file=args.output,
x_range=args.x_range,
y_range=args.y_range,
num_points_per_vec=args.num_points,
)
if __name__ == '__main__':
main()