bev-project/archive_scripts/准备剪枝工具.sh

296 lines
6.6 KiB
Bash
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/bin/bash
# 准备模型剪枝工具和环境
set -e
cd /workspace/bevfusion
echo "========================================================================"
echo "准备模型剪枝工具"
echo "========================================================================"
echo ""
# 1. 检查torch-pruning
echo "========== 1. 检查torch-pruning =========="
echo ""
if /opt/conda/bin/python -c "import torch_pruning" 2>/dev/null; then
echo "✅ torch-pruning已安装"
VERSION=$(/opt/conda/bin/python -c "import torch_pruning as tp; print(tp.__version__)" 2>/dev/null || echo "未知")
echo " 版本: $VERSION"
else
echo "⏳ torch-pruning未安装正在安装..."
/opt/conda/bin/pip install torch-pruning -i https://pypi.tuna.tsinghua.edu.cn/simple
if [ $? -eq 0 ]; then
echo "✅ torch-pruning安装成功"
else
echo "⚠️ torch-pruning安装失败尝试使用PyTorch内置剪枝"
fi
fi
echo ""
# 2. 创建剪枝工具目录
echo "========== 2. 创建工具目录 =========="
echo ""
mkdir -p tools/pruning
mkdir -p pruning_results
echo "✅ 目录已创建:"
echo " tools/pruning/ # 剪枝脚本"
echo " pruning_results/ # 剪枝结果"
echo ""
# 3. 创建简单的剪枝测试脚本
echo "========== 3. 创建剪枝测试脚本 =========="
echo ""
cat > tools/pruning/test_pruning.py << 'PYTHON_EOF'
#!/usr/bin/env python3
"""
测试torch-pruning是否可用
"""
import sys
sys.path.insert(0, '/workspace/bevfusion')
def test_torch_pruning():
"""测试torch-pruning"""
try:
import torch_pruning as tp
print(f"✅ torch-pruning可用")
print(f" 版本: {tp.__version__}")
return True
except ImportError:
print(f"❌ torch-pruning不可用")
return False
def test_pytorch_builtin():
"""测试PyTorch内置剪枝"""
try:
import torch.nn.utils.prune as prune
print(f"✅ PyTorch内置剪枝可用")
return True
except ImportError:
print(f"❌ PyTorch内置剪枝不可用")
return False
def test_checkpoint_load():
"""测试checkpoint加载"""
import torch
checkpoint_file = 'runs/enhanced_from_epoch19/epoch_23.pth'
try:
checkpoint = torch.load(checkpoint_file, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
total_params = sum(v.numel() for v in state_dict.values() if isinstance(v, torch.Tensor))
print(f"✅ Checkpoint加载成功")
print(f" 参数量: {total_params/1e6:.2f}M")
print(f" Epoch: {checkpoint.get('epoch', 'unknown')}")
return True
else:
print(f"⚠️ Checkpoint格式异常")
return False
except Exception as e:
print(f"❌ Checkpoint加载失败: {e}")
return False
if __name__ == '__main__':
print("=" * 60)
print("剪枝工具可用性测试")
print("=" * 60)
print()
# 测试torch-pruning
print("1. torch-pruning:")
has_tp = test_torch_pruning()
print()
# 测试PyTorch内置
print("2. PyTorch内置剪枝:")
has_builtin = test_pytorch_builtin()
print()
# 测试checkpoint
print("3. Checkpoint加载:")
can_load = test_checkpoint_load()
print()
# 总结
print("=" * 60)
print("测试结果总结")
print("=" * 60)
if has_tp:
print("✅ 推荐使用: torch-pruning (功能更强大)")
elif has_builtin:
print("✅ 可以使用: PyTorch内置剪枝 (基本功能)")
else:
print("❌ 需要安装剪枝工具")
if can_load:
print("✅ Checkpoint可正常加载")
else:
print("❌ Checkpoint加载有问题")
print()
print("准备就绪!可以开始剪枝。")
PYTHON_EOF
chmod +x tools/pruning/test_pruning.py
echo "✅ 测试脚本已创建: tools/pruning/test_pruning.py"
echo ""
# 4. 运行测试
echo "========== 4. 运行可用性测试 =========="
echo ""
/opt/conda/bin/python tools/pruning/test_pruning.py
echo ""
# 5. 生成剪枝方案文件
echo "========== 5. 生成剪枝方案 =========="
echo ""
cat > pruning_results/pruning_plan.md << 'PLAN_EOF'
# BEVFusion剪枝实施方案
**生成时间**: $(date)
**Baseline**: epoch_23.pth (45.72M参数)
---
## 🎯 剪枝目标
```
参数量: 45.72M → 32M (-30%)
模型大小: 174MB → 122MB (FP32)
推理加速: ~30%
精度损失: <2%
```
---
## 📋 分层剪枝计划
### Layer 1: Camera Backbone (27.55M → 20M)
**方法**: 通道剪枝
```
SwinTransformer各stage剪枝:
Stage 1: 96 channels → 80 (-17%)
Stage 2: 192 channels → 160 (-17%)
Stage 3: 384 channels → 320 (-17%)
Stage 4: 768 channels → 640 (-17%)
预期减少: ~5.5M参数
```
**实施**:
```python
# 使用torch-pruning
import torch_pruning as tp
# 对SwinTransformer剪枝
pruner = tp.pruner.MetaPruner(
model.encoders['camera'].backbone,
example_inputs,
importance=tp.importance.MagnitudeImportance(p=2),
pruning_ratio=0.17, # 17%
)
pruner.step()
```
---
### Layer 2: ASPP模块 (4.13M → 3M)
**方法**: 通道剪枝
```
ASPP通道数: 512 → 384 (-25%)
各分支同比例剪枝
预期减少: ~1.1M参数
```
---
### Layer 3: Decoder (4.58M → 3.9M)
**方法**: 通道剪枝
```
Decoder通道: [128, 256] → [96, 192] (-25%)
预期减少: ~0.7M参数
```
---
### Layer 4: Camera VTransform (2.61M → 2.35M)
**方法**: 轻度剪枝
```
通道剪枝: 10%
预期减少: ~0.26M参数
```
---
## 📊 预期结果
```
总计减少: 7.56M参数 (-16.5%)
保守估计: 剪枝后38M参数
如果效果好可进一步剪枝到32M
```
---
## 🚀 下一步
1. 创建剪枝脚本
2. 小规模测试剪枝
3. 评估剪枝效果
4. 全量剪枝
5. 微调3 epochs
PLAN_EOF
echo "✅ 剪枝方案已生成: pruning_results/pruning_plan.md"
echo ""
# 6. 总结
echo "========================================================================"
echo "准备工作完成!"
echo "========================================================================"
echo ""
echo "已完成:"
echo " ✅ torch-pruning检查/安装"
echo " ✅ 工具目录创建"
echo " ✅ 测试脚本创建"
echo " ✅ 剪枝方案生成"
echo ""
echo "下一步:"
echo " 1. 查看剪枝方案: cat pruning_results/pruning_plan.md"
echo " 2. 创建剪枝脚本 (明天)"
echo " 3. 开始剪枝实施 (明天)"
echo ""
echo "当前状态:"
echo " - Stage 1训练: 继续进行中 (不受影响)"
echo " - 优化准备: ✅ 就绪"
echo ""
echo "========================================================================"