bev-project/tools/pruning/prune_bevfusion_builtin.py

222 lines
7.6 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
使用PyTorch内置剪枝功能剪枝BEVFusion模型
适用于PyTorch 1.10+
"""
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import sys
import copy
from collections import OrderedDict
sys.path.insert(0, '/workspace/bevfusion')
def load_checkpoint(checkpoint_path):
"""加载checkpoint"""
print(f"加载checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
epoch = checkpoint.get('epoch', 'unknown')
print(f" Epoch: {epoch}")
else:
state_dict = checkpoint
return state_dict, checkpoint
def count_parameters(state_dict):
"""统计参数量"""
total = sum(v.numel() for k, v in state_dict.items() if isinstance(v, torch.Tensor))
return total
def prune_weights_l1(state_dict, module_prefix, prune_ratio):
"""
对指定模块进行L1范数剪枝
Args:
state_dict: 模型权重字典
module_prefix: 模块前缀 'encoders.camera.backbone'
prune_ratio: 剪枝比例 0.3 表示剪掉30%
"""
pruned_keys = []
for key in state_dict.keys():
if key.startswith(module_prefix) and 'weight' in key:
# 只剪枝Conv和Linear层
if len(state_dict[key].shape) in [2, 4]: # Linear或Conv
original_shape = state_dict[key].shape
original_params = state_dict[key].numel()
# 计算L1范数
weight = state_dict[key]
if len(weight.shape) == 4: # Conv2d: (out, in, h, w)
# 按输出通道剪枝
l1_norm = weight.abs().sum(dim=[1, 2, 3])
num_channels = weight.shape[0]
num_keep = int(num_channels * (1 - prune_ratio))
# 保留L1范数最大的通道
_, indices = torch.topk(l1_norm, num_keep)
indices = indices.sort()[0]
# 剪枝
state_dict[key] = weight[indices]
# 同时剪枝对应的bias
bias_key = key.replace('weight', 'bias')
if bias_key in state_dict:
state_dict[bias_key] = state_dict[bias_key][indices]
pruned_params = state_dict[key].numel()
pruned_keys.append((key, original_params, pruned_params))
elif len(weight.shape) == 2: # Linear: (out, in)
# 按输出维度剪枝
l1_norm = weight.abs().sum(dim=1)
num_features = weight.shape[0]
num_keep = int(num_features * (1 - prune_ratio))
_, indices = torch.topk(l1_norm, num_keep)
indices = indices.sort()[0]
state_dict[key] = weight[indices]
bias_key = key.replace('weight', 'bias')
if bias_key in state_dict:
state_dict[bias_key] = state_dict[bias_key][indices]
pruned_params = state_dict[key].numel()
pruned_keys.append((key, original_params, pruned_params))
return pruned_keys
def smart_prune(state_dict, target_params_ratio=0.7):
"""
智能剪枝根据模块大小自动确定剪枝比例
Args:
state_dict: 模型权重
target_params_ratio: 目标参数保留比例0.7表示保留70%
"""
print("=" * 80)
print("开始智能剪枝")
print("=" * 80)
original_params = count_parameters(state_dict)
target_params = int(original_params * target_params_ratio)
print(f"原始参数量: {original_params/1e6:.2f}M")
print(f"目标参数量: {target_params/1e6:.2f}M")
print(f"需要剪枝: {(1-target_params_ratio)*100:.1f}%")
print()
# 剪枝策略:按模块大小分配不同的剪枝比例
pruning_plan = {
'encoders.camera.backbone': 0.20, # Camera backbone剪20%(最大模块)
'heads.map.aspp': 0.25, # ASPP剪25%
'decoder': 0.15, # Decoder剪15%
'encoders.camera.vtransform': 0.10, # VTransform剪10%
}
print("剪枝计划:")
for module, ratio in pruning_plan.items():
print(f" {module:40s}: 剪枝 {ratio*100:.0f}%")
print()
# 执行剪枝
all_pruned_keys = []
for module_prefix, prune_ratio in pruning_plan.items():
print(f"\n剪枝模块: {module_prefix} (比例: {prune_ratio*100:.0f}%)")
print("-" * 60)
pruned_keys = prune_weights_l1(state_dict, module_prefix, prune_ratio)
for key, orig, pruned in pruned_keys:
reduction = (1 - pruned/orig) * 100
print(f" {key:50s}: {orig:10,}{pruned:10,} (-{reduction:.1f}%)")
all_pruned_keys.extend(pruned_keys)
# 统计最终结果
final_params = count_parameters(state_dict)
actual_reduction = (1 - final_params/original_params) * 100
print("\n" + "=" * 80)
print("剪枝结果")
print("=" * 80)
print(f"原始参数量: {original_params/1e6:.2f}M")
print(f"剪枝后参数量: {final_params/1e6:.2f}M")
print(f"实际减少: {actual_reduction:.1f}%")
print(f"模型大小: {final_params*4/1024/1024:.2f}MB (FP32)")
print()
return state_dict, final_params
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='剪枝BEVFusion模型')
parser.add_argument('--checkpoint', required=True, help='输入checkpoint路径')
parser.add_argument('--output', required=True, help='输出checkpoint路径')
parser.add_argument('--target-ratio', type=float, default=0.70,
help='目标参数保留比例 (0.7=保留70%)')
args = parser.parse_args()
print("=" * 80)
print("BEVFusion模型剪枝")
print("=" * 80)
print(f"输入: {args.checkpoint}")
print(f"输出: {args.output}")
print(f"目标保留: {args.target_ratio*100:.0f}%")
print()
# 加载checkpoint
state_dict, checkpoint = load_checkpoint(args.checkpoint)
print(f"原始参数量: {count_parameters(state_dict)/1e6:.2f}M")
print()
# 执行剪枝
pruned_state_dict, final_params = smart_prune(state_dict, args.target_ratio)
# 保存剪枝后的模型
print("保存剪枝后的模型...")
if 'state_dict' in checkpoint:
checkpoint['state_dict'] = pruned_state_dict
# 清空optimizer状态剪枝后需要重新初始化
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
else:
checkpoint = pruned_state_dict
torch.save(checkpoint, args.output)
print(f"\n✅ 剪枝完成!")
print(f" 输出文件: {args.output}")
print(f" 最终参数量: {final_params/1e6:.2f}M")
print(f" 模型大小: {final_params*4/1024/1024:.2f}MB (FP32)")
print()
print("=" * 80)
print("下一步: 微调训练")
print("=" * 80)
print()
print("命令:")
print(" torchpack dist-run -np 8 python tools/train.py \\")
print(" configs/.../multitask_enhanced_phase1_HIGHRES.yaml \\")
print(f" --load_from {args.output} \\")
print(" --cfg-options max_epochs=3 optimizer.lr=5.0e-6")
print()
if __name__ == '__main__':
main()