bev-project/tools/analyze_model_params.py

174 lines
5.0 KiB
Python

#!/usr/bin/env python3
"""分析BEVFusion模型参数量和结构"""
import torch
import sys
import os
def count_parameters(model):
"""统计模型参数"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params
def analyze_module_params(model, prefix=''):
"""分析各模块参数量"""
results = []
for name, module in model.named_children():
full_name = f"{prefix}.{name}" if prefix else name
# 统计当前模块参数
params = sum(p.numel() for p in module.parameters())
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
if params > 0:
results.append({
'name': full_name,
'total': params,
'trainable': trainable,
'module_type': type(module).__name__
})
# 递归子模块(只递归一层)
if len(list(module.children())) > 0:
sub_results = analyze_module_params(module, full_name)
results.extend(sub_results)
return results
def format_params(num):
"""格式化参数数量"""
if num >= 1e9:
return f"{num/1e9:.2f}B"
elif num >= 1e6:
return f"{num/1e6:.2f}M"
elif num >= 1e3:
return f"{num/1e3:.2f}K"
else:
return str(num)
if __name__ == "__main__":
print("="*90)
print(f"{'🔍 BEVFusion 模型参数分析':^90}")
print("="*90)
print()
# 加载checkpoint
checkpoint_path = "/workspace/bevfusion/runs/enhanced_from_epoch19/epoch_7.pth"
if not os.path.exists(checkpoint_path):
print(f"❌ Checkpoint不存在: {checkpoint_path}")
sys.exit(1)
print(f"📂 加载checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
print(f"✅ Checkpoint加载成功")
print()
# 统计参数
print(""*90)
print("📊 总体参数统计")
print(""*90)
total_params = sum(p.numel() for p in state_dict.values())
print(f" 总参数量: {format_params(total_params)} ({total_params:,})")
# 按模块分组统计
module_stats = {}
for key, param in state_dict.items():
# 提取模块名(第一级)
parts = key.split('.')
if len(parts) > 0:
if parts[0] == 'model':
module = parts[1] if len(parts) > 1 else 'model'
else:
module = parts[0]
if module not in module_stats:
module_stats[module] = 0
module_stats[module] += param.numel()
print()
print(""*90)
print("📊 各模块参数分布")
print(""*90)
print()
print(f"{'模块名':<30} {'参数量':<15} {'占比':<10} {'可视化':<30}")
print("-"*90)
# 排序显示
sorted_modules = sorted(module_stats.items(), key=lambda x: x[1], reverse=True)
for module, params in sorted_modules:
percentage = params / total_params * 100
bar_length = int(percentage / 2)
bar = "" * bar_length
print(f"{module:<30} {format_params(params):<15} {percentage:>6.2f}% {bar}")
print()
print(""*90)
print("🎯 剪枝潜力分析")
print(""*90)
# 分析各模块的剪枝潜力
pruning_potential = []
for module, params in sorted_modules[:10]:
if 'encoder' in module.lower():
potential = "🟢 高 (30-50%)"
elif 'decoder' in module.lower():
potential = "🟡 中 (20-30%)"
elif 'head' in module.lower():
potential = "🟡 中 (15-25%)"
elif 'neck' in module.lower() or 'fpn' in module.lower():
potential = "🟢 高 (25-40%)"
else:
potential = "🔵 低 (10-20%)"
pruning_potential.append((module, format_params(params), percentage, potential))
print()
print(f"{'模块':<30} {'参数量':<15} {'占比':<10} {'剪枝潜力':<20}")
print("-"*90)
for module, params, pct, potential in pruning_potential:
print(f"{module:<30} {params:<15} {pct:>6.2f}% {potential}")
print()
print(""*90)
print("💡 剪枝建议")
print(""*90)
print(f"""
总参数量: {format_params(total_params)}
剪枝目标:
保守目标: 110M → 70M (-36%, 保持>99%精度)
激进目标: 110M → 50M (-55%, 保持>97%精度)
优先剪枝模块:
1. Camera Encoder (如果参数量大)
2. Neck/FPN层
3. Decoder
4. 保留Head (影响精度)
预期效果:
- 推理速度: 1.5-2x加速
- 显存占用: -30-40%
- 模型大小: 515MB → 200-280MB
""")
print("="*90)
print(f"✅ 分析完成! 建议将此报告保存用于剪枝策略制定")
print("="*90)