bev-project/tools/pruning/prune_bevfusion_builtin.py

222 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
使用PyTorch内置剪枝功能剪枝BEVFusion模型
适用于PyTorch 1.10+
"""
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import sys
import copy
from collections import OrderedDict
sys.path.insert(0, '/workspace/bevfusion')
def load_checkpoint(checkpoint_path):
"""加载checkpoint"""
print(f"加载checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
epoch = checkpoint.get('epoch', 'unknown')
print(f" Epoch: {epoch}")
else:
state_dict = checkpoint
return state_dict, checkpoint
def count_parameters(state_dict):
"""统计参数量"""
total = sum(v.numel() for k, v in state_dict.items() if isinstance(v, torch.Tensor))
return total
def prune_weights_l1(state_dict, module_prefix, prune_ratio):
"""
对指定模块进行L1范数剪枝
Args:
state_dict: 模型权重字典
module_prefix: 模块前缀,如 'encoders.camera.backbone'
prune_ratio: 剪枝比例,如 0.3 表示剪掉30%
"""
pruned_keys = []
for key in state_dict.keys():
if key.startswith(module_prefix) and 'weight' in key:
# 只剪枝Conv和Linear层
if len(state_dict[key].shape) in [2, 4]: # Linear或Conv
original_shape = state_dict[key].shape
original_params = state_dict[key].numel()
# 计算L1范数
weight = state_dict[key]
if len(weight.shape) == 4: # Conv2d: (out, in, h, w)
# 按输出通道剪枝
l1_norm = weight.abs().sum(dim=[1, 2, 3])
num_channels = weight.shape[0]
num_keep = int(num_channels * (1 - prune_ratio))
# 保留L1范数最大的通道
_, indices = torch.topk(l1_norm, num_keep)
indices = indices.sort()[0]
# 剪枝
state_dict[key] = weight[indices]
# 同时剪枝对应的bias
bias_key = key.replace('weight', 'bias')
if bias_key in state_dict:
state_dict[bias_key] = state_dict[bias_key][indices]
pruned_params = state_dict[key].numel()
pruned_keys.append((key, original_params, pruned_params))
elif len(weight.shape) == 2: # Linear: (out, in)
# 按输出维度剪枝
l1_norm = weight.abs().sum(dim=1)
num_features = weight.shape[0]
num_keep = int(num_features * (1 - prune_ratio))
_, indices = torch.topk(l1_norm, num_keep)
indices = indices.sort()[0]
state_dict[key] = weight[indices]
bias_key = key.replace('weight', 'bias')
if bias_key in state_dict:
state_dict[bias_key] = state_dict[bias_key][indices]
pruned_params = state_dict[key].numel()
pruned_keys.append((key, original_params, pruned_params))
return pruned_keys
def smart_prune(state_dict, target_params_ratio=0.7):
"""
智能剪枝:根据模块大小自动确定剪枝比例
Args:
state_dict: 模型权重
target_params_ratio: 目标参数保留比例0.7表示保留70%
"""
print("=" * 80)
print("开始智能剪枝")
print("=" * 80)
original_params = count_parameters(state_dict)
target_params = int(original_params * target_params_ratio)
print(f"原始参数量: {original_params/1e6:.2f}M")
print(f"目标参数量: {target_params/1e6:.2f}M")
print(f"需要剪枝: {(1-target_params_ratio)*100:.1f}%")
print()
# 剪枝策略:按模块大小分配不同的剪枝比例
pruning_plan = {
'encoders.camera.backbone': 0.20, # Camera backbone剪20%(最大模块)
'heads.map.aspp': 0.25, # ASPP剪25%
'decoder': 0.15, # Decoder剪15%
'encoders.camera.vtransform': 0.10, # VTransform剪10%
}
print("剪枝计划:")
for module, ratio in pruning_plan.items():
print(f" {module:40s}: 剪枝 {ratio*100:.0f}%")
print()
# 执行剪枝
all_pruned_keys = []
for module_prefix, prune_ratio in pruning_plan.items():
print(f"\n剪枝模块: {module_prefix} (比例: {prune_ratio*100:.0f}%)")
print("-" * 60)
pruned_keys = prune_weights_l1(state_dict, module_prefix, prune_ratio)
for key, orig, pruned in pruned_keys:
reduction = (1 - pruned/orig) * 100
print(f" {key:50s}: {orig:10,}{pruned:10,} (-{reduction:.1f}%)")
all_pruned_keys.extend(pruned_keys)
# 统计最终结果
final_params = count_parameters(state_dict)
actual_reduction = (1 - final_params/original_params) * 100
print("\n" + "=" * 80)
print("剪枝结果")
print("=" * 80)
print(f"原始参数量: {original_params/1e6:.2f}M")
print(f"剪枝后参数量: {final_params/1e6:.2f}M")
print(f"实际减少: {actual_reduction:.1f}%")
print(f"模型大小: {final_params*4/1024/1024:.2f}MB (FP32)")
print()
return state_dict, final_params
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='剪枝BEVFusion模型')
parser.add_argument('--checkpoint', required=True, help='输入checkpoint路径')
parser.add_argument('--output', required=True, help='输出checkpoint路径')
parser.add_argument('--target-ratio', type=float, default=0.70,
help='目标参数保留比例 (0.7=保留70%)')
args = parser.parse_args()
print("=" * 80)
print("BEVFusion模型剪枝")
print("=" * 80)
print(f"输入: {args.checkpoint}")
print(f"输出: {args.output}")
print(f"目标保留: {args.target_ratio*100:.0f}%")
print()
# 加载checkpoint
state_dict, checkpoint = load_checkpoint(args.checkpoint)
print(f"原始参数量: {count_parameters(state_dict)/1e6:.2f}M")
print()
# 执行剪枝
pruned_state_dict, final_params = smart_prune(state_dict, args.target_ratio)
# 保存剪枝后的模型
print("保存剪枝后的模型...")
if 'state_dict' in checkpoint:
checkpoint['state_dict'] = pruned_state_dict
# 清空optimizer状态剪枝后需要重新初始化
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
else:
checkpoint = pruned_state_dict
torch.save(checkpoint, args.output)
print(f"\n✅ 剪枝完成!")
print(f" 输出文件: {args.output}")
print(f" 最终参数量: {final_params/1e6:.2f}M")
print(f" 模型大小: {final_params*4/1024/1024:.2f}MB (FP32)")
print()
print("=" * 80)
print("下一步: 微调训练")
print("=" * 80)
print()
print("命令:")
print(" torchpack dist-run -np 8 python tools/train.py \\")
print(" configs/.../multitask_enhanced_phase1_HIGHRES.yaml \\")
print(f" --load_from {args.output} \\")
print(" --cfg-options max_epochs=3 optimizer.lr=5.0e-6")
print()
if __name__ == '__main__':
main()