bev-project/scripts/check_env_detailed.sh

277 lines
7.8 KiB
Bash
Raw Normal View History

#!/bin/bash
# BEVFusion 详细环境检查脚本参考README.md
echo "=========================================="
echo "BEVFusion 环境依赖检查"
echo "参考README.md Prerequisites"
echo "=========================================="
echo ""
PASS=0
FAIL=0
# 检查函数
check_pass() {
echo "$1"
((PASS++))
}
check_fail() {
echo "$1"
((FAIL++))
}
check_warn() {
echo "$1"
}
# 1. Python版本检查
echo "[1/12] Python 版本 (需要 >= 3.8, < 3.9)"
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1)
PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2)
if [ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -eq 8 ]; then
check_pass "Python $PYTHON_VERSION"
else
check_warn "Python $PYTHON_VERSION (推荐3.8.x)"
fi
# 2. PyTorch版本检查
echo ""
echo "[2/12] PyTorch 版本 (需要 >= 1.9, <= 1.10.2)"
PYTORCH_VERSION=$(python3 -c "import torch; print(torch.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
check_pass "PyTorch $PYTORCH_VERSION"
# 检查CUDA
CUDA_AVAILABLE=$(python3 -c "import torch; print(torch.cuda.is_available())" 2>/dev/null)
if [ "$CUDA_AVAILABLE" = "True" ]; then
GPU_COUNT=$(python3 -c "import torch; print(torch.cuda.device_count())" 2>/dev/null)
check_pass "CUDA 可用GPU数量: $GPU_COUNT"
else
check_fail "CUDA 不可用"
fi
else
check_fail "PyTorch 未安装"
fi
# 3. torchpack检查
echo ""
echo "[3/12] torchpack (必需)"
TORCHPACK_VERSION=$(python3 -c "import torchpack; print(torchpack.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
check_pass "torchpack $TORCHPACK_VERSION"
else
check_fail "torchpack 未安装"
echo " 安装: pip install torchpack"
fi
# 4. mmcv检查
echo ""
echo "[4/12] mmcv (需要 = 1.4.0)"
MMCV_VERSION=$(python3 -c "import mmcv; print(mmcv.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
if [ "$MMCV_VERSION" = "1.4.0" ]; then
check_pass "mmcv $MMCV_VERSION"
else
check_warn "mmcv $MMCV_VERSION (推荐1.4.0)"
fi
else
check_fail "mmcv 未安装"
echo " 安装: pip install mmcv==1.4.0 mmcv-full==1.4.0"
fi
# 5. mmdet检查
echo ""
echo "[5/12] mmdetection (需要 = 2.20.0)"
MMDET_VERSION=$(python3 -c "import mmdet; print(mmdet.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
if [ "$MMDET_VERSION" = "2.20.0" ]; then
check_pass "mmdetection $MMDET_VERSION"
else
check_warn "mmdetection $MMDET_VERSION (推荐2.20.0)"
fi
else
check_fail "mmdetection 未安装"
echo " 安装: pip install mmdet==2.20.0"
fi
# 6. OpenMPI和mpi4py检查
echo ""
echo "[6/12] OpenMPI 和 mpi4py (torchpack需要)"
if command -v mpirun &> /dev/null; then
OMPI_VERSION=$(mpirun --version 2>&1 | head -n 1)
check_pass "OpenMPI: $OMPI_VERSION"
else
check_warn "OpenMPI 未安装(可选)"
fi
MPI4PY_VERSION=$(python3 -c "import mpi4py; print(mpi4py.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
check_pass "mpi4py $MPI4PY_VERSION"
else
check_warn "mpi4py 未安装(可选)"
fi
# 7. Pillow检查
echo ""
echo "[7/12] Pillow (需要 = 8.4.0)"
PILLOW_VERSION=$(python3 -c "from PIL import Image; import PIL; print(PIL.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
if [ "$PILLOW_VERSION" = "8.4.0" ]; then
check_pass "Pillow $PILLOW_VERSION"
else
check_warn "Pillow $PILLOW_VERSION (推荐8.4.0)"
fi
else
check_fail "Pillow 未安装"
fi
# 8. nuscenes-devkit检查
echo ""
echo "[8/12] nuscenes-devkit (必需)"
python3 -c "from nuscenes import NuScenes" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "nuscenes-devkit 已安装"
else
check_fail "nuscenes-devkit 未安装"
echo " 安装: pip install nuscenes-devkit"
fi
# 9. tqdm检查
echo ""
echo "[9/12] tqdm (必需)"
python3 -c "import tqdm" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "tqdm 已安装"
else
check_fail "tqdm 未安装"
echo " 安装: pip install tqdm"
fi
# 10. 自定义CUDA算子检查
echo ""
echo "[10/12] 自定义CUDA算子 (需要运行 python setup.py develop)"
python3 -c "from mmdet3d.ops import bev_pool_v2" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "BEV Pool 算子已编译"
else
check_fail "BEV Pool 算子未编译"
echo " 运行: cd /workspace/bevfusion && python3 setup.py develop"
fi
python3 -c "from mmdet3d.ops import Voxelization" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "Voxelization 算子已编译"
else
check_fail "Voxelization 算子未编译"
fi
python3 -c "from mmdet3d.ops.spconv import SparseConv3d" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "Sparse Conv 算子已编译"
else
check_fail "Sparse Conv 算子未编译"
fi
# 11. 数据集检查
echo ""
echo "[11/12] nuScenes 数据集"
DATA_ROOT="data/nuscenes"
if [ -d "$DATA_ROOT" ]; then
check_pass "数据集目录存在"
# 检查关键文件
if [ -f "$DATA_ROOT/nuscenes_infos_train.pkl" ]; then
check_pass "训练数据info文件"
else
check_fail "nuscenes_infos_train.pkl 缺失"
fi
if [ -f "$DATA_ROOT/nuscenes_infos_val.pkl" ]; then
check_pass "验证数据info文件"
else
check_fail "nuscenes_infos_val.pkl 缺失"
fi
if [ -d "$DATA_ROOT/samples" ]; then
SAMPLE_COUNT=$(find "$DATA_ROOT/samples" -type f | wc -l)
check_pass "samples目录 ($SAMPLE_COUNT 文件)"
else
check_fail "samples目录缺失"
fi
else
check_fail "数据集目录不存在: $DATA_ROOT"
echo " 参考README.md的Data Preparation部分"
fi
# 12. 预训练模型检查
echo ""
echo "[12/12] 预训练模型"
PRETRAIN_DIR="pretrained"
if [ -d "$PRETRAIN_DIR" ]; then
check_pass "预训练模型目录存在"
# 检查关键模型
if [ -f "$PRETRAIN_DIR/swint-nuimages-pretrained.pth" ]; then
SIZE=$(du -h "$PRETRAIN_DIR/swint-nuimages-pretrained.pth" | cut -f1)
check_pass "swint-nuimages-pretrained.pth ($SIZE)"
else
check_warn "swint-nuimages-pretrained.pth 缺失"
fi
if [ -f "$PRETRAIN_DIR/lidar-only-det.pth" ]; then
SIZE=$(du -h "$PRETRAIN_DIR/lidar-only-det.pth" | cut -f1)
check_pass "lidar-only-det.pth ($SIZE)"
else
check_warn "lidar-only-det.pth 缺失"
fi
else
check_warn "预训练模型目录不存在"
echo " 运行: ./tools/download_pretrained.sh"
fi
# 总结
echo ""
echo "=========================================="
echo "检查完成"
echo "通过: $PASS"
echo "失败: $FAIL"
echo "=========================================="
if [ $FAIL -eq 0 ]; then
echo ""
echo "✓ 环境配置完整,可以开始训练!"
echo ""
echo "训练命令示例:"
echo " # 3D检测"
echo " torchpack dist-run -np 8 python3 tools/train.py \\"
echo " configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/convfuser.yaml \\"
echo " --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth \\"
echo " --load_from pretrained/lidar-only-det.pth"
echo ""
echo " # BEV分割"
echo " torchpack dist-run -np 8 python3 tools/train.py \\"
echo " configs/nuscenes/seg/fusion-bev256d2-lss.yaml \\"
echo " --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth"
exit 0
else
echo ""
echo "⚠ 发现 $FAIL 个问题,请先解决后再训练"
echo ""
echo "常见解决方案:"
echo "1. 安装缺失的包:"
echo " pip install torchpack mmcv==1.4.0 mmcv-full==1.4.0 mmdet==2.20.0"
echo " pip install nuscenes-devkit tqdm Pillow==8.4.0"
echo ""
echo "2. 编译自定义算子:"
echo " cd /workspace/bevfusion"
echo " python3 setup.py develop"
echo ""
echo "3. 准备数据集:"
echo " 参考README.md的Data Preparation部分"
exit 1
fi