bev-project/archive/docs_old/代码修改说明_Batch2支持_20251102.md

3.7 KiB
Raw Blame History

代码修改支持Batch Size > 1

修改时间: 2025-11-02 12:00 UTC
目的: 移除batch size限制支持batch=2训练


🔧 修改内容

文件: mmdet3d/models/fusion_models/bevfusion.py

位置: 第334-339行

原代码:

if self.fuser is not None:
    x = self.fuser(features)
else:
    assert len(features) == 1, features  # ← 限制batch=1
    x = features[0]

修改后:

if self.fuser is not None:
    x = self.fuser(features)
else:
    # 修改支持batch>1使用torch.cat处理多个features
    if len(features) == 1:
        x = features[0]
    else:
        # 如果有多个features沿batch维度拼接
        x = torch.cat(features, dim=0)

⚠️ 风险评估

可能的问题

  1. 特征拼接正确性:

    • 假设: 多个features可以沿batch维度拼接
    • 风险: 如果features形状不一致会出错
    • 缓解: batch内样本应该shape一致
  2. 内存占用:

    • Batch=2会增加显存占用
    • 预期: 24-26GB/GPU (仍在32GB范围内)
  3. 数值稳定性:

    • 大batch可能影响收敛
    • 已调整学习率: 2e-5 → 4e-5

配置更新

配置文件

# multitask_BEV2X_phase4a_stage1_fp16.yaml

data:
  samples_per_gpu: 2  # 启用batch=2
  workers_per_gpu: 0

optimizer:
  lr: 4.0e-5  # 2倍学习率

启动脚本

# RESTART_PHASE4A_STAGE1_FP16.sh
# 已更新为batch=2配置

🧪 测试验证

启动后需要验证

1. 训练正常启动5分钟内

# 查看日志,确认无错误
tail -100 $(ls -t phase4a_stage1_fp16_batch2*.log | head -1)

# 应该看到正常的训练日志无AssertionError

2. 显存占用正常5分钟内

# 检查显存
nvidia-smi

# 预期: 24-26GB/GPU
# 如果>30GB → 可能OOM风险

3. 速度提升10分钟后

# 查看iteration速度
tail -50 $(ls -t phase4a_stage1_fp16_batch2*.log | head -1) | grep "time:"

# 预期: ~1.4-1.6s/iter
# 如果仍是2.6s → 优化未生效

4. Loss稳定30分钟后

# 查看Loss值
tail -100 $(ls -t phase4a_stage1_fp16_batch2*.log | head -1) | grep "loss:" | tail -10

# Loss应该在2.5-2.8范围
# 注意是否有NaN或异常波动

🎯 预期效果

如果成功

指标 FP32 FP16+Batch2 改进
训练速度 2.65s/iter ~1.5s/iter +43%
显存 29GB ~25GB 节省4GB
完成时间 9天 ~5天 节省4天

如果失败

可能的失败情况:

  1. OOM (显存不足)
  2. Loss发散或NaN
  3. 其他运行时错误

回退方案:

# 恢复batch=1配置
# 修改配置文件中 samples_per_gpu: 1
# 修改 lr: 2.0e-5

📝 监控清单

立即检查启动后5分钟

  • 训练进程正常运行
  • 无AssertionError或其他错误
  • 显存占用在24-26GB范围
  • GPU利用率100%

短期检查30分钟-1小时

  • Loss值正常2.5-2.8
  • 无NaN或Inf
  • 梯度范数正常10-20
  • iteration速度约1.5s

中期检查3小时

  • Loss持续下降
  • 显存稳定
  • 无错误或警告
  • 速度保持稳定

🔄 如果需要回退

Batch=1配置

data:
  samples_per_gpu: 1

optimizer:
  lr: 2.0e-5

恢复代码(可选)

# 如果torch.cat导致问题可以恢复原代码
assert len(features) == 1, features
x = features[0]

状态: 代码已修改,配置已更新
下一步: 清理进程并启动训练
风险: 中等(需要验证)

建议: 密切监控前1小时的训练状态