296 lines
6.6 KiB
Bash
296 lines
6.6 KiB
Bash
|
|
#!/bin/bash
|
|||
|
|
# 准备模型剪枝工具和环境
|
|||
|
|
|
|||
|
|
set -e
|
|||
|
|
|
|||
|
|
cd /workspace/bevfusion
|
|||
|
|
|
|||
|
|
echo "========================================================================"
|
|||
|
|
echo "准备模型剪枝工具"
|
|||
|
|
echo "========================================================================"
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
# 1. 检查torch-pruning
|
|||
|
|
echo "========== 1. 检查torch-pruning =========="
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
if /opt/conda/bin/python -c "import torch_pruning" 2>/dev/null; then
|
|||
|
|
echo "✅ torch-pruning已安装"
|
|||
|
|
VERSION=$(/opt/conda/bin/python -c "import torch_pruning as tp; print(tp.__version__)" 2>/dev/null || echo "未知")
|
|||
|
|
echo " 版本: $VERSION"
|
|||
|
|
else
|
|||
|
|
echo "⏳ torch-pruning未安装,正在安装..."
|
|||
|
|
/opt/conda/bin/pip install torch-pruning -i https://pypi.tuna.tsinghua.edu.cn/simple
|
|||
|
|
|
|||
|
|
if [ $? -eq 0 ]; then
|
|||
|
|
echo "✅ torch-pruning安装成功"
|
|||
|
|
else
|
|||
|
|
echo "⚠️ torch-pruning安装失败,尝试使用PyTorch内置剪枝"
|
|||
|
|
fi
|
|||
|
|
fi
|
|||
|
|
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
# 2. 创建剪枝工具目录
|
|||
|
|
echo "========== 2. 创建工具目录 =========="
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
mkdir -p tools/pruning
|
|||
|
|
mkdir -p pruning_results
|
|||
|
|
|
|||
|
|
echo "✅ 目录已创建:"
|
|||
|
|
echo " tools/pruning/ # 剪枝脚本"
|
|||
|
|
echo " pruning_results/ # 剪枝结果"
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
# 3. 创建简单的剪枝测试脚本
|
|||
|
|
echo "========== 3. 创建剪枝测试脚本 =========="
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
cat > tools/pruning/test_pruning.py << 'PYTHON_EOF'
|
|||
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
测试torch-pruning是否可用
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
sys.path.insert(0, '/workspace/bevfusion')
|
|||
|
|
|
|||
|
|
def test_torch_pruning():
|
|||
|
|
"""测试torch-pruning"""
|
|||
|
|
try:
|
|||
|
|
import torch_pruning as tp
|
|||
|
|
print(f"✅ torch-pruning可用")
|
|||
|
|
print(f" 版本: {tp.__version__}")
|
|||
|
|
return True
|
|||
|
|
except ImportError:
|
|||
|
|
print(f"❌ torch-pruning不可用")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def test_pytorch_builtin():
|
|||
|
|
"""测试PyTorch内置剪枝"""
|
|||
|
|
try:
|
|||
|
|
import torch.nn.utils.prune as prune
|
|||
|
|
print(f"✅ PyTorch内置剪枝可用")
|
|||
|
|
return True
|
|||
|
|
except ImportError:
|
|||
|
|
print(f"❌ PyTorch内置剪枝不可用")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def test_checkpoint_load():
|
|||
|
|
"""测试checkpoint加载"""
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
checkpoint_file = 'runs/enhanced_from_epoch19/epoch_23.pth'
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
|||
|
|
|
|||
|
|
if 'state_dict' in checkpoint:
|
|||
|
|
state_dict = checkpoint['state_dict']
|
|||
|
|
total_params = sum(v.numel() for v in state_dict.values() if isinstance(v, torch.Tensor))
|
|||
|
|
|
|||
|
|
print(f"✅ Checkpoint加载成功")
|
|||
|
|
print(f" 参数量: {total_params/1e6:.2f}M")
|
|||
|
|
print(f" Epoch: {checkpoint.get('epoch', 'unknown')}")
|
|||
|
|
return True
|
|||
|
|
else:
|
|||
|
|
print(f"⚠️ Checkpoint格式异常")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ Checkpoint加载失败: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("剪枝工具可用性测试")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 测试torch-pruning
|
|||
|
|
print("1. torch-pruning:")
|
|||
|
|
has_tp = test_torch_pruning()
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 测试PyTorch内置
|
|||
|
|
print("2. PyTorch内置剪枝:")
|
|||
|
|
has_builtin = test_pytorch_builtin()
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 测试checkpoint
|
|||
|
|
print("3. Checkpoint加载:")
|
|||
|
|
can_load = test_checkpoint_load()
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 总结
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("测试结果总结")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
if has_tp:
|
|||
|
|
print("✅ 推荐使用: torch-pruning (功能更强大)")
|
|||
|
|
elif has_builtin:
|
|||
|
|
print("✅ 可以使用: PyTorch内置剪枝 (基本功能)")
|
|||
|
|
else:
|
|||
|
|
print("❌ 需要安装剪枝工具")
|
|||
|
|
|
|||
|
|
if can_load:
|
|||
|
|
print("✅ Checkpoint可正常加载")
|
|||
|
|
else:
|
|||
|
|
print("❌ Checkpoint加载有问题")
|
|||
|
|
|
|||
|
|
print()
|
|||
|
|
print("准备就绪!可以开始剪枝。")
|
|||
|
|
|
|||
|
|
PYTHON_EOF
|
|||
|
|
|
|||
|
|
chmod +x tools/pruning/test_pruning.py
|
|||
|
|
|
|||
|
|
echo "✅ 测试脚本已创建: tools/pruning/test_pruning.py"
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
# 4. 运行测试
|
|||
|
|
echo "========== 4. 运行可用性测试 =========="
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
/opt/conda/bin/python tools/pruning/test_pruning.py
|
|||
|
|
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
# 5. 生成剪枝方案文件
|
|||
|
|
echo "========== 5. 生成剪枝方案 =========="
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
cat > pruning_results/pruning_plan.md << 'PLAN_EOF'
|
|||
|
|
# BEVFusion剪枝实施方案
|
|||
|
|
|
|||
|
|
**生成时间**: $(date)
|
|||
|
|
**Baseline**: epoch_23.pth (45.72M参数)
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 🎯 剪枝目标
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
参数量: 45.72M → 32M (-30%)
|
|||
|
|
模型大小: 174MB → 122MB (FP32)
|
|||
|
|
推理加速: ~30%
|
|||
|
|
精度损失: <2%
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 📋 分层剪枝计划
|
|||
|
|
|
|||
|
|
### Layer 1: Camera Backbone (27.55M → 20M)
|
|||
|
|
|
|||
|
|
**方法**: 通道剪枝
|
|||
|
|
```
|
|||
|
|
SwinTransformer各stage剪枝:
|
|||
|
|
Stage 1: 96 channels → 80 (-17%)
|
|||
|
|
Stage 2: 192 channels → 160 (-17%)
|
|||
|
|
Stage 3: 384 channels → 320 (-17%)
|
|||
|
|
Stage 4: 768 channels → 640 (-17%)
|
|||
|
|
|
|||
|
|
预期减少: ~5.5M参数
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**实施**:
|
|||
|
|
```python
|
|||
|
|
# 使用torch-pruning
|
|||
|
|
import torch_pruning as tp
|
|||
|
|
|
|||
|
|
# 对SwinTransformer剪枝
|
|||
|
|
pruner = tp.pruner.MetaPruner(
|
|||
|
|
model.encoders['camera'].backbone,
|
|||
|
|
example_inputs,
|
|||
|
|
importance=tp.importance.MagnitudeImportance(p=2),
|
|||
|
|
pruning_ratio=0.17, # 17%
|
|||
|
|
)
|
|||
|
|
pruner.step()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
### Layer 2: ASPP模块 (4.13M → 3M)
|
|||
|
|
|
|||
|
|
**方法**: 通道剪枝
|
|||
|
|
```
|
|||
|
|
ASPP通道数: 512 → 384 (-25%)
|
|||
|
|
各分支同比例剪枝
|
|||
|
|
|
|||
|
|
预期减少: ~1.1M参数
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
### Layer 3: Decoder (4.58M → 3.9M)
|
|||
|
|
|
|||
|
|
**方法**: 通道剪枝
|
|||
|
|
```
|
|||
|
|
Decoder通道: [128, 256] → [96, 192] (-25%)
|
|||
|
|
|
|||
|
|
预期减少: ~0.7M参数
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
### Layer 4: Camera VTransform (2.61M → 2.35M)
|
|||
|
|
|
|||
|
|
**方法**: 轻度剪枝
|
|||
|
|
```
|
|||
|
|
通道剪枝: 10%
|
|||
|
|
|
|||
|
|
预期减少: ~0.26M参数
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 📊 预期结果
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
总计减少: 7.56M参数 (-16.5%)
|
|||
|
|
保守估计: 剪枝后38M参数
|
|||
|
|
|
|||
|
|
如果效果好,可进一步剪枝到32M
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 🚀 下一步
|
|||
|
|
|
|||
|
|
1. 创建剪枝脚本
|
|||
|
|
2. 小规模测试剪枝
|
|||
|
|
3. 评估剪枝效果
|
|||
|
|
4. 全量剪枝
|
|||
|
|
5. 微调3 epochs
|
|||
|
|
|
|||
|
|
PLAN_EOF
|
|||
|
|
|
|||
|
|
echo "✅ 剪枝方案已生成: pruning_results/pruning_plan.md"
|
|||
|
|
echo ""
|
|||
|
|
|
|||
|
|
# 6. 总结
|
|||
|
|
echo "========================================================================"
|
|||
|
|
echo "准备工作完成!"
|
|||
|
|
echo "========================================================================"
|
|||
|
|
echo ""
|
|||
|
|
echo "已完成:"
|
|||
|
|
echo " ✅ torch-pruning检查/安装"
|
|||
|
|
echo " ✅ 工具目录创建"
|
|||
|
|
echo " ✅ 测试脚本创建"
|
|||
|
|
echo " ✅ 剪枝方案生成"
|
|||
|
|
echo ""
|
|||
|
|
echo "下一步:"
|
|||
|
|
echo " 1. 查看剪枝方案: cat pruning_results/pruning_plan.md"
|
|||
|
|
echo " 2. 创建剪枝脚本 (明天)"
|
|||
|
|
echo " 3. 开始剪枝实施 (明天)"
|
|||
|
|
echo ""
|
|||
|
|
echo "当前状态:"
|
|||
|
|
echo " - Stage 1训练: 继续进行中 (不受影响)"
|
|||
|
|
echo " - 优化准备: ✅ 就绪"
|
|||
|
|
echo ""
|
|||
|
|
echo "========================================================================"
|
|||
|
|
|