import torch from torch import nn as nn from torch.autograd import Function from typing import Tuple from ..ball_query import ball_query from ..knn import knn from . import group_points_ext class QueryAndGroup(nn.Module): """Query and Group. Groups with a ball query of radius Args: max_radius (float | None): The maximum radius of the balls. If None is given, we will use kNN sampling instead of ball query. sample_num (int): Maximum number of features to gather in the ball. min_radius (float): The minimum radius of the balls. use_xyz (bool): Whether to use xyz. Default: True. return_grouped_xyz (bool): Whether to return grouped xyz. Default: False. normalize_xyz (bool): Whether to normalize xyz. Default: False. uniform_sample (bool): Whether to sample uniformly. Default: False return_unique_cnt (bool): Whether to return the count of unique samples. Default: False. return_grouped_idx (bool): Whether to return grouped idx. Default: False. """ def __init__( self, max_radius, sample_num, min_radius=0, use_xyz=True, return_grouped_xyz=False, normalize_xyz=False, uniform_sample=False, return_unique_cnt=False, return_grouped_idx=False, ): super(QueryAndGroup, self).__init__() self.max_radius = max_radius self.min_radius = min_radius self.sample_num = sample_num self.use_xyz = use_xyz self.return_grouped_xyz = return_grouped_xyz self.normalize_xyz = normalize_xyz self.uniform_sample = uniform_sample self.return_unique_cnt = return_unique_cnt self.return_grouped_idx = return_grouped_idx if self.return_unique_cnt: assert self.uniform_sample, ( "uniform_sample should be True when " "returning the count of unique samples" ) if self.max_radius is None: assert not self.normalize_xyz, "can not normalize grouped xyz when max_radius is None" def forward(self, points_xyz, center_xyz, features=None): """forward. Args: points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. center_xyz (Tensor): (B, npoint, 3) Centriods. features (Tensor): (B, C, N) Descriptors of the features. Return: Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. """ # if self.max_radius is None, we will perform kNN instead of ball query # idx is of shape [B, npoint, sample_num] if self.max_radius is None: idx = knn(self.sample_num, points_xyz, center_xyz, False) idx = idx.transpose(1, 2).contiguous() else: idx = ball_query( self.min_radius, self.max_radius, self.sample_num, points_xyz, center_xyz ) if self.uniform_sample: unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) for i_batch in range(idx.shape[0]): for i_region in range(idx.shape[1]): unique_ind = torch.unique(idx[i_batch, i_region, :]) num_unique = unique_ind.shape[0] unique_cnt[i_batch, i_region] = num_unique sample_ind = torch.randint( 0, num_unique, (self.sample_num - num_unique,), dtype=torch.long ) all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) idx[i_batch, i_region, :] = all_ind xyz_trans = points_xyz.transpose(1, 2).contiguous() # (B, 3, npoint, sample_num) grouped_xyz = grouping_operation(xyz_trans, idx) grouped_xyz_diff = grouped_xyz - center_xyz.transpose(1, 2).unsqueeze( -1 ) # relative offsets if self.normalize_xyz: grouped_xyz_diff /= self.max_radius if features is not None: grouped_features = grouping_operation(features, idx) if self.use_xyz: # (B, C + 3, npoint, sample_num) new_features = torch.cat([grouped_xyz_diff, grouped_features], dim=1) else: new_features = grouped_features else: assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" new_features = grouped_xyz_diff ret = [new_features] if self.return_grouped_xyz: ret.append(grouped_xyz) if self.return_unique_cnt: ret.append(unique_cnt) if self.return_grouped_idx: ret.append(idx) if len(ret) == 1: return ret[0] else: return tuple(ret) class GroupAll(nn.Module): """Group All. Group xyz with feature. Args: use_xyz (bool): Whether to use xyz. """ def __init__(self, use_xyz: bool = True): super().__init__() self.use_xyz = use_xyz def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): """forward. Args: xyz (Tensor): (B, N, 3) xyz coordinates of the features. new_xyz (Tensor): Ignored. features (Tensor): (B, C, N) features to group. Return: Tensor: (B, C + 3, 1, N) Grouped feature. """ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) if features is not None: grouped_features = features.unsqueeze(2) if self.use_xyz: new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N) else: new_features = grouped_features else: new_features = grouped_xyz return new_features class GroupingOperation(Function): """Grouping Operation. Group feature with given index. """ @staticmethod def forward(ctx, features: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """forward. Args: features (Tensor): (B, C, N) tensor of features to group. indices (Tensor): (B, npoint, nsample) the indicies of features to group with. Returns: Tensor: (B, C, npoint, nsample) Grouped features. """ assert features.is_contiguous() assert indices.is_contiguous() B, nfeatures, nsample = indices.size() _, C, N = features.size() output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) group_points_ext.forward(B, C, N, nfeatures, nsample, features, indices, output) ctx.for_backwards = (indices, N) return output @staticmethod def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """backward. Args: grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients of the output from forward. Returns: Tensor: (B, C, N) gradient of the features. """ idx, N = ctx.for_backwards B, C, npoint, nsample = grad_out.size() grad_features = torch.cuda.FloatTensor(B, C, N).zero_() grad_out_data = grad_out.data.contiguous() group_points_ext.backward(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) return grad_features, None grouping_operation = GroupingOperation.apply