355 lines
11 KiB
Python
355 lines
11 KiB
Python
|
|
# Copyright 2021 Toyota Research Institute. All rights reserved.
|
||
|
|
# Adapted from:
|
||
|
|
# https://github.com/ucbdrive/dla/blob/master/dla.py
|
||
|
|
from collections import OrderedDict
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from torch import nn
|
||
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||
|
|
from mmdet.models.builder import BACKBONES
|
||
|
|
from mmcv.runner import BaseModule
|
||
|
|
from mmcv.cnn import ConvModule
|
||
|
|
|
||
|
|
__all__ = ["DLA"]
|
||
|
|
|
||
|
|
class BasicBlock(nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self, inplanes, planes, stride=1, dilation=1,
|
||
|
|
conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"),
|
||
|
|
act_cfg=None
|
||
|
|
):
|
||
|
|
super(BasicBlock, self).__init__()
|
||
|
|
self.conv1 = ConvModule(
|
||
|
|
inplanes,
|
||
|
|
planes,
|
||
|
|
kernel_size=3,
|
||
|
|
stride=stride,
|
||
|
|
padding=dilation,
|
||
|
|
bias=norm_cfg is None,
|
||
|
|
dilation=dilation,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
|
||
|
|
self.conv2 = ConvModule(
|
||
|
|
planes,
|
||
|
|
planes,
|
||
|
|
kernel_size=3,
|
||
|
|
stride=1,
|
||
|
|
padding=dilation,
|
||
|
|
bias=norm_cfg is None,
|
||
|
|
dilation=dilation,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.stride = stride
|
||
|
|
|
||
|
|
def forward(self, x, residual=None):
|
||
|
|
if residual is None:
|
||
|
|
residual = x
|
||
|
|
|
||
|
|
out = self.conv1(x)
|
||
|
|
out = F.relu_(out)
|
||
|
|
|
||
|
|
out = self.conv2(out)
|
||
|
|
|
||
|
|
out = out + residual
|
||
|
|
out = F.relu_(out)
|
||
|
|
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
class Bottleneck(nn.Module):
|
||
|
|
expansion = 2
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self, inplanes, planes, stride=1, dilation=1,
|
||
|
|
conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"),
|
||
|
|
act_cfg=None
|
||
|
|
):
|
||
|
|
super(Bottleneck, self).__init__()
|
||
|
|
expansion = Bottleneck.expansion
|
||
|
|
bottle_planes = planes // expansion
|
||
|
|
self.conv1 = ConvModule(
|
||
|
|
inplanes,
|
||
|
|
bottle_planes,
|
||
|
|
kernel_size=1,
|
||
|
|
bias=norm_cfg is None,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.conv2 = ConvModule(
|
||
|
|
bottle_planes,
|
||
|
|
bottle_planes,
|
||
|
|
kernel_size=3,
|
||
|
|
stride=stride,
|
||
|
|
padding=dilation,
|
||
|
|
bias=norm_cfg is None,
|
||
|
|
dilation=dilation,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.conv3 = ConvModule(
|
||
|
|
bottle_planes,
|
||
|
|
planes,
|
||
|
|
kernel_size=1,
|
||
|
|
bias=norm_cfg is None,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.stride = stride
|
||
|
|
|
||
|
|
def forward(self, x, residual=None):
|
||
|
|
if residual is None:
|
||
|
|
residual = x
|
||
|
|
|
||
|
|
out = self.conv1(x)
|
||
|
|
out = F.relu_(out)
|
||
|
|
|
||
|
|
out = self.conv2(out)
|
||
|
|
out = F.relu_(out)
|
||
|
|
|
||
|
|
out = self.conv3(out)
|
||
|
|
|
||
|
|
out = out + residual
|
||
|
|
out = F.relu_(out)
|
||
|
|
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
class Root(nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self, in_channels, out_channels, kernel_size, residual,
|
||
|
|
conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"),
|
||
|
|
act_cfg=None
|
||
|
|
):
|
||
|
|
super(Root, self).__init__()
|
||
|
|
self.conv = ConvModule(
|
||
|
|
in_channels,
|
||
|
|
out_channels,
|
||
|
|
kernel_size,
|
||
|
|
stride=1,
|
||
|
|
padding=(kernel_size - 1) // 2,
|
||
|
|
bias=norm_cfg is None,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.residual = residual
|
||
|
|
|
||
|
|
def forward(self, *x):
|
||
|
|
children = x
|
||
|
|
y = self.conv(torch.cat(x, 1))
|
||
|
|
if self.residual:
|
||
|
|
y = y + children[0]
|
||
|
|
y = F.relu_(y)
|
||
|
|
|
||
|
|
return y
|
||
|
|
|
||
|
|
|
||
|
|
class Tree(nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
levels,
|
||
|
|
block,
|
||
|
|
in_channels,
|
||
|
|
out_channels,
|
||
|
|
stride=1,
|
||
|
|
level_root=False,
|
||
|
|
root_dim=0,
|
||
|
|
root_kernel_size=1,
|
||
|
|
dilation=1,
|
||
|
|
root_residual=False,
|
||
|
|
conv_cfg=dict(type="Conv2d"),
|
||
|
|
norm_cfg=dict(type="BN2d"),
|
||
|
|
act_cfg=None
|
||
|
|
):
|
||
|
|
super(Tree, self).__init__()
|
||
|
|
if root_dim == 0:
|
||
|
|
root_dim = 2 * out_channels
|
||
|
|
if level_root:
|
||
|
|
root_dim += in_channels
|
||
|
|
if levels == 1:
|
||
|
|
self.tree1 = block(in_channels, out_channels, stride, dilation=dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||
|
|
self.tree2 = block(out_channels, out_channels, 1, dilation=dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||
|
|
else:
|
||
|
|
self.tree1 = Tree(
|
||
|
|
levels - 1,
|
||
|
|
block,
|
||
|
|
in_channels,
|
||
|
|
out_channels,
|
||
|
|
stride,
|
||
|
|
root_dim=0,
|
||
|
|
root_kernel_size=root_kernel_size,
|
||
|
|
dilation=dilation,
|
||
|
|
root_residual=root_residual,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.tree2 = Tree(
|
||
|
|
levels - 1,
|
||
|
|
block,
|
||
|
|
out_channels,
|
||
|
|
out_channels,
|
||
|
|
root_dim=root_dim + out_channels,
|
||
|
|
root_kernel_size=root_kernel_size,
|
||
|
|
dilation=dilation,
|
||
|
|
root_residual=root_residual,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
if levels == 1:
|
||
|
|
self.root = Root(root_dim, out_channels, root_kernel_size, root_residual, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||
|
|
self.level_root = level_root
|
||
|
|
self.root_dim = root_dim
|
||
|
|
self.downsample = None
|
||
|
|
self.project = None
|
||
|
|
self.levels = levels
|
||
|
|
if stride > 1:
|
||
|
|
self.downsample = nn.MaxPool2d(stride, stride=stride)
|
||
|
|
# (dennis.park) If 'self.tree1' is a Tree (not BasicBlock), then the output of project is not used.
|
||
|
|
# if in_channels != out_channels:
|
||
|
|
if in_channels != out_channels and not isinstance(self.tree1, Tree):
|
||
|
|
self.project = ConvModule(
|
||
|
|
in_channels, out_channels, kernel_size=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg,
|
||
|
|
act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(self, x, residual=None, children=None):
|
||
|
|
children = [] if children is None else children
|
||
|
|
bottom = self.downsample(x) if self.downsample else x
|
||
|
|
# (dennis.park) If 'self.tree1' is a 'Tree', then 'residual' is not used.
|
||
|
|
residual = self.project(bottom) if self.project is not None else bottom
|
||
|
|
if self.level_root:
|
||
|
|
children.append(bottom)
|
||
|
|
x1 = self.tree1(x, residual)
|
||
|
|
if self.levels == 1:
|
||
|
|
x2 = self.tree2(x1)
|
||
|
|
y = self.root(x2, x1, *children)
|
||
|
|
else:
|
||
|
|
children.append(x1)
|
||
|
|
y = self.tree2(x1, children=children)
|
||
|
|
return y
|
||
|
|
|
||
|
|
@BACKBONES.register_module()
|
||
|
|
class DLA(BaseModule):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
levels,
|
||
|
|
channels,
|
||
|
|
block=BasicBlock,
|
||
|
|
residual_root=False,
|
||
|
|
norm_eval=False,
|
||
|
|
out_features=None,
|
||
|
|
conv_cfg=dict(type="Conv2d"),
|
||
|
|
norm_cfg=dict(type="BN2d"),
|
||
|
|
act_cfg=None
|
||
|
|
):
|
||
|
|
super(DLA, self).__init__()
|
||
|
|
self.channels = channels
|
||
|
|
self.base_layer = ConvModule(
|
||
|
|
3,
|
||
|
|
channels[0],
|
||
|
|
kernel_size=7,
|
||
|
|
stride=1,
|
||
|
|
padding=3,
|
||
|
|
bias=norm_cfg is None,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=dict(type="ReLU")
|
||
|
|
)
|
||
|
|
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0], conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg, act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.level1 = self._make_conv_level(
|
||
|
|
channels[0], channels[1], levels[1], stride=2, conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg, act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.level2 = Tree(
|
||
|
|
levels[2], block, channels[1], channels[2], 2, level_root=False, root_residual=residual_root,
|
||
|
|
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.level3 = Tree(
|
||
|
|
levels[3], block, channels[2], channels[3], 2, level_root=True, root_residual=residual_root,
|
||
|
|
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.level4 = Tree(
|
||
|
|
levels[4], block, channels[3], channels[4], 2, level_root=True, root_residual=residual_root,
|
||
|
|
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.level5 = Tree(
|
||
|
|
levels[5], block, channels[4], channels[5], 2, level_root=True, root_residual=residual_root,
|
||
|
|
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
|
||
|
|
)
|
||
|
|
self.norm_eval = norm_eval
|
||
|
|
|
||
|
|
if out_features is None:
|
||
|
|
out_features = ['level5']
|
||
|
|
self._out_features = out_features
|
||
|
|
assert len(self._out_features)
|
||
|
|
children = [x[0] for x in self.named_children()]
|
||
|
|
for out_feature in self._out_features:
|
||
|
|
assert out_feature in children, "Available children: {}".format(", ".join(children))
|
||
|
|
|
||
|
|
out_feature_channels, out_feature_strides = {}, {}
|
||
|
|
for lvl in range(6):
|
||
|
|
name = f"level{lvl}"
|
||
|
|
out_feature_channels[name] = channels[lvl]
|
||
|
|
out_feature_strides[name] = 2**lvl
|
||
|
|
|
||
|
|
self._out_feature_channels = {name: out_feature_channels[name] for name in self._out_features}
|
||
|
|
self._out_feature_strides = {name: out_feature_strides[name] for name in self._out_features}
|
||
|
|
|
||
|
|
@property
|
||
|
|
def size_divisibility(self):
|
||
|
|
return 32
|
||
|
|
|
||
|
|
def _make_conv_level(self, inplanes, planes, convs, conv_cfg, norm_cfg, act_cfg=None, stride=1, dilation=1):
|
||
|
|
modules = []
|
||
|
|
for i in range(convs):
|
||
|
|
modules.append(
|
||
|
|
ConvModule(
|
||
|
|
inplanes,
|
||
|
|
planes,
|
||
|
|
kernel_size=3,
|
||
|
|
stride=stride if i == 0 else 1,
|
||
|
|
padding=dilation,
|
||
|
|
bias=norm_cfg is None,
|
||
|
|
dilation=dilation,
|
||
|
|
conv_cfg=conv_cfg,
|
||
|
|
norm_cfg=norm_cfg,
|
||
|
|
act_cfg=dict(type="ReLU")
|
||
|
|
)
|
||
|
|
)
|
||
|
|
inplanes = planes
|
||
|
|
return nn.Sequential(*modules)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
assert x.dim() == 4, f"DLA takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
||
|
|
outputs = {}
|
||
|
|
x = self.base_layer(x)
|
||
|
|
for i in range(6):
|
||
|
|
name = f"level{i}"
|
||
|
|
x = self._modules[name](x)
|
||
|
|
if name in self._out_features:
|
||
|
|
outputs[name] = x
|
||
|
|
return outputs
|
||
|
|
|
||
|
|
def train(self, mode=True):
|
||
|
|
"""Convert the model into training mode while keep normalization layer
|
||
|
|
freezed."""
|
||
|
|
super(DLA, self).train(mode)
|
||
|
|
if mode and self.norm_eval:
|
||
|
|
for m in self.modules():
|
||
|
|
# trick: eval have effect on BatchNorm only
|
||
|
|
if isinstance(m, _BatchNorm):
|
||
|
|
m.eval()
|