174 lines
5.0 KiB
Python
174 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
"""分析BEVFusion模型参数量和结构"""
|
|
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
def count_parameters(model):
|
|
"""统计模型参数"""
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
return total_params, trainable_params
|
|
|
|
def analyze_module_params(model, prefix=''):
|
|
"""分析各模块参数量"""
|
|
results = []
|
|
|
|
for name, module in model.named_children():
|
|
full_name = f"{prefix}.{name}" if prefix else name
|
|
|
|
# 统计当前模块参数
|
|
params = sum(p.numel() for p in module.parameters())
|
|
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
|
|
|
if params > 0:
|
|
results.append({
|
|
'name': full_name,
|
|
'total': params,
|
|
'trainable': trainable,
|
|
'module_type': type(module).__name__
|
|
})
|
|
|
|
# 递归子模块(只递归一层)
|
|
if len(list(module.children())) > 0:
|
|
sub_results = analyze_module_params(module, full_name)
|
|
results.extend(sub_results)
|
|
|
|
return results
|
|
|
|
def format_params(num):
|
|
"""格式化参数数量"""
|
|
if num >= 1e9:
|
|
return f"{num/1e9:.2f}B"
|
|
elif num >= 1e6:
|
|
return f"{num/1e6:.2f}M"
|
|
elif num >= 1e3:
|
|
return f"{num/1e3:.2f}K"
|
|
else:
|
|
return str(num)
|
|
|
|
if __name__ == "__main__":
|
|
print("="*90)
|
|
print(f"{'🔍 BEVFusion 模型参数分析':^90}")
|
|
print("="*90)
|
|
print()
|
|
|
|
# 加载checkpoint
|
|
checkpoint_path = "/workspace/bevfusion/runs/enhanced_from_epoch19/epoch_7.pth"
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
print(f"❌ Checkpoint不存在: {checkpoint_path}")
|
|
sys.exit(1)
|
|
|
|
print(f"📂 加载checkpoint: {checkpoint_path}")
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
if 'state_dict' in checkpoint:
|
|
state_dict = checkpoint['state_dict']
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
print(f"✅ Checkpoint加载成功")
|
|
print()
|
|
|
|
# 统计参数
|
|
print("━"*90)
|
|
print("📊 总体参数统计")
|
|
print("━"*90)
|
|
|
|
total_params = sum(p.numel() for p in state_dict.values())
|
|
print(f" 总参数量: {format_params(total_params)} ({total_params:,})")
|
|
|
|
# 按模块分组统计
|
|
module_stats = {}
|
|
for key, param in state_dict.items():
|
|
# 提取模块名(第一级)
|
|
parts = key.split('.')
|
|
if len(parts) > 0:
|
|
if parts[0] == 'model':
|
|
module = parts[1] if len(parts) > 1 else 'model'
|
|
else:
|
|
module = parts[0]
|
|
|
|
if module not in module_stats:
|
|
module_stats[module] = 0
|
|
module_stats[module] += param.numel()
|
|
|
|
print()
|
|
print("━"*90)
|
|
print("📊 各模块参数分布")
|
|
print("━"*90)
|
|
print()
|
|
print(f"{'模块名':<30} {'参数量':<15} {'占比':<10} {'可视化':<30}")
|
|
print("-"*90)
|
|
|
|
# 排序显示
|
|
sorted_modules = sorted(module_stats.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
for module, params in sorted_modules:
|
|
percentage = params / total_params * 100
|
|
bar_length = int(percentage / 2)
|
|
bar = "█" * bar_length
|
|
|
|
print(f"{module:<30} {format_params(params):<15} {percentage:>6.2f}% {bar}")
|
|
|
|
print()
|
|
print("━"*90)
|
|
print("🎯 剪枝潜力分析")
|
|
print("━"*90)
|
|
|
|
# 分析各模块的剪枝潜力
|
|
pruning_potential = []
|
|
|
|
for module, params in sorted_modules[:10]:
|
|
if 'encoder' in module.lower():
|
|
potential = "🟢 高 (30-50%)"
|
|
elif 'decoder' in module.lower():
|
|
potential = "🟡 中 (20-30%)"
|
|
elif 'head' in module.lower():
|
|
potential = "🟡 中 (15-25%)"
|
|
elif 'neck' in module.lower() or 'fpn' in module.lower():
|
|
potential = "🟢 高 (25-40%)"
|
|
else:
|
|
potential = "🔵 低 (10-20%)"
|
|
|
|
pruning_potential.append((module, format_params(params), percentage, potential))
|
|
|
|
print()
|
|
print(f"{'模块':<30} {'参数量':<15} {'占比':<10} {'剪枝潜力':<20}")
|
|
print("-"*90)
|
|
|
|
for module, params, pct, potential in pruning_potential:
|
|
print(f"{module:<30} {params:<15} {pct:>6.2f}% {potential}")
|
|
|
|
print()
|
|
print("━"*90)
|
|
print("💡 剪枝建议")
|
|
print("━"*90)
|
|
|
|
print(f"""
|
|
总参数量: {format_params(total_params)}
|
|
|
|
剪枝目标:
|
|
保守目标: 110M → 70M (-36%, 保持>99%精度)
|
|
激进目标: 110M → 50M (-55%, 保持>97%精度)
|
|
|
|
优先剪枝模块:
|
|
1. Camera Encoder (如果参数量大)
|
|
2. Neck/FPN层
|
|
3. Decoder
|
|
4. 保留Head (影响精度)
|
|
|
|
预期效果:
|
|
- 推理速度: 1.5-2x加速
|
|
- 显存占用: -30-40%
|
|
- 模型大小: 515MB → 200-280MB
|
|
""")
|
|
|
|
print("="*90)
|
|
print(f"✅ 分析完成! 建议将此报告保存用于剪枝策略制定")
|
|
print("="*90)
|
|
|
|
|