from torch.autograd import Function from . import assign_score_withk_ext class AssignScoreWithK(Function): r"""Perform weighted sum to generate output features according to scores. Modified from `PAConv `_. This is a memory-efficient CUDA implementation of assign_scores operation, which first transform all point feature with weight bank, then assemble neighbor features with `knn_idx` and perform weighted sum of `scores`. See the `paper `_ appendix Sec. D for more detailed descriptions. Note: This implementation assumes using ``neighbor`` kernel input, which is (point_features - center_features, point_features). See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/ pointnet2/paconv.py#L128 for more details. """ @staticmethod def forward(ctx, scores, point_features, center_features, knn_idx, aggregate="sum"): """Forward. Args: scores (torch.Tensor): (B, npoint, K, M), predicted scores to aggregate weight matrices in the weight bank. ``npoint`` is the number of sampled centers. ``K`` is the number of queried neighbors. ``M`` is the number of weight matrices in the weight bank. point_features (torch.Tensor): (B, N, M, out_dim) Pre-computed point features to be aggregated. center_features (torch.Tensor): (B, N, M, out_dim) Pre-computed center features to be aggregated. knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN. We assume the first idx in each row is the idx of the center. aggregate (str, optional): Aggregation method. Can be 'sum', 'avg' or 'max'. Defaults to 'sum'. Returns: torch.Tensor: (B, out_dim, npoint, K), the aggregated features. """ agg = {"sum": 0, "avg": 1, "max": 2} B, N, M, out_dim = point_features.size() _, npoint, K, _ = scores.size() output = point_features.new_zeros((B, out_dim, npoint, K)) assign_score_withk_ext.assign_score_withk_forward_wrapper( B, N, npoint, M, K, out_dim, agg[aggregate], point_features.contiguous(), center_features.contiguous(), scores.contiguous(), knn_idx.contiguous(), output, ) ctx.save_for_backward(output, point_features, center_features, scores, knn_idx) ctx.agg = agg[aggregate] return output @staticmethod def backward(ctx, grad_out): """Backward. Args: grad_out (torch.Tensor): (B, out_dim, npoint, K) Returns: grad_scores (torch.Tensor): (B, npoint, K, M) grad_point_features (torch.Tensor): (B, N, M, out_dim) grad_center_features (torch.Tensor): (B, N, M, out_dim) """ _, point_features, center_features, scores, knn_idx = ctx.saved_tensors agg = ctx.agg B, N, M, out_dim = point_features.size() _, npoint, K, _ = scores.size() grad_point_features = point_features.new_zeros(point_features.shape) grad_center_features = center_features.new_zeros(center_features.shape) grad_scores = scores.new_zeros(scores.shape) assign_score_withk_ext.assign_score_withk_backward_wrapper( B, N, npoint, M, K, out_dim, agg, grad_out.contiguous(), point_features.contiguous(), center_features.contiguous(), scores.contiguous(), knn_idx.contiguous(), grad_point_features, grad_center_features, grad_scores, ) return grad_scores, grad_point_features, grad_center_features, None, None assign_score_withk = AssignScoreWithK.apply