355 lines
13 KiB
Python
355 lines
13 KiB
Python
|
|
import torch
|
||
|
|
from mmcv.cnn import ConvModule
|
||
|
|
from torch import nn as nn
|
||
|
|
from torch.nn import functional as F
|
||
|
|
|
||
|
|
from mmdet3d.ops import GroupAll, PAConv, Points_Sampler, QueryAndGroup, gather_points
|
||
|
|
from .builder import SA_MODULES
|
||
|
|
|
||
|
|
|
||
|
|
class BasePointSAModule(nn.Module):
|
||
|
|
"""Base module for point set abstraction module used in PointNets.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
num_point (int): Number of points.
|
||
|
|
radii (list[float]): List of radius in each ball query.
|
||
|
|
sample_nums (list[int]): Number of samples in each ball query.
|
||
|
|
mlp_channels (list[list[int]]): Specify of the pointnet before
|
||
|
|
the global pooling for each scale.
|
||
|
|
fps_mod (list[str]: Type of FPS method, valid mod
|
||
|
|
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
|
||
|
|
F-FPS: using feature distances for FPS.
|
||
|
|
D-FPS: using Euclidean distances of points for FPS.
|
||
|
|
FS: using F-FPS and D-FPS simultaneously.
|
||
|
|
fps_sample_range_list (list[int]): Range of points to apply FPS.
|
||
|
|
Default: [-1].
|
||
|
|
dilated_group (bool): Whether to use dilated ball query.
|
||
|
|
Default: False.
|
||
|
|
use_xyz (bool): Whether to use xyz.
|
||
|
|
Default: True.
|
||
|
|
pool_mod (str): Type of pooling method.
|
||
|
|
Default: 'max_pool'.
|
||
|
|
normalize_xyz (bool): Whether to normalize local XYZ with radius.
|
||
|
|
Default: False.
|
||
|
|
grouper_return_grouped_xyz (bool): Whether to return grouped xyz in
|
||
|
|
`QueryAndGroup`. Defaults to False.
|
||
|
|
grouper_return_grouped_idx (bool): Whether to return grouped idx in
|
||
|
|
`QueryAndGroup`. Defaults to False.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
num_point,
|
||
|
|
radii,
|
||
|
|
sample_nums,
|
||
|
|
mlp_channels,
|
||
|
|
fps_mod=["D-FPS"],
|
||
|
|
fps_sample_range_list=[-1],
|
||
|
|
dilated_group=False,
|
||
|
|
use_xyz=True,
|
||
|
|
pool_mod="max",
|
||
|
|
normalize_xyz=False,
|
||
|
|
grouper_return_grouped_xyz=False,
|
||
|
|
grouper_return_grouped_idx=False,
|
||
|
|
):
|
||
|
|
super(BasePointSAModule, self).__init__()
|
||
|
|
|
||
|
|
assert len(radii) == len(sample_nums) == len(mlp_channels)
|
||
|
|
assert pool_mod in ["max", "avg"]
|
||
|
|
assert isinstance(fps_mod, list) or isinstance(fps_mod, tuple)
|
||
|
|
assert isinstance(fps_sample_range_list, list) or isinstance(fps_sample_range_list, tuple)
|
||
|
|
assert len(fps_mod) == len(fps_sample_range_list)
|
||
|
|
|
||
|
|
if isinstance(mlp_channels, tuple):
|
||
|
|
mlp_channels = list(map(list, mlp_channels))
|
||
|
|
self.mlp_channels = mlp_channels
|
||
|
|
|
||
|
|
if isinstance(num_point, int):
|
||
|
|
self.num_point = [num_point]
|
||
|
|
elif isinstance(num_point, list) or isinstance(num_point, tuple):
|
||
|
|
self.num_point = num_point
|
||
|
|
else:
|
||
|
|
raise NotImplementedError("Error type of num_point!")
|
||
|
|
|
||
|
|
self.pool_mod = pool_mod
|
||
|
|
self.groupers = nn.ModuleList()
|
||
|
|
self.mlps = nn.ModuleList()
|
||
|
|
self.fps_mod_list = fps_mod
|
||
|
|
self.fps_sample_range_list = fps_sample_range_list
|
||
|
|
|
||
|
|
self.points_sampler = Points_Sampler(
|
||
|
|
self.num_point, self.fps_mod_list, self.fps_sample_range_list
|
||
|
|
)
|
||
|
|
|
||
|
|
for i in range(len(radii)):
|
||
|
|
radius = radii[i]
|
||
|
|
sample_num = sample_nums[i]
|
||
|
|
if num_point is not None:
|
||
|
|
if dilated_group and i != 0:
|
||
|
|
min_radius = radii[i - 1]
|
||
|
|
else:
|
||
|
|
min_radius = 0
|
||
|
|
grouper = QueryAndGroup(
|
||
|
|
radius,
|
||
|
|
sample_num,
|
||
|
|
min_radius=min_radius,
|
||
|
|
use_xyz=use_xyz,
|
||
|
|
normalize_xyz=normalize_xyz,
|
||
|
|
return_grouped_xyz=grouper_return_grouped_xyz,
|
||
|
|
return_grouped_idx=grouper_return_grouped_idx,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
grouper = GroupAll(use_xyz)
|
||
|
|
self.groupers.append(grouper)
|
||
|
|
|
||
|
|
def _sample_points(self, points_xyz, features, indices, target_xyz):
|
||
|
|
"""Perform point sampling based on inputs.
|
||
|
|
|
||
|
|
If `indices` is specified, directly sample corresponding points.
|
||
|
|
Else if `target_xyz` is specified, use is as sampled points.
|
||
|
|
Otherwise sample points using `self.points_sampler`.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
||
|
|
features (Tensor): (B, C, N) features of each point.
|
||
|
|
Default: None.
|
||
|
|
indices (Tensor): (B, num_point) Index of the features.
|
||
|
|
Default: None.
|
||
|
|
target_xyz (Tensor): (B, M, 3) new_xyz coordinates of the outputs.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tensor: (B, num_point, 3) sampled xyz coordinates of points.
|
||
|
|
Tensor: (B, num_point) sampled points' index.
|
||
|
|
"""
|
||
|
|
xyz_flipped = points_xyz.transpose(1, 2).contiguous()
|
||
|
|
if indices is not None:
|
||
|
|
assert indices.shape[1] == self.num_point[0]
|
||
|
|
new_xyz = (
|
||
|
|
gather_points(xyz_flipped, indices).transpose(1, 2).contiguous()
|
||
|
|
if self.num_point is not None
|
||
|
|
else None
|
||
|
|
)
|
||
|
|
elif target_xyz is not None:
|
||
|
|
new_xyz = target_xyz.contiguous()
|
||
|
|
else:
|
||
|
|
indices = self.points_sampler(points_xyz, features)
|
||
|
|
new_xyz = (
|
||
|
|
gather_points(xyz_flipped, indices).transpose(1, 2).contiguous()
|
||
|
|
if self.num_point is not None
|
||
|
|
else None
|
||
|
|
)
|
||
|
|
|
||
|
|
return new_xyz, indices
|
||
|
|
|
||
|
|
def _pool_features(self, features):
|
||
|
|
"""Perform feature aggregation using pooling operation.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
features (torch.Tensor): (B, C, N, K)
|
||
|
|
Features of locally grouped points before pooling.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
torch.Tensor: (B, C, N)
|
||
|
|
Pooled features aggregating local information.
|
||
|
|
"""
|
||
|
|
if self.pool_mod == "max":
|
||
|
|
# (B, C, N, 1)
|
||
|
|
new_features = F.max_pool2d(features, kernel_size=[1, features.size(3)])
|
||
|
|
elif self.pool_mod == "avg":
|
||
|
|
# (B, C, N, 1)
|
||
|
|
new_features = F.avg_pool2d(features, kernel_size=[1, features.size(3)])
|
||
|
|
else:
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
return new_features.squeeze(-1).contiguous()
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
points_xyz,
|
||
|
|
features=None,
|
||
|
|
indices=None,
|
||
|
|
target_xyz=None,
|
||
|
|
):
|
||
|
|
"""forward.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
||
|
|
features (Tensor): (B, C, N) features of each point.
|
||
|
|
Default: None.
|
||
|
|
indices (Tensor): (B, num_point) Index of the features.
|
||
|
|
Default: None.
|
||
|
|
target_xyz (Tensor): (B, M, 3) new_xyz coordinates of the outputs.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tensor: (B, M, 3) where M is the number of points.
|
||
|
|
New features xyz.
|
||
|
|
Tensor: (B, M, sum_k(mlps[k][-1])) where M is the number
|
||
|
|
of points. New feature descriptors.
|
||
|
|
Tensor: (B, M) where M is the number of points.
|
||
|
|
Index of the features.
|
||
|
|
"""
|
||
|
|
new_features_list = []
|
||
|
|
|
||
|
|
# sample points, (B, num_point, 3), (B, num_point)
|
||
|
|
new_xyz, indices = self._sample_points(points_xyz, features, indices, target_xyz)
|
||
|
|
|
||
|
|
for i in range(len(self.groupers)):
|
||
|
|
# grouped_results may contain:
|
||
|
|
# - grouped_features: (B, C, num_point, nsample)
|
||
|
|
# - grouped_xyz: (B, 3, num_point, nsample)
|
||
|
|
# - grouped_idx: (B, num_point, nsample)
|
||
|
|
grouped_results = self.groupers[i](points_xyz, new_xyz, features)
|
||
|
|
|
||
|
|
# (B, mlp[-1], num_point, nsample)
|
||
|
|
new_features = self.mlps[i](grouped_results)
|
||
|
|
|
||
|
|
# this is a bit hack because PAConv outputs two values
|
||
|
|
# we take the first one as feature
|
||
|
|
if isinstance(self.mlps[i][0], PAConv):
|
||
|
|
assert isinstance(new_features, tuple)
|
||
|
|
new_features = new_features[0]
|
||
|
|
|
||
|
|
# (B, mlp[-1], num_point)
|
||
|
|
new_features = self._pool_features(new_features)
|
||
|
|
new_features_list.append(new_features)
|
||
|
|
|
||
|
|
return new_xyz, torch.cat(new_features_list, dim=1), indices
|
||
|
|
|
||
|
|
|
||
|
|
@SA_MODULES.register_module()
|
||
|
|
class PointSAModuleMSG(BasePointSAModule):
|
||
|
|
"""Point set abstraction module with multi-scale grouping (MSG) used in
|
||
|
|
PointNets.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
num_point (int): Number of points.
|
||
|
|
radii (list[float]): List of radius in each ball query.
|
||
|
|
sample_nums (list[int]): Number of samples in each ball query.
|
||
|
|
mlp_channels (list[list[int]]): Specify of the pointnet before
|
||
|
|
the global pooling for each scale.
|
||
|
|
fps_mod (list[str]: Type of FPS method, valid mod
|
||
|
|
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
|
||
|
|
F-FPS: using feature distances for FPS.
|
||
|
|
D-FPS: using Euclidean distances of points for FPS.
|
||
|
|
FS: using F-FPS and D-FPS simultaneously.
|
||
|
|
fps_sample_range_list (list[int]): Range of points to apply FPS.
|
||
|
|
Default: [-1].
|
||
|
|
dilated_group (bool): Whether to use dilated ball query.
|
||
|
|
Default: False.
|
||
|
|
norm_cfg (dict): Type of normalization method.
|
||
|
|
Default: dict(type='BN2d').
|
||
|
|
use_xyz (bool): Whether to use xyz.
|
||
|
|
Default: True.
|
||
|
|
pool_mod (str): Type of pooling method.
|
||
|
|
Default: 'max_pool'.
|
||
|
|
normalize_xyz (bool): Whether to normalize local XYZ with radius.
|
||
|
|
Default: False.
|
||
|
|
bias (bool | str): If specified as `auto`, it will be decided by the
|
||
|
|
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
|
||
|
|
False. Default: "auto".
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
num_point,
|
||
|
|
radii,
|
||
|
|
sample_nums,
|
||
|
|
mlp_channels,
|
||
|
|
fps_mod=["D-FPS"],
|
||
|
|
fps_sample_range_list=[-1],
|
||
|
|
dilated_group=False,
|
||
|
|
norm_cfg=dict(type="BN2d"),
|
||
|
|
use_xyz=True,
|
||
|
|
pool_mod="max",
|
||
|
|
normalize_xyz=False,
|
||
|
|
bias="auto",
|
||
|
|
):
|
||
|
|
super(PointSAModuleMSG, self).__init__(
|
||
|
|
num_point=num_point,
|
||
|
|
radii=radii,
|
||
|
|
sample_nums=sample_nums,
|
||
|
|
mlp_channels=mlp_channels,
|
||
|
|
fps_mod=fps_mod,
|
||
|
|
fps_sample_range_list=fps_sample_range_list,
|
||
|
|
dilated_group=dilated_group,
|
||
|
|
use_xyz=use_xyz,
|
||
|
|
pool_mod=pool_mod,
|
||
|
|
normalize_xyz=normalize_xyz,
|
||
|
|
)
|
||
|
|
|
||
|
|
for i in range(len(self.mlp_channels)):
|
||
|
|
mlp_channel = self.mlp_channels[i]
|
||
|
|
if use_xyz:
|
||
|
|
mlp_channel[0] += 3
|
||
|
|
|
||
|
|
mlp = nn.Sequential()
|
||
|
|
for i in range(len(mlp_channel) - 1):
|
||
|
|
mlp.add_module(
|
||
|
|
f"layer{i}",
|
||
|
|
ConvModule(
|
||
|
|
mlp_channel[i],
|
||
|
|
mlp_channel[i + 1],
|
||
|
|
kernel_size=(1, 1),
|
||
|
|
stride=(1, 1),
|
||
|
|
conv_cfg=dict(type="Conv2d"),
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
bias=bias,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
self.mlps.append(mlp)
|
||
|
|
|
||
|
|
|
||
|
|
@SA_MODULES.register_module()
|
||
|
|
class PointSAModule(PointSAModuleMSG):
|
||
|
|
"""Point set abstraction module with single-scale grouping (SSG) used in
|
||
|
|
PointNets.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
mlp_channels (list[int]): Specify of the pointnet before
|
||
|
|
the global pooling for each scale.
|
||
|
|
num_point (int): Number of points.
|
||
|
|
Default: None.
|
||
|
|
radius (float): Radius to group with.
|
||
|
|
Default: None.
|
||
|
|
num_sample (int): Number of samples in each ball query.
|
||
|
|
Default: None.
|
||
|
|
norm_cfg (dict): Type of normalization method.
|
||
|
|
Default: dict(type='BN2d').
|
||
|
|
use_xyz (bool): Whether to use xyz.
|
||
|
|
Default: True.
|
||
|
|
pool_mod (str): Type of pooling method.
|
||
|
|
Default: 'max_pool'.
|
||
|
|
fps_mod (list[str]: Type of FPS method, valid mod
|
||
|
|
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
|
||
|
|
fps_sample_range_list (list[int]): Range of points to apply FPS.
|
||
|
|
Default: [-1].
|
||
|
|
normalize_xyz (bool): Whether to normalize local XYZ with radius.
|
||
|
|
Default: False.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
mlp_channels,
|
||
|
|
num_point=None,
|
||
|
|
radius=None,
|
||
|
|
num_sample=None,
|
||
|
|
norm_cfg=dict(type="BN2d"),
|
||
|
|
use_xyz=True,
|
||
|
|
pool_mod="max",
|
||
|
|
fps_mod=["D-FPS"],
|
||
|
|
fps_sample_range_list=[-1],
|
||
|
|
normalize_xyz=False,
|
||
|
|
):
|
||
|
|
super(PointSAModule, self).__init__(
|
||
|
|
mlp_channels=[mlp_channels],
|
||
|
|
num_point=num_point,
|
||
|
|
radii=[radius],
|
||
|
|
sample_nums=[num_sample],
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
use_xyz=use_xyz,
|
||
|
|
pool_mod=pool_mod,
|
||
|
|
fps_mod=fps_mod,
|
||
|
|
fps_sample_range_list=fps_sample_range_list,
|
||
|
|
normalize_xyz=normalize_xyz,
|
||
|
|
)
|