bev-project/mmdet3d/ops/pointnet_modules/point_fp_module.py

80 lines
2.6 KiB
Python
Raw Permalink Normal View History

2022-06-03 12:21:18 +08:00
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from typing import List
from mmdet3d.ops import three_interpolate, three_nn
class PointFPModule(BaseModule):
"""Point feature propagation module used in PointNets.
Propagate the features from one set to another.
Args:
mlp_channels (list[int]): List of mlp channels.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
"""
def __init__(self, mlp_channels: List[int], norm_cfg: dict = dict(type="BN2d"), init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.fp16_enabled = False
self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1):
self.mlps.add_module(
f"layer{i}",
ConvModule(
mlp_channels[i],
mlp_channels[i + 1],
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type="Conv2d"),
norm_cfg=norm_cfg,
),
)
@force_fp32()
def forward(
self,
target: torch.Tensor,
source: torch.Tensor,
target_feats: torch.Tensor,
source_feats: torch.Tensor,
) -> torch.Tensor:
"""forward.
Args:
target (Tensor): (B, n, 3) tensor of the xyz positions of
the target features.
source (Tensor): (B, m, 3) tensor of the xyz positions of
the source features.
target_feats (Tensor): (B, C1, n) tensor of the features to be
propagated to.
source_feats (Tensor): (B, C2, m) tensor of features
to be propagated.
Return:
Tensor: (B, M, N) M = mlp[-1], tensor of the target features.
"""
if source is not None:
dist, idx = three_nn(target, source)
dist_reciprocal = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_reciprocal, dim=2, keepdim=True)
weight = dist_reciprocal / norm
interpolated_feats = three_interpolate(source_feats, idx, weight)
else:
interpolated_feats = source_feats.expand(*source_feats.size()[0:2], target.size(1))
if target_feats is not None:
new_features = torch.cat([interpolated_feats, target_feats], dim=1) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlps(new_features)
return new_features.squeeze(-1)