188 lines
6.5 KiB
Python
Executable File
188 lines
6.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
直接分析checkpoint文件
|
|
不需要构建完整模型
|
|
"""
|
|
|
|
import torch
|
|
import sys
|
|
from collections import defaultdict
|
|
|
|
def analyze_checkpoint(checkpoint_file):
|
|
"""分析checkpoint文件"""
|
|
|
|
print("=" * 80)
|
|
print("BEVFusion Checkpoint分析")
|
|
print("=" * 80)
|
|
print(f"文件: {checkpoint_file}")
|
|
print()
|
|
|
|
# 加载checkpoint
|
|
print("加载checkpoint...")
|
|
checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
|
|
|
# 获取state_dict
|
|
if 'state_dict' in checkpoint:
|
|
state_dict = checkpoint['state_dict']
|
|
print(f"Checkpoint包含: state_dict + 其他元数据")
|
|
if 'epoch' in checkpoint:
|
|
print(f"Epoch: {checkpoint['epoch']}")
|
|
if 'meta' in checkpoint:
|
|
print(f"Meta信息: {checkpoint.get('meta', {})}")
|
|
else:
|
|
state_dict = checkpoint
|
|
print(f"Checkpoint类型: 纯state_dict")
|
|
|
|
print()
|
|
|
|
# 统计参数
|
|
print("=" * 80)
|
|
print("参数统计")
|
|
print("=" * 80)
|
|
|
|
total_params = 0
|
|
layer_count = 0
|
|
|
|
for key, value in state_dict.items():
|
|
if isinstance(value, torch.Tensor):
|
|
params = value.numel()
|
|
total_params += params
|
|
layer_count += 1
|
|
|
|
print(f"总参数量: {total_params:,} ({total_params/1e6:.2f}M)")
|
|
print(f"层数: {layer_count}")
|
|
print(f"模型大小(FP32): {total_params * 4 / 1024 / 1024:.2f} MB")
|
|
print(f"模型大小(FP16): {total_params * 2 / 1024 / 1024:.2f} MB")
|
|
print(f"模型大小(INT8): {total_params * 1 / 1024 / 1024:.2f} MB")
|
|
print()
|
|
|
|
# 按模块统计
|
|
print("=" * 80)
|
|
print("各模块参数分布")
|
|
print("=" * 80)
|
|
|
|
module_params = defaultdict(int)
|
|
|
|
for key, value in state_dict.items():
|
|
if not isinstance(value, torch.Tensor):
|
|
continue
|
|
|
|
# 提取模块名(取第一级)
|
|
parts = key.split('.')
|
|
if len(parts) > 0:
|
|
module_name = parts[0]
|
|
module_params[module_name] += value.numel()
|
|
|
|
# 排序并显示
|
|
sorted_modules = sorted(module_params.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
for name, params in sorted_modules:
|
|
percentage = params / total_params * 100
|
|
size_mb = params * 4 / 1024 / 1024
|
|
print(f"{name:30s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%, {size_mb:6.2f}MB)")
|
|
|
|
# 详细分析encoders
|
|
print("\n" + "=" * 80)
|
|
print("Encoders子模块详细分析")
|
|
print("=" * 80)
|
|
|
|
encoder_submodules = defaultdict(int)
|
|
|
|
for key, value in state_dict.items():
|
|
if not isinstance(value, torch.Tensor):
|
|
continue
|
|
|
|
parts = key.split('.')
|
|
if len(parts) >= 2 and parts[0] == 'encoders':
|
|
submodule = f"{parts[0]}.{parts[1]}"
|
|
if len(parts) >= 3:
|
|
submodule = f"{parts[0]}.{parts[1]}.{parts[2]}"
|
|
encoder_submodules[submodule] += value.numel()
|
|
|
|
sorted_enc = sorted(encoder_submodules.items(), key=lambda x: x[1], reverse=True)
|
|
for name, params in sorted_enc:
|
|
if params > 1e5: # 只显示>0.1M的模块
|
|
percentage = params / total_params * 100
|
|
print(f"{name:50s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%)")
|
|
|
|
# 详细分析heads
|
|
print("\n" + "=" * 80)
|
|
print("Heads子模块详细分析")
|
|
print("=" * 80)
|
|
|
|
head_submodules = defaultdict(int)
|
|
|
|
for key, value in state_dict.items():
|
|
if not isinstance(value, torch.Tensor):
|
|
continue
|
|
|
|
parts = key.split('.')
|
|
if len(parts) >= 2 and parts[0] == 'heads':
|
|
submodule = f"{parts[0]}.{parts[1]}"
|
|
if len(parts) >= 3:
|
|
submodule = f"{parts[0]}.{parts[1]}.{parts[2]}"
|
|
head_submodules[submodule] += value.numel()
|
|
|
|
sorted_heads = sorted(head_submodules.items(), key=lambda x: x[1], reverse=True)
|
|
for name, params in sorted_heads:
|
|
if params > 1e5: # 只显示>0.1M的模块
|
|
percentage = params / total_params * 100
|
|
print(f"{name:50s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%)")
|
|
|
|
# 优化建议
|
|
print("\n" + "=" * 80)
|
|
print("优化建议")
|
|
print("=" * 80)
|
|
|
|
# 找出最大的模块
|
|
if sorted_modules:
|
|
largest = sorted_modules[0]
|
|
print(f"\n1. 最大模块: {largest[0]}")
|
|
print(f" 参数量: {largest[1]/1e6:.2f}M ({largest[1]/total_params*100:.1f}%)")
|
|
print(f" 建议: 优先剪枝此模块")
|
|
print(f" 如果剪枝30%: 可减少 {largest[1]*0.3/1e6:.2f}M参数")
|
|
|
|
print(f"\n2. 剪枝潜力估算:")
|
|
print(f" 保守剪枝(20%): {total_params*0.8/1e6:.2f}M参数 → {total_params*0.8*4/1024/1024:.2f}MB (FP32)")
|
|
print(f" 中等剪枝(40%): {total_params*0.6/1e6:.2f}M参数 → {total_params*0.6*4/1024/1024:.2f}MB (FP32)")
|
|
print(f" 激进剪枝(60%): {total_params*0.4/1e6:.2f}M参数 → {total_params*0.4*4/1024/1024:.2f}MB (FP32)")
|
|
|
|
print(f"\n3. 量化收益:")
|
|
print(f" 原始FP32: {total_params*4/1024/1024:.2f} MB")
|
|
print(f" FP16: {total_params*2/1024/1024:.2f} MB (-50%)")
|
|
print(f" INT8: {total_params*1/1024/1024:.2f} MB (-75%)")
|
|
|
|
# 推荐路线
|
|
pruned_params = total_params * 0.6 # 40%剪枝
|
|
print(f"\n4. 推荐优化路线:")
|
|
print(f" 原始模型: {total_params/1e6:.2f}M, {total_params*4/1024/1024:.2f}MB (FP32)")
|
|
print(f" ↓ 剪枝40%")
|
|
print(f" 剪枝模型: {pruned_params/1e6:.2f}M, {pruned_params*4/1024/1024:.2f}MB (FP32)")
|
|
print(f" ↓ INT8量化")
|
|
print(f" 最终模型: {pruned_params/1e6:.2f}M, {pruned_params*1/1024/1024:.2f}MB (INT8)")
|
|
print(f" ")
|
|
print(f" 总压缩比: {(1 - pruned_params*1/1024/1024 / (total_params*4/1024/1024))*100:.1f}%")
|
|
print(f" 预期加速: 2.5-3倍")
|
|
print(f" 精度损失: <3%")
|
|
|
|
print("\n" + "=" * 80)
|
|
print("分析完成!")
|
|
print("=" * 80)
|
|
|
|
return total_params, module_params
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) < 2:
|
|
print("用法: python analyze_checkpoint.py <checkpoint_file>")
|
|
print("\n示例:")
|
|
print(" python tools/analysis/analyze_checkpoint.py \\")
|
|
print(" runs/enhanced_from_epoch19/epoch_23.pth")
|
|
sys.exit(1)
|
|
|
|
checkpoint_file = sys.argv[1]
|
|
|
|
total_params, module_params = analyze_checkpoint(checkpoint_file)
|
|
|
|
print(f"\n✅ 总参数量: {total_params/1e6:.2f}M")
|
|
|