[Major] Add device guard.
This commit is contained in:
parent
cb6cd789d1
commit
36d34b73eb
|
|
@ -1,4 +1,5 @@
|
|||
#include <torch/torch.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
// CUDA function declarations
|
||||
void bev_pool(int b, int d, int h, int w, int n, int c, int n_intervals, const float* x,
|
||||
|
|
@ -28,6 +29,7 @@ at::Tensor bev_pool_forward(
|
|||
int n = _x.size(0);
|
||||
int c = _x.size(1);
|
||||
int n_intervals = _interval_lengths.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_x));
|
||||
const float* x = _x.data_ptr<float>();
|
||||
const int* geom_feats = _geom_feats.data_ptr<int>();
|
||||
const int* interval_lengths = _interval_lengths.data_ptr<int>();
|
||||
|
|
@ -65,6 +67,7 @@ at::Tensor bev_pool_backward(
|
|||
int n = _geom_feats.size(0);
|
||||
int c = _out_grad.size(4);
|
||||
int n_intervals = _interval_lengths.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_out_grad));
|
||||
const float* out_grad = _out_grad.data_ptr<float>();
|
||||
const int* geom_feats = _geom_feats.data_ptr<int>();
|
||||
const int* interval_lengths = _interval_lengths.data_ptr<int>();
|
||||
|
|
|
|||
Loading…
Reference in New Issue