bev-project/scripts/train_three_tasks.sh

79 lines
2.6 KiB
Bash
Raw Normal View History

#!/bin/bash
# 三任务训练脚本: 3D检测 + BEV分割 + 矢量地图
export PATH=/opt/conda/bin:$PATH
cd /workspace/bevfusion
echo "=========================================="
echo "BEVFusion 三任务训练"
echo "任务: 3D检测 + BEV分割 + 矢量地图"
echo "=========================================="
# 检查矢量地图数据
if [ ! -f "data/nuscenes/vector_maps_bevfusion.pkl" ]; then
echo "❌ 矢量地图数据不存在!"
echo "请先运行: bash scripts/extract_vector_map.sh"
exit 1
fi
# 检查当前训练的checkpoint
PRETRAINED_MODEL="runs/run-326653dc-74184412/epoch_20.pth"
if [ ! -f "$PRETRAINED_MODEL" ]; then
echo "⚠️ 预训练模型不存在: $PRETRAINED_MODEL"
echo "将从头开始训练(不推荐)"
LOAD_FROM=""
else
echo "✅ 加载预训练模型: $PRETRAINED_MODEL"
LOAD_FROM="--load_from $PRETRAINED_MODEL"
fi
# 训练策略选择
echo -e "\n请选择训练策略:"
echo "1) 分阶段训练(推荐): 先冻结其他任务只训练矢量地图head"
echo "2) 一步训练: 三任务同时训练"
read -p "请选择 [1/2]: " strategy
if [ "$strategy" == "1" ]; then
echo -e "\n========== 阶段1: 训练矢量地图head冻结其他 =========="
torchpack dist-run -np 8 python tools/train.py \
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
$LOAD_FROM \
--cfg-options \
model.freeze_encoder=true \
max_epochs=6 \
optimizer.lr=2.0e-4 \
--data.workers_per_gpu 0 \
2>&1 | tee logs/three_tasks_stage1.log
echo -e "\n========== 阶段2: 三任务联合fine-tune =========="
STAGE1_MODEL="runs/three_tasks/epoch_6.pth"
if [ ! -f "$STAGE1_MODEL" ]; then
echo "⚠️ 阶段1模型不存在使用最新checkpoint"
STAGE1_MODEL=$(ls -t runs/three_tasks/*.pth | head -1)
fi
torchpack dist-run -np 8 python tools/train.py \
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
--load_from $STAGE1_MODEL \
--cfg-options \
max_epochs=12 \
optimizer.lr=1.0e-4 \
--data.workers_per_gpu 0 \
2>&1 | tee logs/three_tasks_stage2.log
else
echo -e "\n========== 三任务同时训练 =========="
torchpack dist-run -np 8 python tools/train.py \
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
$LOAD_FROM \
--cfg-options \
max_epochs=20 \
optimizer.lr=1.5e-4 \
--data.workers_per_gpu 0 \
2>&1 | tee logs/three_tasks_full.log
fi
echo -e "\n✅ 三任务训练完成!"
echo "检查日志: logs/three_tasks_*.log"