# Copyright (c) Facebook, Inc. and its affiliates. import math import torch import torch.nn.functional as F from torch import nn from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, auto_fp16 from mmdet.models.builder import NECKS from typing import List, Optional, Dict, Any __all__ = ["DetectronFPN"] @NECKS.register_module() class DetectronFPN(BaseModule): """ This module implements :paper:`FPN`. It creates pyramid features built on top of some input feature maps. """ _fuse_type: str def __init__( self, in_indices: List[int], out_indices: List[int], in_channels: List[int], out_channels: int, start_level: int, conv_cfg: Optional[Dict[str, Any]] = dict(type="Conv2d"), norm_cfg: Optional[Dict[str, Any]] = dict(type="BN2d"), act_cfg: Optional[Dict[str, Any]] = None, fuse_type: Optional[str] = "sum" ): """ Args: bottom_up (Backbone): module representing the bottom up subnetwork. Must be a subclass of :class:`Backbone`. The multi-scale feature maps generated by the bottom up network, and listed in `in_features`, are used to generate FPN levels. in_features (list[str]): names of the input feature maps coming from the backbone to which FPN is attached. For example, if the backbone produces ["res2", "res3", "res4"], any *contiguous* sublist of these may be used; order must be from high to low resolution. out_channels (int): number of channels in the output feature maps. norm (str): the normalization to use. top_block (nn.Module or None): if provided, an extra operation will be performed on the output of the last (smallest resolution) FPN output, and the result will extend the result list. The top_block further downsamples the feature map. It must have an attribute "num_levels", meaning the number of extra FPN levels added by this block, and "in_feature", which is a string representing its input feature (e.g., p5). fuse_type (str): types for fusing the top down features and the lateral ones. It can be "sum" (default), which sums up element-wise; or "avg", which takes the element-wise mean of the two. """ super(DetectronFPN, self).__init__() # Feature map strides and channels from the bottom up network (e.g. ResNet) lateral_convs = [] output_convs = [] use_bias = norm_cfg is None stage = start_level for idx, in_channel in enumerate(in_channels): lateral_conv = ConvModule( in_channel, out_channels, kernel_size=1, bias=use_bias, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg ) output_conv = ConvModule( out_channels, out_channels, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg ) self.add_module("fpn_lateral{}".format(stage), lateral_conv) self.add_module("fpn_output{}".format(stage), output_conv) stage += 1 lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] self.in_indices = tuple(in_indices) self.out_indices = tuple(out_indices) assert fuse_type in {"avg", "sum"} self._fuse_type = fuse_type def forward(self, bottom_up_features): """ Args: input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to feature map tensor for each feature level in high to low resolution order. Returns: dict[str->Tensor]: mapping from feature map name to FPN feature map tensor in high to low resolution order. Returned feature names follow the FPN paper convention: "p", where stage has stride = 2 ** stage e.g., ["p2", "p3", ..., "p6"]. """ results = [] prev_features = self.lateral_convs[0](bottom_up_features[self.in_indices[-1]]) results.append(self.output_convs[0](prev_features)) # Reverse feature maps into top-down order (from low to high resolution) for idx, (lateral_conv, output_conv) in enumerate( zip(self.lateral_convs, self.output_convs) ): # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336 # Therefore we loop over all modules but skip the first one if idx > 0: features = self.in_indices[-idx - 1] features = bottom_up_features[features] top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest") lateral_features = lateral_conv(features) prev_features = lateral_features + top_down_features if self._fuse_type == "avg": prev_features /= 2 results.insert(0, output_conv(prev_features)) return [results[x] for x in sorted(self.out_indices)]