bev-project/mmdet3d/models/backbones/resnet.py

41 lines
1022 B
Python
Raw Permalink Normal View History

2022-06-03 12:21:18 +08:00
from typing import List, Tuple
import torch
from mmcv.cnn.resnet import BasicBlock, make_res_layer
from torch import nn
from mmdet.models import BACKBONES
__all__ = ["GeneralizedResNet"]
@BACKBONES.register_module()
class GeneralizedResNet(nn.ModuleList):
def __init__(
self,
in_channels: int,
blocks: List[Tuple[int, int, int]],
) -> None:
super().__init__()
self.in_channels = in_channels
self.blocks = blocks
for num_blocks, out_channels, stride in self.blocks:
blocks = make_res_layer(
BasicBlock,
in_channels,
out_channels,
num_blocks,
stride=stride,
dilation=1,
)
in_channels = out_channels
self.append(blocks)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outputs = []
for module in self:
x = module(x)
outputs.append(x)
return outputs