349 lines
12 KiB
Python
349 lines
12 KiB
Python
|
|
import torch
|
||
|
|
from torch import nn as nn
|
||
|
|
|
||
|
|
from mmdet3d.ops import PAConv, PAConvCUDA
|
||
|
|
from .builder import SA_MODULES
|
||
|
|
from .point_sa_module import BasePointSAModule
|
||
|
|
|
||
|
|
|
||
|
|
@SA_MODULES.register_module()
|
||
|
|
class PAConvSAModuleMSG(BasePointSAModule):
|
||
|
|
r"""Point set abstraction module with multi-scale grouping (MSG) used in
|
||
|
|
PAConv networks.
|
||
|
|
|
||
|
|
Replace the MLPs in `PointSAModuleMSG` with PAConv layers.
|
||
|
|
See the `paper <https://arxiv.org/abs/2103.14635>`_ for more details.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
paconv_num_kernels (list[list[int]]): Number of kernel weights in the
|
||
|
|
weight banks of each layer's PAConv.
|
||
|
|
paconv_kernel_input (str, optional): Input features to be multiplied
|
||
|
|
with kernel weights. Can be 'identity' or 'w_neighbor'.
|
||
|
|
Defaults to 'w_neighbor'.
|
||
|
|
scorenet_input (str, optional): Type of the input to ScoreNet.
|
||
|
|
Defaults to 'w_neighbor_dist'. Can be the following values:
|
||
|
|
|
||
|
|
- 'identity': Use xyz coordinates as input.
|
||
|
|
- 'w_neighbor': Use xyz coordinates and the difference with center
|
||
|
|
points as input.
|
||
|
|
- 'w_neighbor_dist': Use xyz coordinates, the difference with
|
||
|
|
center points and the Euclidian distance as input.
|
||
|
|
|
||
|
|
scorenet_cfg (dict, optional): Config of the ScoreNet module, which
|
||
|
|
may contain the following keys and values:
|
||
|
|
|
||
|
|
- mlp_channels (List[int]): Hidden units of MLPs.
|
||
|
|
- score_norm (str): Normalization function of output scores.
|
||
|
|
Can be 'softmax', 'sigmoid' or 'identity'.
|
||
|
|
- temp_factor (float): Temperature factor to scale the output
|
||
|
|
scores before softmax.
|
||
|
|
- last_bn (bool): Whether to use BN on the last output of mlps.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
num_point,
|
||
|
|
radii,
|
||
|
|
sample_nums,
|
||
|
|
mlp_channels,
|
||
|
|
paconv_num_kernels,
|
||
|
|
fps_mod=["D-FPS"],
|
||
|
|
fps_sample_range_list=[-1],
|
||
|
|
dilated_group=False,
|
||
|
|
norm_cfg=dict(type="BN2d", momentum=0.1),
|
||
|
|
use_xyz=True,
|
||
|
|
pool_mod="max",
|
||
|
|
normalize_xyz=False,
|
||
|
|
bias="auto",
|
||
|
|
paconv_kernel_input="w_neighbor",
|
||
|
|
scorenet_input="w_neighbor_dist",
|
||
|
|
scorenet_cfg=dict(
|
||
|
|
mlp_channels=[16, 16, 16], score_norm="softmax", temp_factor=1.0, last_bn=False
|
||
|
|
),
|
||
|
|
):
|
||
|
|
super(PAConvSAModuleMSG, self).__init__(
|
||
|
|
num_point=num_point,
|
||
|
|
radii=radii,
|
||
|
|
sample_nums=sample_nums,
|
||
|
|
mlp_channels=mlp_channels,
|
||
|
|
fps_mod=fps_mod,
|
||
|
|
fps_sample_range_list=fps_sample_range_list,
|
||
|
|
dilated_group=dilated_group,
|
||
|
|
use_xyz=use_xyz,
|
||
|
|
pool_mod=pool_mod,
|
||
|
|
normalize_xyz=normalize_xyz,
|
||
|
|
grouper_return_grouped_xyz=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert len(paconv_num_kernels) == len(mlp_channels)
|
||
|
|
for i in range(len(mlp_channels)):
|
||
|
|
assert (
|
||
|
|
len(paconv_num_kernels[i]) == len(mlp_channels[i]) - 1
|
||
|
|
), "PAConv number of kernel weights wrong"
|
||
|
|
|
||
|
|
# in PAConv, bias only exists in ScoreNet
|
||
|
|
scorenet_cfg["bias"] = bias
|
||
|
|
|
||
|
|
for i in range(len(self.mlp_channels)):
|
||
|
|
mlp_channel = self.mlp_channels[i]
|
||
|
|
if use_xyz:
|
||
|
|
mlp_channel[0] += 3
|
||
|
|
|
||
|
|
num_kernels = paconv_num_kernels[i]
|
||
|
|
|
||
|
|
mlp = nn.Sequential()
|
||
|
|
for i in range(len(mlp_channel) - 1):
|
||
|
|
mlp.add_module(
|
||
|
|
f"layer{i}",
|
||
|
|
PAConv(
|
||
|
|
mlp_channel[i],
|
||
|
|
mlp_channel[i + 1],
|
||
|
|
num_kernels[i],
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
kernel_input=paconv_kernel_input,
|
||
|
|
scorenet_input=scorenet_input,
|
||
|
|
scorenet_cfg=scorenet_cfg,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
self.mlps.append(mlp)
|
||
|
|
|
||
|
|
|
||
|
|
@SA_MODULES.register_module()
|
||
|
|
class PAConvSAModule(PAConvSAModuleMSG):
|
||
|
|
r"""Point set abstraction module with single-scale grouping (SSG) used in
|
||
|
|
PAConv networks.
|
||
|
|
|
||
|
|
Replace the MLPs in `PointSAModule` with PAConv layers. See the `paper
|
||
|
|
<https://arxiv.org/abs/2103.14635>`_ for more details.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
mlp_channels,
|
||
|
|
paconv_num_kernels,
|
||
|
|
num_point=None,
|
||
|
|
radius=None,
|
||
|
|
num_sample=None,
|
||
|
|
norm_cfg=dict(type="BN2d", momentum=0.1),
|
||
|
|
use_xyz=True,
|
||
|
|
pool_mod="max",
|
||
|
|
fps_mod=["D-FPS"],
|
||
|
|
fps_sample_range_list=[-1],
|
||
|
|
normalize_xyz=False,
|
||
|
|
paconv_kernel_input="w_neighbor",
|
||
|
|
scorenet_input="w_neighbor_dist",
|
||
|
|
scorenet_cfg=dict(
|
||
|
|
mlp_channels=[16, 16, 16], score_norm="softmax", temp_factor=1.0, last_bn=False
|
||
|
|
),
|
||
|
|
):
|
||
|
|
super(PAConvSAModule, self).__init__(
|
||
|
|
mlp_channels=[mlp_channels],
|
||
|
|
paconv_num_kernels=[paconv_num_kernels],
|
||
|
|
num_point=num_point,
|
||
|
|
radii=[radius],
|
||
|
|
sample_nums=[num_sample],
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
use_xyz=use_xyz,
|
||
|
|
pool_mod=pool_mod,
|
||
|
|
fps_mod=fps_mod,
|
||
|
|
fps_sample_range_list=fps_sample_range_list,
|
||
|
|
normalize_xyz=normalize_xyz,
|
||
|
|
paconv_kernel_input=paconv_kernel_input,
|
||
|
|
scorenet_input=scorenet_input,
|
||
|
|
scorenet_cfg=scorenet_cfg,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@SA_MODULES.register_module()
|
||
|
|
class PAConvCUDASAModuleMSG(BasePointSAModule):
|
||
|
|
r"""Point set abstraction module with multi-scale grouping (MSG) used in
|
||
|
|
PAConv networks.
|
||
|
|
|
||
|
|
Replace the non CUDA version PAConv with CUDA implemented PAConv for
|
||
|
|
efficient computation. See the `paper <https://arxiv.org/abs/2103.14635>`_
|
||
|
|
for more details.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
num_point,
|
||
|
|
radii,
|
||
|
|
sample_nums,
|
||
|
|
mlp_channels,
|
||
|
|
paconv_num_kernels,
|
||
|
|
fps_mod=["D-FPS"],
|
||
|
|
fps_sample_range_list=[-1],
|
||
|
|
dilated_group=False,
|
||
|
|
norm_cfg=dict(type="BN2d", momentum=0.1),
|
||
|
|
use_xyz=True,
|
||
|
|
pool_mod="max",
|
||
|
|
normalize_xyz=False,
|
||
|
|
bias="auto",
|
||
|
|
paconv_kernel_input="w_neighbor",
|
||
|
|
scorenet_input="w_neighbor_dist",
|
||
|
|
scorenet_cfg=dict(
|
||
|
|
mlp_channels=[8, 16, 16], score_norm="softmax", temp_factor=1.0, last_bn=False
|
||
|
|
),
|
||
|
|
):
|
||
|
|
super(PAConvCUDASAModuleMSG, self).__init__(
|
||
|
|
num_point=num_point,
|
||
|
|
radii=radii,
|
||
|
|
sample_nums=sample_nums,
|
||
|
|
mlp_channels=mlp_channels,
|
||
|
|
fps_mod=fps_mod,
|
||
|
|
fps_sample_range_list=fps_sample_range_list,
|
||
|
|
dilated_group=dilated_group,
|
||
|
|
use_xyz=use_xyz,
|
||
|
|
pool_mod=pool_mod,
|
||
|
|
normalize_xyz=normalize_xyz,
|
||
|
|
grouper_return_grouped_xyz=True,
|
||
|
|
grouper_return_grouped_idx=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert len(paconv_num_kernels) == len(mlp_channels)
|
||
|
|
for i in range(len(mlp_channels)):
|
||
|
|
assert (
|
||
|
|
len(paconv_num_kernels[i]) == len(mlp_channels[i]) - 1
|
||
|
|
), "PAConv number of kernel weights wrong"
|
||
|
|
|
||
|
|
# in PAConv, bias only exists in ScoreNet
|
||
|
|
scorenet_cfg["bias"] = bias
|
||
|
|
|
||
|
|
# we need to manually concat xyz for CUDA implemented PAConv
|
||
|
|
self.use_xyz = use_xyz
|
||
|
|
|
||
|
|
for i in range(len(self.mlp_channels)):
|
||
|
|
mlp_channel = self.mlp_channels[i]
|
||
|
|
if use_xyz:
|
||
|
|
mlp_channel[0] += 3
|
||
|
|
|
||
|
|
num_kernels = paconv_num_kernels[i]
|
||
|
|
|
||
|
|
# can't use `nn.Sequential` for PAConvCUDA because its input and
|
||
|
|
# output have different shapes
|
||
|
|
mlp = nn.ModuleList()
|
||
|
|
for i in range(len(mlp_channel) - 1):
|
||
|
|
mlp.append(
|
||
|
|
PAConvCUDA(
|
||
|
|
mlp_channel[i],
|
||
|
|
mlp_channel[i + 1],
|
||
|
|
num_kernels[i],
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
kernel_input=paconv_kernel_input,
|
||
|
|
scorenet_input=scorenet_input,
|
||
|
|
scorenet_cfg=scorenet_cfg,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
self.mlps.append(mlp)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
points_xyz,
|
||
|
|
features=None,
|
||
|
|
indices=None,
|
||
|
|
target_xyz=None,
|
||
|
|
):
|
||
|
|
"""forward.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
||
|
|
features (Tensor): (B, C, N) features of each point.
|
||
|
|
Default: None.
|
||
|
|
indices (Tensor): (B, num_point) Index of the features.
|
||
|
|
Default: None.
|
||
|
|
target_xyz (Tensor): (B, M, 3) new_xyz coordinates of the outputs.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tensor: (B, M, 3) where M is the number of points.
|
||
|
|
New features xyz.
|
||
|
|
Tensor: (B, M, sum_k(mlps[k][-1])) where M is the number
|
||
|
|
of points. New feature descriptors.
|
||
|
|
Tensor: (B, M) where M is the number of points.
|
||
|
|
Index of the features.
|
||
|
|
"""
|
||
|
|
new_features_list = []
|
||
|
|
|
||
|
|
# sample points, (B, num_point, 3), (B, num_point)
|
||
|
|
new_xyz, indices = self._sample_points(points_xyz, features, indices, target_xyz)
|
||
|
|
|
||
|
|
for i in range(len(self.groupers)):
|
||
|
|
xyz = points_xyz
|
||
|
|
new_features = features
|
||
|
|
for j in range(len(self.mlps[i])):
|
||
|
|
# we don't use grouped_features here to avoid large GPU memory
|
||
|
|
# _, (B, 3, num_point, nsample), (B, num_point, nsample)
|
||
|
|
_, grouped_xyz, grouped_idx = self.groupers[i](xyz, new_xyz, new_features)
|
||
|
|
|
||
|
|
# concat xyz as additional features
|
||
|
|
if self.use_xyz and j == 0:
|
||
|
|
# (B, C+3, N)
|
||
|
|
new_features = torch.cat((points_xyz.permute(0, 2, 1), new_features), dim=1)
|
||
|
|
|
||
|
|
# (B, out_c, num_point, nsample)
|
||
|
|
grouped_new_features = self.mlps[i][j](
|
||
|
|
(new_features, grouped_xyz, grouped_idx.long())
|
||
|
|
)[0]
|
||
|
|
|
||
|
|
# different from PointNet++ and non CUDA version of PAConv
|
||
|
|
# CUDA version of PAConv needs to aggregate local features
|
||
|
|
# every time after it passes through a Conv layer
|
||
|
|
# in order to transform to valid input shape
|
||
|
|
# (B, out_c, num_point)
|
||
|
|
new_features = self._pool_features(grouped_new_features)
|
||
|
|
|
||
|
|
# constrain the points to be grouped for next PAConv layer
|
||
|
|
# because new_features only contains sampled centers now
|
||
|
|
# (B, num_point, 3)
|
||
|
|
xyz = new_xyz
|
||
|
|
|
||
|
|
new_features_list.append(new_features)
|
||
|
|
|
||
|
|
return new_xyz, torch.cat(new_features_list, dim=1), indices
|
||
|
|
|
||
|
|
|
||
|
|
@SA_MODULES.register_module()
|
||
|
|
class PAConvCUDASAModule(PAConvCUDASAModuleMSG):
|
||
|
|
r"""Point set abstraction module with single-scale grouping (SSG) used in
|
||
|
|
PAConv networks.
|
||
|
|
|
||
|
|
Replace the non CUDA version PAConv with CUDA implemented PAConv for
|
||
|
|
efficient computation. See the `paper <https://arxiv.org/abs/2103.14635>`_
|
||
|
|
for more details.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
mlp_channels,
|
||
|
|
paconv_num_kernels,
|
||
|
|
num_point=None,
|
||
|
|
radius=None,
|
||
|
|
num_sample=None,
|
||
|
|
norm_cfg=dict(type="BN2d", momentum=0.1),
|
||
|
|
use_xyz=True,
|
||
|
|
pool_mod="max",
|
||
|
|
fps_mod=["D-FPS"],
|
||
|
|
fps_sample_range_list=[-1],
|
||
|
|
normalize_xyz=False,
|
||
|
|
paconv_kernel_input="w_neighbor",
|
||
|
|
scorenet_input="w_neighbor_dist",
|
||
|
|
scorenet_cfg=dict(
|
||
|
|
mlp_channels=[8, 16, 16], score_norm="softmax", temp_factor=1.0, last_bn=False
|
||
|
|
),
|
||
|
|
):
|
||
|
|
super(PAConvCUDASAModule, self).__init__(
|
||
|
|
mlp_channels=[mlp_channels],
|
||
|
|
paconv_num_kernels=[paconv_num_kernels],
|
||
|
|
num_point=num_point,
|
||
|
|
radii=[radius],
|
||
|
|
sample_nums=[num_sample],
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
use_xyz=use_xyz,
|
||
|
|
pool_mod=pool_mod,
|
||
|
|
fps_mod=fps_mod,
|
||
|
|
fps_sample_range_list=fps_sample_range_list,
|
||
|
|
normalize_xyz=normalize_xyz,
|
||
|
|
paconv_kernel_input=paconv_kernel_input,
|
||
|
|
scorenet_input=scorenet_input,
|
||
|
|
scorenet_cfg=scorenet_cfg,
|
||
|
|
)
|