[ROCm] add support for ROCm/HIP (#509)
* [ROCm] add support for ROCm/HIP - do not use THC.h or THH.h or THCState - rename files to avoid collisions at build time due to hipify - HIP host device definitions in tensorview.h - use HIP atomicAdd() in scatter_points_cuda.cu - add WITH_ROCM anywhere WITH_CUDA was used - update setup.py for renamed files and ROCm/HIP torch * enforce c++17 since pytorch requires it
This commit is contained in:
parent
5b08cc8cee
commit
326653dc06
|
|
@ -132,3 +132,7 @@ dmypy.json
|
|||
models/*
|
||||
data/*
|
||||
runs/*
|
||||
|
||||
# torch hipify generated files
|
||||
*_hip.*
|
||||
*.hip
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/serialize/tensor.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
extern THCState *state;
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
|
|
@ -2,14 +2,11 @@
|
|||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling.cpp
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THC.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/serialize/tensor.h>
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <vector>
|
||||
|
||||
extern THCState *state;
|
||||
|
||||
int furthest_point_sampling_wrapper(int b, int n, int m,
|
||||
at::Tensor points_tensor,
|
||||
at::Tensor temp_tensor,
|
||||
|
|
@ -1,12 +1,10 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THC.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/serialize/tensor.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
extern THCState *state;
|
||||
|
||||
int gather_points_wrapper(int b, int c, int n, int npoints,
|
||||
at::Tensor points_tensor, at::Tensor idx_tensor,
|
||||
at::Tensor out_tensor);
|
||||
|
|
@ -1,16 +1,14 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points.cpp
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/serialize/tensor.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
extern THCState *state;
|
||||
|
||||
int group_points_wrapper(int b, int c, int n, int npoints, int nsample,
|
||||
at::Tensor points_tensor, at::Tensor idx_tensor,
|
||||
at::Tensor out_tensor);
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <math.h>
|
||||
|
|
@ -9,11 +8,10 @@
|
|||
#include <stdlib.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/serialize/tensor.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
extern THCState *state;
|
||||
|
||||
void three_nn_wrapper(int b, int n, int m, at::Tensor unknown_tensor,
|
||||
at::Tensor known_tensor, at::Tensor dist2_tensor,
|
||||
at::Tensor idx_tensor);
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@
|
|||
#include <torch/serialize/tensor.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
#include <THC/THC.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
extern THCState *state;
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x, " must be a CUDAtensor ")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
|
||||
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
namespace tv {
|
||||
|
||||
#ifdef __NVCC__
|
||||
#if defined(__NVCC__) || defined(__HIP__)
|
||||
#define TV_HOST_DEVICE_INLINE __forceinline__ __device__ __host__
|
||||
#define TV_DEVICE_INLINE __forceinline__ __device__
|
||||
#define TV_HOST_DEVICE __device__ __host__
|
||||
|
|
|
|||
|
|
@ -75,6 +75,13 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) {
|
|||
atomicAdd(address, val);
|
||||
#endif
|
||||
}
|
||||
#elif defined(__HIP__)
|
||||
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
|
||||
atomicAdd(address, val);
|
||||
}
|
||||
__device__ __forceinline__ static void reduceAdd(double *address, double val) {
|
||||
atomicAdd(address, val);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ std::vector<at::Tensor> dynamic_point_to_voxel_cpu(
|
|||
const at::Tensor &points, const at::Tensor &voxel_mapping,
|
||||
const std::vector<float> voxel_size, const std::vector<float> coors_range);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#if defined(WITH_CUDA) || defined(WITH_ROCM)
|
||||
int hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels,
|
||||
at::Tensor &coors, at::Tensor &num_points_per_voxel,
|
||||
const std::vector<float> voxel_size,
|
||||
|
|
@ -62,7 +62,7 @@ inline int hard_voxelize(const at::Tensor &points, at::Tensor &voxels,
|
|||
const int max_points, const int max_voxels,
|
||||
const int NDim = 3, const bool deterministic = true) {
|
||||
if (points.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
#if defined(WITH_CUDA) || defined(WITH_ROCM)
|
||||
if (deterministic) {
|
||||
return hard_voxelize_gpu(points, voxels, coors, num_points_per_voxel,
|
||||
voxel_size, coors_range, max_points, max_voxels,
|
||||
|
|
@ -85,7 +85,7 @@ inline void dynamic_voxelize(const at::Tensor &points, at::Tensor &coors,
|
|||
const std::vector<float> coors_range,
|
||||
const int NDim = 3) {
|
||||
if (points.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
#if defined(WITH_CUDA) || defined(WITH_ROCM)
|
||||
return dynamic_voxelize_gpu(points, coors, voxel_size, coors_range, NDim);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
|
|
@ -109,7 +109,7 @@ inline std::vector<torch::Tensor> dynamic_point_to_voxel_forward(const torch::Te
|
|||
const torch::Tensor &coors,
|
||||
const std::string &reduce_type) {
|
||||
if (feats.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
#if defined(WITH_CUDA) || defined(WITH_ROCM)
|
||||
return dynamic_point_to_voxel_forward_gpu(feats, coors, convert_reduce_type(reduce_type));
|
||||
#else
|
||||
TORCH_CHECK(false, "Not compiled with GPU support");
|
||||
|
|
@ -127,7 +127,7 @@ inline void dynamic_point_to_voxel_backward(torch::Tensor &grad_feats,
|
|||
const torch::Tensor &reduce_count,
|
||||
const std::string &reduce_type) {
|
||||
if (grad_feats.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
#if defined(WITH_CUDA) || defined(WITH_ROCM)
|
||||
dynamic_point_to_voxel_backward_gpu(
|
||||
grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, reduce_count,
|
||||
convert_reduce_type(reduce_type));
|
||||
|
|
|
|||
31
setup.py
31
setup.py
|
|
@ -12,7 +12,7 @@ def make_cuda_ext(
|
|||
define_macros = []
|
||||
extra_compile_args = {"cxx": [] + extra_args}
|
||||
|
||||
if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1":
|
||||
if (torch.cuda.is_available() and torch.version.cuda is not None) or os.getenv("FORCE_CUDA", "0") == "1":
|
||||
define_macros += [("WITH_CUDA", None)]
|
||||
extension = CUDAExtension
|
||||
extra_compile_args["nvcc"] = extra_args + [
|
||||
|
|
@ -25,6 +25,15 @@ def make_cuda_ext(
|
|||
"-gencode=arch=compute_86,code=sm_86",
|
||||
]
|
||||
sources += sources_cuda
|
||||
elif (torch.cuda.is_available() and torch.version.hip is not None) or os.getenv("FORCE_ROCM", "0") == 1:
|
||||
define_macros += [("WITH_ROCM", None)]
|
||||
extension = CUDAExtension
|
||||
extra_compile_args["hipcc"] = extra_args + [
|
||||
"-D__HIP_NO_HALF_OPERATORS__",
|
||||
"-D__HIP_NO_HALF_CONVERSIONS__",
|
||||
"-D__HIP_NO_HALF2_OPERATORS__",
|
||||
]
|
||||
sources += sources_cuda
|
||||
else:
|
||||
print("Compiling {} without CUDA".format(name))
|
||||
extension = CppExtension
|
||||
|
|
@ -66,20 +75,20 @@ if __name__ == "__main__":
|
|||
],
|
||||
sources=[
|
||||
"src/all.cc",
|
||||
"src/reordering.cc",
|
||||
"src/reordering_cpu.cc",
|
||||
"src/reordering_cuda.cu",
|
||||
"src/indice.cc",
|
||||
"src/indice_cpu.cc",
|
||||
"src/indice_cuda.cu",
|
||||
"src/maxpool.cc",
|
||||
"src/maxpool_cpu.cc",
|
||||
"src/maxpool_cuda.cu",
|
||||
],
|
||||
extra_args=["-w", "-std=c++14"],
|
||||
extra_args=["-w", "-std=c++17"],
|
||||
),
|
||||
make_cuda_ext(
|
||||
name="bev_pool_ext",
|
||||
module="mmdet3d.ops.bev_pool",
|
||||
sources=[
|
||||
"src/bev_pool.cpp",
|
||||
"src/bev_pool_cpu.cpp",
|
||||
"src/bev_pool_cuda.cu",
|
||||
],
|
||||
),
|
||||
|
|
@ -117,13 +126,13 @@ if __name__ == "__main__":
|
|||
make_cuda_ext(
|
||||
name="ball_query_ext",
|
||||
module="mmdet3d.ops.ball_query",
|
||||
sources=["src/ball_query.cpp"],
|
||||
sources=["src/ball_query_cpu.cpp"],
|
||||
sources_cuda=["src/ball_query_cuda.cu"],
|
||||
),
|
||||
make_cuda_ext(
|
||||
name="knn_ext",
|
||||
module="mmdet3d.ops.knn",
|
||||
sources=["src/knn.cpp"],
|
||||
sources=["src/knn_cpu.cpp"],
|
||||
sources_cuda=["src/knn_cuda.cu"],
|
||||
),
|
||||
make_cuda_ext(
|
||||
|
|
@ -135,7 +144,7 @@ if __name__ == "__main__":
|
|||
make_cuda_ext(
|
||||
name="group_points_ext",
|
||||
module="mmdet3d.ops.group_points",
|
||||
sources=["src/group_points.cpp"],
|
||||
sources=["src/group_points_cpu.cpp"],
|
||||
sources_cuda=["src/group_points_cuda.cu"],
|
||||
),
|
||||
make_cuda_ext(
|
||||
|
|
@ -147,13 +156,13 @@ if __name__ == "__main__":
|
|||
make_cuda_ext(
|
||||
name="furthest_point_sample_ext",
|
||||
module="mmdet3d.ops.furthest_point_sample",
|
||||
sources=["src/furthest_point_sample.cpp"],
|
||||
sources=["src/furthest_point_sample_cpu.cpp"],
|
||||
sources_cuda=["src/furthest_point_sample_cuda.cu"],
|
||||
),
|
||||
make_cuda_ext(
|
||||
name="gather_points_ext",
|
||||
module="mmdet3d.ops.gather_points",
|
||||
sources=["src/gather_points.cpp"],
|
||||
sources=["src/gather_points_cpu.cpp"],
|
||||
sources_cuda=["src/gather_points_cuda.cu"],
|
||||
),
|
||||
],
|
||||
|
|
|
|||
Loading…
Reference in New Issue