176 lines
5.9 KiB
Python
176 lines
5.9 KiB
Python
|
|
"""
|
|||
|
|
可视化矢量地图标注
|
|||
|
|
|
|||
|
|
用于验证提取的矢量地图数据是否正确
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pickle
|
|||
|
|
import argparse
|
|||
|
|
import numpy as np
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
from matplotlib.patches import Polygon as MPLPolygon
|
|||
|
|
from matplotlib.collections import LineCollection
|
|||
|
|
|
|||
|
|
|
|||
|
|
def visualize_vector_map(vector_map_file, num_samples=5, save_dir='vis_vector_map'):
|
|||
|
|
"""
|
|||
|
|
可视化矢量地图
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
vector_map_file: 矢量地图pkl文件
|
|||
|
|
num_samples: 可视化的样本数
|
|||
|
|
save_dir: 保存目录
|
|||
|
|
"""
|
|||
|
|
import os
|
|||
|
|
os.makedirs(save_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 加载数据
|
|||
|
|
print(f"加载矢量地图: {vector_map_file}")
|
|||
|
|
with open(vector_map_file, 'rb') as f:
|
|||
|
|
vector_maps = pickle.load(f)
|
|||
|
|
|
|||
|
|
print(f"总样本数: {len(vector_maps)}")
|
|||
|
|
|
|||
|
|
# 随机选择几个样本
|
|||
|
|
sample_tokens = list(vector_maps.keys())[:num_samples]
|
|||
|
|
|
|||
|
|
for idx, token in enumerate(sample_tokens):
|
|||
|
|
vectors = vector_maps[token]
|
|||
|
|
|
|||
|
|
# 创建图
|
|||
|
|
fig, ax = plt.subplots(figsize=(12, 12))
|
|||
|
|
ax.set_xlim(0, 1)
|
|||
|
|
ax.set_ylim(0, 1)
|
|||
|
|
ax.set_aspect('equal')
|
|||
|
|
ax.grid(True, alpha=0.3)
|
|||
|
|
ax.set_title(f'Sample {idx+1}: {token[:8]}...')
|
|||
|
|
ax.set_xlabel('X (normalized)')
|
|||
|
|
ax.set_ylabel('Y (normalized)')
|
|||
|
|
|
|||
|
|
# 绘制divider(蓝色)
|
|||
|
|
for vec in vectors.get('divider', []):
|
|||
|
|
points = np.array(vec['points'])
|
|||
|
|
if len(points) > 0:
|
|||
|
|
ax.plot(points[:, 0], points[:, 1], 'b-', linewidth=2, alpha=0.7, label='Divider')
|
|||
|
|
ax.scatter(points[:, 0], points[:, 1], c='b', s=10, alpha=0.5)
|
|||
|
|
|
|||
|
|
# 绘制boundary(红色)
|
|||
|
|
for vec in vectors.get('boundary', []):
|
|||
|
|
points = np.array(vec['points'])
|
|||
|
|
if len(points) > 0:
|
|||
|
|
ax.plot(points[:, 0], points[:, 1], 'r-', linewidth=2, alpha=0.7, label='Boundary')
|
|||
|
|
ax.scatter(points[:, 0], points[:, 1], c='r', s=10, alpha=0.5)
|
|||
|
|
|
|||
|
|
# 绘制ped_crossing(绿色)
|
|||
|
|
for vec in vectors.get('ped_crossing', []):
|
|||
|
|
points = np.array(vec['points'])
|
|||
|
|
if len(points) > 0:
|
|||
|
|
ax.plot(points[:, 0], points[:, 1], 'g-', linewidth=2, alpha=0.7, label='Ped Crossing')
|
|||
|
|
ax.scatter(points[:, 0], points[:, 1], c='g', s=10, alpha=0.5)
|
|||
|
|
|
|||
|
|
# 去重图例
|
|||
|
|
handles, labels = ax.get_legend_handles_labels()
|
|||
|
|
by_label = dict(zip(labels, handles))
|
|||
|
|
ax.legend(by_label.values(), by_label.keys(), loc='upper right')
|
|||
|
|
|
|||
|
|
# 统计
|
|||
|
|
num_dividers = len(vectors.get('divider', []))
|
|||
|
|
num_boundaries = len(vectors.get('boundary', []))
|
|||
|
|
num_crossings = len(vectors.get('ped_crossing', []))
|
|||
|
|
|
|||
|
|
ax.text(0.02, 0.98,
|
|||
|
|
f'Dividers: {num_dividers}\nBoundaries: {num_boundaries}\nCrossings: {num_crossings}',
|
|||
|
|
transform=ax.transAxes,
|
|||
|
|
verticalalignment='top',
|
|||
|
|
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
|||
|
|
|
|||
|
|
# 保存
|
|||
|
|
save_path = f'{save_dir}/sample_{idx+1:03d}.png'
|
|||
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|||
|
|
print(f"保存到: {save_path}")
|
|||
|
|
plt.close()
|
|||
|
|
|
|||
|
|
print(f"\n✅ 可视化完成!图像保存在: {save_dir}/")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_vector_statistics(vector_map_file):
|
|||
|
|
"""打印矢量地图统计信息"""
|
|||
|
|
with open(vector_map_file, 'rb') as f:
|
|||
|
|
vector_maps = pickle.load(f)
|
|||
|
|
|
|||
|
|
total_dividers = 0
|
|||
|
|
total_boundaries = 0
|
|||
|
|
total_crossings = 0
|
|||
|
|
|
|||
|
|
divider_points = []
|
|||
|
|
boundary_points = []
|
|||
|
|
crossing_points = []
|
|||
|
|
|
|||
|
|
for vectors in vector_maps.values():
|
|||
|
|
dividers = vectors.get('divider', [])
|
|||
|
|
boundaries = vectors.get('boundary', [])
|
|||
|
|
crossings = vectors.get('ped_crossing', [])
|
|||
|
|
|
|||
|
|
total_dividers += len(dividers)
|
|||
|
|
total_boundaries += len(boundaries)
|
|||
|
|
total_crossings += len(crossings)
|
|||
|
|
|
|||
|
|
for vec in dividers:
|
|||
|
|
divider_points.append(len(vec['points']))
|
|||
|
|
for vec in boundaries:
|
|||
|
|
boundary_points.append(len(vec['points']))
|
|||
|
|
for vec in crossings:
|
|||
|
|
crossing_points.append(len(vec['points']))
|
|||
|
|
|
|||
|
|
print("\n========== 矢量地图统计 ==========")
|
|||
|
|
print(f"总样本数: {len(vector_maps)}")
|
|||
|
|
print(f"\nDividers:")
|
|||
|
|
print(f" 总数: {total_dividers}")
|
|||
|
|
print(f" 平均/样本: {total_dividers/len(vector_maps):.2f}")
|
|||
|
|
if divider_points:
|
|||
|
|
print(f" 平均点数: {np.mean(divider_points):.1f}")
|
|||
|
|
|
|||
|
|
print(f"\nBoundaries:")
|
|||
|
|
print(f" 总数: {total_boundaries}")
|
|||
|
|
print(f" 平均/样本: {total_boundaries/len(vector_maps):.2f}")
|
|||
|
|
if boundary_points:
|
|||
|
|
print(f" 平均点数: {np.mean(boundary_points):.1f}")
|
|||
|
|
|
|||
|
|
print(f"\nPed Crossings:")
|
|||
|
|
print(f" 总数: {total_crossings}")
|
|||
|
|
print(f" 平均/样本: {total_crossings/len(vector_maps):.2f}")
|
|||
|
|
if crossing_points:
|
|||
|
|
print(f" 平均点数: {np.mean(crossing_points):.1f}")
|
|||
|
|
print("================================\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(description='可视化矢量地图')
|
|||
|
|
parser.add_argument('--vector-map-file', type=str,
|
|||
|
|
default='data/nuscenes/vector_maps_bevfusion.pkl',
|
|||
|
|
help='矢量地图文件')
|
|||
|
|
parser.add_argument('--num-samples', type=int, default=5,
|
|||
|
|
help='可视化的样本数')
|
|||
|
|
parser.add_argument('--save-dir', type=str, default='vis_vector_map',
|
|||
|
|
help='保存目录')
|
|||
|
|
parser.add_argument('--stats-only', action='store_true',
|
|||
|
|
help='只打印统计信息')
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 打印统计
|
|||
|
|
print_vector_statistics(args.vector_map_file)
|
|||
|
|
|
|||
|
|
# 可视化
|
|||
|
|
if not args.stats_only:
|
|||
|
|
visualize_vector_map(
|
|||
|
|
args.vector_map_file,
|
|||
|
|
num_samples=args.num_samples,
|
|||
|
|
save_dir=args.save_dir
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|
|||
|
|
|