#!/usr/bin/env python3 """ 绘制训练Loss曲线和性能趋势 """ import re import matplotlib.pyplot as plt import numpy as np from collections import defaultdict def parse_training_log(log_file='enhanced_training_6gpus.log'): """解析训练日志""" data = defaultdict(list) with open(log_file, 'r') as f: lines = f.readlines() for line in lines: if 'Epoch [' in line and 'loss:' in line: # 提取iteration iter_match = re.search(r'Epoch \[(\d+)\]\[(\d+)/\d+\]', line) if iter_match: epoch = int(iter_match.group(1)) iter_in_epoch = int(iter_match.group(2)) global_iter = (epoch - 1) * 10299 + iter_in_epoch data['iterations'].append(global_iter) data['epochs'].append(epoch) # 提取total loss loss_match = re.search(r'loss: ([\d.]+)', line) if loss_match: data['total_loss'].append(float(loss_match.group(1))) # 提取分割loss map_losses = [] for category in ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider']: match = re.search(f'loss/map/{category}/focal: ([\\d.]+)', line) if match: map_losses.append(float(match.group(1))) if map_losses: data['map_loss'].append(np.mean(map_losses)) # 提取检测loss obj_match = re.search(r'loss/object/loss_heatmap: ([\d.]+)', line) if obj_match: data['object_loss'].append(float(obj_match.group(1))) return data def plot_curves(data, output_dir='visualizations'): """绘制训练曲线""" import os os.makedirs(output_dir, exist_ok=True) # 1. Total Loss曲线 plt.figure(figsize=(14, 6)) plt.subplot(1, 2, 1) plt.plot(data['iterations'], data['total_loss'], linewidth=0.8, alpha=0.6, color='blue') plt.xlabel('Iteration', fontsize=12) plt.ylabel('Total Loss', fontsize=12) plt.title('Training Total Loss', fontsize=14, fontweight='bold') plt.grid(True, alpha=0.3) # 2. 分任务Loss plt.subplot(1, 2, 2) if len(data['map_loss']) > 0: plt.plot(data['iterations'][:len(data['map_loss'])], data['map_loss'], label='Map Loss (Seg)', linewidth=1, alpha=0.7, color='green') if len(data['object_loss']) > 0: plt.plot(data['iterations'][:len(data['object_loss'])], data['object_loss'], label='Object Loss (Det)', linewidth=1, alpha=0.7, color='red') plt.xlabel('Iteration', fontsize=12) plt.ylabel('Loss', fontsize=12) plt.title('Task-wise Loss', fontsize=14, fontweight='bold') plt.legend(fontsize=10) plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(f'{output_dir}/training_loss_curves.png', dpi=150, bbox_inches='tight') print(f"✅ Loss曲线已保存: {output_dir}/training_loss_curves.png") # 3. 各Epoch平均Loss epoch_avg_loss = defaultdict(list) for i, (ep, loss) in enumerate(zip(data['epochs'], data['total_loss'])): epoch_avg_loss[ep].append(loss) epochs = sorted(epoch_avg_loss.keys()) avg_losses = [np.mean(epoch_avg_loss[ep]) for ep in epochs] min_losses = [np.min(epoch_avg_loss[ep]) for ep in epochs] plt.figure(figsize=(10, 6)) plt.plot(epochs, avg_losses, 'o-', label='Average Loss', linewidth=2, markersize=8) plt.plot(epochs, min_losses, 's--', label='Min Loss', linewidth=2, markersize=6, alpha=0.7) plt.xlabel('Epoch', fontsize=12) plt.ylabel('Loss', fontsize=12) plt.title('Loss per Epoch', fontsize=14, fontweight='bold') plt.legend(fontsize=11) plt.grid(True, alpha=0.3) plt.savefig(f'{output_dir}/epoch_loss_trend.png', dpi=150, bbox_inches='tight') print(f"✅ Epoch趋势已保存: {output_dir}/epoch_loss_trend.png") if __name__ == '__main__': print("="*80) print("📈 绘制训练曲线") print("="*80) data = parse_training_log() print(f"\n📊 解析统计:") print(f" 总iterations: {len(data['iterations'])}") print(f" 训练epochs: {len(set(data['epochs']))}") print(f" Loss数据点: {len(data['total_loss'])}") if len(data['total_loss']) > 0: print(f"\n 起始Loss: {data['total_loss'][0]:.4f}") print(f" 当前Loss: {data['total_loss'][-1]:.4f}") print(f" 下降幅度: {(data['total_loss'][0] - data['total_loss'][-1])/data['total_loss'][0]*100:.1f}%") print("\n绘制中...") plot_curves(data) print("\n✅ 完成!") print("="*80)