170 lines
5.8 KiB
Python
170 lines
5.8 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
简化版模型复杂度分析
|
|||
|
|
不依赖thop,直接统计参数
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import sys
|
|||
|
|
sys.path.insert(0, '/workspace/bevfusion')
|
|||
|
|
|
|||
|
|
def count_parameters(model):
|
|||
|
|
"""统计模型参数"""
|
|||
|
|
total_params = 0
|
|||
|
|
trainable_params = 0
|
|||
|
|
|
|||
|
|
for param in model.parameters():
|
|||
|
|
total_params += param.numel()
|
|||
|
|
if param.requires_grad:
|
|||
|
|
trainable_params += param.numel()
|
|||
|
|
|
|||
|
|
return total_params, trainable_params
|
|||
|
|
|
|||
|
|
def analyze_by_module(model):
|
|||
|
|
"""按模块统计参数"""
|
|||
|
|
module_stats = {}
|
|||
|
|
|
|||
|
|
for name, module in model.named_children():
|
|||
|
|
params = sum(p.numel() for p in module.parameters())
|
|||
|
|
module_stats[name] = params
|
|||
|
|
|
|||
|
|
return module_stats
|
|||
|
|
|
|||
|
|
def main(config_file, checkpoint_file):
|
|||
|
|
"""主函数"""
|
|||
|
|
from mmcv import Config
|
|||
|
|
from mmdet3d.models import build_model
|
|||
|
|
|
|||
|
|
print("=" * 80)
|
|||
|
|
print("BEVFusion模型复杂度分析")
|
|||
|
|
print("=" * 80)
|
|||
|
|
print(f"配置文件: {config_file}")
|
|||
|
|
print(f"Checkpoint: {checkpoint_file}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 加载配置
|
|||
|
|
print("加载配置...")
|
|||
|
|
cfg = Config.fromfile(config_file)
|
|||
|
|
|
|||
|
|
# 构建模型
|
|||
|
|
print("构建模型...")
|
|||
|
|
model = build_model(cfg.model)
|
|||
|
|
|
|||
|
|
# 加载权重
|
|||
|
|
if checkpoint_file:
|
|||
|
|
print(f"加载checkpoint...")
|
|||
|
|
checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
|||
|
|
if 'state_dict' in checkpoint:
|
|||
|
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
|||
|
|
else:
|
|||
|
|
model.load_state_dict(checkpoint, strict=False)
|
|||
|
|
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
# 统计总参数
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("总体统计")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
total_params, trainable_params = count_parameters(model)
|
|||
|
|
|
|||
|
|
print(f"总参数量: {total_params:,} ({total_params/1e6:.2f}M)")
|
|||
|
|
print(f"可训练参数: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
|
|||
|
|
print(f"模型大小(FP32): {total_params * 4 / 1024 / 1024:.2f} MB")
|
|||
|
|
print(f"模型大小(FP16): {total_params * 2 / 1024 / 1024:.2f} MB")
|
|||
|
|
print(f"模型大小(INT8): {total_params * 1 / 1024 / 1024:.2f} MB")
|
|||
|
|
|
|||
|
|
# 按模块统计
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("各模块参数统计")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
module_stats = analyze_by_module(model)
|
|||
|
|
|
|||
|
|
# 排序并显示
|
|||
|
|
sorted_modules = sorted(module_stats.items(), key=lambda x: x[1], reverse=True)
|
|||
|
|
|
|||
|
|
for name, params in sorted_modules:
|
|||
|
|
percentage = params / total_params * 100
|
|||
|
|
print(f"{name:30s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%)")
|
|||
|
|
|
|||
|
|
# 详细分析编码器
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("Encoders详细分析")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
if hasattr(model, 'encoders'):
|
|||
|
|
for enc_name in ['camera', 'lidar', 'radar']:
|
|||
|
|
if enc_name in model.encoders:
|
|||
|
|
encoder = model.encoders[enc_name]
|
|||
|
|
enc_params = sum(p.numel() for p in encoder.parameters())
|
|||
|
|
print(f"\n{enc_name.upper()} Encoder: {enc_params:,} ({enc_params/1e6:.2f}M)")
|
|||
|
|
|
|||
|
|
for name, module in encoder.named_children():
|
|||
|
|
params = sum(p.numel() for p in module.parameters())
|
|||
|
|
if params > 0:
|
|||
|
|
print(f" {name:28s}: {params:12,} ({params/1e6:6.2f}M)")
|
|||
|
|
|
|||
|
|
# 详细分析heads
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("Heads详细分析")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
if hasattr(model, 'heads'):
|
|||
|
|
for head_name, head in model.heads.items():
|
|||
|
|
head_params = sum(p.numel() for p in head.parameters())
|
|||
|
|
print(f"\n{head_name.upper()} Head: {head_params:,} ({head_params/1e6:.2f}M)")
|
|||
|
|
|
|||
|
|
for name, module in head.named_children():
|
|||
|
|
params = sum(p.numel() for p in module.parameters())
|
|||
|
|
if params > 0:
|
|||
|
|
print(f" {name:28s}: {params:12,} ({params/1e6:6.2f}M)")
|
|||
|
|
|
|||
|
|
# 优化建议
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("优化建议")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
# 找出最大的模块
|
|||
|
|
largest_module = max(sorted_modules, key=lambda x: x[1])
|
|||
|
|
print(f"\n1. 最大模块: {largest_module[0]} ({largest_module[1]/1e6:.2f}M, {largest_module[1]/total_params*100:.1f}%)")
|
|||
|
|
print(f" 建议: 优先剪枝此模块,预期可减少30-40%参数")
|
|||
|
|
|
|||
|
|
# 剪枝潜力估算
|
|||
|
|
print(f"\n2. 剪枝潜力估算:")
|
|||
|
|
print(f" 保守剪枝(20%): {total_params*0.8/1e6:.2f}M参数, {total_params*0.8*4/1024/1024:.2f}MB")
|
|||
|
|
print(f" 中等剪枝(40%): {total_params*0.6/1e6:.2f}M参数, {total_params*0.6*4/1024/1024:.2f}MB")
|
|||
|
|
print(f" 激进剪枝(60%): {total_params*0.4/1e6:.2f}M参数, {total_params*0.4*4/1024/1024:.2f}MB")
|
|||
|
|
|
|||
|
|
# 量化收益
|
|||
|
|
print(f"\n3. 量化收益:")
|
|||
|
|
print(f" FP32→FP16: {total_params*4/1024/1024:.2f}MB → {total_params*2/1024/1024:.2f}MB (-50%)")
|
|||
|
|
print(f" FP32→INT8: {total_params*4/1024/1024:.2f}MB → {total_params*1/1024/1024:.2f}MB (-75%)")
|
|||
|
|
|
|||
|
|
# 推荐优化路线
|
|||
|
|
print(f"\n4. 推荐优化路线:")
|
|||
|
|
target_params = total_params * 0.6 # 40%剪枝
|
|||
|
|
print(f" Step 1: 剪枝40% → {target_params/1e6:.2f}M参数")
|
|||
|
|
print(f" Step 2: INT8量化 → {target_params*1/1024/1024:.2f}MB模型")
|
|||
|
|
print(f" 预期速度提升: 2-3倍")
|
|||
|
|
print(f" 预期精度损失: <3%")
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("分析完成!")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
if len(sys.argv) < 3:
|
|||
|
|
print("用法: python model_complexity_simple.py <config> <checkpoint>")
|
|||
|
|
print("\n示例:")
|
|||
|
|
print(" python tools/analysis/model_complexity_simple.py \\")
|
|||
|
|
print(" configs/.../multitask_enhanced_phase1_HIGHRES.yaml \\")
|
|||
|
|
print(" runs/enhanced_from_epoch19/epoch_23.pth")
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
config_file = sys.argv[1]
|
|||
|
|
checkpoint_file = sys.argv[2]
|
|||
|
|
|
|||
|
|
main(config_file, checkpoint_file)
|
|||
|
|
|