138 lines
5.6 KiB
Python
138 lines
5.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from mmcv.cnn import ConvModule
|
|
from mmcv.runner import BaseModule, auto_fp16
|
|
from mmdet.models.builder import NECKS
|
|
|
|
from typing import List, Optional, Dict, Any
|
|
|
|
|
|
__all__ = ["DetectronFPN"]
|
|
|
|
|
|
@NECKS.register_module()
|
|
class DetectronFPN(BaseModule):
|
|
"""
|
|
This module implements :paper:`FPN`.
|
|
It creates pyramid features built on top of some input feature maps.
|
|
"""
|
|
|
|
_fuse_type: str
|
|
|
|
def __init__(
|
|
self,
|
|
in_indices: List[int],
|
|
out_indices: List[int],
|
|
in_channels: List[int],
|
|
out_channels: int,
|
|
start_level: int,
|
|
conv_cfg: Optional[Dict[str, Any]] = dict(type="Conv2d"),
|
|
norm_cfg: Optional[Dict[str, Any]] = dict(type="BN2d"),
|
|
act_cfg: Optional[Dict[str, Any]] = None,
|
|
fuse_type: Optional[str] = "sum"
|
|
):
|
|
"""
|
|
Args:
|
|
bottom_up (Backbone): module representing the bottom up subnetwork.
|
|
Must be a subclass of :class:`Backbone`. The multi-scale feature
|
|
maps generated by the bottom up network, and listed in `in_features`,
|
|
are used to generate FPN levels.
|
|
in_features (list[str]): names of the input feature maps coming
|
|
from the backbone to which FPN is attached. For example, if the
|
|
backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
|
|
of these may be used; order must be from high to low resolution.
|
|
out_channels (int): number of channels in the output feature maps.
|
|
norm (str): the normalization to use.
|
|
top_block (nn.Module or None): if provided, an extra operation will
|
|
be performed on the output of the last (smallest resolution)
|
|
FPN output, and the result will extend the result list. The top_block
|
|
further downsamples the feature map. It must have an attribute
|
|
"num_levels", meaning the number of extra FPN levels added by
|
|
this block, and "in_feature", which is a string representing
|
|
its input feature (e.g., p5).
|
|
fuse_type (str): types for fusing the top down features and the lateral
|
|
ones. It can be "sum" (default), which sums up element-wise; or "avg",
|
|
which takes the element-wise mean of the two.
|
|
"""
|
|
super(DetectronFPN, self).__init__()
|
|
|
|
# Feature map strides and channels from the bottom up network (e.g. ResNet)
|
|
lateral_convs = []
|
|
output_convs = []
|
|
|
|
use_bias = norm_cfg is None
|
|
stage = start_level
|
|
for idx, in_channel in enumerate(in_channels):
|
|
lateral_conv = ConvModule(
|
|
in_channel,
|
|
out_channels,
|
|
kernel_size=1,
|
|
bias=use_bias,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg
|
|
)
|
|
output_conv = ConvModule(
|
|
out_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg
|
|
)
|
|
self.add_module("fpn_lateral{}".format(stage), lateral_conv)
|
|
self.add_module("fpn_output{}".format(stage), output_conv)
|
|
stage += 1
|
|
|
|
lateral_convs.append(lateral_conv)
|
|
output_convs.append(output_conv)
|
|
# Place convs into top-down order (from low to high resolution)
|
|
# to make the top-down computation in forward clearer.
|
|
self.lateral_convs = lateral_convs[::-1]
|
|
self.output_convs = output_convs[::-1]
|
|
self.in_indices = tuple(in_indices)
|
|
self.out_indices = tuple(out_indices)
|
|
assert fuse_type in {"avg", "sum"}
|
|
self._fuse_type = fuse_type
|
|
|
|
|
|
def forward(self, bottom_up_features):
|
|
"""
|
|
Args:
|
|
input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
|
|
feature map tensor for each feature level in high to low resolution order.
|
|
Returns:
|
|
dict[str->Tensor]:
|
|
mapping from feature map name to FPN feature map tensor
|
|
in high to low resolution order. Returned feature names follow the FPN
|
|
paper convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
|
|
["p2", "p3", ..., "p6"].
|
|
"""
|
|
|
|
results = []
|
|
prev_features = self.lateral_convs[0](bottom_up_features[self.in_indices[-1]])
|
|
results.append(self.output_convs[0](prev_features))
|
|
|
|
# Reverse feature maps into top-down order (from low to high resolution)
|
|
for idx, (lateral_conv, output_conv) in enumerate(
|
|
zip(self.lateral_convs, self.output_convs)
|
|
):
|
|
# Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336
|
|
# Therefore we loop over all modules but skip the first one
|
|
if idx > 0:
|
|
features = self.in_indices[-idx - 1]
|
|
features = bottom_up_features[features]
|
|
top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest")
|
|
lateral_features = lateral_conv(features)
|
|
prev_features = lateral_features + top_down_features
|
|
if self._fuse_type == "avg":
|
|
prev_features /= 2
|
|
results.insert(0, output_conv(prev_features))
|
|
return [results[x] for x in sorted(self.out_indices)]
|