From 326653dc06e0938edf1aae7d01efcd158ba83de5 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 30 Jul 2024 21:04:51 -0700 Subject: [PATCH] [ROCm] add support for ROCm/HIP (#509) * [ROCm] add support for ROCm/HIP - do not use THC.h or THH.h or THCState - rename files to avoid collisions at build time due to hipify - HIP host device definitions in tensorview.h - use HIP atomicAdd() in scatter_points_cuda.cu - add WITH_ROCM anywhere WITH_CUDA was used - update setup.py for renamed files and ROCm/HIP torch * enforce c++17 since pytorch requires it --- .gitignore | 4 +++ .../{ball_query.cpp => ball_query_cpu.cpp} | 4 +-- .../src/{bev_pool.cpp => bev_pool_cpu.cpp} | 0 ...mple.cpp => furthest_point_sample_cpu.cpp} | 5 +-- ...ather_points.cpp => gather_points_cpu.cpp} | 4 +-- ...{group_points.cpp => group_points_cpu.cpp} | 4 +-- mmdet3d/ops/interpolate/src/interpolate.cpp | 4 +-- mmdet3d/ops/knn/src/{knn.cpp => knn_cpu.cpp} | 4 +-- .../spconv/include/tensorview/tensorview.h | 2 +- .../spconv/src/{indice.cc => indice_cpu.cc} | 0 .../spconv/src/{maxpool.cc => maxpool_cpu.cc} | 0 .../src/{reordering.cc => reordering_cpu.cc} | 0 mmdet3d/ops/voxel/src/scatter_points_cuda.cu | 7 +++++ mmdet3d/ops/voxel/src/voxelization.h | 10 +++--- setup.py | 31 ++++++++++++------- 15 files changed, 43 insertions(+), 36 deletions(-) rename mmdet3d/ops/ball_query/src/{ball_query.cpp => ball_query_cpu.cpp} (97%) rename mmdet3d/ops/bev_pool/src/{bev_pool.cpp => bev_pool_cpu.cpp} (100%) rename mmdet3d/ops/furthest_point_sample/src/{furthest_point_sample.cpp => furthest_point_sample_cpu.cpp} (98%) rename mmdet3d/ops/gather_points/src/{gather_points.cpp => gather_points_cpu.cpp} (98%) rename mmdet3d/ops/group_points/src/{group_points.cpp => group_points_cpu.cpp} (98%) rename mmdet3d/ops/knn/src/{knn.cpp => knn_cpu.cpp} (96%) rename mmdet3d/ops/spconv/src/{indice.cc => indice_cpu.cc} (100%) rename mmdet3d/ops/spconv/src/{maxpool.cc => maxpool_cpu.cc} (100%) rename mmdet3d/ops/spconv/src/{reordering.cc => reordering_cpu.cc} (100%) diff --git a/.gitignore b/.gitignore index 7525a451..76a892ff 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,7 @@ dmypy.json models/* data/* runs/* + +# torch hipify generated files +*_hip.* +*.hip diff --git a/mmdet3d/ops/ball_query/src/ball_query.cpp b/mmdet3d/ops/ball_query/src/ball_query_cpu.cpp similarity index 97% rename from mmdet3d/ops/ball_query/src/ball_query.cpp rename to mmdet3d/ops/ball_query/src/ball_query_cpu.cpp index 6e9d5f75..28ff0391 100644 --- a/mmdet3d/ops/ball_query/src/ball_query.cpp +++ b/mmdet3d/ops/ball_query/src/ball_query_cpu.cpp @@ -1,16 +1,14 @@ // Modified from // https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp -#include #include #include #include #include +#include #include -extern THCState *state; - #define CHECK_CUDA(x) \ TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") #define CHECK_CONTIGUOUS(x) \ diff --git a/mmdet3d/ops/bev_pool/src/bev_pool.cpp b/mmdet3d/ops/bev_pool/src/bev_pool_cpu.cpp similarity index 100% rename from mmdet3d/ops/bev_pool/src/bev_pool.cpp rename to mmdet3d/ops/bev_pool/src/bev_pool_cpu.cpp diff --git a/mmdet3d/ops/furthest_point_sample/src/furthest_point_sample.cpp b/mmdet3d/ops/furthest_point_sample/src/furthest_point_sample_cpu.cpp similarity index 98% rename from mmdet3d/ops/furthest_point_sample/src/furthest_point_sample.cpp rename to mmdet3d/ops/furthest_point_sample/src/furthest_point_sample_cpu.cpp index be058902..c67d0e60 100644 --- a/mmdet3d/ops/furthest_point_sample/src/furthest_point_sample.cpp +++ b/mmdet3d/ops/furthest_point_sample/src/furthest_point_sample_cpu.cpp @@ -2,14 +2,11 @@ // https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling.cpp #include -#include #include #include - +#include #include -extern THCState *state; - int furthest_point_sampling_wrapper(int b, int n, int m, at::Tensor points_tensor, at::Tensor temp_tensor, diff --git a/mmdet3d/ops/gather_points/src/gather_points.cpp b/mmdet3d/ops/gather_points/src/gather_points_cpu.cpp similarity index 98% rename from mmdet3d/ops/gather_points/src/gather_points.cpp rename to mmdet3d/ops/gather_points/src/gather_points_cpu.cpp index 01a3e404..ba861845 100644 --- a/mmdet3d/ops/gather_points/src/gather_points.cpp +++ b/mmdet3d/ops/gather_points/src/gather_points_cpu.cpp @@ -1,12 +1,10 @@ #include -#include #include #include +#include #include -extern THCState *state; - int gather_points_wrapper(int b, int c, int n, int npoints, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); diff --git a/mmdet3d/ops/group_points/src/group_points.cpp b/mmdet3d/ops/group_points/src/group_points_cpu.cpp similarity index 98% rename from mmdet3d/ops/group_points/src/group_points.cpp rename to mmdet3d/ops/group_points/src/group_points_cpu.cpp index 3cd6fc84..ca9aeb31 100644 --- a/mmdet3d/ops/group_points/src/group_points.cpp +++ b/mmdet3d/ops/group_points/src/group_points_cpu.cpp @@ -1,16 +1,14 @@ // Modified from // https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points.cpp -#include #include #include #include #include +#include #include -extern THCState *state; - int group_points_wrapper(int b, int c, int n, int npoints, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); diff --git a/mmdet3d/ops/interpolate/src/interpolate.cpp b/mmdet3d/ops/interpolate/src/interpolate.cpp index 6382579b..797258ae 100644 --- a/mmdet3d/ops/interpolate/src/interpolate.cpp +++ b/mmdet3d/ops/interpolate/src/interpolate.cpp @@ -1,7 +1,6 @@ // Modified from // https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp -#include #include #include #include @@ -9,11 +8,10 @@ #include #include #include +#include #include -extern THCState *state; - void three_nn_wrapper(int b, int n, int m, at::Tensor unknown_tensor, at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); diff --git a/mmdet3d/ops/knn/src/knn.cpp b/mmdet3d/ops/knn/src/knn_cpu.cpp similarity index 96% rename from mmdet3d/ops/knn/src/knn.cpp rename to mmdet3d/ops/knn/src/knn_cpu.cpp index 84ddf4b8..bdf0685e 100644 --- a/mmdet3d/ops/knn/src/knn.cpp +++ b/mmdet3d/ops/knn/src/knn_cpu.cpp @@ -3,10 +3,8 @@ #include #include #include -#include #include - -extern THCState *state; +#include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x, " must be a CUDAtensor ") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") diff --git a/mmdet3d/ops/spconv/include/tensorview/tensorview.h b/mmdet3d/ops/spconv/include/tensorview/tensorview.h index e4cdf352..5d549948 100644 --- a/mmdet3d/ops/spconv/include/tensorview/tensorview.h +++ b/mmdet3d/ops/spconv/include/tensorview/tensorview.h @@ -27,7 +27,7 @@ namespace tv { -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIP__) #define TV_HOST_DEVICE_INLINE __forceinline__ __device__ __host__ #define TV_DEVICE_INLINE __forceinline__ __device__ #define TV_HOST_DEVICE __device__ __host__ diff --git a/mmdet3d/ops/spconv/src/indice.cc b/mmdet3d/ops/spconv/src/indice_cpu.cc similarity index 100% rename from mmdet3d/ops/spconv/src/indice.cc rename to mmdet3d/ops/spconv/src/indice_cpu.cc diff --git a/mmdet3d/ops/spconv/src/maxpool.cc b/mmdet3d/ops/spconv/src/maxpool_cpu.cc similarity index 100% rename from mmdet3d/ops/spconv/src/maxpool.cc rename to mmdet3d/ops/spconv/src/maxpool_cpu.cc diff --git a/mmdet3d/ops/spconv/src/reordering.cc b/mmdet3d/ops/spconv/src/reordering_cpu.cc similarity index 100% rename from mmdet3d/ops/spconv/src/reordering.cc rename to mmdet3d/ops/spconv/src/reordering_cpu.cc diff --git a/mmdet3d/ops/voxel/src/scatter_points_cuda.cu b/mmdet3d/ops/voxel/src/scatter_points_cuda.cu index 2ed18690..fab15723 100644 --- a/mmdet3d/ops/voxel/src/scatter_points_cuda.cu +++ b/mmdet3d/ops/voxel/src/scatter_points_cuda.cu @@ -75,6 +75,13 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) { atomicAdd(address, val); #endif } +#elif defined(__HIP__) +__device__ __forceinline__ static void reduceAdd(float *address, float val) { + atomicAdd(address, val); +} +__device__ __forceinline__ static void reduceAdd(double *address, double val) { + atomicAdd(address, val); +} #endif template diff --git a/mmdet3d/ops/voxel/src/voxelization.h b/mmdet3d/ops/voxel/src/voxelization.h index 765b30a5..8b5f722a 100644 --- a/mmdet3d/ops/voxel/src/voxelization.h +++ b/mmdet3d/ops/voxel/src/voxelization.h @@ -21,7 +21,7 @@ std::vector dynamic_point_to_voxel_cpu( const at::Tensor &points, const at::Tensor &voxel_mapping, const std::vector voxel_size, const std::vector coors_range); -#ifdef WITH_CUDA +#if defined(WITH_CUDA) || defined(WITH_ROCM) int hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, at::Tensor &num_points_per_voxel, const std::vector voxel_size, @@ -62,7 +62,7 @@ inline int hard_voxelize(const at::Tensor &points, at::Tensor &voxels, const int max_points, const int max_voxels, const int NDim = 3, const bool deterministic = true) { if (points.device().is_cuda()) { -#ifdef WITH_CUDA +#if defined(WITH_CUDA) || defined(WITH_ROCM) if (deterministic) { return hard_voxelize_gpu(points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, max_points, max_voxels, @@ -85,7 +85,7 @@ inline void dynamic_voxelize(const at::Tensor &points, at::Tensor &coors, const std::vector coors_range, const int NDim = 3) { if (points.device().is_cuda()) { -#ifdef WITH_CUDA +#if defined(WITH_CUDA) || defined(WITH_ROCM) return dynamic_voxelize_gpu(points, coors, voxel_size, coors_range, NDim); #else AT_ERROR("Not compiled with GPU support"); @@ -109,7 +109,7 @@ inline std::vector dynamic_point_to_voxel_forward(const torch::Te const torch::Tensor &coors, const std::string &reduce_type) { if (feats.device().is_cuda()) { -#ifdef WITH_CUDA +#if defined(WITH_CUDA) || defined(WITH_ROCM) return dynamic_point_to_voxel_forward_gpu(feats, coors, convert_reduce_type(reduce_type)); #else TORCH_CHECK(false, "Not compiled with GPU support"); @@ -127,7 +127,7 @@ inline void dynamic_point_to_voxel_backward(torch::Tensor &grad_feats, const torch::Tensor &reduce_count, const std::string &reduce_type) { if (grad_feats.device().is_cuda()) { -#ifdef WITH_CUDA +#if defined(WITH_CUDA) || defined(WITH_ROCM) dynamic_point_to_voxel_backward_gpu( grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, reduce_count, convert_reduce_type(reduce_type)); diff --git a/setup.py b/setup.py index 01187c6e..256e4224 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ def make_cuda_ext( define_macros = [] extra_compile_args = {"cxx": [] + extra_args} - if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1": + if (torch.cuda.is_available() and torch.version.cuda is not None) or os.getenv("FORCE_CUDA", "0") == "1": define_macros += [("WITH_CUDA", None)] extension = CUDAExtension extra_compile_args["nvcc"] = extra_args + [ @@ -25,6 +25,15 @@ def make_cuda_ext( "-gencode=arch=compute_86,code=sm_86", ] sources += sources_cuda + elif (torch.cuda.is_available() and torch.version.hip is not None) or os.getenv("FORCE_ROCM", "0") == 1: + define_macros += [("WITH_ROCM", None)] + extension = CUDAExtension + extra_compile_args["hipcc"] = extra_args + [ + "-D__HIP_NO_HALF_OPERATORS__", + "-D__HIP_NO_HALF_CONVERSIONS__", + "-D__HIP_NO_HALF2_OPERATORS__", + ] + sources += sources_cuda else: print("Compiling {} without CUDA".format(name)) extension = CppExtension @@ -66,20 +75,20 @@ if __name__ == "__main__": ], sources=[ "src/all.cc", - "src/reordering.cc", + "src/reordering_cpu.cc", "src/reordering_cuda.cu", - "src/indice.cc", + "src/indice_cpu.cc", "src/indice_cuda.cu", - "src/maxpool.cc", + "src/maxpool_cpu.cc", "src/maxpool_cuda.cu", ], - extra_args=["-w", "-std=c++14"], + extra_args=["-w", "-std=c++17"], ), make_cuda_ext( name="bev_pool_ext", module="mmdet3d.ops.bev_pool", sources=[ - "src/bev_pool.cpp", + "src/bev_pool_cpu.cpp", "src/bev_pool_cuda.cu", ], ), @@ -117,13 +126,13 @@ if __name__ == "__main__": make_cuda_ext( name="ball_query_ext", module="mmdet3d.ops.ball_query", - sources=["src/ball_query.cpp"], + sources=["src/ball_query_cpu.cpp"], sources_cuda=["src/ball_query_cuda.cu"], ), make_cuda_ext( name="knn_ext", module="mmdet3d.ops.knn", - sources=["src/knn.cpp"], + sources=["src/knn_cpu.cpp"], sources_cuda=["src/knn_cuda.cu"], ), make_cuda_ext( @@ -135,7 +144,7 @@ if __name__ == "__main__": make_cuda_ext( name="group_points_ext", module="mmdet3d.ops.group_points", - sources=["src/group_points.cpp"], + sources=["src/group_points_cpu.cpp"], sources_cuda=["src/group_points_cuda.cu"], ), make_cuda_ext( @@ -147,13 +156,13 @@ if __name__ == "__main__": make_cuda_ext( name="furthest_point_sample_ext", module="mmdet3d.ops.furthest_point_sample", - sources=["src/furthest_point_sample.cpp"], + sources=["src/furthest_point_sample_cpu.cpp"], sources_cuda=["src/furthest_point_sample_cuda.cu"], ), make_cuda_ext( name="gather_points_ext", module="mmdet3d.ops.gather_points", - sources=["src/gather_points.cpp"], + sources=["src/gather_points_cpu.cpp"], sources_cuda=["src/gather_points_cuda.cu"], ), ],