bev-project/scripts/utils/check_env_detailed.sh

277 lines
7.8 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/bin/bash
# BEVFusion 详细环境检查脚本参考README.md
echo "=========================================="
echo "BEVFusion 环境依赖检查"
echo "参考README.md Prerequisites"
echo "=========================================="
echo ""
PASS=0
FAIL=0
# 检查函数
check_pass() {
echo "$1"
((PASS++))
}
check_fail() {
echo "$1"
((FAIL++))
}
check_warn() {
echo "$1"
}
# 1. Python版本检查
echo "[1/12] Python 版本 (需要 >= 3.8, < 3.9)"
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1)
PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2)
if [ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -eq 8 ]; then
check_pass "Python $PYTHON_VERSION"
else
check_warn "Python $PYTHON_VERSION (推荐3.8.x)"
fi
# 2. PyTorch版本检查
echo ""
echo "[2/12] PyTorch 版本 (需要 >= 1.9, <= 1.10.2)"
PYTORCH_VERSION=$(python3 -c "import torch; print(torch.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
check_pass "PyTorch $PYTORCH_VERSION"
# 检查CUDA
CUDA_AVAILABLE=$(python3 -c "import torch; print(torch.cuda.is_available())" 2>/dev/null)
if [ "$CUDA_AVAILABLE" = "True" ]; then
GPU_COUNT=$(python3 -c "import torch; print(torch.cuda.device_count())" 2>/dev/null)
check_pass "CUDA 可用GPU数量: $GPU_COUNT"
else
check_fail "CUDA 不可用"
fi
else
check_fail "PyTorch 未安装"
fi
# 3. torchpack检查
echo ""
echo "[3/12] torchpack (必需)"
TORCHPACK_VERSION=$(python3 -c "import torchpack; print(torchpack.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
check_pass "torchpack $TORCHPACK_VERSION"
else
check_fail "torchpack 未安装"
echo " 安装: pip install torchpack"
fi
# 4. mmcv检查
echo ""
echo "[4/12] mmcv (需要 = 1.4.0)"
MMCV_VERSION=$(python3 -c "import mmcv; print(mmcv.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
if [ "$MMCV_VERSION" = "1.4.0" ]; then
check_pass "mmcv $MMCV_VERSION"
else
check_warn "mmcv $MMCV_VERSION (推荐1.4.0)"
fi
else
check_fail "mmcv 未安装"
echo " 安装: pip install mmcv==1.4.0 mmcv-full==1.4.0"
fi
# 5. mmdet检查
echo ""
echo "[5/12] mmdetection (需要 = 2.20.0)"
MMDET_VERSION=$(python3 -c "import mmdet; print(mmdet.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
if [ "$MMDET_VERSION" = "2.20.0" ]; then
check_pass "mmdetection $MMDET_VERSION"
else
check_warn "mmdetection $MMDET_VERSION (推荐2.20.0)"
fi
else
check_fail "mmdetection 未安装"
echo " 安装: pip install mmdet==2.20.0"
fi
# 6. OpenMPI和mpi4py检查
echo ""
echo "[6/12] OpenMPI 和 mpi4py (torchpack需要)"
if command -v mpirun &> /dev/null; then
OMPI_VERSION=$(mpirun --version 2>&1 | head -n 1)
check_pass "OpenMPI: $OMPI_VERSION"
else
check_warn "OpenMPI 未安装(可选)"
fi
MPI4PY_VERSION=$(python3 -c "import mpi4py; print(mpi4py.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
check_pass "mpi4py $MPI4PY_VERSION"
else
check_warn "mpi4py 未安装(可选)"
fi
# 7. Pillow检查
echo ""
echo "[7/12] Pillow (需要 = 8.4.0)"
PILLOW_VERSION=$(python3 -c "from PIL import Image; import PIL; print(PIL.__version__)" 2>/dev/null)
if [ $? -eq 0 ]; then
if [ "$PILLOW_VERSION" = "8.4.0" ]; then
check_pass "Pillow $PILLOW_VERSION"
else
check_warn "Pillow $PILLOW_VERSION (推荐8.4.0)"
fi
else
check_fail "Pillow 未安装"
fi
# 8. nuscenes-devkit检查
echo ""
echo "[8/12] nuscenes-devkit (必需)"
python3 -c "from nuscenes import NuScenes" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "nuscenes-devkit 已安装"
else
check_fail "nuscenes-devkit 未安装"
echo " 安装: pip install nuscenes-devkit"
fi
# 9. tqdm检查
echo ""
echo "[9/12] tqdm (必需)"
python3 -c "import tqdm" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "tqdm 已安装"
else
check_fail "tqdm 未安装"
echo " 安装: pip install tqdm"
fi
# 10. 自定义CUDA算子检查
echo ""
echo "[10/12] 自定义CUDA算子 (需要运行 python setup.py develop)"
python3 -c "from mmdet3d.ops import bev_pool_v2" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "BEV Pool 算子已编译"
else
check_fail "BEV Pool 算子未编译"
echo " 运行: cd /workspace/bevfusion && python3 setup.py develop"
fi
python3 -c "from mmdet3d.ops import Voxelization" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "Voxelization 算子已编译"
else
check_fail "Voxelization 算子未编译"
fi
python3 -c "from mmdet3d.ops.spconv import SparseConv3d" 2>/dev/null
if [ $? -eq 0 ]; then
check_pass "Sparse Conv 算子已编译"
else
check_fail "Sparse Conv 算子未编译"
fi
# 11. 数据集检查
echo ""
echo "[11/12] nuScenes 数据集"
DATA_ROOT="data/nuscenes"
if [ -d "$DATA_ROOT" ]; then
check_pass "数据集目录存在"
# 检查关键文件
if [ -f "$DATA_ROOT/nuscenes_infos_train.pkl" ]; then
check_pass "训练数据info文件"
else
check_fail "nuscenes_infos_train.pkl 缺失"
fi
if [ -f "$DATA_ROOT/nuscenes_infos_val.pkl" ]; then
check_pass "验证数据info文件"
else
check_fail "nuscenes_infos_val.pkl 缺失"
fi
if [ -d "$DATA_ROOT/samples" ]; then
SAMPLE_COUNT=$(find "$DATA_ROOT/samples" -type f | wc -l)
check_pass "samples目录 ($SAMPLE_COUNT 文件)"
else
check_fail "samples目录缺失"
fi
else
check_fail "数据集目录不存在: $DATA_ROOT"
echo " 参考README.md的Data Preparation部分"
fi
# 12. 预训练模型检查
echo ""
echo "[12/12] 预训练模型"
PRETRAIN_DIR="pretrained"
if [ -d "$PRETRAIN_DIR" ]; then
check_pass "预训练模型目录存在"
# 检查关键模型
if [ -f "$PRETRAIN_DIR/swint-nuimages-pretrained.pth" ]; then
SIZE=$(du -h "$PRETRAIN_DIR/swint-nuimages-pretrained.pth" | cut -f1)
check_pass "swint-nuimages-pretrained.pth ($SIZE)"
else
check_warn "swint-nuimages-pretrained.pth 缺失"
fi
if [ -f "$PRETRAIN_DIR/lidar-only-det.pth" ]; then
SIZE=$(du -h "$PRETRAIN_DIR/lidar-only-det.pth" | cut -f1)
check_pass "lidar-only-det.pth ($SIZE)"
else
check_warn "lidar-only-det.pth 缺失"
fi
else
check_warn "预训练模型目录不存在"
echo " 运行: ./tools/download_pretrained.sh"
fi
# 总结
echo ""
echo "=========================================="
echo "检查完成"
echo "通过: $PASS"
echo "失败: $FAIL"
echo "=========================================="
if [ $FAIL -eq 0 ]; then
echo ""
echo "✓ 环境配置完整,可以开始训练!"
echo ""
echo "训练命令示例:"
echo " # 3D检测"
echo " torchpack dist-run -np 8 python3 tools/train.py \\"
echo " configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/convfuser.yaml \\"
echo " --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth \\"
echo " --load_from pretrained/lidar-only-det.pth"
echo ""
echo " # BEV分割"
echo " torchpack dist-run -np 8 python3 tools/train.py \\"
echo " configs/nuscenes/seg/fusion-bev256d2-lss.yaml \\"
echo " --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth"
exit 0
else
echo ""
echo "⚠ 发现 $FAIL 个问题,请先解决后再训练"
echo ""
echo "常见解决方案:"
echo "1. 安装缺失的包:"
echo " pip install torchpack mmcv==1.4.0 mmcv-full==1.4.0 mmdet==2.20.0"
echo " pip install nuscenes-devkit tqdm Pillow==8.4.0"
echo ""
echo "2. 编译自定义算子:"
echo " cd /workspace/bevfusion"
echo " python3 setup.py develop"
echo ""
echo "3. 准备数据集:"
echo " 参考README.md的Data Preparation部分"
exit 1
fi