#!/usr/bin/env python3 """ 使用PyTorch内置剪枝功能剪枝BEVFusion模型 适用于PyTorch 1.10+ """ import torch import torch.nn as nn import torch.nn.utils.prune as prune import sys import copy from collections import OrderedDict sys.path.insert(0, '/workspace/bevfusion') def load_checkpoint(checkpoint_path): """加载checkpoint""" print(f"加载checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] epoch = checkpoint.get('epoch', 'unknown') print(f" Epoch: {epoch}") else: state_dict = checkpoint return state_dict, checkpoint def count_parameters(state_dict): """统计参数量""" total = sum(v.numel() for k, v in state_dict.items() if isinstance(v, torch.Tensor)) return total def prune_weights_l1(state_dict, module_prefix, prune_ratio): """ 对指定模块进行L1范数剪枝 Args: state_dict: 模型权重字典 module_prefix: 模块前缀,如 'encoders.camera.backbone' prune_ratio: 剪枝比例,如 0.3 表示剪掉30% """ pruned_keys = [] for key in state_dict.keys(): if key.startswith(module_prefix) and 'weight' in key: # 只剪枝Conv和Linear层 if len(state_dict[key].shape) in [2, 4]: # Linear或Conv original_shape = state_dict[key].shape original_params = state_dict[key].numel() # 计算L1范数 weight = state_dict[key] if len(weight.shape) == 4: # Conv2d: (out, in, h, w) # 按输出通道剪枝 l1_norm = weight.abs().sum(dim=[1, 2, 3]) num_channels = weight.shape[0] num_keep = int(num_channels * (1 - prune_ratio)) # 保留L1范数最大的通道 _, indices = torch.topk(l1_norm, num_keep) indices = indices.sort()[0] # 剪枝 state_dict[key] = weight[indices] # 同时剪枝对应的bias bias_key = key.replace('weight', 'bias') if bias_key in state_dict: state_dict[bias_key] = state_dict[bias_key][indices] pruned_params = state_dict[key].numel() pruned_keys.append((key, original_params, pruned_params)) elif len(weight.shape) == 2: # Linear: (out, in) # 按输出维度剪枝 l1_norm = weight.abs().sum(dim=1) num_features = weight.shape[0] num_keep = int(num_features * (1 - prune_ratio)) _, indices = torch.topk(l1_norm, num_keep) indices = indices.sort()[0] state_dict[key] = weight[indices] bias_key = key.replace('weight', 'bias') if bias_key in state_dict: state_dict[bias_key] = state_dict[bias_key][indices] pruned_params = state_dict[key].numel() pruned_keys.append((key, original_params, pruned_params)) return pruned_keys def smart_prune(state_dict, target_params_ratio=0.7): """ 智能剪枝:根据模块大小自动确定剪枝比例 Args: state_dict: 模型权重 target_params_ratio: 目标参数保留比例,0.7表示保留70% """ print("=" * 80) print("开始智能剪枝") print("=" * 80) original_params = count_parameters(state_dict) target_params = int(original_params * target_params_ratio) print(f"原始参数量: {original_params/1e6:.2f}M") print(f"目标参数量: {target_params/1e6:.2f}M") print(f"需要剪枝: {(1-target_params_ratio)*100:.1f}%") print() # 剪枝策略:按模块大小分配不同的剪枝比例 pruning_plan = { 'encoders.camera.backbone': 0.20, # Camera backbone剪20%(最大模块) 'heads.map.aspp': 0.25, # ASPP剪25% 'decoder': 0.15, # Decoder剪15% 'encoders.camera.vtransform': 0.10, # VTransform剪10% } print("剪枝计划:") for module, ratio in pruning_plan.items(): print(f" {module:40s}: 剪枝 {ratio*100:.0f}%") print() # 执行剪枝 all_pruned_keys = [] for module_prefix, prune_ratio in pruning_plan.items(): print(f"\n剪枝模块: {module_prefix} (比例: {prune_ratio*100:.0f}%)") print("-" * 60) pruned_keys = prune_weights_l1(state_dict, module_prefix, prune_ratio) for key, orig, pruned in pruned_keys: reduction = (1 - pruned/orig) * 100 print(f" {key:50s}: {orig:10,} → {pruned:10,} (-{reduction:.1f}%)") all_pruned_keys.extend(pruned_keys) # 统计最终结果 final_params = count_parameters(state_dict) actual_reduction = (1 - final_params/original_params) * 100 print("\n" + "=" * 80) print("剪枝结果") print("=" * 80) print(f"原始参数量: {original_params/1e6:.2f}M") print(f"剪枝后参数量: {final_params/1e6:.2f}M") print(f"实际减少: {actual_reduction:.1f}%") print(f"模型大小: {final_params*4/1024/1024:.2f}MB (FP32)") print() return state_dict, final_params def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='剪枝BEVFusion模型') parser.add_argument('--checkpoint', required=True, help='输入checkpoint路径') parser.add_argument('--output', required=True, help='输出checkpoint路径') parser.add_argument('--target-ratio', type=float, default=0.70, help='目标参数保留比例 (0.7=保留70%)') args = parser.parse_args() print("=" * 80) print("BEVFusion模型剪枝") print("=" * 80) print(f"输入: {args.checkpoint}") print(f"输出: {args.output}") print(f"目标保留: {args.target_ratio*100:.0f}%") print() # 加载checkpoint state_dict, checkpoint = load_checkpoint(args.checkpoint) print(f"原始参数量: {count_parameters(state_dict)/1e6:.2f}M") print() # 执行剪枝 pruned_state_dict, final_params = smart_prune(state_dict, args.target_ratio) # 保存剪枝后的模型 print("保存剪枝后的模型...") if 'state_dict' in checkpoint: checkpoint['state_dict'] = pruned_state_dict # 清空optimizer状态(剪枝后需要重新初始化) if 'optimizer' in checkpoint: del checkpoint['optimizer'] else: checkpoint = pruned_state_dict torch.save(checkpoint, args.output) print(f"\n✅ 剪枝完成!") print(f" 输出文件: {args.output}") print(f" 最终参数量: {final_params/1e6:.2f}M") print(f" 模型大小: {final_params*4/1024/1024:.2f}MB (FP32)") print() print("=" * 80) print("下一步: 微调训练") print("=" * 80) print() print("命令:") print(" torchpack dist-run -np 8 python tools/train.py \\") print(" configs/.../multitask_enhanced_phase1_HIGHRES.yaml \\") print(f" --load_from {args.output} \\") print(" --cfg-options max_epochs=3 optimizer.lr=5.0e-6") print() if __name__ == '__main__': main()