80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
|
|
import torch
|
||
|
|
from mmcv.cnn import ConvModule
|
||
|
|
from mmcv.runner import BaseModule, force_fp32
|
||
|
|
from torch import nn as nn
|
||
|
|
from typing import List
|
||
|
|
|
||
|
|
from mmdet3d.ops import three_interpolate, three_nn
|
||
|
|
|
||
|
|
|
||
|
|
class PointFPModule(BaseModule):
|
||
|
|
"""Point feature propagation module used in PointNets.
|
||
|
|
|
||
|
|
Propagate the features from one set to another.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
mlp_channels (list[int]): List of mlp channels.
|
||
|
|
norm_cfg (dict): Type of normalization method.
|
||
|
|
Default: dict(type='BN2d').
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, mlp_channels: List[int], norm_cfg: dict = dict(type="BN2d"), init_cfg=None):
|
||
|
|
super().__init__(init_cfg=init_cfg)
|
||
|
|
self.fp16_enabled = False
|
||
|
|
self.mlps = nn.Sequential()
|
||
|
|
for i in range(len(mlp_channels) - 1):
|
||
|
|
self.mlps.add_module(
|
||
|
|
f"layer{i}",
|
||
|
|
ConvModule(
|
||
|
|
mlp_channels[i],
|
||
|
|
mlp_channels[i + 1],
|
||
|
|
kernel_size=(1, 1),
|
||
|
|
stride=(1, 1),
|
||
|
|
conv_cfg=dict(type="Conv2d"),
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
@force_fp32()
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
target: torch.Tensor,
|
||
|
|
source: torch.Tensor,
|
||
|
|
target_feats: torch.Tensor,
|
||
|
|
source_feats: torch.Tensor,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
"""forward.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
target (Tensor): (B, n, 3) tensor of the xyz positions of
|
||
|
|
the target features.
|
||
|
|
source (Tensor): (B, m, 3) tensor of the xyz positions of
|
||
|
|
the source features.
|
||
|
|
target_feats (Tensor): (B, C1, n) tensor of the features to be
|
||
|
|
propagated to.
|
||
|
|
source_feats (Tensor): (B, C2, m) tensor of features
|
||
|
|
to be propagated.
|
||
|
|
|
||
|
|
Return:
|
||
|
|
Tensor: (B, M, N) M = mlp[-1], tensor of the target features.
|
||
|
|
"""
|
||
|
|
if source is not None:
|
||
|
|
dist, idx = three_nn(target, source)
|
||
|
|
dist_reciprocal = 1.0 / (dist + 1e-8)
|
||
|
|
norm = torch.sum(dist_reciprocal, dim=2, keepdim=True)
|
||
|
|
weight = dist_reciprocal / norm
|
||
|
|
|
||
|
|
interpolated_feats = three_interpolate(source_feats, idx, weight)
|
||
|
|
else:
|
||
|
|
interpolated_feats = source_feats.expand(*source_feats.size()[0:2], target.size(1))
|
||
|
|
|
||
|
|
if target_feats is not None:
|
||
|
|
new_features = torch.cat([interpolated_feats, target_feats], dim=1) # (B, C2 + C1, n)
|
||
|
|
else:
|
||
|
|
new_features = interpolated_feats
|
||
|
|
|
||
|
|
new_features = new_features.unsqueeze(-1)
|
||
|
|
new_features = self.mlps(new_features)
|
||
|
|
|
||
|
|
return new_features.squeeze(-1)
|