bev-project/archive_scripts/准备剪枝工具.sh

296 lines
6.6 KiB
Bash
Raw Normal View History

#!/bin/bash
# 准备模型剪枝工具和环境
set -e
cd /workspace/bevfusion
echo "========================================================================"
echo "准备模型剪枝工具"
echo "========================================================================"
echo ""
# 1. 检查torch-pruning
echo "========== 1. 检查torch-pruning =========="
echo ""
if /opt/conda/bin/python -c "import torch_pruning" 2>/dev/null; then
echo "✅ torch-pruning已安装"
VERSION=$(/opt/conda/bin/python -c "import torch_pruning as tp; print(tp.__version__)" 2>/dev/null || echo "未知")
echo " 版本: $VERSION"
else
echo "⏳ torch-pruning未安装正在安装..."
/opt/conda/bin/pip install torch-pruning -i https://pypi.tuna.tsinghua.edu.cn/simple
if [ $? -eq 0 ]; then
echo "✅ torch-pruning安装成功"
else
echo "⚠️ torch-pruning安装失败尝试使用PyTorch内置剪枝"
fi
fi
echo ""
# 2. 创建剪枝工具目录
echo "========== 2. 创建工具目录 =========="
echo ""
mkdir -p tools/pruning
mkdir -p pruning_results
echo "✅ 目录已创建:"
echo " tools/pruning/ # 剪枝脚本"
echo " pruning_results/ # 剪枝结果"
echo ""
# 3. 创建简单的剪枝测试脚本
echo "========== 3. 创建剪枝测试脚本 =========="
echo ""
cat > tools/pruning/test_pruning.py << 'PYTHON_EOF'
#!/usr/bin/env python3
"""
测试torch-pruning是否可用
"""
import sys
sys.path.insert(0, '/workspace/bevfusion')
def test_torch_pruning():
"""测试torch-pruning"""
try:
import torch_pruning as tp
print(f"✅ torch-pruning可用")
print(f" 版本: {tp.__version__}")
return True
except ImportError:
print(f"❌ torch-pruning不可用")
return False
def test_pytorch_builtin():
"""测试PyTorch内置剪枝"""
try:
import torch.nn.utils.prune as prune
print(f"✅ PyTorch内置剪枝可用")
return True
except ImportError:
print(f"❌ PyTorch内置剪枝不可用")
return False
def test_checkpoint_load():
"""测试checkpoint加载"""
import torch
checkpoint_file = 'runs/enhanced_from_epoch19/epoch_23.pth'
try:
checkpoint = torch.load(checkpoint_file, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
total_params = sum(v.numel() for v in state_dict.values() if isinstance(v, torch.Tensor))
print(f"✅ Checkpoint加载成功")
print(f" 参数量: {total_params/1e6:.2f}M")
print(f" Epoch: {checkpoint.get('epoch', 'unknown')}")
return True
else:
print(f"⚠️ Checkpoint格式异常")
return False
except Exception as e:
print(f"❌ Checkpoint加载失败: {e}")
return False
if __name__ == '__main__':
print("=" * 60)
print("剪枝工具可用性测试")
print("=" * 60)
print()
# 测试torch-pruning
print("1. torch-pruning:")
has_tp = test_torch_pruning()
print()
# 测试PyTorch内置
print("2. PyTorch内置剪枝:")
has_builtin = test_pytorch_builtin()
print()
# 测试checkpoint
print("3. Checkpoint加载:")
can_load = test_checkpoint_load()
print()
# 总结
print("=" * 60)
print("测试结果总结")
print("=" * 60)
if has_tp:
print("✅ 推荐使用: torch-pruning (功能更强大)")
elif has_builtin:
print("✅ 可以使用: PyTorch内置剪枝 (基本功能)")
else:
print("❌ 需要安装剪枝工具")
if can_load:
print("✅ Checkpoint可正常加载")
else:
print("❌ Checkpoint加载有问题")
print()
print("准备就绪!可以开始剪枝。")
PYTHON_EOF
chmod +x tools/pruning/test_pruning.py
echo "✅ 测试脚本已创建: tools/pruning/test_pruning.py"
echo ""
# 4. 运行测试
echo "========== 4. 运行可用性测试 =========="
echo ""
/opt/conda/bin/python tools/pruning/test_pruning.py
echo ""
# 5. 生成剪枝方案文件
echo "========== 5. 生成剪枝方案 =========="
echo ""
cat > pruning_results/pruning_plan.md << 'PLAN_EOF'
# BEVFusion剪枝实施方案
**生成时间**: $(date)
**Baseline**: epoch_23.pth (45.72M参数)
---
## 🎯 剪枝目标
```
参数量: 45.72M → 32M (-30%)
模型大小: 174MB → 122MB (FP32)
推理加速: ~30%
精度损失: <2%
```
---
## 📋 分层剪枝计划
### Layer 1: Camera Backbone (27.55M → 20M)
**方法**: 通道剪枝
```
SwinTransformer各stage剪枝:
Stage 1: 96 channels → 80 (-17%)
Stage 2: 192 channels → 160 (-17%)
Stage 3: 384 channels → 320 (-17%)
Stage 4: 768 channels → 640 (-17%)
预期减少: ~5.5M参数
```
**实施**:
```python
# 使用torch-pruning
import torch_pruning as tp
# 对SwinTransformer剪枝
pruner = tp.pruner.MetaPruner(
model.encoders['camera'].backbone,
example_inputs,
importance=tp.importance.MagnitudeImportance(p=2),
pruning_ratio=0.17, # 17%
)
pruner.step()
```
---
### Layer 2: ASPP模块 (4.13M → 3M)
**方法**: 通道剪枝
```
ASPP通道数: 512384 (-25%)
各分支同比例剪枝
预期减少: ~1.1M参数
```
---
### Layer 3: Decoder (4.58M → 3.9M)
**方法**: 通道剪枝
```
Decoder通道: [128, 256][96, 192] (-25%)
预期减少: ~0.7M参数
```
---
### Layer 4: Camera VTransform (2.61M → 2.35M)
**方法**: 轻度剪枝
```
通道剪枝: 10%
预期减少: ~0.26M参数
```
---
## 📊 预期结果
```
总计减少: 7.56M参数 (-16.5%)
保守估计: 剪枝后38M参数
如果效果好可进一步剪枝到32M
```
---
## 🚀 下一步
1. 创建剪枝脚本
2. 小规模测试剪枝
3. 评估剪枝效果
4. 全量剪枝
5. 微调3 epochs
PLAN_EOF
echo "✅ 剪枝方案已生成: pruning_results/pruning_plan.md"
echo ""
# 6. 总结
echo "========================================================================"
echo "准备工作完成!"
echo "========================================================================"
echo ""
echo "已完成:"
echo " ✅ torch-pruning检查/安装"
echo " ✅ 工具目录创建"
echo " ✅ 测试脚本创建"
echo " ✅ 剪枝方案生成"
echo ""
echo "下一步:"
echo " 1. 查看剪枝方案: cat pruning_results/pruning_plan.md"
echo " 2. 创建剪枝脚本 (明天)"
echo " 3. 开始剪枝实施 (明天)"
echo ""
echo "当前状态:"
echo " - Stage 1训练: 继续进行中 (不受影响)"
echo " - 优化准备: ✅ 就绪"
echo ""
echo "========================================================================"