bev-project/scripts/train_three_tasks.sh

79 lines
2.6 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/bin/bash
# 三任务训练脚本: 3D检测 + BEV分割 + 矢量地图
export PATH=/opt/conda/bin:$PATH
cd /workspace/bevfusion
echo "=========================================="
echo "BEVFusion 三任务训练"
echo "任务: 3D检测 + BEV分割 + 矢量地图"
echo "=========================================="
# 检查矢量地图数据
if [ ! -f "data/nuscenes/vector_maps_bevfusion.pkl" ]; then
echo "❌ 矢量地图数据不存在!"
echo "请先运行: bash scripts/extract_vector_map.sh"
exit 1
fi
# 检查当前训练的checkpoint
PRETRAINED_MODEL="runs/run-326653dc-74184412/epoch_20.pth"
if [ ! -f "$PRETRAINED_MODEL" ]; then
echo "⚠️ 预训练模型不存在: $PRETRAINED_MODEL"
echo "将从头开始训练(不推荐)"
LOAD_FROM=""
else
echo "✅ 加载预训练模型: $PRETRAINED_MODEL"
LOAD_FROM="--load_from $PRETRAINED_MODEL"
fi
# 训练策略选择
echo -e "\n请选择训练策略:"
echo "1) 分阶段训练(推荐): 先冻结其他任务只训练矢量地图head"
echo "2) 一步训练: 三任务同时训练"
read -p "请选择 [1/2]: " strategy
if [ "$strategy" == "1" ]; then
echo -e "\n========== 阶段1: 训练矢量地图head冻结其他 =========="
torchpack dist-run -np 8 python tools/train.py \
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
$LOAD_FROM \
--cfg-options \
model.freeze_encoder=true \
max_epochs=6 \
optimizer.lr=2.0e-4 \
--data.workers_per_gpu 0 \
2>&1 | tee logs/three_tasks_stage1.log
echo -e "\n========== 阶段2: 三任务联合fine-tune =========="
STAGE1_MODEL="runs/three_tasks/epoch_6.pth"
if [ ! -f "$STAGE1_MODEL" ]; then
echo "⚠️ 阶段1模型不存在使用最新checkpoint"
STAGE1_MODEL=$(ls -t runs/three_tasks/*.pth | head -1)
fi
torchpack dist-run -np 8 python tools/train.py \
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
--load_from $STAGE1_MODEL \
--cfg-options \
max_epochs=12 \
optimizer.lr=1.0e-4 \
--data.workers_per_gpu 0 \
2>&1 | tee logs/three_tasks_stage2.log
else
echo -e "\n========== 三任务同时训练 =========="
torchpack dist-run -np 8 python tools/train.py \
configs/nuscenes/three_tasks/bevfusion_det_seg_vec.yaml \
$LOAD_FROM \
--cfg-options \
max_epochs=20 \
optimizer.lr=1.5e-4 \
--data.workers_per_gpu 0 \
2>&1 | tee logs/three_tasks_full.log
fi
echo -e "\n✅ 三任务训练完成!"
echo "检查日志: logs/three_tasks_*.log"