296 lines
6.6 KiB
Bash
Executable File
296 lines
6.6 KiB
Bash
Executable File
#!/bin/bash
|
||
# 准备模型剪枝工具和环境
|
||
|
||
set -e
|
||
|
||
cd /workspace/bevfusion
|
||
|
||
echo "========================================================================"
|
||
echo "准备模型剪枝工具"
|
||
echo "========================================================================"
|
||
echo ""
|
||
|
||
# 1. 检查torch-pruning
|
||
echo "========== 1. 检查torch-pruning =========="
|
||
echo ""
|
||
|
||
if /opt/conda/bin/python -c "import torch_pruning" 2>/dev/null; then
|
||
echo "✅ torch-pruning已安装"
|
||
VERSION=$(/opt/conda/bin/python -c "import torch_pruning as tp; print(tp.__version__)" 2>/dev/null || echo "未知")
|
||
echo " 版本: $VERSION"
|
||
else
|
||
echo "⏳ torch-pruning未安装,正在安装..."
|
||
/opt/conda/bin/pip install torch-pruning -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||
|
||
if [ $? -eq 0 ]; then
|
||
echo "✅ torch-pruning安装成功"
|
||
else
|
||
echo "⚠️ torch-pruning安装失败,尝试使用PyTorch内置剪枝"
|
||
fi
|
||
fi
|
||
|
||
echo ""
|
||
|
||
# 2. 创建剪枝工具目录
|
||
echo "========== 2. 创建工具目录 =========="
|
||
echo ""
|
||
|
||
mkdir -p tools/pruning
|
||
mkdir -p pruning_results
|
||
|
||
echo "✅ 目录已创建:"
|
||
echo " tools/pruning/ # 剪枝脚本"
|
||
echo " pruning_results/ # 剪枝结果"
|
||
echo ""
|
||
|
||
# 3. 创建简单的剪枝测试脚本
|
||
echo "========== 3. 创建剪枝测试脚本 =========="
|
||
echo ""
|
||
|
||
cat > tools/pruning/test_pruning.py << 'PYTHON_EOF'
|
||
#!/usr/bin/env python3
|
||
"""
|
||
测试torch-pruning是否可用
|
||
"""
|
||
|
||
import sys
|
||
sys.path.insert(0, '/workspace/bevfusion')
|
||
|
||
def test_torch_pruning():
|
||
"""测试torch-pruning"""
|
||
try:
|
||
import torch_pruning as tp
|
||
print(f"✅ torch-pruning可用")
|
||
print(f" 版本: {tp.__version__}")
|
||
return True
|
||
except ImportError:
|
||
print(f"❌ torch-pruning不可用")
|
||
return False
|
||
|
||
def test_pytorch_builtin():
|
||
"""测试PyTorch内置剪枝"""
|
||
try:
|
||
import torch.nn.utils.prune as prune
|
||
print(f"✅ PyTorch内置剪枝可用")
|
||
return True
|
||
except ImportError:
|
||
print(f"❌ PyTorch内置剪枝不可用")
|
||
return False
|
||
|
||
def test_checkpoint_load():
|
||
"""测试checkpoint加载"""
|
||
import torch
|
||
|
||
checkpoint_file = 'runs/enhanced_from_epoch19/epoch_23.pth'
|
||
|
||
try:
|
||
checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
||
|
||
if 'state_dict' in checkpoint:
|
||
state_dict = checkpoint['state_dict']
|
||
total_params = sum(v.numel() for v in state_dict.values() if isinstance(v, torch.Tensor))
|
||
|
||
print(f"✅ Checkpoint加载成功")
|
||
print(f" 参数量: {total_params/1e6:.2f}M")
|
||
print(f" Epoch: {checkpoint.get('epoch', 'unknown')}")
|
||
return True
|
||
else:
|
||
print(f"⚠️ Checkpoint格式异常")
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"❌ Checkpoint加载失败: {e}")
|
||
return False
|
||
|
||
if __name__ == '__main__':
|
||
print("=" * 60)
|
||
print("剪枝工具可用性测试")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
# 测试torch-pruning
|
||
print("1. torch-pruning:")
|
||
has_tp = test_torch_pruning()
|
||
print()
|
||
|
||
# 测试PyTorch内置
|
||
print("2. PyTorch内置剪枝:")
|
||
has_builtin = test_pytorch_builtin()
|
||
print()
|
||
|
||
# 测试checkpoint
|
||
print("3. Checkpoint加载:")
|
||
can_load = test_checkpoint_load()
|
||
print()
|
||
|
||
# 总结
|
||
print("=" * 60)
|
||
print("测试结果总结")
|
||
print("=" * 60)
|
||
|
||
if has_tp:
|
||
print("✅ 推荐使用: torch-pruning (功能更强大)")
|
||
elif has_builtin:
|
||
print("✅ 可以使用: PyTorch内置剪枝 (基本功能)")
|
||
else:
|
||
print("❌ 需要安装剪枝工具")
|
||
|
||
if can_load:
|
||
print("✅ Checkpoint可正常加载")
|
||
else:
|
||
print("❌ Checkpoint加载有问题")
|
||
|
||
print()
|
||
print("准备就绪!可以开始剪枝。")
|
||
|
||
PYTHON_EOF
|
||
|
||
chmod +x tools/pruning/test_pruning.py
|
||
|
||
echo "✅ 测试脚本已创建: tools/pruning/test_pruning.py"
|
||
echo ""
|
||
|
||
# 4. 运行测试
|
||
echo "========== 4. 运行可用性测试 =========="
|
||
echo ""
|
||
|
||
/opt/conda/bin/python tools/pruning/test_pruning.py
|
||
|
||
echo ""
|
||
|
||
# 5. 生成剪枝方案文件
|
||
echo "========== 5. 生成剪枝方案 =========="
|
||
echo ""
|
||
|
||
cat > pruning_results/pruning_plan.md << 'PLAN_EOF'
|
||
# BEVFusion剪枝实施方案
|
||
|
||
**生成时间**: $(date)
|
||
**Baseline**: epoch_23.pth (45.72M参数)
|
||
|
||
---
|
||
|
||
## 🎯 剪枝目标
|
||
|
||
```
|
||
参数量: 45.72M → 32M (-30%)
|
||
模型大小: 174MB → 122MB (FP32)
|
||
推理加速: ~30%
|
||
精度损失: <2%
|
||
```
|
||
|
||
---
|
||
|
||
## 📋 分层剪枝计划
|
||
|
||
### Layer 1: Camera Backbone (27.55M → 20M)
|
||
|
||
**方法**: 通道剪枝
|
||
```
|
||
SwinTransformer各stage剪枝:
|
||
Stage 1: 96 channels → 80 (-17%)
|
||
Stage 2: 192 channels → 160 (-17%)
|
||
Stage 3: 384 channels → 320 (-17%)
|
||
Stage 4: 768 channels → 640 (-17%)
|
||
|
||
预期减少: ~5.5M参数
|
||
```
|
||
|
||
**实施**:
|
||
```python
|
||
# 使用torch-pruning
|
||
import torch_pruning as tp
|
||
|
||
# 对SwinTransformer剪枝
|
||
pruner = tp.pruner.MetaPruner(
|
||
model.encoders['camera'].backbone,
|
||
example_inputs,
|
||
importance=tp.importance.MagnitudeImportance(p=2),
|
||
pruning_ratio=0.17, # 17%
|
||
)
|
||
pruner.step()
|
||
```
|
||
|
||
---
|
||
|
||
### Layer 2: ASPP模块 (4.13M → 3M)
|
||
|
||
**方法**: 通道剪枝
|
||
```
|
||
ASPP通道数: 512 → 384 (-25%)
|
||
各分支同比例剪枝
|
||
|
||
预期减少: ~1.1M参数
|
||
```
|
||
|
||
---
|
||
|
||
### Layer 3: Decoder (4.58M → 3.9M)
|
||
|
||
**方法**: 通道剪枝
|
||
```
|
||
Decoder通道: [128, 256] → [96, 192] (-25%)
|
||
|
||
预期减少: ~0.7M参数
|
||
```
|
||
|
||
---
|
||
|
||
### Layer 4: Camera VTransform (2.61M → 2.35M)
|
||
|
||
**方法**: 轻度剪枝
|
||
```
|
||
通道剪枝: 10%
|
||
|
||
预期减少: ~0.26M参数
|
||
```
|
||
|
||
---
|
||
|
||
## 📊 预期结果
|
||
|
||
```
|
||
总计减少: 7.56M参数 (-16.5%)
|
||
保守估计: 剪枝后38M参数
|
||
|
||
如果效果好,可进一步剪枝到32M
|
||
```
|
||
|
||
---
|
||
|
||
## 🚀 下一步
|
||
|
||
1. 创建剪枝脚本
|
||
2. 小规模测试剪枝
|
||
3. 评估剪枝效果
|
||
4. 全量剪枝
|
||
5. 微调3 epochs
|
||
|
||
PLAN_EOF
|
||
|
||
echo "✅ 剪枝方案已生成: pruning_results/pruning_plan.md"
|
||
echo ""
|
||
|
||
# 6. 总结
|
||
echo "========================================================================"
|
||
echo "准备工作完成!"
|
||
echo "========================================================================"
|
||
echo ""
|
||
echo "已完成:"
|
||
echo " ✅ torch-pruning检查/安装"
|
||
echo " ✅ 工具目录创建"
|
||
echo " ✅ 测试脚本创建"
|
||
echo " ✅ 剪枝方案生成"
|
||
echo ""
|
||
echo "下一步:"
|
||
echo " 1. 查看剪枝方案: cat pruning_results/pruning_plan.md"
|
||
echo " 2. 创建剪枝脚本 (明天)"
|
||
echo " 3. 开始剪枝实施 (明天)"
|
||
echo ""
|
||
echo "当前状态:"
|
||
echo " - Stage 1训练: 继续进行中 (不受影响)"
|
||
echo " - 优化准备: ✅ 就绪"
|
||
echo ""
|
||
echo "========================================================================"
|
||
|