bev-project/mmdet3d/models/backbones/dla.py

355 lines
11 KiB
Python
Raw Normal View History

2022-06-03 12:21:18 +08:00
# Copyright 2021 Toyota Research Institute. All rights reserved.
# Adapted from:
# https://github.com/ucbdrive/dla/blob/master/dla.py
from collections import OrderedDict
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.builder import BACKBONES
from mmcv.runner import BaseModule
from mmcv.cnn import ConvModule
__all__ = ["DLA"]
class BasicBlock(nn.Module):
def __init__(
self, inplanes, planes, stride=1, dilation=1,
conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"),
act_cfg=None
):
super(BasicBlock, self).__init__()
self.conv1 = ConvModule(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=norm_cfg is None,
dilation=dilation,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
self.conv2 = ConvModule(
planes,
planes,
kernel_size=3,
stride=1,
padding=dilation,
bias=norm_cfg is None,
dilation=dilation,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = F.relu_(out)
out = self.conv2(out)
out = out + residual
out = F.relu_(out)
return out
class Bottleneck(nn.Module):
expansion = 2
def __init__(
self, inplanes, planes, stride=1, dilation=1,
conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"),
act_cfg=None
):
super(Bottleneck, self).__init__()
expansion = Bottleneck.expansion
bottle_planes = planes // expansion
self.conv1 = ConvModule(
inplanes,
bottle_planes,
kernel_size=1,
bias=norm_cfg is None,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
self.conv2 = ConvModule(
bottle_planes,
bottle_planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=norm_cfg is None,
dilation=dilation,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
self.conv3 = ConvModule(
bottle_planes,
planes,
kernel_size=1,
bias=norm_cfg is None,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = F.relu_(out)
out = self.conv2(out)
out = F.relu_(out)
out = self.conv3(out)
out = out + residual
out = F.relu_(out)
return out
class Root(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, residual,
conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"),
act_cfg=None
):
super(Root, self).__init__()
self.conv = ConvModule(
in_channels,
out_channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
bias=norm_cfg is None,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
self.residual = residual
def forward(self, *x):
children = x
y = self.conv(torch.cat(x, 1))
if self.residual:
y = y + children[0]
y = F.relu_(y)
return y
class Tree(nn.Module):
def __init__(
self,
levels,
block,
in_channels,
out_channels,
stride=1,
level_root=False,
root_dim=0,
root_kernel_size=1,
dilation=1,
root_residual=False,
conv_cfg=dict(type="Conv2d"),
norm_cfg=dict(type="BN2d"),
act_cfg=None
):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if levels == 1:
self.tree1 = block(in_channels, out_channels, stride, dilation=dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.tree2 = block(out_channels, out_channels, 1, dilation=dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
else:
self.tree1 = Tree(
levels - 1,
block,
in_channels,
out_channels,
stride,
root_dim=0,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
self.tree2 = Tree(
levels - 1,
block,
out_channels,
out_channels,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
if levels == 1:
self.root = Root(root_dim, out_channels, root_kernel_size, root_residual, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
self.levels = levels
if stride > 1:
self.downsample = nn.MaxPool2d(stride, stride=stride)
# (dennis.park) If 'self.tree1' is a Tree (not BasicBlock), then the output of project is not used.
# if in_channels != out_channels:
if in_channels != out_channels and not isinstance(self.tree1, Tree):
self.project = ConvModule(
in_channels, out_channels, kernel_size=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg
)
def forward(self, x, residual=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
# (dennis.park) If 'self.tree1' is a 'Tree', then 'residual' is not used.
residual = self.project(bottom) if self.project is not None else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
if self.levels == 1:
x2 = self.tree2(x1)
y = self.root(x2, x1, *children)
else:
children.append(x1)
y = self.tree2(x1, children=children)
return y
@BACKBONES.register_module()
class DLA(BaseModule):
def __init__(
self,
levels,
channels,
block=BasicBlock,
residual_root=False,
norm_eval=False,
out_features=None,
conv_cfg=dict(type="Conv2d"),
norm_cfg=dict(type="BN2d"),
act_cfg=None
):
super(DLA, self).__init__()
self.channels = channels
self.base_layer = ConvModule(
3,
channels[0],
kernel_size=7,
stride=1,
padding=3,
bias=norm_cfg is None,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=dict(type="ReLU")
)
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0], conv_cfg=conv_cfg,
norm_cfg=norm_cfg, act_cfg=act_cfg
)
self.level1 = self._make_conv_level(
channels[0], channels[1], levels[1], stride=2, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, act_cfg=act_cfg
)
self.level2 = Tree(
levels[2], block, channels[1], channels[2], 2, level_root=False, root_residual=residual_root,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
)
self.level3 = Tree(
levels[3], block, channels[2], channels[3], 2, level_root=True, root_residual=residual_root,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
)
self.level4 = Tree(
levels[4], block, channels[3], channels[4], 2, level_root=True, root_residual=residual_root,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
)
self.level5 = Tree(
levels[5], block, channels[4], channels[5], 2, level_root=True, root_residual=residual_root,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
)
self.norm_eval = norm_eval
if out_features is None:
out_features = ['level5']
self._out_features = out_features
assert len(self._out_features)
children = [x[0] for x in self.named_children()]
for out_feature in self._out_features:
assert out_feature in children, "Available children: {}".format(", ".join(children))
out_feature_channels, out_feature_strides = {}, {}
for lvl in range(6):
name = f"level{lvl}"
out_feature_channels[name] = channels[lvl]
out_feature_strides[name] = 2**lvl
self._out_feature_channels = {name: out_feature_channels[name] for name in self._out_features}
self._out_feature_strides = {name: out_feature_strides[name] for name in self._out_features}
@property
def size_divisibility(self):
return 32
def _make_conv_level(self, inplanes, planes, convs, conv_cfg, norm_cfg, act_cfg=None, stride=1, dilation=1):
modules = []
for i in range(convs):
modules.append(
ConvModule(
inplanes,
planes,
kernel_size=3,
stride=stride if i == 0 else 1,
padding=dilation,
bias=norm_cfg is None,
dilation=dilation,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=dict(type="ReLU")
)
)
inplanes = planes
return nn.Sequential(*modules)
def forward(self, x):
assert x.dim() == 4, f"DLA takes an input of shape (N, C, H, W). Got {x.shape} instead!"
outputs = {}
x = self.base_layer(x)
for i in range(6):
name = f"level{i}"
x = self._modules[name](x)
if name in self._out_features:
outputs[name] = x
return outputs
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(DLA, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()