bev-project/mmdet3d/models/backbones/sparse_encoder.py

219 lines
8.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from mmdet.models import BACKBONES
@BACKBONES.register_module()
class SparseEncoder(nn.Module):
r"""Sparse encoder for SECOND and Part-A2.
Args:
in_channels (int): The number of input channels.
sparse_shape (list[int]): The sparse shape of input tensor.
order (list[str], optional): Order of conv module.
Defaults to ('conv', 'norm', 'act').
norm_cfg (dict, optional): Config of normalization layer. Defaults to
dict(type='BN1d', eps=1e-3, momentum=0.01).
base_channels (int, optional): Out channels for conv_input layer.
Defaults to 16.
output_channels (int, optional): Out channels for conv_out layer.
Defaults to 128.
encoder_channels (tuple[tuple[int]], optional):
Convolutional channels of each encode block.
encoder_paddings (tuple[tuple[int]], optional):
Paddings of each encode block.
Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)).
block_type (str, optional): Type of the block to use.
Defaults to 'conv_module'.
"""
def __init__(
self,
in_channels,
sparse_shape,
order=("conv", "norm", "act"),
norm_cfg=dict(type="BN1d", eps=1e-3, momentum=0.01),
base_channels=16,
output_channels=128,
encoder_channels=((16,), (32, 32, 32), (64, 64, 64), (64, 64, 64)),
encoder_paddings=((1,), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)),
block_type="conv_module",
):
super().__init__()
assert block_type in ["conv_module", "basicblock"]
self.sparse_shape = sparse_shape
self.in_channels = in_channels
self.order = order
self.base_channels = base_channels
self.output_channels = output_channels
self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own
assert isinstance(order, (list, tuple)) and len(order) == 3
assert set(order) == {"conv", "norm", "act"}
if self.order[0] != "conv": # pre activate
self.conv_input = make_sparse_convmodule(
in_channels,
self.base_channels,
3,
norm_cfg=norm_cfg,
padding=1,
indice_key="subm1",
conv_type="SubMConv3d",
order=("conv",),
)
else: # post activate
self.conv_input = make_sparse_convmodule(
in_channels,
self.base_channels,
3,
norm_cfg=norm_cfg,
padding=1,
indice_key="subm1",
conv_type="SubMConv3d",
)
encoder_out_channels = self.make_encoder_layers(
make_sparse_convmodule, norm_cfg, self.base_channels, block_type=block_type
)
self.conv_out = make_sparse_convmodule(
encoder_out_channels,
self.output_channels,
kernel_size=(1, 1, 3),
stride=(1, 1, 2),
norm_cfg=norm_cfg,
padding=0,
indice_key="spconv_down2",
conv_type="SparseConv3d",
)
@auto_fp16(apply_to=("voxel_features",))
def forward(self, voxel_features, coors, batch_size, **kwargs):
"""Forward of SparseEncoder.
Args:
voxel_features (torch.float32): Voxel features in shape (N, C).
coors (torch.int32): Coordinates in shape (N, 4),
the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
batch_size (int): Batch size.
Returns:
dict: Backbone features.
"""
coors = coors.int()
input_sp_tensor = spconv.SparseConvTensor(
voxel_features, coors, self.sparse_shape, batch_size
)
x = self.conv_input(input_sp_tensor)
encode_features = []
for encoder_layer in self.encoder_layers:
x = encoder_layer(x)
encode_features.append(x)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(encode_features[-1])
spatial_features = out.dense()
N, C, H, W, D = spatial_features.shape
spatial_features = spatial_features.permute(0, 1, 4, 2, 3).contiguous()
spatial_features = spatial_features.view(N, C * D, H, W)
return spatial_features
def make_encoder_layers(
self,
make_block,
norm_cfg,
in_channels,
block_type="conv_module",
conv_cfg=dict(type="SubMConv3d"),
):
"""make encoder layers using sparse convs.
Args:
make_block (method): A bounded function to build blocks.
norm_cfg (dict[str]): Config of normalization layer.
in_channels (int): The number of encoder input channels.
block_type (str, optional): Type of the block to use.
Defaults to 'conv_module'.
conv_cfg (dict, optional): Config of conv layer. Defaults to
dict(type='SubMConv3d').
Returns:
int: The number of encoder output channels.
"""
assert block_type in ["conv_module", "basicblock"]
self.encoder_layers = spconv.SparseSequential()
for i, blocks in enumerate(self.encoder_channels):
blocks_list = []
for j, out_channels in enumerate(tuple(blocks)):
padding = tuple(self.encoder_paddings[i])[j]
# each stage started with a spconv layer
# except the first stage
if i != 0 and j == 0 and block_type == "conv_module":
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
stride=2,
padding=padding,
indice_key=f"spconv{i + 1}",
conv_type="SparseConv3d",
)
)
elif block_type == "basicblock":
if j == len(blocks) - 1 and i != len(self.encoder_channels) - 1:
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
stride=2,
padding=padding,
indice_key=f"spconv{i + 1}",
conv_type="SparseConv3d",
)
)
else:
blocks_list.append(
SparseBasicBlock(
out_channels,
out_channels,
norm_cfg=norm_cfg,
conv_cfg=conv_cfg,
)
)
else:
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
padding=padding,
indice_key=f"subm{i + 1}",
conv_type="SubMConv3d",
)
)
in_channels = out_channels
stage_name = f"encoder_layer{i + 1}"
stage_layers = spconv.SparseSequential(*blocks_list)
self.encoder_layers.add_module(stage_name, stage_layers)
return out_channels