bev-project/tools/convert_to_carla_opendrive.py

387 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
将nuScenes/MapTR矢量地图转换为CARLA OpenDRIVE格式
使用方法:
python tools/convert_to_carla_opendrive.py \
--vector-maps data/nuscenes/vector_maps_bevfusion.pkl \
--nuscenes-root data/nuscenes \
--output carla_maps/nuscenes_map.xodr
"""
import pickle
import argparse
import numpy as np
from lxml import etree
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from pyquaternion import Quaternion
class NuScenesToOpenDRIVE:
"""nuScenes矢量地图转CARLA OpenDRIVE格式"""
def __init__(self, vector_maps_path, nuscenes_root, bev_range=50):
# 加载矢量地图
print(f"加载矢量地图: {vector_maps_path}")
with open(vector_maps_path, 'rb') as f:
self.vector_maps = pickle.load(f)
print(f"加载样本数: {len(self.vector_maps)}")
# 加载nuScenes数据以获取全局位姿
from nuscenes.nuscenes import NuScenes
self.nusc = NuScenes(version='v1.0-trainval', dataroot=nuscenes_root, verbose=False)
self.bev_range = bev_range
# 加载训练info以获取token映射
info_path = f"{nuscenes_root}/nuscenes_infos_train.pkl"
with open(info_path, 'rb') as f:
infos = pickle.load(f)
self.infos = infos['infos'] if isinstance(infos, dict) else infos
def convert(self, output_path):
"""
主转换函数
流程:
1. 将所有局部矢量转为全局坐标
2. 聚合重叠的矢量(去重)
3. 构建OpenDRIVE XML
"""
print("\n开始转换为OpenDRIVE格式...")
# 1. 转为全局坐标
print("步骤1: 转换为全局坐标...")
global_vectors = self.convert_to_global_coords()
# 2. 聚合矢量
print("步骤2: 聚合和去重...")
merged_vectors = self.merge_vectors(global_vectors)
# 3. 构建road网络
print("步骤3: 构建道路网络...")
roads = self.build_road_network(merged_vectors)
# 4. 生成OpenDRIVE XML
print("步骤4: 生成OpenDRIVE XML...")
self.generate_opendrive_xml(roads, output_path)
print(f"\n✅ OpenDRIVE文件已保存: {output_path}")
# 统计
total_roads = len(roads)
total_lanes = sum(len(r['lanes']) for r in roads)
print(f"\n统计:")
print(f" Roads: {total_roads}")
print(f" Lanes: {total_lanes}")
def convert_to_global_coords(self):
"""将所有局部矢量转换为全局坐标"""
global_vectors = {
'divider': [],
'boundary': [],
'ped_crossing': []
}
for info in tqdm(self.infos, desc="转换坐标"):
token = info['token']
if token not in self.vector_maps:
continue
vectors = self.vector_maps[token]
# 获取全局位姿
try:
sample = self.nusc.get('sample', token)
lidar_token = sample['data']['LIDAR_TOP']
lidar_data = self.nusc.get('sample_data', lidar_token)
ego_pose = self.nusc.get('ego_pose', lidar_data['ego_pose_token'])
except:
continue
# 转换每个类别的矢量
for category in ['divider', 'boundary', 'ped_crossing']:
for vec in vectors.get(category, []):
points_norm = np.array(vec['points']) # (20, 2) 归一化坐标
# 反归一化到ego坐标
points_ego = self.denormalize(points_norm)
# 转全局坐标
points_global = self.ego_to_global(points_ego, ego_pose)
global_vectors[category].append({
'points': points_global.tolist(),
'token': token
})
return global_vectors
def denormalize(self, points_norm):
"""反归一化: [0,1] → [-50,50]米"""
points = points_norm * (2 * self.bev_range) - self.bev_range
return points
def ego_to_global(self, points_ego, ego_pose):
"""ego坐标系 → 全局坐标系"""
# 旋转
quat = Quaternion(ego_pose['rotation'])
rot_matrix = quat.rotation_matrix[:2, :2]
points_rotated = points_ego @ rot_matrix.T
# 平移
translation = np.array(ego_pose['translation'][:2])
points_global = points_rotated + translation
return points_global
def merge_vectors(self, global_vectors, distance_threshold=2.0):
"""聚合相近的矢量(去重)"""
merged = {
'divider': [],
'boundary': [],
'ped_crossing': []
}
for category in ['divider', 'boundary', 'ped_crossing']:
vectors = global_vectors[category]
if len(vectors) == 0:
continue
print(f" 聚合{category}: {len(vectors)}个矢量...")
# 简单去重:距离很近的矢量只保留一个
used = [False] * len(vectors)
for i in range(len(vectors)):
if used[i]:
continue
current = np.array(vectors[i]['points'])
used[i] = True
# 找到相似的矢量并平均
similar_indices = [i]
for j in range(i+1, len(vectors)):
if used[j]:
continue
other = np.array(vectors[j]['points'])
# 计算Hausdorff距离
dist = self.hausdorff_distance(current, other)
if dist < distance_threshold:
similar_indices.append(j)
used[j] = True
# 平均相似的矢量
if len(similar_indices) > 1:
all_points = [np.array(vectors[idx]['points']) for idx in similar_indices]
merged_points = np.mean(all_points, axis=0)
else:
merged_points = current
merged[category].append({
'points': merged_points.tolist()
})
print(f" 去重后: divider={len(merged['divider'])}, "
f"boundary={len(merged['boundary'])}, "
f"crossing={len(merged['ped_crossing'])}")
return merged
def hausdorff_distance(self, points1, points2):
"""计算Hausdorff距离"""
from scipy.spatial.distance import cdist
dist_matrix = cdist(points1, points2)
d1 = dist_matrix.min(axis=1).max()
d2 = dist_matrix.min(axis=0).max()
return max(d1, d2)
def build_road_network(self, merged_vectors):
"""构建道路网络"""
roads = []
# 简化版每个divider对应一条road
# 实际应该做拓扑分析连接成road network
for idx, divider in enumerate(merged_vectors['divider']):
points = np.array(divider['points'])
if len(points) < 2:
continue
# 计算road的长度
diffs = np.diff(points, axis=0)
lengths = np.sqrt((diffs**2).sum(axis=1))
total_length = lengths.sum()
# 计算初始heading
dx = points[1, 0] - points[0, 0]
dy = points[1, 1] - points[0, 1]
hdg = np.arctan2(dy, dx)
roads.append({
'id': idx,
'length': total_length,
'start': points[0],
'hdg': hdg,
'geometry': points,
'lanes': [
{
'id': 1,
'type': 'driving',
'level': 0,
'width': 3.5 # 标准车道宽度
}
]
})
return roads
def generate_opendrive_xml(self, roads, output_path):
"""生成OpenDRIVE XML文件"""
# 创建根元素
root = etree.Element('OpenDRIVE')
# Header
header = etree.SubElement(root, 'header',
revMajor='1',
revMinor='6',
name='nuScenes Map Converted',
version='1.00',
date='2025-10-18',
north='0.0',
south='0.0',
east='0.0',
west='0.0'
)
# Roads
for road_data in roads:
road = self.create_road_xml(road_data)
root.append(road)
# 格式化并保存
tree = etree.ElementTree(root)
# 确保目录存在
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
tree.write(
output_path,
pretty_print=True,
xml_declaration=True,
encoding='UTF-8'
)
def create_road_xml(self, road_data):
"""创建单个road的XML元素"""
road = etree.Element('road',
name=f"road_{road_data['id']}",
length=f"{road_data['length']:.2f}",
id=str(road_data['id']),
junction='-1'
)
# 1. Plan View (几何)
plan_view = etree.SubElement(road, 'planView')
# 简化使用line几何实际应该拟合曲线
points = road_data['geometry']
# 分段geometry
s = 0.0
for i in range(len(points) - 1):
x, y = points[i]
dx = points[i+1, 0] - x
dy = points[i+1, 1] - y
length = np.sqrt(dx**2 + dy**2)
hdg = np.arctan2(dy, dx)
geometry = etree.SubElement(plan_view, 'geometry',
s=f"{s:.2f}",
x=f"{x:.2f}",
y=f"{y:.2f}",
hdg=f"{hdg:.4f}",
length=f"{length:.2f}"
)
# Line类型
line = etree.SubElement(geometry, 'line')
s += length
# 2. Lanes
lanes = etree.SubElement(road, 'lanes')
# Lane Section
lane_section = etree.SubElement(lanes, 'laneSection', s='0.0')
# Center lane
center = etree.SubElement(lane_section, 'center')
center_lane = etree.SubElement(center, 'lane', id='0', type='none', level='0')
# Right lanes
right = etree.SubElement(lane_section, 'right')
for lane_info in road_data['lanes']:
lane = etree.SubElement(right, 'lane',
id=str(lane_info['id']),
type=lane_info['type'],
level=str(lane_info['level'])
)
# Lane width
width = etree.SubElement(lane, 'width',
sOffset='0.0',
a=str(lane_info['width']),
b='0.0',
c='0.0',
d='0.0'
)
return road
def main():
parser = argparse.ArgumentParser(description='转换nuScenes矢量地图为OpenDRIVE')
parser.add_argument('--vector-maps', type=str, required=True,
help='矢量地图pkl文件路径')
parser.add_argument('--nuscenes-root', type=str, default='data/nuscenes',
help='nuScenes数据根目录')
parser.add_argument('--output', type=str, required=True,
help='输出OpenDRIVE文件路径(.xodr)')
parser.add_argument('--bev-range', type=float, default=50,
help='BEV范围')
args = parser.parse_args()
# 转换
converter = NuScenesToOpenDRIVE(
args.vector_maps,
args.nuscenes_root,
args.bev_range
)
converter.convert(args.output)
print(f"\n✅ 转换完成!")
print(f"OpenDRIVE文件: {args.output}")
print(f"\n在CARLA中使用:")
print(f" 1. 复制到: CARLA/Import/{Path(args.output).name}")
print(f" 2. 启动CARLA并导入地图")
if __name__ == '__main__':
main()