221 lines
7.4 KiB
Python
221 lines
7.4 KiB
Python
import torch
|
||
from torch import nn as nn
|
||
from torch.autograd import Function
|
||
from typing import Tuple
|
||
|
||
from ..ball_query import ball_query
|
||
from ..knn import knn
|
||
from . import group_points_ext
|
||
|
||
|
||
class QueryAndGroup(nn.Module):
|
||
"""Query and Group.
|
||
|
||
Groups with a ball query of radius
|
||
|
||
Args:
|
||
max_radius (float | None): The maximum radius of the balls.
|
||
If None is given, we will use kNN sampling instead of ball query.
|
||
sample_num (int): Maximum number of features to gather in the ball.
|
||
min_radius (float): The minimum radius of the balls.
|
||
use_xyz (bool): Whether to use xyz.
|
||
Default: True.
|
||
return_grouped_xyz (bool): Whether to return grouped xyz.
|
||
Default: False.
|
||
normalize_xyz (bool): Whether to normalize xyz.
|
||
Default: False.
|
||
uniform_sample (bool): Whether to sample uniformly.
|
||
Default: False
|
||
return_unique_cnt (bool): Whether to return the count of
|
||
unique samples.
|
||
Default: False.
|
||
return_grouped_idx (bool): Whether to return grouped idx.
|
||
Default: False.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
max_radius,
|
||
sample_num,
|
||
min_radius=0,
|
||
use_xyz=True,
|
||
return_grouped_xyz=False,
|
||
normalize_xyz=False,
|
||
uniform_sample=False,
|
||
return_unique_cnt=False,
|
||
return_grouped_idx=False,
|
||
):
|
||
super(QueryAndGroup, self).__init__()
|
||
self.max_radius = max_radius
|
||
self.min_radius = min_radius
|
||
self.sample_num = sample_num
|
||
self.use_xyz = use_xyz
|
||
self.return_grouped_xyz = return_grouped_xyz
|
||
self.normalize_xyz = normalize_xyz
|
||
self.uniform_sample = uniform_sample
|
||
self.return_unique_cnt = return_unique_cnt
|
||
self.return_grouped_idx = return_grouped_idx
|
||
if self.return_unique_cnt:
|
||
assert self.uniform_sample, (
|
||
"uniform_sample should be True when " "returning the count of unique samples"
|
||
)
|
||
if self.max_radius is None:
|
||
assert not self.normalize_xyz, "can not normalize grouped xyz when max_radius is None"
|
||
|
||
def forward(self, points_xyz, center_xyz, features=None):
|
||
"""forward.
|
||
|
||
Args:
|
||
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
||
center_xyz (Tensor): (B, npoint, 3) Centriods.
|
||
features (Tensor): (B, C, N) Descriptors of the features.
|
||
|
||
Return:
|
||
Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
|
||
"""
|
||
# if self.max_radius is None, we will perform kNN instead of ball query
|
||
# idx is of shape [B, npoint, sample_num]
|
||
if self.max_radius is None:
|
||
idx = knn(self.sample_num, points_xyz, center_xyz, False)
|
||
idx = idx.transpose(1, 2).contiguous()
|
||
else:
|
||
idx = ball_query(
|
||
self.min_radius, self.max_radius, self.sample_num, points_xyz, center_xyz
|
||
)
|
||
|
||
if self.uniform_sample:
|
||
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
|
||
for i_batch in range(idx.shape[0]):
|
||
for i_region in range(idx.shape[1]):
|
||
unique_ind = torch.unique(idx[i_batch, i_region, :])
|
||
num_unique = unique_ind.shape[0]
|
||
unique_cnt[i_batch, i_region] = num_unique
|
||
sample_ind = torch.randint(
|
||
0, num_unique, (self.sample_num - num_unique,), dtype=torch.long
|
||
)
|
||
all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
|
||
idx[i_batch, i_region, :] = all_ind
|
||
|
||
xyz_trans = points_xyz.transpose(1, 2).contiguous()
|
||
# (B, 3, npoint, sample_num)
|
||
grouped_xyz = grouping_operation(xyz_trans, idx)
|
||
grouped_xyz_diff = grouped_xyz - center_xyz.transpose(1, 2).unsqueeze(
|
||
-1
|
||
) # relative offsets
|
||
if self.normalize_xyz:
|
||
grouped_xyz_diff /= self.max_radius
|
||
|
||
if features is not None:
|
||
grouped_features = grouping_operation(features, idx)
|
||
if self.use_xyz:
|
||
# (B, C + 3, npoint, sample_num)
|
||
new_features = torch.cat([grouped_xyz_diff, grouped_features], dim=1)
|
||
else:
|
||
new_features = grouped_features
|
||
else:
|
||
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
|
||
new_features = grouped_xyz_diff
|
||
|
||
ret = [new_features]
|
||
if self.return_grouped_xyz:
|
||
ret.append(grouped_xyz)
|
||
if self.return_unique_cnt:
|
||
ret.append(unique_cnt)
|
||
if self.return_grouped_idx:
|
||
ret.append(idx)
|
||
if len(ret) == 1:
|
||
return ret[0]
|
||
else:
|
||
return tuple(ret)
|
||
|
||
|
||
class GroupAll(nn.Module):
|
||
"""Group All.
|
||
|
||
Group xyz with feature.
|
||
|
||
Args:
|
||
use_xyz (bool): Whether to use xyz.
|
||
"""
|
||
|
||
def __init__(self, use_xyz: bool = True):
|
||
super().__init__()
|
||
self.use_xyz = use_xyz
|
||
|
||
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
|
||
"""forward.
|
||
|
||
Args:
|
||
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
||
new_xyz (Tensor): Ignored.
|
||
features (Tensor): (B, C, N) features to group.
|
||
|
||
Return:
|
||
Tensor: (B, C + 3, 1, N) Grouped feature.
|
||
"""
|
||
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
||
if features is not None:
|
||
grouped_features = features.unsqueeze(2)
|
||
if self.use_xyz:
|
||
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
|
||
else:
|
||
new_features = grouped_features
|
||
else:
|
||
new_features = grouped_xyz
|
||
|
||
return new_features
|
||
|
||
|
||
class GroupingOperation(Function):
|
||
"""Grouping Operation.
|
||
|
||
Group feature with given index.
|
||
"""
|
||
|
||
@staticmethod
|
||
def forward(ctx, features: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||
"""forward.
|
||
|
||
Args:
|
||
features (Tensor): (B, C, N) tensor of features to group.
|
||
indices (Tensor): (B, npoint, nsample) the indicies of
|
||
features to group with.
|
||
|
||
Returns:
|
||
Tensor: (B, C, npoint, nsample) Grouped features.
|
||
"""
|
||
assert features.is_contiguous()
|
||
assert indices.is_contiguous()
|
||
|
||
B, nfeatures, nsample = indices.size()
|
||
_, C, N = features.size()
|
||
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
||
|
||
group_points_ext.forward(B, C, N, nfeatures, nsample, features, indices, output)
|
||
|
||
ctx.for_backwards = (indices, N)
|
||
return output
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""backward.
|
||
|
||
Args:
|
||
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
|
||
of the output from forward.
|
||
|
||
Returns:
|
||
Tensor: (B, C, N) gradient of the features.
|
||
"""
|
||
idx, N = ctx.for_backwards
|
||
|
||
B, C, npoint, nsample = grad_out.size()
|
||
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
|
||
|
||
grad_out_data = grad_out.data.contiguous()
|
||
group_points_ext.backward(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
|
||
return grad_features, None
|
||
|
||
|
||
grouping_operation = GroupingOperation.apply
|