import torch from torch.autograd import Function from . import ball_query_ext class BallQuery(Function): """Ball Query. Find nearby points in spherical space. """ @staticmethod def forward( ctx, min_radius: float, max_radius: float, sample_num: int, xyz: torch.Tensor, center_xyz: torch.Tensor, ) -> torch.Tensor: """forward. Args: min_radius (float): minimum radius of the balls. max_radius (float): maximum radius of the balls. sample_num (int): maximum number of features in the balls. xyz (Tensor): (B, N, 3) xyz coordinates of the features. center_xyz (Tensor): (B, npoint, 3) centers of the ball query. Returns: Tensor: (B, npoint, nsample) tensor with the indicies of the features that form the query balls. """ assert center_xyz.is_contiguous() assert xyz.is_contiguous() assert min_radius < max_radius B, N, _ = xyz.size() npoint = center_xyz.size(1) idx = torch.cuda.IntTensor(B, npoint, sample_num).zero_() ball_query_ext.ball_query_wrapper( B, N, npoint, min_radius, max_radius, sample_num, center_xyz, xyz, idx ) ctx.mark_non_differentiable(idx) return idx @staticmethod def backward(ctx, a=None): return None, None, None, None ball_query = BallQuery.apply