176 lines
5.9 KiB
Python
176 lines
5.9 KiB
Python
"""
|
||
可视化矢量地图标注
|
||
|
||
用于验证提取的矢量地图数据是否正确
|
||
"""
|
||
|
||
import pickle
|
||
import argparse
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from matplotlib.patches import Polygon as MPLPolygon
|
||
from matplotlib.collections import LineCollection
|
||
|
||
|
||
def visualize_vector_map(vector_map_file, num_samples=5, save_dir='vis_vector_map'):
|
||
"""
|
||
可视化矢量地图
|
||
|
||
Args:
|
||
vector_map_file: 矢量地图pkl文件
|
||
num_samples: 可视化的样本数
|
||
save_dir: 保存目录
|
||
"""
|
||
import os
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# 加载数据
|
||
print(f"加载矢量地图: {vector_map_file}")
|
||
with open(vector_map_file, 'rb') as f:
|
||
vector_maps = pickle.load(f)
|
||
|
||
print(f"总样本数: {len(vector_maps)}")
|
||
|
||
# 随机选择几个样本
|
||
sample_tokens = list(vector_maps.keys())[:num_samples]
|
||
|
||
for idx, token in enumerate(sample_tokens):
|
||
vectors = vector_maps[token]
|
||
|
||
# 创建图
|
||
fig, ax = plt.subplots(figsize=(12, 12))
|
||
ax.set_xlim(0, 1)
|
||
ax.set_ylim(0, 1)
|
||
ax.set_aspect('equal')
|
||
ax.grid(True, alpha=0.3)
|
||
ax.set_title(f'Sample {idx+1}: {token[:8]}...')
|
||
ax.set_xlabel('X (normalized)')
|
||
ax.set_ylabel('Y (normalized)')
|
||
|
||
# 绘制divider(蓝色)
|
||
for vec in vectors.get('divider', []):
|
||
points = np.array(vec['points'])
|
||
if len(points) > 0:
|
||
ax.plot(points[:, 0], points[:, 1], 'b-', linewidth=2, alpha=0.7, label='Divider')
|
||
ax.scatter(points[:, 0], points[:, 1], c='b', s=10, alpha=0.5)
|
||
|
||
# 绘制boundary(红色)
|
||
for vec in vectors.get('boundary', []):
|
||
points = np.array(vec['points'])
|
||
if len(points) > 0:
|
||
ax.plot(points[:, 0], points[:, 1], 'r-', linewidth=2, alpha=0.7, label='Boundary')
|
||
ax.scatter(points[:, 0], points[:, 1], c='r', s=10, alpha=0.5)
|
||
|
||
# 绘制ped_crossing(绿色)
|
||
for vec in vectors.get('ped_crossing', []):
|
||
points = np.array(vec['points'])
|
||
if len(points) > 0:
|
||
ax.plot(points[:, 0], points[:, 1], 'g-', linewidth=2, alpha=0.7, label='Ped Crossing')
|
||
ax.scatter(points[:, 0], points[:, 1], c='g', s=10, alpha=0.5)
|
||
|
||
# 去重图例
|
||
handles, labels = ax.get_legend_handles_labels()
|
||
by_label = dict(zip(labels, handles))
|
||
ax.legend(by_label.values(), by_label.keys(), loc='upper right')
|
||
|
||
# 统计
|
||
num_dividers = len(vectors.get('divider', []))
|
||
num_boundaries = len(vectors.get('boundary', []))
|
||
num_crossings = len(vectors.get('ped_crossing', []))
|
||
|
||
ax.text(0.02, 0.98,
|
||
f'Dividers: {num_dividers}\nBoundaries: {num_boundaries}\nCrossings: {num_crossings}',
|
||
transform=ax.transAxes,
|
||
verticalalignment='top',
|
||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||
|
||
# 保存
|
||
save_path = f'{save_dir}/sample_{idx+1:03d}.png'
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
print(f"保存到: {save_path}")
|
||
plt.close()
|
||
|
||
print(f"\n✅ 可视化完成!图像保存在: {save_dir}/")
|
||
|
||
|
||
def print_vector_statistics(vector_map_file):
|
||
"""打印矢量地图统计信息"""
|
||
with open(vector_map_file, 'rb') as f:
|
||
vector_maps = pickle.load(f)
|
||
|
||
total_dividers = 0
|
||
total_boundaries = 0
|
||
total_crossings = 0
|
||
|
||
divider_points = []
|
||
boundary_points = []
|
||
crossing_points = []
|
||
|
||
for vectors in vector_maps.values():
|
||
dividers = vectors.get('divider', [])
|
||
boundaries = vectors.get('boundary', [])
|
||
crossings = vectors.get('ped_crossing', [])
|
||
|
||
total_dividers += len(dividers)
|
||
total_boundaries += len(boundaries)
|
||
total_crossings += len(crossings)
|
||
|
||
for vec in dividers:
|
||
divider_points.append(len(vec['points']))
|
||
for vec in boundaries:
|
||
boundary_points.append(len(vec['points']))
|
||
for vec in crossings:
|
||
crossing_points.append(len(vec['points']))
|
||
|
||
print("\n========== 矢量地图统计 ==========")
|
||
print(f"总样本数: {len(vector_maps)}")
|
||
print(f"\nDividers:")
|
||
print(f" 总数: {total_dividers}")
|
||
print(f" 平均/样本: {total_dividers/len(vector_maps):.2f}")
|
||
if divider_points:
|
||
print(f" 平均点数: {np.mean(divider_points):.1f}")
|
||
|
||
print(f"\nBoundaries:")
|
||
print(f" 总数: {total_boundaries}")
|
||
print(f" 平均/样本: {total_boundaries/len(vector_maps):.2f}")
|
||
if boundary_points:
|
||
print(f" 平均点数: {np.mean(boundary_points):.1f}")
|
||
|
||
print(f"\nPed Crossings:")
|
||
print(f" 总数: {total_crossings}")
|
||
print(f" 平均/样本: {total_crossings/len(vector_maps):.2f}")
|
||
if crossing_points:
|
||
print(f" 平均点数: {np.mean(crossing_points):.1f}")
|
||
print("================================\n")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='可视化矢量地图')
|
||
parser.add_argument('--vector-map-file', type=str,
|
||
default='data/nuscenes/vector_maps_bevfusion.pkl',
|
||
help='矢量地图文件')
|
||
parser.add_argument('--num-samples', type=int, default=5,
|
||
help='可视化的样本数')
|
||
parser.add_argument('--save-dir', type=str, default='vis_vector_map',
|
||
help='保存目录')
|
||
parser.add_argument('--stats-only', action='store_true',
|
||
help='只打印统计信息')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 打印统计
|
||
print_vector_statistics(args.vector_map_file)
|
||
|
||
# 可视化
|
||
if not args.stats_only:
|
||
visualize_vector_map(
|
||
args.vector_map_file,
|
||
num_samples=args.num_samples,
|
||
save_dir=args.save_dir
|
||
)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|