#!/bin/bash # 准备模型剪枝工具和环境 set -e cd /workspace/bevfusion echo "========================================================================" echo "准备模型剪枝工具" echo "========================================================================" echo "" # 1. 检查torch-pruning echo "========== 1. 检查torch-pruning ==========" echo "" if /opt/conda/bin/python -c "import torch_pruning" 2>/dev/null; then echo "✅ torch-pruning已安装" VERSION=$(/opt/conda/bin/python -c "import torch_pruning as tp; print(tp.__version__)" 2>/dev/null || echo "未知") echo " 版本: $VERSION" else echo "⏳ torch-pruning未安装,正在安装..." /opt/conda/bin/pip install torch-pruning -i https://pypi.tuna.tsinghua.edu.cn/simple if [ $? -eq 0 ]; then echo "✅ torch-pruning安装成功" else echo "⚠️ torch-pruning安装失败,尝试使用PyTorch内置剪枝" fi fi echo "" # 2. 创建剪枝工具目录 echo "========== 2. 创建工具目录 ==========" echo "" mkdir -p tools/pruning mkdir -p pruning_results echo "✅ 目录已创建:" echo " tools/pruning/ # 剪枝脚本" echo " pruning_results/ # 剪枝结果" echo "" # 3. 创建简单的剪枝测试脚本 echo "========== 3. 创建剪枝测试脚本 ==========" echo "" cat > tools/pruning/test_pruning.py << 'PYTHON_EOF' #!/usr/bin/env python3 """ 测试torch-pruning是否可用 """ import sys sys.path.insert(0, '/workspace/bevfusion') def test_torch_pruning(): """测试torch-pruning""" try: import torch_pruning as tp print(f"✅ torch-pruning可用") print(f" 版本: {tp.__version__}") return True except ImportError: print(f"❌ torch-pruning不可用") return False def test_pytorch_builtin(): """测试PyTorch内置剪枝""" try: import torch.nn.utils.prune as prune print(f"✅ PyTorch内置剪枝可用") return True except ImportError: print(f"❌ PyTorch内置剪枝不可用") return False def test_checkpoint_load(): """测试checkpoint加载""" import torch checkpoint_file = 'runs/enhanced_from_epoch19/epoch_23.pth' try: checkpoint = torch.load(checkpoint_file, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] total_params = sum(v.numel() for v in state_dict.values() if isinstance(v, torch.Tensor)) print(f"✅ Checkpoint加载成功") print(f" 参数量: {total_params/1e6:.2f}M") print(f" Epoch: {checkpoint.get('epoch', 'unknown')}") return True else: print(f"⚠️ Checkpoint格式异常") return False except Exception as e: print(f"❌ Checkpoint加载失败: {e}") return False if __name__ == '__main__': print("=" * 60) print("剪枝工具可用性测试") print("=" * 60) print() # 测试torch-pruning print("1. torch-pruning:") has_tp = test_torch_pruning() print() # 测试PyTorch内置 print("2. PyTorch内置剪枝:") has_builtin = test_pytorch_builtin() print() # 测试checkpoint print("3. Checkpoint加载:") can_load = test_checkpoint_load() print() # 总结 print("=" * 60) print("测试结果总结") print("=" * 60) if has_tp: print("✅ 推荐使用: torch-pruning (功能更强大)") elif has_builtin: print("✅ 可以使用: PyTorch内置剪枝 (基本功能)") else: print("❌ 需要安装剪枝工具") if can_load: print("✅ Checkpoint可正常加载") else: print("❌ Checkpoint加载有问题") print() print("准备就绪!可以开始剪枝。") PYTHON_EOF chmod +x tools/pruning/test_pruning.py echo "✅ 测试脚本已创建: tools/pruning/test_pruning.py" echo "" # 4. 运行测试 echo "========== 4. 运行可用性测试 ==========" echo "" /opt/conda/bin/python tools/pruning/test_pruning.py echo "" # 5. 生成剪枝方案文件 echo "========== 5. 生成剪枝方案 ==========" echo "" cat > pruning_results/pruning_plan.md << 'PLAN_EOF' # BEVFusion剪枝实施方案 **生成时间**: $(date) **Baseline**: epoch_23.pth (45.72M参数) --- ## 🎯 剪枝目标 ``` 参数量: 45.72M → 32M (-30%) 模型大小: 174MB → 122MB (FP32) 推理加速: ~30% 精度损失: <2% ``` --- ## 📋 分层剪枝计划 ### Layer 1: Camera Backbone (27.55M → 20M) **方法**: 通道剪枝 ``` SwinTransformer各stage剪枝: Stage 1: 96 channels → 80 (-17%) Stage 2: 192 channels → 160 (-17%) Stage 3: 384 channels → 320 (-17%) Stage 4: 768 channels → 640 (-17%) 预期减少: ~5.5M参数 ``` **实施**: ```python # 使用torch-pruning import torch_pruning as tp # 对SwinTransformer剪枝 pruner = tp.pruner.MetaPruner( model.encoders['camera'].backbone, example_inputs, importance=tp.importance.MagnitudeImportance(p=2), pruning_ratio=0.17, # 17% ) pruner.step() ``` --- ### Layer 2: ASPP模块 (4.13M → 3M) **方法**: 通道剪枝 ``` ASPP通道数: 512 → 384 (-25%) 各分支同比例剪枝 预期减少: ~1.1M参数 ``` --- ### Layer 3: Decoder (4.58M → 3.9M) **方法**: 通道剪枝 ``` Decoder通道: [128, 256] → [96, 192] (-25%) 预期减少: ~0.7M参数 ``` --- ### Layer 4: Camera VTransform (2.61M → 2.35M) **方法**: 轻度剪枝 ``` 通道剪枝: 10% 预期减少: ~0.26M参数 ``` --- ## 📊 预期结果 ``` 总计减少: 7.56M参数 (-16.5%) 保守估计: 剪枝后38M参数 如果效果好,可进一步剪枝到32M ``` --- ## 🚀 下一步 1. 创建剪枝脚本 2. 小规模测试剪枝 3. 评估剪枝效果 4. 全量剪枝 5. 微调3 epochs PLAN_EOF echo "✅ 剪枝方案已生成: pruning_results/pruning_plan.md" echo "" # 6. 总结 echo "========================================================================" echo "准备工作完成!" echo "========================================================================" echo "" echo "已完成:" echo " ✅ torch-pruning检查/安装" echo " ✅ 工具目录创建" echo " ✅ 测试脚本创建" echo " ✅ 剪枝方案生成" echo "" echo "下一步:" echo " 1. 查看剪枝方案: cat pruning_results/pruning_plan.md" echo " 2. 创建剪枝脚本 (明天)" echo " 3. 开始剪枝实施 (明天)" echo "" echo "当前状态:" echo " - Stage 1训练: 继续进行中 (不受影响)" echo " - 优化准备: ✅ 就绪" echo "" echo "========================================================================"