bev-project/mmdet3d/models/vtransforms/lss.py

79 lines
2.3 KiB
Python
Raw Permalink Normal View History

2022-06-03 12:21:18 +08:00
from typing import Tuple
from mmcv.runner import force_fp32
from torch import nn
from mmdet3d.models.builder import VTRANSFORMS
from .base import BaseTransform
__all__ = ["LSSTransform"]
@VTRANSFORMS.register_module()
class LSSTransform(BaseTransform):
def __init__(
self,
in_channels: int,
out_channels: int,
image_size: Tuple[int, int],
feature_size: Tuple[int, int],
xbound: Tuple[float, float, float],
ybound: Tuple[float, float, float],
zbound: Tuple[float, float, float],
dbound: Tuple[float, float, float],
downsample: int = 1,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
image_size=image_size,
feature_size=feature_size,
xbound=xbound,
ybound=ybound,
zbound=zbound,
dbound=dbound,
)
self.depthnet = nn.Conv2d(in_channels, self.D + self.C, 1)
if downsample > 1:
assert downsample == 2, downsample
self.downsample = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(
out_channels,
out_channels,
3,
stride=downsample,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
self.downsample = nn.Identity()
@force_fp32()
def get_cam_feats(self, x):
B, N, C, fH, fW = x.shape
x = x.view(B * N, C, fH, fW)
x = self.depthnet(x)
depth = x[:, : self.D].softmax(dim=1)
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
x = x.view(B, N, self.C, self.D, fH, fW)
x = x.permute(0, 1, 3, 4, 5, 2)
return x
def forward(self, *args, **kwargs):
x = super().forward(*args, **kwargs)
x = self.downsample(x)
return x