bev-project/tools/visualize_vector_map.py

176 lines
5.9 KiB
Python
Raw Normal View History

"""
可视化矢量地图标注
用于验证提取的矢量地图数据是否正确
"""
import pickle
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon as MPLPolygon
from matplotlib.collections import LineCollection
def visualize_vector_map(vector_map_file, num_samples=5, save_dir='vis_vector_map'):
"""
可视化矢量地图
Args:
vector_map_file: 矢量地图pkl文件
num_samples: 可视化的样本数
save_dir: 保存目录
"""
import os
os.makedirs(save_dir, exist_ok=True)
# 加载数据
print(f"加载矢量地图: {vector_map_file}")
with open(vector_map_file, 'rb') as f:
vector_maps = pickle.load(f)
print(f"总样本数: {len(vector_maps)}")
# 随机选择几个样本
sample_tokens = list(vector_maps.keys())[:num_samples]
for idx, token in enumerate(sample_tokens):
vectors = vector_maps[token]
# 创建图
fig, ax = plt.subplots(figsize=(12, 12))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.set_title(f'Sample {idx+1}: {token[:8]}...')
ax.set_xlabel('X (normalized)')
ax.set_ylabel('Y (normalized)')
# 绘制divider蓝色
for vec in vectors.get('divider', []):
points = np.array(vec['points'])
if len(points) > 0:
ax.plot(points[:, 0], points[:, 1], 'b-', linewidth=2, alpha=0.7, label='Divider')
ax.scatter(points[:, 0], points[:, 1], c='b', s=10, alpha=0.5)
# 绘制boundary红色
for vec in vectors.get('boundary', []):
points = np.array(vec['points'])
if len(points) > 0:
ax.plot(points[:, 0], points[:, 1], 'r-', linewidth=2, alpha=0.7, label='Boundary')
ax.scatter(points[:, 0], points[:, 1], c='r', s=10, alpha=0.5)
# 绘制ped_crossing绿色
for vec in vectors.get('ped_crossing', []):
points = np.array(vec['points'])
if len(points) > 0:
ax.plot(points[:, 0], points[:, 1], 'g-', linewidth=2, alpha=0.7, label='Ped Crossing')
ax.scatter(points[:, 0], points[:, 1], c='g', s=10, alpha=0.5)
# 去重图例
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax.legend(by_label.values(), by_label.keys(), loc='upper right')
# 统计
num_dividers = len(vectors.get('divider', []))
num_boundaries = len(vectors.get('boundary', []))
num_crossings = len(vectors.get('ped_crossing', []))
ax.text(0.02, 0.98,
f'Dividers: {num_dividers}\nBoundaries: {num_boundaries}\nCrossings: {num_crossings}',
transform=ax.transAxes,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# 保存
save_path = f'{save_dir}/sample_{idx+1:03d}.png'
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"保存到: {save_path}")
plt.close()
print(f"\n✅ 可视化完成!图像保存在: {save_dir}/")
def print_vector_statistics(vector_map_file):
"""打印矢量地图统计信息"""
with open(vector_map_file, 'rb') as f:
vector_maps = pickle.load(f)
total_dividers = 0
total_boundaries = 0
total_crossings = 0
divider_points = []
boundary_points = []
crossing_points = []
for vectors in vector_maps.values():
dividers = vectors.get('divider', [])
boundaries = vectors.get('boundary', [])
crossings = vectors.get('ped_crossing', [])
total_dividers += len(dividers)
total_boundaries += len(boundaries)
total_crossings += len(crossings)
for vec in dividers:
divider_points.append(len(vec['points']))
for vec in boundaries:
boundary_points.append(len(vec['points']))
for vec in crossings:
crossing_points.append(len(vec['points']))
print("\n========== 矢量地图统计 ==========")
print(f"总样本数: {len(vector_maps)}")
print(f"\nDividers:")
print(f" 总数: {total_dividers}")
print(f" 平均/样本: {total_dividers/len(vector_maps):.2f}")
if divider_points:
print(f" 平均点数: {np.mean(divider_points):.1f}")
print(f"\nBoundaries:")
print(f" 总数: {total_boundaries}")
print(f" 平均/样本: {total_boundaries/len(vector_maps):.2f}")
if boundary_points:
print(f" 平均点数: {np.mean(boundary_points):.1f}")
print(f"\nPed Crossings:")
print(f" 总数: {total_crossings}")
print(f" 平均/样本: {total_crossings/len(vector_maps):.2f}")
if crossing_points:
print(f" 平均点数: {np.mean(crossing_points):.1f}")
print("================================\n")
def main():
parser = argparse.ArgumentParser(description='可视化矢量地图')
parser.add_argument('--vector-map-file', type=str,
default='data/nuscenes/vector_maps_bevfusion.pkl',
help='矢量地图文件')
parser.add_argument('--num-samples', type=int, default=5,
help='可视化的样本数')
parser.add_argument('--save-dir', type=str, default='vis_vector_map',
help='保存目录')
parser.add_argument('--stats-only', action='store_true',
help='只打印统计信息')
args = parser.parse_args()
# 打印统计
print_vector_statistics(args.vector_map_file)
# 可视化
if not args.stats_only:
visualize_vector_map(
args.vector_map_file,
num_samples=args.num_samples,
save_dir=args.save_dir
)
if __name__ == '__main__':
main()