#!/usr/bin/env python3 """ 简化版模型复杂度分析 不依赖thop,直接统计参数 """ import torch import sys sys.path.insert(0, '/workspace/bevfusion') def count_parameters(model): """统计模型参数""" total_params = 0 trainable_params = 0 for param in model.parameters(): total_params += param.numel() if param.requires_grad: trainable_params += param.numel() return total_params, trainable_params def analyze_by_module(model): """按模块统计参数""" module_stats = {} for name, module in model.named_children(): params = sum(p.numel() for p in module.parameters()) module_stats[name] = params return module_stats def main(config_file, checkpoint_file): """主函数""" from mmcv import Config from mmdet3d.models import build_model print("=" * 80) print("BEVFusion模型复杂度分析") print("=" * 80) print(f"配置文件: {config_file}") print(f"Checkpoint: {checkpoint_file}") print() # 加载配置 print("加载配置...") cfg = Config.fromfile(config_file) # 构建模型 print("构建模型...") model = build_model(cfg.model) # 加载权重 if checkpoint_file: print(f"加载checkpoint...") checkpoint = torch.load(checkpoint_file, map_location='cpu') if 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict'], strict=False) else: model.load_state_dict(checkpoint, strict=False) model.eval() # 统计总参数 print("\n" + "=" * 80) print("总体统计") print("=" * 80) total_params, trainable_params = count_parameters(model) print(f"总参数量: {total_params:,} ({total_params/1e6:.2f}M)") print(f"可训练参数: {trainable_params:,} ({trainable_params/1e6:.2f}M)") print(f"模型大小(FP32): {total_params * 4 / 1024 / 1024:.2f} MB") print(f"模型大小(FP16): {total_params * 2 / 1024 / 1024:.2f} MB") print(f"模型大小(INT8): {total_params * 1 / 1024 / 1024:.2f} MB") # 按模块统计 print("\n" + "=" * 80) print("各模块参数统计") print("=" * 80) module_stats = analyze_by_module(model) # 排序并显示 sorted_modules = sorted(module_stats.items(), key=lambda x: x[1], reverse=True) for name, params in sorted_modules: percentage = params / total_params * 100 print(f"{name:30s}: {params:12,} ({params/1e6:6.2f}M, {percentage:5.2f}%)") # 详细分析编码器 print("\n" + "=" * 80) print("Encoders详细分析") print("=" * 80) if hasattr(model, 'encoders'): for enc_name in ['camera', 'lidar', 'radar']: if enc_name in model.encoders: encoder = model.encoders[enc_name] enc_params = sum(p.numel() for p in encoder.parameters()) print(f"\n{enc_name.upper()} Encoder: {enc_params:,} ({enc_params/1e6:.2f}M)") for name, module in encoder.named_children(): params = sum(p.numel() for p in module.parameters()) if params > 0: print(f" {name:28s}: {params:12,} ({params/1e6:6.2f}M)") # 详细分析heads print("\n" + "=" * 80) print("Heads详细分析") print("=" * 80) if hasattr(model, 'heads'): for head_name, head in model.heads.items(): head_params = sum(p.numel() for p in head.parameters()) print(f"\n{head_name.upper()} Head: {head_params:,} ({head_params/1e6:.2f}M)") for name, module in head.named_children(): params = sum(p.numel() for p in module.parameters()) if params > 0: print(f" {name:28s}: {params:12,} ({params/1e6:6.2f}M)") # 优化建议 print("\n" + "=" * 80) print("优化建议") print("=" * 80) # 找出最大的模块 largest_module = max(sorted_modules, key=lambda x: x[1]) print(f"\n1. 最大模块: {largest_module[0]} ({largest_module[1]/1e6:.2f}M, {largest_module[1]/total_params*100:.1f}%)") print(f" 建议: 优先剪枝此模块,预期可减少30-40%参数") # 剪枝潜力估算 print(f"\n2. 剪枝潜力估算:") print(f" 保守剪枝(20%): {total_params*0.8/1e6:.2f}M参数, {total_params*0.8*4/1024/1024:.2f}MB") print(f" 中等剪枝(40%): {total_params*0.6/1e6:.2f}M参数, {total_params*0.6*4/1024/1024:.2f}MB") print(f" 激进剪枝(60%): {total_params*0.4/1e6:.2f}M参数, {total_params*0.4*4/1024/1024:.2f}MB") # 量化收益 print(f"\n3. 量化收益:") print(f" FP32→FP16: {total_params*4/1024/1024:.2f}MB → {total_params*2/1024/1024:.2f}MB (-50%)") print(f" FP32→INT8: {total_params*4/1024/1024:.2f}MB → {total_params*1/1024/1024:.2f}MB (-75%)") # 推荐优化路线 print(f"\n4. 推荐优化路线:") target_params = total_params * 0.6 # 40%剪枝 print(f" Step 1: 剪枝40% → {target_params/1e6:.2f}M参数") print(f" Step 2: INT8量化 → {target_params*1/1024/1024:.2f}MB模型") print(f" 预期速度提升: 2-3倍") print(f" 预期精度损失: <3%") print("\n" + "=" * 80) print("分析完成!") print("=" * 80) if __name__ == '__main__': if len(sys.argv) < 3: print("用法: python model_complexity_simple.py ") print("\n示例:") print(" python tools/analysis/model_complexity_simple.py \\") print(" configs/.../multitask_enhanced_phase1_HIGHRES.yaml \\") print(" runs/enhanced_from_epoch19/epoch_23.pth") sys.exit(1) config_file = sys.argv[1] checkpoint_file = sys.argv[2] main(config_file, checkpoint_file)