79 lines
2.6 KiB
Bash
Executable File
79 lines
2.6 KiB
Bash
Executable File
#!/bin/bash
|
||
# 三任务训练脚本: 3D检测 + BEV分割 + 矢量地图
|
||
|
||
export PATH=/opt/conda/bin:$PATH
|
||
cd /workspace/bevfusion
|
||
|
||
echo "=========================================="
|
||
echo "BEVFusion 三任务训练"
|
||
echo "任务: 3D检测 + BEV分割 + 矢量地图"
|
||
echo "=========================================="
|
||
|
||
# 检查矢量地图数据
|
||
if [ ! -f "data/nuscenes/vector_maps_bevfusion.pkl" ]; then
|
||
echo "❌ 矢量地图数据不存在!"
|
||
echo "请先运行: bash scripts/extract_vector_map.sh"
|
||
exit 1
|
||
fi
|
||
|
||
# 检查当前训练的checkpoint
|
||
PRETRAINED_MODEL="runs/run-326653dc-74184412/epoch_20.pth"
|
||
if [ ! -f "$PRETRAINED_MODEL" ]; then
|
||
echo "⚠️ 预训练模型不存在: $PRETRAINED_MODEL"
|
||
echo "将从头开始训练(不推荐)"
|
||
LOAD_FROM=""
|
||
else
|
||
echo "✅ 加载预训练模型: $PRETRAINED_MODEL"
|
||
LOAD_FROM="--load_from $PRETRAINED_MODEL"
|
||
fi
|
||
|
||
# 训练策略选择
|
||
echo -e "\n请选择训练策略:"
|
||
echo "1) 分阶段训练(推荐): 先冻结其他任务,只训练矢量地图head"
|
||
echo "2) 一步训练: 三任务同时训练"
|
||
read -p "请选择 [1/2]: " strategy
|
||
|
||
if [ "$strategy" == "1" ]; then
|
||
echo -e "\n========== 阶段1: 训练矢量地图head(冻结其他) =========="
|
||
torchpack dist-run -np 8 python tools/train.py \
|
||
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
|
||
$LOAD_FROM \
|
||
--cfg-options \
|
||
model.freeze_encoder=true \
|
||
max_epochs=6 \
|
||
optimizer.lr=2.0e-4 \
|
||
--data.workers_per_gpu 0 \
|
||
2>&1 | tee logs/three_tasks_stage1.log
|
||
|
||
echo -e "\n========== 阶段2: 三任务联合fine-tune =========="
|
||
STAGE1_MODEL="runs/three_tasks/epoch_6.pth"
|
||
if [ ! -f "$STAGE1_MODEL" ]; then
|
||
echo "⚠️ 阶段1模型不存在,使用最新checkpoint"
|
||
STAGE1_MODEL=$(ls -t runs/three_tasks/*.pth | head -1)
|
||
fi
|
||
|
||
torchpack dist-run -np 8 python tools/train.py \
|
||
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
|
||
--load_from $STAGE1_MODEL \
|
||
--cfg-options \
|
||
max_epochs=12 \
|
||
optimizer.lr=1.0e-4 \
|
||
--data.workers_per_gpu 0 \
|
||
2>&1 | tee logs/three_tasks_stage2.log
|
||
|
||
else
|
||
echo -e "\n========== 三任务同时训练 =========="
|
||
torchpack dist-run -np 8 python tools/train.py \
|
||
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
|
||
$LOAD_FROM \
|
||
--cfg-options \
|
||
max_epochs=20 \
|
||
optimizer.lr=1.5e-4 \
|
||
--data.workers_per_gpu 0 \
|
||
2>&1 | tee logs/three_tasks_full.log
|
||
fi
|
||
|
||
echo -e "\n✅ 三任务训练完成!"
|
||
echo "检查日志: logs/three_tasks_*.log"
|
||
|