132 lines
4.7 KiB
Python
132 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
绘制训练Loss曲线和性能趋势
|
|
"""
|
|
|
|
import re
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from collections import defaultdict
|
|
|
|
def parse_training_log(log_file='enhanced_training_6gpus.log'):
|
|
"""解析训练日志"""
|
|
|
|
data = defaultdict(list)
|
|
|
|
with open(log_file, 'r') as f:
|
|
lines = f.readlines()
|
|
|
|
for line in lines:
|
|
if 'Epoch [' in line and 'loss:' in line:
|
|
# 提取iteration
|
|
iter_match = re.search(r'Epoch \[(\d+)\]\[(\d+)/\d+\]', line)
|
|
if iter_match:
|
|
epoch = int(iter_match.group(1))
|
|
iter_in_epoch = int(iter_match.group(2))
|
|
global_iter = (epoch - 1) * 10299 + iter_in_epoch
|
|
data['iterations'].append(global_iter)
|
|
data['epochs'].append(epoch)
|
|
|
|
# 提取total loss
|
|
loss_match = re.search(r'loss: ([\d.]+)', line)
|
|
if loss_match:
|
|
data['total_loss'].append(float(loss_match.group(1)))
|
|
|
|
# 提取分割loss
|
|
map_losses = []
|
|
for category in ['drivable_area', 'ped_crossing', 'walkway',
|
|
'stop_line', 'carpark_area', 'divider']:
|
|
match = re.search(f'loss/map/{category}/focal: ([\\d.]+)', line)
|
|
if match:
|
|
map_losses.append(float(match.group(1)))
|
|
if map_losses:
|
|
data['map_loss'].append(np.mean(map_losses))
|
|
|
|
# 提取检测loss
|
|
obj_match = re.search(r'loss/object/loss_heatmap: ([\d.]+)', line)
|
|
if obj_match:
|
|
data['object_loss'].append(float(obj_match.group(1)))
|
|
|
|
return data
|
|
|
|
def plot_curves(data, output_dir='visualizations'):
|
|
"""绘制训练曲线"""
|
|
|
|
import os
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# 1. Total Loss曲线
|
|
plt.figure(figsize=(14, 6))
|
|
plt.subplot(1, 2, 1)
|
|
plt.plot(data['iterations'], data['total_loss'],
|
|
linewidth=0.8, alpha=0.6, color='blue')
|
|
plt.xlabel('Iteration', fontsize=12)
|
|
plt.ylabel('Total Loss', fontsize=12)
|
|
plt.title('Training Total Loss', fontsize=14, fontweight='bold')
|
|
plt.grid(True, alpha=0.3)
|
|
|
|
# 2. 分任务Loss
|
|
plt.subplot(1, 2, 2)
|
|
if len(data['map_loss']) > 0:
|
|
plt.plot(data['iterations'][:len(data['map_loss'])],
|
|
data['map_loss'],
|
|
label='Map Loss (Seg)', linewidth=1, alpha=0.7, color='green')
|
|
if len(data['object_loss']) > 0:
|
|
plt.plot(data['iterations'][:len(data['object_loss'])],
|
|
data['object_loss'],
|
|
label='Object Loss (Det)', linewidth=1, alpha=0.7, color='red')
|
|
plt.xlabel('Iteration', fontsize=12)
|
|
plt.ylabel('Loss', fontsize=12)
|
|
plt.title('Task-wise Loss', fontsize=14, fontweight='bold')
|
|
plt.legend(fontsize=10)
|
|
plt.grid(True, alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(f'{output_dir}/training_loss_curves.png', dpi=150, bbox_inches='tight')
|
|
print(f"✅ Loss曲线已保存: {output_dir}/training_loss_curves.png")
|
|
|
|
# 3. 各Epoch平均Loss
|
|
epoch_avg_loss = defaultdict(list)
|
|
for i, (ep, loss) in enumerate(zip(data['epochs'], data['total_loss'])):
|
|
epoch_avg_loss[ep].append(loss)
|
|
|
|
epochs = sorted(epoch_avg_loss.keys())
|
|
avg_losses = [np.mean(epoch_avg_loss[ep]) for ep in epochs]
|
|
min_losses = [np.min(epoch_avg_loss[ep]) for ep in epochs]
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(epochs, avg_losses, 'o-', label='Average Loss', linewidth=2, markersize=8)
|
|
plt.plot(epochs, min_losses, 's--', label='Min Loss', linewidth=2, markersize=6, alpha=0.7)
|
|
plt.xlabel('Epoch', fontsize=12)
|
|
plt.ylabel('Loss', fontsize=12)
|
|
plt.title('Loss per Epoch', fontsize=14, fontweight='bold')
|
|
plt.legend(fontsize=11)
|
|
plt.grid(True, alpha=0.3)
|
|
plt.savefig(f'{output_dir}/epoch_loss_trend.png', dpi=150, bbox_inches='tight')
|
|
print(f"✅ Epoch趋势已保存: {output_dir}/epoch_loss_trend.png")
|
|
|
|
if __name__ == '__main__':
|
|
print("="*80)
|
|
print("📈 绘制训练曲线")
|
|
print("="*80)
|
|
|
|
data = parse_training_log()
|
|
|
|
print(f"\n📊 解析统计:")
|
|
print(f" 总iterations: {len(data['iterations'])}")
|
|
print(f" 训练epochs: {len(set(data['epochs']))}")
|
|
print(f" Loss数据点: {len(data['total_loss'])}")
|
|
|
|
if len(data['total_loss']) > 0:
|
|
print(f"\n 起始Loss: {data['total_loss'][0]:.4f}")
|
|
print(f" 当前Loss: {data['total_loss'][-1]:.4f}")
|
|
print(f" 下降幅度: {(data['total_loss'][0] - data['total_loss'][-1])/data['total_loss'][0]*100:.1f}%")
|
|
|
|
print("\n绘制中...")
|
|
plot_curves(data)
|
|
|
|
print("\n✅ 完成!")
|
|
print("="*80)
|
|
|
|
|