100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
|
||
|
|
from mmcv.runner import BaseModule, auto_fp16
|
||
|
|
from torch import nn as nn
|
||
|
|
|
||
|
|
from mmdet.models import NECKS
|
||
|
|
|
||
|
|
__all__ = ["SECONDFPN"]
|
||
|
|
|
||
|
|
|
||
|
|
@NECKS.register_module()
|
||
|
|
class SECONDFPN(BaseModule):
|
||
|
|
"""FPN used in SECOND/PointPillars/PartA2/MVXNet.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
in_channels (list[int]): Input channels of multi-scale feature maps.
|
||
|
|
out_channels (list[int]): Output channels of feature maps.
|
||
|
|
upsample_strides (list[int]): Strides used to upsample the
|
||
|
|
feature maps.
|
||
|
|
norm_cfg (dict): Config dict of normalization layers.
|
||
|
|
upsample_cfg (dict): Config dict of upsample layers.
|
||
|
|
conv_cfg (dict): Config dict of conv layers.
|
||
|
|
use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
in_channels=[128, 128, 256],
|
||
|
|
out_channels=[256, 256, 256],
|
||
|
|
upsample_strides=[1, 2, 4],
|
||
|
|
norm_cfg=dict(type="BN", eps=1e-3, momentum=0.01),
|
||
|
|
upsample_cfg=dict(type="deconv", bias=False),
|
||
|
|
conv_cfg=dict(type="Conv2d", bias=False),
|
||
|
|
use_conv_for_no_stride=False,
|
||
|
|
init_cfg=None,
|
||
|
|
):
|
||
|
|
# if for GroupNorm,
|
||
|
|
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
|
||
|
|
super(SECONDFPN, self).__init__(init_cfg=init_cfg)
|
||
|
|
assert len(out_channels) == len(upsample_strides) == len(in_channels)
|
||
|
|
self.in_channels = in_channels
|
||
|
|
self.out_channels = out_channels
|
||
|
|
self.fp16_enabled = False
|
||
|
|
|
||
|
|
deblocks = []
|
||
|
|
for i, out_channel in enumerate(out_channels):
|
||
|
|
stride = upsample_strides[i]
|
||
|
|
if stride > 1 or (stride == 1 and not use_conv_for_no_stride):
|
||
|
|
upsample_layer = build_upsample_layer(
|
||
|
|
upsample_cfg,
|
||
|
|
in_channels=in_channels[i],
|
||
|
|
out_channels=out_channel,
|
||
|
|
kernel_size=upsample_strides[i],
|
||
|
|
stride=upsample_strides[i],
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
stride = np.round(1 / stride).astype(np.int64)
|
||
|
|
upsample_layer = build_conv_layer(
|
||
|
|
conv_cfg,
|
||
|
|
in_channels=in_channels[i],
|
||
|
|
out_channels=out_channel,
|
||
|
|
kernel_size=stride,
|
||
|
|
stride=stride,
|
||
|
|
)
|
||
|
|
|
||
|
|
deblock = nn.Sequential(
|
||
|
|
upsample_layer,
|
||
|
|
build_norm_layer(norm_cfg, out_channel)[1],
|
||
|
|
nn.ReLU(inplace=True),
|
||
|
|
)
|
||
|
|
deblocks.append(deblock)
|
||
|
|
self.deblocks = nn.ModuleList(deblocks)
|
||
|
|
|
||
|
|
if init_cfg is None:
|
||
|
|
self.init_cfg = [
|
||
|
|
dict(type="Kaiming", layer="ConvTranspose2d"),
|
||
|
|
dict(type="Constant", layer="NaiveSyncBatchNorm2d", val=1.0),
|
||
|
|
]
|
||
|
|
|
||
|
|
@auto_fp16()
|
||
|
|
def forward(self, x):
|
||
|
|
"""Forward function.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
list[torch.Tensor]: Multi-level feature maps.
|
||
|
|
"""
|
||
|
|
assert len(x) == len(self.in_channels)
|
||
|
|
ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
|
||
|
|
|
||
|
|
if len(ups) > 1:
|
||
|
|
out = torch.cat(ups, dim=1)
|
||
|
|
else:
|
||
|
|
out = ups[0]
|
||
|
|
return [out]
|