bev-project/scripts/utils/plot_training_curves.py

132 lines
4.7 KiB
Python

#!/usr/bin/env python3
"""
绘制训练Loss曲线和性能趋势
"""
import re
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
def parse_training_log(log_file='enhanced_training_6gpus.log'):
"""解析训练日志"""
data = defaultdict(list)
with open(log_file, 'r') as f:
lines = f.readlines()
for line in lines:
if 'Epoch [' in line and 'loss:' in line:
# 提取iteration
iter_match = re.search(r'Epoch \[(\d+)\]\[(\d+)/\d+\]', line)
if iter_match:
epoch = int(iter_match.group(1))
iter_in_epoch = int(iter_match.group(2))
global_iter = (epoch - 1) * 10299 + iter_in_epoch
data['iterations'].append(global_iter)
data['epochs'].append(epoch)
# 提取total loss
loss_match = re.search(r'loss: ([\d.]+)', line)
if loss_match:
data['total_loss'].append(float(loss_match.group(1)))
# 提取分割loss
map_losses = []
for category in ['drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider']:
match = re.search(f'loss/map/{category}/focal: ([\\d.]+)', line)
if match:
map_losses.append(float(match.group(1)))
if map_losses:
data['map_loss'].append(np.mean(map_losses))
# 提取检测loss
obj_match = re.search(r'loss/object/loss_heatmap: ([\d.]+)', line)
if obj_match:
data['object_loss'].append(float(obj_match.group(1)))
return data
def plot_curves(data, output_dir='visualizations'):
"""绘制训练曲线"""
import os
os.makedirs(output_dir, exist_ok=True)
# 1. Total Loss曲线
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
plt.plot(data['iterations'], data['total_loss'],
linewidth=0.8, alpha=0.6, color='blue')
plt.xlabel('Iteration', fontsize=12)
plt.ylabel('Total Loss', fontsize=12)
plt.title('Training Total Loss', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
# 2. 分任务Loss
plt.subplot(1, 2, 2)
if len(data['map_loss']) > 0:
plt.plot(data['iterations'][:len(data['map_loss'])],
data['map_loss'],
label='Map Loss (Seg)', linewidth=1, alpha=0.7, color='green')
if len(data['object_loss']) > 0:
plt.plot(data['iterations'][:len(data['object_loss'])],
data['object_loss'],
label='Object Loss (Det)', linewidth=1, alpha=0.7, color='red')
plt.xlabel('Iteration', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Task-wise Loss', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{output_dir}/training_loss_curves.png', dpi=150, bbox_inches='tight')
print(f"✅ Loss曲线已保存: {output_dir}/training_loss_curves.png")
# 3. 各Epoch平均Loss
epoch_avg_loss = defaultdict(list)
for i, (ep, loss) in enumerate(zip(data['epochs'], data['total_loss'])):
epoch_avg_loss[ep].append(loss)
epochs = sorted(epoch_avg_loss.keys())
avg_losses = [np.mean(epoch_avg_loss[ep]) for ep in epochs]
min_losses = [np.min(epoch_avg_loss[ep]) for ep in epochs]
plt.figure(figsize=(10, 6))
plt.plot(epochs, avg_losses, 'o-', label='Average Loss', linewidth=2, markersize=8)
plt.plot(epochs, min_losses, 's--', label='Min Loss', linewidth=2, markersize=6, alpha=0.7)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Loss per Epoch', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.savefig(f'{output_dir}/epoch_loss_trend.png', dpi=150, bbox_inches='tight')
print(f"✅ Epoch趋势已保存: {output_dir}/epoch_loss_trend.png")
if __name__ == '__main__':
print("="*80)
print("📈 绘制训练曲线")
print("="*80)
data = parse_training_log()
print(f"\n📊 解析统计:")
print(f" 总iterations: {len(data['iterations'])}")
print(f" 训练epochs: {len(set(data['epochs']))}")
print(f" Loss数据点: {len(data['total_loss'])}")
if len(data['total_loss']) > 0:
print(f"\n 起始Loss: {data['total_loss'][0]:.4f}")
print(f" 当前Loss: {data['total_loss'][-1]:.4f}")
print(f" 下降幅度: {(data['total_loss'][0] - data['total_loss'][-1])/data['total_loss'][0]*100:.1f}%")
print("\n绘制中...")
plot_curves(data)
print("\n✅ 完成!")
print("="*80)