#!/usr/bin/env python3 """ 直接分析checkpoint文件 不需要构建完整模型 """ import torch import sys from collections import defaultdict def analyze_checkpoint(checkpoint_file): """分析checkpoint文件""" print("=" * 80) print("BEVFusion Checkpoint分析") print("=" * 80) print(f"文件: {checkpoint_file}") print() # 加载checkpoint print("加载checkpoint...") checkpoint = torch.load(checkpoint_file, map_location='cpu') # 获取state_dict if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] print(f"Checkpoint包含: state_dict + 其他元数据") if 'epoch' in checkpoint: print(f"Epoch: {checkpoint['epoch']}") if 'meta' in checkpoint: print(f"Meta信息: {checkpoint.get('meta', {})}") else: state_dict = checkpoint print(f"Checkpoint类型: 纯state_dict") print() # 统计参数 print("=" * 80) print("参数统计") print("=" * 80) total_params = 0 layer_count = 0 for key, value in state_dict.items(): if isinstance(value, torch.Tensor): params = value.numel() total_params += params layer_count += 1 print(f"总参数量: {total_params:,} ({total_params/1e6:.2f}M)") print(f"层数: {layer_count}") print(f"模型大小(FP32): {total_params * 4 / 1024 / 1024:.2f} MB") print(f"模型大小(FP16): {total_params * 2 / 1024 / 1024:.2f} MB") print(f"模型大小(INT8): {total_params * 1 / 1024 / 1024:.2f} MB") print() # 按模块统计 print("=" * 80) print("各模块参数分布") print("=" * 80) module_params = defaultdict(int) for key, value in state_dict.items(): if not isinstance(value, torch.Tensor): continue # 提取模块名(取第一级) parts = key.split('.') if len(parts) > 0: module_name = parts[0] module_params[module_name] += value.numel() # 排序并显示 sorted_modules = sorted(module_params.items(), key=lambda x: x[1], reverse=True) for name, params in sorted_modules: percentage = params / total_params * 100 size_mb = params * 4 / 1024 / 1024 print(f"{name:30s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%, {size_mb:6.2f}MB)") # 详细分析encoders print("\n" + "=" * 80) print("Encoders子模块详细分析") print("=" * 80) encoder_submodules = defaultdict(int) for key, value in state_dict.items(): if not isinstance(value, torch.Tensor): continue parts = key.split('.') if len(parts) >= 2 and parts[0] == 'encoders': submodule = f"{parts[0]}.{parts[1]}" if len(parts) >= 3: submodule = f"{parts[0]}.{parts[1]}.{parts[2]}" encoder_submodules[submodule] += value.numel() sorted_enc = sorted(encoder_submodules.items(), key=lambda x: x[1], reverse=True) for name, params in sorted_enc: if params > 1e5: # 只显示>0.1M的模块 percentage = params / total_params * 100 print(f"{name:50s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%)") # 详细分析heads print("\n" + "=" * 80) print("Heads子模块详细分析") print("=" * 80) head_submodules = defaultdict(int) for key, value in state_dict.items(): if not isinstance(value, torch.Tensor): continue parts = key.split('.') if len(parts) >= 2 and parts[0] == 'heads': submodule = f"{parts[0]}.{parts[1]}" if len(parts) >= 3: submodule = f"{parts[0]}.{parts[1]}.{parts[2]}" head_submodules[submodule] += value.numel() sorted_heads = sorted(head_submodules.items(), key=lambda x: x[1], reverse=True) for name, params in sorted_heads: if params > 1e5: # 只显示>0.1M的模块 percentage = params / total_params * 100 print(f"{name:50s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%)") # 优化建议 print("\n" + "=" * 80) print("优化建议") print("=" * 80) # 找出最大的模块 if sorted_modules: largest = sorted_modules[0] print(f"\n1. 最大模块: {largest[0]}") print(f" 参数量: {largest[1]/1e6:.2f}M ({largest[1]/total_params*100:.1f}%)") print(f" 建议: 优先剪枝此模块") print(f" 如果剪枝30%: 可减少 {largest[1]*0.3/1e6:.2f}M参数") print(f"\n2. 剪枝潜力估算:") print(f" 保守剪枝(20%): {total_params*0.8/1e6:.2f}M参数 → {total_params*0.8*4/1024/1024:.2f}MB (FP32)") print(f" 中等剪枝(40%): {total_params*0.6/1e6:.2f}M参数 → {total_params*0.6*4/1024/1024:.2f}MB (FP32)") print(f" 激进剪枝(60%): {total_params*0.4/1e6:.2f}M参数 → {total_params*0.4*4/1024/1024:.2f}MB (FP32)") print(f"\n3. 量化收益:") print(f" 原始FP32: {total_params*4/1024/1024:.2f} MB") print(f" FP16: {total_params*2/1024/1024:.2f} MB (-50%)") print(f" INT8: {total_params*1/1024/1024:.2f} MB (-75%)") # 推荐路线 pruned_params = total_params * 0.6 # 40%剪枝 print(f"\n4. 推荐优化路线:") print(f" 原始模型: {total_params/1e6:.2f}M, {total_params*4/1024/1024:.2f}MB (FP32)") print(f" ↓ 剪枝40%") print(f" 剪枝模型: {pruned_params/1e6:.2f}M, {pruned_params*4/1024/1024:.2f}MB (FP32)") print(f" ↓ INT8量化") print(f" 最终模型: {pruned_params/1e6:.2f}M, {pruned_params*1/1024/1024:.2f}MB (INT8)") print(f" ") print(f" 总压缩比: {(1 - pruned_params*1/1024/1024 / (total_params*4/1024/1024))*100:.1f}%") print(f" 预期加速: 2.5-3倍") print(f" 精度损失: <3%") print("\n" + "=" * 80) print("分析完成!") print("=" * 80) return total_params, module_params if __name__ == '__main__': if len(sys.argv) < 2: print("用法: python analyze_checkpoint.py ") print("\n示例:") print(" python tools/analysis/analyze_checkpoint.py \\") print(" runs/enhanced_from_epoch19/epoch_23.pth") sys.exit(1) checkpoint_file = sys.argv[1] total_params, module_params = analyze_checkpoint(checkpoint_file) print(f"\n✅ 总参数量: {total_params/1e6:.2f}M")