41 lines
1022 B
Python
41 lines
1022 B
Python
from typing import List, Tuple
|
|
|
|
import torch
|
|
from mmcv.cnn.resnet import BasicBlock, make_res_layer
|
|
from torch import nn
|
|
|
|
from mmdet.models import BACKBONES
|
|
|
|
__all__ = ["GeneralizedResNet"]
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class GeneralizedResNet(nn.ModuleList):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
blocks: List[Tuple[int, int, int]],
|
|
) -> None:
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.blocks = blocks
|
|
|
|
for num_blocks, out_channels, stride in self.blocks:
|
|
blocks = make_res_layer(
|
|
BasicBlock,
|
|
in_channels,
|
|
out_channels,
|
|
num_blocks,
|
|
stride=stride,
|
|
dilation=1,
|
|
)
|
|
in_channels = out_channels
|
|
self.append(blocks)
|
|
|
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
outputs = []
|
|
for module in self:
|
|
x = module(x)
|
|
outputs.append(x)
|
|
return outputs
|