222 lines
7.6 KiB
Python
222 lines
7.6 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
使用PyTorch内置剪枝功能剪枝BEVFusion模型
|
|||
|
|
适用于PyTorch 1.10+
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.utils.prune as prune
|
|||
|
|
import sys
|
|||
|
|
import copy
|
|||
|
|
from collections import OrderedDict
|
|||
|
|
|
|||
|
|
sys.path.insert(0, '/workspace/bevfusion')
|
|||
|
|
|
|||
|
|
def load_checkpoint(checkpoint_path):
|
|||
|
|
"""加载checkpoint"""
|
|||
|
|
print(f"加载checkpoint: {checkpoint_path}")
|
|||
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|||
|
|
|
|||
|
|
if 'state_dict' in checkpoint:
|
|||
|
|
state_dict = checkpoint['state_dict']
|
|||
|
|
epoch = checkpoint.get('epoch', 'unknown')
|
|||
|
|
print(f" Epoch: {epoch}")
|
|||
|
|
else:
|
|||
|
|
state_dict = checkpoint
|
|||
|
|
|
|||
|
|
return state_dict, checkpoint
|
|||
|
|
|
|||
|
|
def count_parameters(state_dict):
|
|||
|
|
"""统计参数量"""
|
|||
|
|
total = sum(v.numel() for k, v in state_dict.items() if isinstance(v, torch.Tensor))
|
|||
|
|
return total
|
|||
|
|
|
|||
|
|
def prune_weights_l1(state_dict, module_prefix, prune_ratio):
|
|||
|
|
"""
|
|||
|
|
对指定模块进行L1范数剪枝
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state_dict: 模型权重字典
|
|||
|
|
module_prefix: 模块前缀,如 'encoders.camera.backbone'
|
|||
|
|
prune_ratio: 剪枝比例,如 0.3 表示剪掉30%
|
|||
|
|
"""
|
|||
|
|
pruned_keys = []
|
|||
|
|
|
|||
|
|
for key in state_dict.keys():
|
|||
|
|
if key.startswith(module_prefix) and 'weight' in key:
|
|||
|
|
# 只剪枝Conv和Linear层
|
|||
|
|
if len(state_dict[key].shape) in [2, 4]: # Linear或Conv
|
|||
|
|
original_shape = state_dict[key].shape
|
|||
|
|
original_params = state_dict[key].numel()
|
|||
|
|
|
|||
|
|
# 计算L1范数
|
|||
|
|
weight = state_dict[key]
|
|||
|
|
|
|||
|
|
if len(weight.shape) == 4: # Conv2d: (out, in, h, w)
|
|||
|
|
# 按输出通道剪枝
|
|||
|
|
l1_norm = weight.abs().sum(dim=[1, 2, 3])
|
|||
|
|
num_channels = weight.shape[0]
|
|||
|
|
num_keep = int(num_channels * (1 - prune_ratio))
|
|||
|
|
|
|||
|
|
# 保留L1范数最大的通道
|
|||
|
|
_, indices = torch.topk(l1_norm, num_keep)
|
|||
|
|
indices = indices.sort()[0]
|
|||
|
|
|
|||
|
|
# 剪枝
|
|||
|
|
state_dict[key] = weight[indices]
|
|||
|
|
|
|||
|
|
# 同时剪枝对应的bias
|
|||
|
|
bias_key = key.replace('weight', 'bias')
|
|||
|
|
if bias_key in state_dict:
|
|||
|
|
state_dict[bias_key] = state_dict[bias_key][indices]
|
|||
|
|
|
|||
|
|
pruned_params = state_dict[key].numel()
|
|||
|
|
pruned_keys.append((key, original_params, pruned_params))
|
|||
|
|
|
|||
|
|
elif len(weight.shape) == 2: # Linear: (out, in)
|
|||
|
|
# 按输出维度剪枝
|
|||
|
|
l1_norm = weight.abs().sum(dim=1)
|
|||
|
|
num_features = weight.shape[0]
|
|||
|
|
num_keep = int(num_features * (1 - prune_ratio))
|
|||
|
|
|
|||
|
|
_, indices = torch.topk(l1_norm, num_keep)
|
|||
|
|
indices = indices.sort()[0]
|
|||
|
|
|
|||
|
|
state_dict[key] = weight[indices]
|
|||
|
|
|
|||
|
|
bias_key = key.replace('weight', 'bias')
|
|||
|
|
if bias_key in state_dict:
|
|||
|
|
state_dict[bias_key] = state_dict[bias_key][indices]
|
|||
|
|
|
|||
|
|
pruned_params = state_dict[key].numel()
|
|||
|
|
pruned_keys.append((key, original_params, pruned_params))
|
|||
|
|
|
|||
|
|
return pruned_keys
|
|||
|
|
|
|||
|
|
def smart_prune(state_dict, target_params_ratio=0.7):
|
|||
|
|
"""
|
|||
|
|
智能剪枝:根据模块大小自动确定剪枝比例
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state_dict: 模型权重
|
|||
|
|
target_params_ratio: 目标参数保留比例,0.7表示保留70%
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
print("=" * 80)
|
|||
|
|
print("开始智能剪枝")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
original_params = count_parameters(state_dict)
|
|||
|
|
target_params = int(original_params * target_params_ratio)
|
|||
|
|
|
|||
|
|
print(f"原始参数量: {original_params/1e6:.2f}M")
|
|||
|
|
print(f"目标参数量: {target_params/1e6:.2f}M")
|
|||
|
|
print(f"需要剪枝: {(1-target_params_ratio)*100:.1f}%")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 剪枝策略:按模块大小分配不同的剪枝比例
|
|||
|
|
pruning_plan = {
|
|||
|
|
'encoders.camera.backbone': 0.20, # Camera backbone剪20%(最大模块)
|
|||
|
|
'heads.map.aspp': 0.25, # ASPP剪25%
|
|||
|
|
'decoder': 0.15, # Decoder剪15%
|
|||
|
|
'encoders.camera.vtransform': 0.10, # VTransform剪10%
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print("剪枝计划:")
|
|||
|
|
for module, ratio in pruning_plan.items():
|
|||
|
|
print(f" {module:40s}: 剪枝 {ratio*100:.0f}%")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 执行剪枝
|
|||
|
|
all_pruned_keys = []
|
|||
|
|
|
|||
|
|
for module_prefix, prune_ratio in pruning_plan.items():
|
|||
|
|
print(f"\n剪枝模块: {module_prefix} (比例: {prune_ratio*100:.0f}%)")
|
|||
|
|
print("-" * 60)
|
|||
|
|
|
|||
|
|
pruned_keys = prune_weights_l1(state_dict, module_prefix, prune_ratio)
|
|||
|
|
|
|||
|
|
for key, orig, pruned in pruned_keys:
|
|||
|
|
reduction = (1 - pruned/orig) * 100
|
|||
|
|
print(f" {key:50s}: {orig:10,} → {pruned:10,} (-{reduction:.1f}%)")
|
|||
|
|
|
|||
|
|
all_pruned_keys.extend(pruned_keys)
|
|||
|
|
|
|||
|
|
# 统计最终结果
|
|||
|
|
final_params = count_parameters(state_dict)
|
|||
|
|
actual_reduction = (1 - final_params/original_params) * 100
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("剪枝结果")
|
|||
|
|
print("=" * 80)
|
|||
|
|
print(f"原始参数量: {original_params/1e6:.2f}M")
|
|||
|
|
print(f"剪枝后参数量: {final_params/1e6:.2f}M")
|
|||
|
|
print(f"实际减少: {actual_reduction:.1f}%")
|
|||
|
|
print(f"模型大小: {final_params*4/1024/1024:.2f}MB (FP32)")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
return state_dict, final_params
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
parser = argparse.ArgumentParser(description='剪枝BEVFusion模型')
|
|||
|
|
parser.add_argument('--checkpoint', required=True, help='输入checkpoint路径')
|
|||
|
|
parser.add_argument('--output', required=True, help='输出checkpoint路径')
|
|||
|
|
parser.add_argument('--target-ratio', type=float, default=0.70,
|
|||
|
|
help='目标参数保留比例 (0.7=保留70%)')
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
print("=" * 80)
|
|||
|
|
print("BEVFusion模型剪枝")
|
|||
|
|
print("=" * 80)
|
|||
|
|
print(f"输入: {args.checkpoint}")
|
|||
|
|
print(f"输出: {args.output}")
|
|||
|
|
print(f"目标保留: {args.target_ratio*100:.0f}%")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 加载checkpoint
|
|||
|
|
state_dict, checkpoint = load_checkpoint(args.checkpoint)
|
|||
|
|
|
|||
|
|
print(f"原始参数量: {count_parameters(state_dict)/1e6:.2f}M")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 执行剪枝
|
|||
|
|
pruned_state_dict, final_params = smart_prune(state_dict, args.target_ratio)
|
|||
|
|
|
|||
|
|
# 保存剪枝后的模型
|
|||
|
|
print("保存剪枝后的模型...")
|
|||
|
|
|
|||
|
|
if 'state_dict' in checkpoint:
|
|||
|
|
checkpoint['state_dict'] = pruned_state_dict
|
|||
|
|
# 清空optimizer状态(剪枝后需要重新初始化)
|
|||
|
|
if 'optimizer' in checkpoint:
|
|||
|
|
del checkpoint['optimizer']
|
|||
|
|
else:
|
|||
|
|
checkpoint = pruned_state_dict
|
|||
|
|
|
|||
|
|
torch.save(checkpoint, args.output)
|
|||
|
|
|
|||
|
|
print(f"\n✅ 剪枝完成!")
|
|||
|
|
print(f" 输出文件: {args.output}")
|
|||
|
|
print(f" 最终参数量: {final_params/1e6:.2f}M")
|
|||
|
|
print(f" 模型大小: {final_params*4/1024/1024:.2f}MB (FP32)")
|
|||
|
|
print()
|
|||
|
|
print("=" * 80)
|
|||
|
|
print("下一步: 微调训练")
|
|||
|
|
print("=" * 80)
|
|||
|
|
print()
|
|||
|
|
print("命令:")
|
|||
|
|
print(" torchpack dist-run -np 8 python tools/train.py \\")
|
|||
|
|
print(" configs/.../multitask_enhanced_phase1_HIGHRES.yaml \\")
|
|||
|
|
print(f" --load_from {args.output} \\")
|
|||
|
|
print(" --cfg-options max_epochs=3 optimizer.lr=5.0e-6")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|
|||
|
|
|