""" 可视化矢量地图标注 用于验证提取的矢量地图数据是否正确 """ import pickle import argparse import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Polygon as MPLPolygon from matplotlib.collections import LineCollection def visualize_vector_map(vector_map_file, num_samples=5, save_dir='vis_vector_map'): """ 可视化矢量地图 Args: vector_map_file: 矢量地图pkl文件 num_samples: 可视化的样本数 save_dir: 保存目录 """ import os os.makedirs(save_dir, exist_ok=True) # 加载数据 print(f"加载矢量地图: {vector_map_file}") with open(vector_map_file, 'rb') as f: vector_maps = pickle.load(f) print(f"总样本数: {len(vector_maps)}") # 随机选择几个样本 sample_tokens = list(vector_maps.keys())[:num_samples] for idx, token in enumerate(sample_tokens): vectors = vector_maps[token] # 创建图 fig, ax = plt.subplots(figsize=(12, 12)) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_aspect('equal') ax.grid(True, alpha=0.3) ax.set_title(f'Sample {idx+1}: {token[:8]}...') ax.set_xlabel('X (normalized)') ax.set_ylabel('Y (normalized)') # 绘制divider(蓝色) for vec in vectors.get('divider', []): points = np.array(vec['points']) if len(points) > 0: ax.plot(points[:, 0], points[:, 1], 'b-', linewidth=2, alpha=0.7, label='Divider') ax.scatter(points[:, 0], points[:, 1], c='b', s=10, alpha=0.5) # 绘制boundary(红色) for vec in vectors.get('boundary', []): points = np.array(vec['points']) if len(points) > 0: ax.plot(points[:, 0], points[:, 1], 'r-', linewidth=2, alpha=0.7, label='Boundary') ax.scatter(points[:, 0], points[:, 1], c='r', s=10, alpha=0.5) # 绘制ped_crossing(绿色) for vec in vectors.get('ped_crossing', []): points = np.array(vec['points']) if len(points) > 0: ax.plot(points[:, 0], points[:, 1], 'g-', linewidth=2, alpha=0.7, label='Ped Crossing') ax.scatter(points[:, 0], points[:, 1], c='g', s=10, alpha=0.5) # 去重图例 handles, labels = ax.get_legend_handles_labels() by_label = dict(zip(labels, handles)) ax.legend(by_label.values(), by_label.keys(), loc='upper right') # 统计 num_dividers = len(vectors.get('divider', [])) num_boundaries = len(vectors.get('boundary', [])) num_crossings = len(vectors.get('ped_crossing', [])) ax.text(0.02, 0.98, f'Dividers: {num_dividers}\nBoundaries: {num_boundaries}\nCrossings: {num_crossings}', transform=ax.transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) # 保存 save_path = f'{save_dir}/sample_{idx+1:03d}.png' plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"保存到: {save_path}") plt.close() print(f"\n✅ 可视化完成!图像保存在: {save_dir}/") def print_vector_statistics(vector_map_file): """打印矢量地图统计信息""" with open(vector_map_file, 'rb') as f: vector_maps = pickle.load(f) total_dividers = 0 total_boundaries = 0 total_crossings = 0 divider_points = [] boundary_points = [] crossing_points = [] for vectors in vector_maps.values(): dividers = vectors.get('divider', []) boundaries = vectors.get('boundary', []) crossings = vectors.get('ped_crossing', []) total_dividers += len(dividers) total_boundaries += len(boundaries) total_crossings += len(crossings) for vec in dividers: divider_points.append(len(vec['points'])) for vec in boundaries: boundary_points.append(len(vec['points'])) for vec in crossings: crossing_points.append(len(vec['points'])) print("\n========== 矢量地图统计 ==========") print(f"总样本数: {len(vector_maps)}") print(f"\nDividers:") print(f" 总数: {total_dividers}") print(f" 平均/样本: {total_dividers/len(vector_maps):.2f}") if divider_points: print(f" 平均点数: {np.mean(divider_points):.1f}") print(f"\nBoundaries:") print(f" 总数: {total_boundaries}") print(f" 平均/样本: {total_boundaries/len(vector_maps):.2f}") if boundary_points: print(f" 平均点数: {np.mean(boundary_points):.1f}") print(f"\nPed Crossings:") print(f" 总数: {total_crossings}") print(f" 平均/样本: {total_crossings/len(vector_maps):.2f}") if crossing_points: print(f" 平均点数: {np.mean(crossing_points):.1f}") print("================================\n") def main(): parser = argparse.ArgumentParser(description='可视化矢量地图') parser.add_argument('--vector-map-file', type=str, default='data/nuscenes/vector_maps_bevfusion.pkl', help='矢量地图文件') parser.add_argument('--num-samples', type=int, default=5, help='可视化的样本数') parser.add_argument('--save-dir', type=str, default='vis_vector_map', help='保存目录') parser.add_argument('--stats-only', action='store_true', help='只打印统计信息') args = parser.parse_args() # 打印统计 print_vector_statistics(args.vector_map_file) # 可视化 if not args.stats_only: visualize_vector_map( args.vector_map_file, num_samples=args.num_samples, save_dir=args.save_dir ) if __name__ == '__main__': main()