222 lines
7.6 KiB
Python
222 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
使用PyTorch内置剪枝功能剪枝BEVFusion模型
|
||
适用于PyTorch 1.10+
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.utils.prune as prune
|
||
import sys
|
||
import copy
|
||
from collections import OrderedDict
|
||
|
||
sys.path.insert(0, '/workspace/bevfusion')
|
||
|
||
def load_checkpoint(checkpoint_path):
|
||
"""加载checkpoint"""
|
||
print(f"加载checkpoint: {checkpoint_path}")
|
||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||
|
||
if 'state_dict' in checkpoint:
|
||
state_dict = checkpoint['state_dict']
|
||
epoch = checkpoint.get('epoch', 'unknown')
|
||
print(f" Epoch: {epoch}")
|
||
else:
|
||
state_dict = checkpoint
|
||
|
||
return state_dict, checkpoint
|
||
|
||
def count_parameters(state_dict):
|
||
"""统计参数量"""
|
||
total = sum(v.numel() for k, v in state_dict.items() if isinstance(v, torch.Tensor))
|
||
return total
|
||
|
||
def prune_weights_l1(state_dict, module_prefix, prune_ratio):
|
||
"""
|
||
对指定模块进行L1范数剪枝
|
||
|
||
Args:
|
||
state_dict: 模型权重字典
|
||
module_prefix: 模块前缀,如 'encoders.camera.backbone'
|
||
prune_ratio: 剪枝比例,如 0.3 表示剪掉30%
|
||
"""
|
||
pruned_keys = []
|
||
|
||
for key in state_dict.keys():
|
||
if key.startswith(module_prefix) and 'weight' in key:
|
||
# 只剪枝Conv和Linear层
|
||
if len(state_dict[key].shape) in [2, 4]: # Linear或Conv
|
||
original_shape = state_dict[key].shape
|
||
original_params = state_dict[key].numel()
|
||
|
||
# 计算L1范数
|
||
weight = state_dict[key]
|
||
|
||
if len(weight.shape) == 4: # Conv2d: (out, in, h, w)
|
||
# 按输出通道剪枝
|
||
l1_norm = weight.abs().sum(dim=[1, 2, 3])
|
||
num_channels = weight.shape[0]
|
||
num_keep = int(num_channels * (1 - prune_ratio))
|
||
|
||
# 保留L1范数最大的通道
|
||
_, indices = torch.topk(l1_norm, num_keep)
|
||
indices = indices.sort()[0]
|
||
|
||
# 剪枝
|
||
state_dict[key] = weight[indices]
|
||
|
||
# 同时剪枝对应的bias
|
||
bias_key = key.replace('weight', 'bias')
|
||
if bias_key in state_dict:
|
||
state_dict[bias_key] = state_dict[bias_key][indices]
|
||
|
||
pruned_params = state_dict[key].numel()
|
||
pruned_keys.append((key, original_params, pruned_params))
|
||
|
||
elif len(weight.shape) == 2: # Linear: (out, in)
|
||
# 按输出维度剪枝
|
||
l1_norm = weight.abs().sum(dim=1)
|
||
num_features = weight.shape[0]
|
||
num_keep = int(num_features * (1 - prune_ratio))
|
||
|
||
_, indices = torch.topk(l1_norm, num_keep)
|
||
indices = indices.sort()[0]
|
||
|
||
state_dict[key] = weight[indices]
|
||
|
||
bias_key = key.replace('weight', 'bias')
|
||
if bias_key in state_dict:
|
||
state_dict[bias_key] = state_dict[bias_key][indices]
|
||
|
||
pruned_params = state_dict[key].numel()
|
||
pruned_keys.append((key, original_params, pruned_params))
|
||
|
||
return pruned_keys
|
||
|
||
def smart_prune(state_dict, target_params_ratio=0.7):
|
||
"""
|
||
智能剪枝:根据模块大小自动确定剪枝比例
|
||
|
||
Args:
|
||
state_dict: 模型权重
|
||
target_params_ratio: 目标参数保留比例,0.7表示保留70%
|
||
"""
|
||
|
||
print("=" * 80)
|
||
print("开始智能剪枝")
|
||
print("=" * 80)
|
||
|
||
original_params = count_parameters(state_dict)
|
||
target_params = int(original_params * target_params_ratio)
|
||
|
||
print(f"原始参数量: {original_params/1e6:.2f}M")
|
||
print(f"目标参数量: {target_params/1e6:.2f}M")
|
||
print(f"需要剪枝: {(1-target_params_ratio)*100:.1f}%")
|
||
print()
|
||
|
||
# 剪枝策略:按模块大小分配不同的剪枝比例
|
||
pruning_plan = {
|
||
'encoders.camera.backbone': 0.20, # Camera backbone剪20%(最大模块)
|
||
'heads.map.aspp': 0.25, # ASPP剪25%
|
||
'decoder': 0.15, # Decoder剪15%
|
||
'encoders.camera.vtransform': 0.10, # VTransform剪10%
|
||
}
|
||
|
||
print("剪枝计划:")
|
||
for module, ratio in pruning_plan.items():
|
||
print(f" {module:40s}: 剪枝 {ratio*100:.0f}%")
|
||
print()
|
||
|
||
# 执行剪枝
|
||
all_pruned_keys = []
|
||
|
||
for module_prefix, prune_ratio in pruning_plan.items():
|
||
print(f"\n剪枝模块: {module_prefix} (比例: {prune_ratio*100:.0f}%)")
|
||
print("-" * 60)
|
||
|
||
pruned_keys = prune_weights_l1(state_dict, module_prefix, prune_ratio)
|
||
|
||
for key, orig, pruned in pruned_keys:
|
||
reduction = (1 - pruned/orig) * 100
|
||
print(f" {key:50s}: {orig:10,} → {pruned:10,} (-{reduction:.1f}%)")
|
||
|
||
all_pruned_keys.extend(pruned_keys)
|
||
|
||
# 统计最终结果
|
||
final_params = count_parameters(state_dict)
|
||
actual_reduction = (1 - final_params/original_params) * 100
|
||
|
||
print("\n" + "=" * 80)
|
||
print("剪枝结果")
|
||
print("=" * 80)
|
||
print(f"原始参数量: {original_params/1e6:.2f}M")
|
||
print(f"剪枝后参数量: {final_params/1e6:.2f}M")
|
||
print(f"实际减少: {actual_reduction:.1f}%")
|
||
print(f"模型大小: {final_params*4/1024/1024:.2f}MB (FP32)")
|
||
print()
|
||
|
||
return state_dict, final_params
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='剪枝BEVFusion模型')
|
||
parser.add_argument('--checkpoint', required=True, help='输入checkpoint路径')
|
||
parser.add_argument('--output', required=True, help='输出checkpoint路径')
|
||
parser.add_argument('--target-ratio', type=float, default=0.70,
|
||
help='目标参数保留比例 (0.7=保留70%)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 80)
|
||
print("BEVFusion模型剪枝")
|
||
print("=" * 80)
|
||
print(f"输入: {args.checkpoint}")
|
||
print(f"输出: {args.output}")
|
||
print(f"目标保留: {args.target_ratio*100:.0f}%")
|
||
print()
|
||
|
||
# 加载checkpoint
|
||
state_dict, checkpoint = load_checkpoint(args.checkpoint)
|
||
|
||
print(f"原始参数量: {count_parameters(state_dict)/1e6:.2f}M")
|
||
print()
|
||
|
||
# 执行剪枝
|
||
pruned_state_dict, final_params = smart_prune(state_dict, args.target_ratio)
|
||
|
||
# 保存剪枝后的模型
|
||
print("保存剪枝后的模型...")
|
||
|
||
if 'state_dict' in checkpoint:
|
||
checkpoint['state_dict'] = pruned_state_dict
|
||
# 清空optimizer状态(剪枝后需要重新初始化)
|
||
if 'optimizer' in checkpoint:
|
||
del checkpoint['optimizer']
|
||
else:
|
||
checkpoint = pruned_state_dict
|
||
|
||
torch.save(checkpoint, args.output)
|
||
|
||
print(f"\n✅ 剪枝完成!")
|
||
print(f" 输出文件: {args.output}")
|
||
print(f" 最终参数量: {final_params/1e6:.2f}M")
|
||
print(f" 模型大小: {final_params*4/1024/1024:.2f}MB (FP32)")
|
||
print()
|
||
print("=" * 80)
|
||
print("下一步: 微调训练")
|
||
print("=" * 80)
|
||
print()
|
||
print("命令:")
|
||
print(" torchpack dist-run -np 8 python tools/train.py \\")
|
||
print(" configs/.../multitask_enhanced_phase1_HIGHRES.yaml \\")
|
||
print(f" --load_from {args.output} \\")
|
||
print(" --cfg-options max_epochs=3 optimizer.lr=5.0e-6")
|
||
print()
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|