bev-project/tools/analysis/model_complexity_simple.py

170 lines
5.8 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
简化版模型复杂度分析
不依赖thop直接统计参数
"""
import torch
import sys
sys.path.insert(0, '/workspace/bevfusion')
def count_parameters(model):
"""统计模型参数"""
total_params = 0
trainable_params = 0
for param in model.parameters():
total_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return total_params, trainable_params
def analyze_by_module(model):
"""按模块统计参数"""
module_stats = {}
for name, module in model.named_children():
params = sum(p.numel() for p in module.parameters())
module_stats[name] = params
return module_stats
def main(config_file, checkpoint_file):
"""主函数"""
from mmcv import Config
from mmdet3d.models import build_model
print("=" * 80)
print("BEVFusion模型复杂度分析")
print("=" * 80)
print(f"配置文件: {config_file}")
print(f"Checkpoint: {checkpoint_file}")
print()
# 加载配置
print("加载配置...")
cfg = Config.fromfile(config_file)
# 构建模型
print("构建模型...")
model = build_model(cfg.model)
# 加载权重
if checkpoint_file:
print(f"加载checkpoint...")
checkpoint = torch.load(checkpoint_file, map_location='cpu')
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
model.eval()
# 统计总参数
print("\n" + "=" * 80)
print("总体统计")
print("=" * 80)
total_params, trainable_params = count_parameters(model)
print(f"总参数量: {total_params:,} ({total_params/1e6:.2f}M)")
print(f"可训练参数: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
print(f"模型大小(FP32): {total_params * 4 / 1024 / 1024:.2f} MB")
print(f"模型大小(FP16): {total_params * 2 / 1024 / 1024:.2f} MB")
print(f"模型大小(INT8): {total_params * 1 / 1024 / 1024:.2f} MB")
# 按模块统计
print("\n" + "=" * 80)
print("各模块参数统计")
print("=" * 80)
module_stats = analyze_by_module(model)
# 排序并显示
sorted_modules = sorted(module_stats.items(), key=lambda x: x[1], reverse=True)
for name, params in sorted_modules:
percentage = params / total_params * 100
print(f"{name:30s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%)")
# 详细分析编码器
print("\n" + "=" * 80)
print("Encoders详细分析")
print("=" * 80)
if hasattr(model, 'encoders'):
for enc_name in ['camera', 'lidar', 'radar']:
if enc_name in model.encoders:
encoder = model.encoders[enc_name]
enc_params = sum(p.numel() for p in encoder.parameters())
print(f"\n{enc_name.upper()} Encoder: {enc_params:,} ({enc_params/1e6:.2f}M)")
for name, module in encoder.named_children():
params = sum(p.numel() for p in module.parameters())
if params > 0:
print(f" {name:28s}: {params:12,} ({params/1e6:6.2f}M)")
# 详细分析heads
print("\n" + "=" * 80)
print("Heads详细分析")
print("=" * 80)
if hasattr(model, 'heads'):
for head_name, head in model.heads.items():
head_params = sum(p.numel() for p in head.parameters())
print(f"\n{head_name.upper()} Head: {head_params:,} ({head_params/1e6:.2f}M)")
for name, module in head.named_children():
params = sum(p.numel() for p in module.parameters())
if params > 0:
print(f" {name:28s}: {params:12,} ({params/1e6:6.2f}M)")
# 优化建议
print("\n" + "=" * 80)
print("优化建议")
print("=" * 80)
# 找出最大的模块
largest_module = max(sorted_modules, key=lambda x: x[1])
print(f"\n1. 最大模块: {largest_module[0]} ({largest_module[1]/1e6:.2f}M, {largest_module[1]/total_params*100:.1f}%)")
print(f" 建议: 优先剪枝此模块预期可减少30-40%参数")
# 剪枝潜力估算
print(f"\n2. 剪枝潜力估算:")
print(f" 保守剪枝(20%): {total_params*0.8/1e6:.2f}M参数, {total_params*0.8*4/1024/1024:.2f}MB")
print(f" 中等剪枝(40%): {total_params*0.6/1e6:.2f}M参数, {total_params*0.6*4/1024/1024:.2f}MB")
print(f" 激进剪枝(60%): {total_params*0.4/1e6:.2f}M参数, {total_params*0.4*4/1024/1024:.2f}MB")
# 量化收益
print(f"\n3. 量化收益:")
print(f" FP32→FP16: {total_params*4/1024/1024:.2f}MB → {total_params*2/1024/1024:.2f}MB (-50%)")
print(f" FP32→INT8: {total_params*4/1024/1024:.2f}MB → {total_params*1/1024/1024:.2f}MB (-75%)")
# 推荐优化路线
print(f"\n4. 推荐优化路线:")
target_params = total_params * 0.6 # 40%剪枝
print(f" Step 1: 剪枝40% → {target_params/1e6:.2f}M参数")
print(f" Step 2: INT8量化 → {target_params*1/1024/1024:.2f}MB模型")
print(f" 预期速度提升: 2-3倍")
print(f" 预期精度损失: <3%")
print("\n" + "=" * 80)
print("分析完成!")
print("=" * 80)
if __name__ == '__main__':
if len(sys.argv) < 3:
print("用法: python model_complexity_simple.py <config> <checkpoint>")
print("\n示例:")
print(" python tools/analysis/model_complexity_simple.py \\")
print(" configs/.../multitask_enhanced_phase1_HIGHRES.yaml \\")
print(" runs/enhanced_from_epoch19/epoch_23.pth")
sys.exit(1)
config_file = sys.argv[1]
checkpoint_file = sys.argv[2]
main(config_file, checkpoint_file)