bev-project/mmdet3d/models/necks/detectron_fpn.py

138 lines
5.6 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
import math
import torch
import torch.nn.functional as F
from torch import nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from mmdet.models.builder import NECKS
from typing import List, Optional, Dict, Any
__all__ = ["DetectronFPN"]
@NECKS.register_module()
class DetectronFPN(BaseModule):
"""
This module implements :paper:`FPN`.
It creates pyramid features built on top of some input feature maps.
"""
_fuse_type: str
def __init__(
self,
in_indices: List[int],
out_indices: List[int],
in_channels: List[int],
out_channels: int,
start_level: int,
conv_cfg: Optional[Dict[str, Any]] = dict(type="Conv2d"),
norm_cfg: Optional[Dict[str, Any]] = dict(type="BN2d"),
act_cfg: Optional[Dict[str, Any]] = None,
fuse_type: Optional[str] = "sum"
):
"""
Args:
bottom_up (Backbone): module representing the bottom up subnetwork.
Must be a subclass of :class:`Backbone`. The multi-scale feature
maps generated by the bottom up network, and listed in `in_features`,
are used to generate FPN levels.
in_features (list[str]): names of the input feature maps coming
from the backbone to which FPN is attached. For example, if the
backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
of these may be used; order must be from high to low resolution.
out_channels (int): number of channels in the output feature maps.
norm (str): the normalization to use.
top_block (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
FPN output, and the result will extend the result list. The top_block
further downsamples the feature map. It must have an attribute
"num_levels", meaning the number of extra FPN levels added by
this block, and "in_feature", which is a string representing
its input feature (e.g., p5).
fuse_type (str): types for fusing the top down features and the lateral
ones. It can be "sum" (default), which sums up element-wise; or "avg",
which takes the element-wise mean of the two.
"""
super(DetectronFPN, self).__init__()
# Feature map strides and channels from the bottom up network (e.g. ResNet)
lateral_convs = []
output_convs = []
use_bias = norm_cfg is None
stage = start_level
for idx, in_channel in enumerate(in_channels):
lateral_conv = ConvModule(
in_channel,
out_channels,
kernel_size=1,
bias=use_bias,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
output_conv = ConvModule(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg
)
self.add_module("fpn_lateral{}".format(stage), lateral_conv)
self.add_module("fpn_output{}".format(stage), output_conv)
stage += 1
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
self.in_indices = tuple(in_indices)
self.out_indices = tuple(out_indices)
assert fuse_type in {"avg", "sum"}
self._fuse_type = fuse_type
def forward(self, bottom_up_features):
"""
Args:
input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
feature map tensor for each feature level in high to low resolution order.
Returns:
dict[str->Tensor]:
mapping from feature map name to FPN feature map tensor
in high to low resolution order. Returned feature names follow the FPN
paper convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
["p2", "p3", ..., "p6"].
"""
results = []
prev_features = self.lateral_convs[0](bottom_up_features[self.in_indices[-1]])
results.append(self.output_convs[0](prev_features))
# Reverse feature maps into top-down order (from low to high resolution)
for idx, (lateral_conv, output_conv) in enumerate(
zip(self.lateral_convs, self.output_convs)
):
# Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336
# Therefore we loop over all modules but skip the first one
if idx > 0:
features = self.in_indices[-idx - 1]
features = bottom_up_features[features]
top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest")
lateral_features = lateral_conv(features)
prev_features = lateral_features + top_down_features
if self._fuse_type == "avg":
prev_features /= 2
results.insert(0, output_conv(prev_features))
return [results[x] for x in sorted(self.out_indices)]