#!/bin/bash # BEVFusion 详细环境检查脚本(参考README.md) echo "==========================================" echo "BEVFusion 环境依赖检查" echo "参考:README.md Prerequisites" echo "==========================================" echo "" PASS=0 FAIL=0 # 检查函数 check_pass() { echo " ✓ $1" ((PASS++)) } check_fail() { echo " ✗ $1" ((FAIL++)) } check_warn() { echo " ⚠ $1" } # 1. Python版本检查 echo "[1/12] Python 版本 (需要 >= 3.8, < 3.9)" PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}') PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1) PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2) if [ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -eq 8 ]; then check_pass "Python $PYTHON_VERSION" else check_warn "Python $PYTHON_VERSION (推荐3.8.x)" fi # 2. PyTorch版本检查 echo "" echo "[2/12] PyTorch 版本 (需要 >= 1.9, <= 1.10.2)" PYTORCH_VERSION=$(python3 -c "import torch; print(torch.__version__)" 2>/dev/null) if [ $? -eq 0 ]; then check_pass "PyTorch $PYTORCH_VERSION" # 检查CUDA CUDA_AVAILABLE=$(python3 -c "import torch; print(torch.cuda.is_available())" 2>/dev/null) if [ "$CUDA_AVAILABLE" = "True" ]; then GPU_COUNT=$(python3 -c "import torch; print(torch.cuda.device_count())" 2>/dev/null) check_pass "CUDA 可用,GPU数量: $GPU_COUNT" else check_fail "CUDA 不可用" fi else check_fail "PyTorch 未安装" fi # 3. torchpack检查 echo "" echo "[3/12] torchpack (必需)" TORCHPACK_VERSION=$(python3 -c "import torchpack; print(torchpack.__version__)" 2>/dev/null) if [ $? -eq 0 ]; then check_pass "torchpack $TORCHPACK_VERSION" else check_fail "torchpack 未安装" echo " 安装: pip install torchpack" fi # 4. mmcv检查 echo "" echo "[4/12] mmcv (需要 = 1.4.0)" MMCV_VERSION=$(python3 -c "import mmcv; print(mmcv.__version__)" 2>/dev/null) if [ $? -eq 0 ]; then if [ "$MMCV_VERSION" = "1.4.0" ]; then check_pass "mmcv $MMCV_VERSION" else check_warn "mmcv $MMCV_VERSION (推荐1.4.0)" fi else check_fail "mmcv 未安装" echo " 安装: pip install mmcv==1.4.0 mmcv-full==1.4.0" fi # 5. mmdet检查 echo "" echo "[5/12] mmdetection (需要 = 2.20.0)" MMDET_VERSION=$(python3 -c "import mmdet; print(mmdet.__version__)" 2>/dev/null) if [ $? -eq 0 ]; then if [ "$MMDET_VERSION" = "2.20.0" ]; then check_pass "mmdetection $MMDET_VERSION" else check_warn "mmdetection $MMDET_VERSION (推荐2.20.0)" fi else check_fail "mmdetection 未安装" echo " 安装: pip install mmdet==2.20.0" fi # 6. OpenMPI和mpi4py检查 echo "" echo "[6/12] OpenMPI 和 mpi4py (torchpack需要)" if command -v mpirun &> /dev/null; then OMPI_VERSION=$(mpirun --version 2>&1 | head -n 1) check_pass "OpenMPI: $OMPI_VERSION" else check_warn "OpenMPI 未安装(可选)" fi MPI4PY_VERSION=$(python3 -c "import mpi4py; print(mpi4py.__version__)" 2>/dev/null) if [ $? -eq 0 ]; then check_pass "mpi4py $MPI4PY_VERSION" else check_warn "mpi4py 未安装(可选)" fi # 7. Pillow检查 echo "" echo "[7/12] Pillow (需要 = 8.4.0)" PILLOW_VERSION=$(python3 -c "from PIL import Image; import PIL; print(PIL.__version__)" 2>/dev/null) if [ $? -eq 0 ]; then if [ "$PILLOW_VERSION" = "8.4.0" ]; then check_pass "Pillow $PILLOW_VERSION" else check_warn "Pillow $PILLOW_VERSION (推荐8.4.0)" fi else check_fail "Pillow 未安装" fi # 8. nuscenes-devkit检查 echo "" echo "[8/12] nuscenes-devkit (必需)" python3 -c "from nuscenes import NuScenes" 2>/dev/null if [ $? -eq 0 ]; then check_pass "nuscenes-devkit 已安装" else check_fail "nuscenes-devkit 未安装" echo " 安装: pip install nuscenes-devkit" fi # 9. tqdm检查 echo "" echo "[9/12] tqdm (必需)" python3 -c "import tqdm" 2>/dev/null if [ $? -eq 0 ]; then check_pass "tqdm 已安装" else check_fail "tqdm 未安装" echo " 安装: pip install tqdm" fi # 10. 自定义CUDA算子检查 echo "" echo "[10/12] 自定义CUDA算子 (需要运行 python setup.py develop)" python3 -c "from mmdet3d.ops import bev_pool_v2" 2>/dev/null if [ $? -eq 0 ]; then check_pass "BEV Pool 算子已编译" else check_fail "BEV Pool 算子未编译" echo " 运行: cd /workspace/bevfusion && python3 setup.py develop" fi python3 -c "from mmdet3d.ops import Voxelization" 2>/dev/null if [ $? -eq 0 ]; then check_pass "Voxelization 算子已编译" else check_fail "Voxelization 算子未编译" fi python3 -c "from mmdet3d.ops.spconv import SparseConv3d" 2>/dev/null if [ $? -eq 0 ]; then check_pass "Sparse Conv 算子已编译" else check_fail "Sparse Conv 算子未编译" fi # 11. 数据集检查 echo "" echo "[11/12] nuScenes 数据集" DATA_ROOT="data/nuscenes" if [ -d "$DATA_ROOT" ]; then check_pass "数据集目录存在" # 检查关键文件 if [ -f "$DATA_ROOT/nuscenes_infos_train.pkl" ]; then check_pass "训练数据info文件" else check_fail "nuscenes_infos_train.pkl 缺失" fi if [ -f "$DATA_ROOT/nuscenes_infos_val.pkl" ]; then check_pass "验证数据info文件" else check_fail "nuscenes_infos_val.pkl 缺失" fi if [ -d "$DATA_ROOT/samples" ]; then SAMPLE_COUNT=$(find "$DATA_ROOT/samples" -type f | wc -l) check_pass "samples目录 ($SAMPLE_COUNT 文件)" else check_fail "samples目录缺失" fi else check_fail "数据集目录不存在: $DATA_ROOT" echo " 参考README.md的Data Preparation部分" fi # 12. 预训练模型检查 echo "" echo "[12/12] 预训练模型" PRETRAIN_DIR="pretrained" if [ -d "$PRETRAIN_DIR" ]; then check_pass "预训练模型目录存在" # 检查关键模型 if [ -f "$PRETRAIN_DIR/swint-nuimages-pretrained.pth" ]; then SIZE=$(du -h "$PRETRAIN_DIR/swint-nuimages-pretrained.pth" | cut -f1) check_pass "swint-nuimages-pretrained.pth ($SIZE)" else check_warn "swint-nuimages-pretrained.pth 缺失" fi if [ -f "$PRETRAIN_DIR/lidar-only-det.pth" ]; then SIZE=$(du -h "$PRETRAIN_DIR/lidar-only-det.pth" | cut -f1) check_pass "lidar-only-det.pth ($SIZE)" else check_warn "lidar-only-det.pth 缺失" fi else check_warn "预训练模型目录不存在" echo " 运行: ./tools/download_pretrained.sh" fi # 总结 echo "" echo "==========================================" echo "检查完成" echo "通过: $PASS 项" echo "失败: $FAIL 项" echo "==========================================" if [ $FAIL -eq 0 ]; then echo "" echo "✓ 环境配置完整,可以开始训练!" echo "" echo "训练命令示例:" echo " # 3D检测" echo " torchpack dist-run -np 8 python3 tools/train.py \\" echo " configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/convfuser.yaml \\" echo " --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth \\" echo " --load_from pretrained/lidar-only-det.pth" echo "" echo " # BEV分割" echo " torchpack dist-run -np 8 python3 tools/train.py \\" echo " configs/nuscenes/seg/fusion-bev256d2-lss.yaml \\" echo " --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth" exit 0 else echo "" echo "⚠ 发现 $FAIL 个问题,请先解决后再训练" echo "" echo "常见解决方案:" echo "1. 安装缺失的包:" echo " pip install torchpack mmcv==1.4.0 mmcv-full==1.4.0 mmdet==2.20.0" echo " pip install nuscenes-devkit tqdm Pillow==8.4.0" echo "" echo "2. 编译自定义算子:" echo " cd /workspace/bevfusion" echo " python3 setup.py develop" echo "" echo "3. 准备数据集:" echo " 参考README.md的Data Preparation部分" exit 1 fi