262 lines
8.3 KiB
Python
262 lines
8.3 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Phase 4B RMT-PPAD 评估结果可视化脚本
|
||
|
|
解析评估结果文件并生成可视化输出
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import pickle
|
||
|
|
import os
|
||
|
|
import numpy as np
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import seaborn as sns
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
def parse_results(results_file):
|
||
|
|
"""解析MMDetection3D评估结果文件"""
|
||
|
|
print(f"📂 读取结果文件: {results_file}")
|
||
|
|
|
||
|
|
if not os.path.exists(results_file):
|
||
|
|
print(f"❌ 文件不存在: {results_file}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
with open(results_file, 'rb') as f:
|
||
|
|
results = pickle.load(f)
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def visualize_bbox_metrics(results, save_path="bbox_metrics.png"):
|
||
|
|
"""可视化3D检测指标"""
|
||
|
|
print("📊 可视化3D检测指标...")
|
||
|
|
|
||
|
|
if 'bbox' not in results:
|
||
|
|
print("⚠️ 没有找到bbox指标")
|
||
|
|
return
|
||
|
|
|
||
|
|
bbox_results = results['bbox']
|
||
|
|
|
||
|
|
# 解析不同类别的指标
|
||
|
|
classes = []
|
||
|
|
aps = []
|
||
|
|
ap_50s = []
|
||
|
|
ap_75s = []
|
||
|
|
|
||
|
|
for class_name, metrics in bbox_results.items():
|
||
|
|
if class_name.startswith('car') or class_name.startswith('truck') or \
|
||
|
|
class_name.startswith('bus') or class_name.startswith('pedestrian'):
|
||
|
|
|
||
|
|
classes.append(class_name.split('_')[0]) # 去掉AP后缀
|
||
|
|
aps.append(metrics.get('ap', 0))
|
||
|
|
ap_50s.append(metrics.get('ap_0.5', 0))
|
||
|
|
ap_75s.append(metrics.get('ap_0.75', 0))
|
||
|
|
|
||
|
|
if not classes:
|
||
|
|
print("⚠️ 没有找到有效的类别指标")
|
||
|
|
return
|
||
|
|
|
||
|
|
# 创建图表
|
||
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||
|
|
|
||
|
|
x = np.arange(len(classes))
|
||
|
|
width = 0.25
|
||
|
|
|
||
|
|
# AP@0.5:0.95
|
||
|
|
axes[0].bar(x - width, aps, width, label='AP@0.5:0.95', color='skyblue')
|
||
|
|
axes[0].set_title('Average Precision (AP@0.5:0.95)')
|
||
|
|
axes[0].set_xticks(x)
|
||
|
|
axes[0].set_xticklabels(classes, rotation=45)
|
||
|
|
axes[0].set_ylabel('AP')
|
||
|
|
axes[0].grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
# AP@0.5
|
||
|
|
axes[1].bar(x, ap_50s, width, label='AP@0.5', color='lightgreen')
|
||
|
|
axes[1].set_title('Average Precision (AP@0.5)')
|
||
|
|
axes[1].set_xticks(x)
|
||
|
|
axes[1].set_xticklabels(classes, rotation=45)
|
||
|
|
axes[1].set_ylabel('AP')
|
||
|
|
axes[1].grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
# AP@0.75
|
||
|
|
axes[2].bar(x + width, ap_75s, width, label='AP@0.75', color='salmon')
|
||
|
|
axes[2].set_title('Average Precision (AP@0.75)')
|
||
|
|
axes[2].set_xticks(x)
|
||
|
|
axes[2].set_xticklabels(classes, rotation=45)
|
||
|
|
axes[2].set_ylabel('AP')
|
||
|
|
axes[2].grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
|
|
plt.close()
|
||
|
|
|
||
|
|
print(f"✅ 3D检测指标图表已保存: {save_path}")
|
||
|
|
|
||
|
|
def visualize_map_metrics(results, save_path="map_metrics.png"):
|
||
|
|
"""可视化BEV分割指标"""
|
||
|
|
print("🗺️ 可视化BEV分割指标...")
|
||
|
|
|
||
|
|
if 'map' not in results:
|
||
|
|
print("⚠️ 没有找到map指标")
|
||
|
|
return
|
||
|
|
|
||
|
|
map_results = results['map']
|
||
|
|
|
||
|
|
# 解析分割类别
|
||
|
|
classes = []
|
||
|
|
ious = []
|
||
|
|
dices = []
|
||
|
|
|
||
|
|
for class_name, metrics in map_results.items():
|
||
|
|
if 'IoU' in metrics and 'Dice' in metrics:
|
||
|
|
# 去掉类别名称中的数字后缀
|
||
|
|
clean_name = class_name.replace('_0', '').replace('_1', '').replace('_2', '') \
|
||
|
|
.replace('_3', '').replace('_4', '').replace('_5', '')
|
||
|
|
|
||
|
|
if clean_name not in classes:
|
||
|
|
classes.append(clean_name)
|
||
|
|
ious.append(metrics['IoU'])
|
||
|
|
dices.append(metrics['Dice'])
|
||
|
|
|
||
|
|
if not classes:
|
||
|
|
print("⚠️ 没有找到有效的分割指标")
|
||
|
|
return
|
||
|
|
|
||
|
|
# 创建图表
|
||
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
|
|
||
|
|
x = np.arange(len(classes))
|
||
|
|
width = 0.35
|
||
|
|
|
||
|
|
bars1 = ax.bar(x - width/2, ious, width, label='IoU', color='lightblue', alpha=0.8)
|
||
|
|
bars2 = ax.bar(x + width/2, dices, width, label='Dice', color='lightcoral', alpha=0.8)
|
||
|
|
|
||
|
|
ax.set_xlabel('类别')
|
||
|
|
ax.set_ylabel('分数')
|
||
|
|
ax.set_title('BEV分割性能指标')
|
||
|
|
ax.set_xticks(x)
|
||
|
|
ax.set_xticklabels(classes, rotation=45, ha='right')
|
||
|
|
ax.legend()
|
||
|
|
ax.grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
# 添加数值标签
|
||
|
|
for bar, value in zip(bars1, ious):
|
||
|
|
height = bar.get_height()
|
||
|
|
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
||
|
|
f'{value:.3f}', ha='center', va='bottom', fontsize=8)
|
||
|
|
|
||
|
|
for bar, value in zip(bars2, dices):
|
||
|
|
height = bar.get_height()
|
||
|
|
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
||
|
|
f'{value:.3f}', ha='center', va='bottom', fontsize=8)
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
|
|
plt.close()
|
||
|
|
|
||
|
|
print(f"✅ BEV分割指标图表已保存: {save_path}")
|
||
|
|
|
||
|
|
def print_summary(results, log_file=None):
|
||
|
|
"""打印评估结果摘要"""
|
||
|
|
print("\n" + "="*60)
|
||
|
|
print("📈 Phase 4B RMT-PPAD 评估结果摘要")
|
||
|
|
print("="*60)
|
||
|
|
|
||
|
|
# 3D检测指标
|
||
|
|
if 'bbox' in results:
|
||
|
|
print("\n🚗 3D检测性能:")
|
||
|
|
bbox_results = results['bbox']
|
||
|
|
|
||
|
|
# 计算NDS (如果有的话)
|
||
|
|
if 'NDS' in bbox_results:
|
||
|
|
print(".4f")
|
||
|
|
|
||
|
|
# 各类别AP
|
||
|
|
for class_name, metrics in bbox_results.items():
|
||
|
|
if isinstance(metrics, dict) and 'ap' in metrics:
|
||
|
|
print(f" {class_name}: AP@0.5:0.95 = {metrics.get('ap', 0):.4f}, "
|
||
|
|
f"AP@0.5 = {metrics.get('ap_0.5', 0):.4f}")
|
||
|
|
|
||
|
|
# BEV分割指标
|
||
|
|
if 'map' in results:
|
||
|
|
print("\n🗺️ BEV分割性能:")
|
||
|
|
map_results = results['map']
|
||
|
|
|
||
|
|
total_iou = 0
|
||
|
|
total_dice = 0
|
||
|
|
count = 0
|
||
|
|
|
||
|
|
for class_name, metrics in map_results.items():
|
||
|
|
if isinstance(metrics, dict) and 'IoU' in metrics:
|
||
|
|
iou = metrics['IoU']
|
||
|
|
dice = metrics.get('Dice', 0)
|
||
|
|
print(f" {class_name}: IoU = {iou:.4f}, Dice = {dice:.4f}")
|
||
|
|
|
||
|
|
total_iou += iou
|
||
|
|
total_dice += dice
|
||
|
|
count += 1
|
||
|
|
|
||
|
|
if count > 0:
|
||
|
|
print(f" 平均IoU: {total_iou/count:.4f}, 平均Dice: {total_dice/count:.4f}")
|
||
|
|
|
||
|
|
print("\n" + "="*60)
|
||
|
|
|
||
|
|
# 保存到日志文件
|
||
|
|
if log_file:
|
||
|
|
with open(log_file, 'w') as f:
|
||
|
|
f.write("Phase 4B RMT-PPAD 评估结果摘要\n\n")
|
||
|
|
f.write("生成时间: " + str(np.datetime64('now')) + "\n\n")
|
||
|
|
|
||
|
|
if 'bbox' in results:
|
||
|
|
f.write("3D检测性能:\n")
|
||
|
|
for class_name, metrics in results['bbox'].items():
|
||
|
|
if isinstance(metrics, dict) and 'ap' in metrics:
|
||
|
|
f.write(f" {class_name}: AP@0.5:0.95 = {metrics.get('ap', 0):.4f}\n")
|
||
|
|
|
||
|
|
if 'map' in results:
|
||
|
|
f.write("\nBEV分割性能:\n")
|
||
|
|
for class_name, metrics in results['map'].items():
|
||
|
|
if isinstance(metrics, dict) and 'IoU' in metrics:
|
||
|
|
f.write(f" {class_name}: IoU = {metrics['IoU']:.4f}\n")
|
||
|
|
|
||
|
|
print(f"📝 详细结果已保存到: {log_file}")
|
||
|
|
|
||
|
|
def create_visualization_script(results_file, output_dir="visualization"):
|
||
|
|
"""创建可视化脚本的主函数"""
|
||
|
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
|
||
|
|
# 解析结果
|
||
|
|
results = parse_results(results_file)
|
||
|
|
if results is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
# 生成可视化
|
||
|
|
bbox_plot = os.path.join(output_dir, "bbox_metrics.png")
|
||
|
|
map_plot = os.path.join(output_dir, "map_metrics.png")
|
||
|
|
summary_file = os.path.join(output_dir, "evaluation_summary.txt")
|
||
|
|
|
||
|
|
visualize_bbox_metrics(results, bbox_plot)
|
||
|
|
visualize_map_metrics(results, map_plot)
|
||
|
|
print_summary(results, summary_file)
|
||
|
|
|
||
|
|
print(f"\n🎉 可视化完成!输出目录: {output_dir}")
|
||
|
|
print(f" 📊 3D检测图表: {bbox_plot}")
|
||
|
|
print(f" 🗺️ 分割图表: {map_plot}")
|
||
|
|
print(f" 📝 结果摘要: {summary_file}")
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Phase 4B RMT-PPAD 评估结果可视化工具")
|
||
|
|
parser.add_argument("results_file", help="评估结果文件路径 (.pkl)")
|
||
|
|
parser.add_argument("--out-dir", default="visualization",
|
||
|
|
help="输出目录 (默认: visualization)")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
print("🚀 Phase 4B RMT-PPAD 评估结果可视化工具")
|
||
|
|
print(f"📁 结果文件: {args.results_file}")
|
||
|
|
print(f"📂 输出目录: {args.out_dir}")
|
||
|
|
print()
|
||
|
|
|
||
|
|
create_visualization_script(args.results_file, args.out_dir)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|