277 lines
7.8 KiB
Bash
Executable File
277 lines
7.8 KiB
Bash
Executable File
#!/bin/bash
|
||
# BEVFusion 详细环境检查脚本(参考README.md)
|
||
|
||
echo "=========================================="
|
||
echo "BEVFusion 环境依赖检查"
|
||
echo "参考:README.md Prerequisites"
|
||
echo "=========================================="
|
||
echo ""
|
||
|
||
PASS=0
|
||
FAIL=0
|
||
|
||
# 检查函数
|
||
check_pass() {
|
||
echo " ✓ $1"
|
||
((PASS++))
|
||
}
|
||
|
||
check_fail() {
|
||
echo " ✗ $1"
|
||
((FAIL++))
|
||
}
|
||
|
||
check_warn() {
|
||
echo " ⚠ $1"
|
||
}
|
||
|
||
# 1. Python版本检查
|
||
echo "[1/12] Python 版本 (需要 >= 3.8, < 3.9)"
|
||
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
|
||
PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1)
|
||
PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2)
|
||
|
||
if [ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -eq 8 ]; then
|
||
check_pass "Python $PYTHON_VERSION"
|
||
else
|
||
check_warn "Python $PYTHON_VERSION (推荐3.8.x)"
|
||
fi
|
||
|
||
# 2. PyTorch版本检查
|
||
echo ""
|
||
echo "[2/12] PyTorch 版本 (需要 >= 1.9, <= 1.10.2)"
|
||
PYTORCH_VERSION=$(python3 -c "import torch; print(torch.__version__)" 2>/dev/null)
|
||
if [ $? -eq 0 ]; then
|
||
check_pass "PyTorch $PYTORCH_VERSION"
|
||
|
||
# 检查CUDA
|
||
CUDA_AVAILABLE=$(python3 -c "import torch; print(torch.cuda.is_available())" 2>/dev/null)
|
||
if [ "$CUDA_AVAILABLE" = "True" ]; then
|
||
GPU_COUNT=$(python3 -c "import torch; print(torch.cuda.device_count())" 2>/dev/null)
|
||
check_pass "CUDA 可用,GPU数量: $GPU_COUNT"
|
||
else
|
||
check_fail "CUDA 不可用"
|
||
fi
|
||
else
|
||
check_fail "PyTorch 未安装"
|
||
fi
|
||
|
||
# 3. torchpack检查
|
||
echo ""
|
||
echo "[3/12] torchpack (必需)"
|
||
TORCHPACK_VERSION=$(python3 -c "import torchpack; print(torchpack.__version__)" 2>/dev/null)
|
||
if [ $? -eq 0 ]; then
|
||
check_pass "torchpack $TORCHPACK_VERSION"
|
||
else
|
||
check_fail "torchpack 未安装"
|
||
echo " 安装: pip install torchpack"
|
||
fi
|
||
|
||
# 4. mmcv检查
|
||
echo ""
|
||
echo "[4/12] mmcv (需要 = 1.4.0)"
|
||
MMCV_VERSION=$(python3 -c "import mmcv; print(mmcv.__version__)" 2>/dev/null)
|
||
if [ $? -eq 0 ]; then
|
||
if [ "$MMCV_VERSION" = "1.4.0" ]; then
|
||
check_pass "mmcv $MMCV_VERSION"
|
||
else
|
||
check_warn "mmcv $MMCV_VERSION (推荐1.4.0)"
|
||
fi
|
||
else
|
||
check_fail "mmcv 未安装"
|
||
echo " 安装: pip install mmcv==1.4.0 mmcv-full==1.4.0"
|
||
fi
|
||
|
||
# 5. mmdet检查
|
||
echo ""
|
||
echo "[5/12] mmdetection (需要 = 2.20.0)"
|
||
MMDET_VERSION=$(python3 -c "import mmdet; print(mmdet.__version__)" 2>/dev/null)
|
||
if [ $? -eq 0 ]; then
|
||
if [ "$MMDET_VERSION" = "2.20.0" ]; then
|
||
check_pass "mmdetection $MMDET_VERSION"
|
||
else
|
||
check_warn "mmdetection $MMDET_VERSION (推荐2.20.0)"
|
||
fi
|
||
else
|
||
check_fail "mmdetection 未安装"
|
||
echo " 安装: pip install mmdet==2.20.0"
|
||
fi
|
||
|
||
# 6. OpenMPI和mpi4py检查
|
||
echo ""
|
||
echo "[6/12] OpenMPI 和 mpi4py (torchpack需要)"
|
||
if command -v mpirun &> /dev/null; then
|
||
OMPI_VERSION=$(mpirun --version 2>&1 | head -n 1)
|
||
check_pass "OpenMPI: $OMPI_VERSION"
|
||
else
|
||
check_warn "OpenMPI 未安装(可选)"
|
||
fi
|
||
|
||
MPI4PY_VERSION=$(python3 -c "import mpi4py; print(mpi4py.__version__)" 2>/dev/null)
|
||
if [ $? -eq 0 ]; then
|
||
check_pass "mpi4py $MPI4PY_VERSION"
|
||
else
|
||
check_warn "mpi4py 未安装(可选)"
|
||
fi
|
||
|
||
# 7. Pillow检查
|
||
echo ""
|
||
echo "[7/12] Pillow (需要 = 8.4.0)"
|
||
PILLOW_VERSION=$(python3 -c "from PIL import Image; import PIL; print(PIL.__version__)" 2>/dev/null)
|
||
if [ $? -eq 0 ]; then
|
||
if [ "$PILLOW_VERSION" = "8.4.0" ]; then
|
||
check_pass "Pillow $PILLOW_VERSION"
|
||
else
|
||
check_warn "Pillow $PILLOW_VERSION (推荐8.4.0)"
|
||
fi
|
||
else
|
||
check_fail "Pillow 未安装"
|
||
fi
|
||
|
||
# 8. nuscenes-devkit检查
|
||
echo ""
|
||
echo "[8/12] nuscenes-devkit (必需)"
|
||
python3 -c "from nuscenes import NuScenes" 2>/dev/null
|
||
if [ $? -eq 0 ]; then
|
||
check_pass "nuscenes-devkit 已安装"
|
||
else
|
||
check_fail "nuscenes-devkit 未安装"
|
||
echo " 安装: pip install nuscenes-devkit"
|
||
fi
|
||
|
||
# 9. tqdm检查
|
||
echo ""
|
||
echo "[9/12] tqdm (必需)"
|
||
python3 -c "import tqdm" 2>/dev/null
|
||
if [ $? -eq 0 ]; then
|
||
check_pass "tqdm 已安装"
|
||
else
|
||
check_fail "tqdm 未安装"
|
||
echo " 安装: pip install tqdm"
|
||
fi
|
||
|
||
# 10. 自定义CUDA算子检查
|
||
echo ""
|
||
echo "[10/12] 自定义CUDA算子 (需要运行 python setup.py develop)"
|
||
python3 -c "from mmdet3d.ops import bev_pool_v2" 2>/dev/null
|
||
if [ $? -eq 0 ]; then
|
||
check_pass "BEV Pool 算子已编译"
|
||
else
|
||
check_fail "BEV Pool 算子未编译"
|
||
echo " 运行: cd /workspace/bevfusion && python3 setup.py develop"
|
||
fi
|
||
|
||
python3 -c "from mmdet3d.ops import Voxelization" 2>/dev/null
|
||
if [ $? -eq 0 ]; then
|
||
check_pass "Voxelization 算子已编译"
|
||
else
|
||
check_fail "Voxelization 算子未编译"
|
||
fi
|
||
|
||
python3 -c "from mmdet3d.ops.spconv import SparseConv3d" 2>/dev/null
|
||
if [ $? -eq 0 ]; then
|
||
check_pass "Sparse Conv 算子已编译"
|
||
else
|
||
check_fail "Sparse Conv 算子未编译"
|
||
fi
|
||
|
||
# 11. 数据集检查
|
||
echo ""
|
||
echo "[11/12] nuScenes 数据集"
|
||
DATA_ROOT="data/nuscenes"
|
||
if [ -d "$DATA_ROOT" ]; then
|
||
check_pass "数据集目录存在"
|
||
|
||
# 检查关键文件
|
||
if [ -f "$DATA_ROOT/nuscenes_infos_train.pkl" ]; then
|
||
check_pass "训练数据info文件"
|
||
else
|
||
check_fail "nuscenes_infos_train.pkl 缺失"
|
||
fi
|
||
|
||
if [ -f "$DATA_ROOT/nuscenes_infos_val.pkl" ]; then
|
||
check_pass "验证数据info文件"
|
||
else
|
||
check_fail "nuscenes_infos_val.pkl 缺失"
|
||
fi
|
||
|
||
if [ -d "$DATA_ROOT/samples" ]; then
|
||
SAMPLE_COUNT=$(find "$DATA_ROOT/samples" -type f | wc -l)
|
||
check_pass "samples目录 ($SAMPLE_COUNT 文件)"
|
||
else
|
||
check_fail "samples目录缺失"
|
||
fi
|
||
else
|
||
check_fail "数据集目录不存在: $DATA_ROOT"
|
||
echo " 参考README.md的Data Preparation部分"
|
||
fi
|
||
|
||
# 12. 预训练模型检查
|
||
echo ""
|
||
echo "[12/12] 预训练模型"
|
||
PRETRAIN_DIR="pretrained"
|
||
if [ -d "$PRETRAIN_DIR" ]; then
|
||
check_pass "预训练模型目录存在"
|
||
|
||
# 检查关键模型
|
||
if [ -f "$PRETRAIN_DIR/swint-nuimages-pretrained.pth" ]; then
|
||
SIZE=$(du -h "$PRETRAIN_DIR/swint-nuimages-pretrained.pth" | cut -f1)
|
||
check_pass "swint-nuimages-pretrained.pth ($SIZE)"
|
||
else
|
||
check_warn "swint-nuimages-pretrained.pth 缺失"
|
||
fi
|
||
|
||
if [ -f "$PRETRAIN_DIR/lidar-only-det.pth" ]; then
|
||
SIZE=$(du -h "$PRETRAIN_DIR/lidar-only-det.pth" | cut -f1)
|
||
check_pass "lidar-only-det.pth ($SIZE)"
|
||
else
|
||
check_warn "lidar-only-det.pth 缺失"
|
||
fi
|
||
else
|
||
check_warn "预训练模型目录不存在"
|
||
echo " 运行: ./tools/download_pretrained.sh"
|
||
fi
|
||
|
||
# 总结
|
||
echo ""
|
||
echo "=========================================="
|
||
echo "检查完成"
|
||
echo "通过: $PASS 项"
|
||
echo "失败: $FAIL 项"
|
||
echo "=========================================="
|
||
|
||
if [ $FAIL -eq 0 ]; then
|
||
echo ""
|
||
echo "✓ 环境配置完整,可以开始训练!"
|
||
echo ""
|
||
echo "训练命令示例:"
|
||
echo " # 3D检测"
|
||
echo " torchpack dist-run -np 8 python3 tools/train.py \\"
|
||
echo " configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/convfuser.yaml \\"
|
||
echo " --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth \\"
|
||
echo " --load_from pretrained/lidar-only-det.pth"
|
||
echo ""
|
||
echo " # BEV分割"
|
||
echo " torchpack dist-run -np 8 python3 tools/train.py \\"
|
||
echo " configs/nuscenes/seg/fusion-bev256d2-lss.yaml \\"
|
||
echo " --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth"
|
||
exit 0
|
||
else
|
||
echo ""
|
||
echo "⚠ 发现 $FAIL 个问题,请先解决后再训练"
|
||
echo ""
|
||
echo "常见解决方案:"
|
||
echo "1. 安装缺失的包:"
|
||
echo " pip install torchpack mmcv==1.4.0 mmcv-full==1.4.0 mmdet==2.20.0"
|
||
echo " pip install nuscenes-devkit tqdm Pillow==8.4.0"
|
||
echo ""
|
||
echo "2. 编译自定义算子:"
|
||
echo " cd /workspace/bevfusion"
|
||
echo " python3 setup.py develop"
|
||
echo ""
|
||
echo "3. 准备数据集:"
|
||
echo " 参考README.md的Data Preparation部分"
|
||
exit 1
|
||
fi
|
||
|