79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
from typing import Tuple
|
|
|
|
from mmcv.runner import force_fp32
|
|
from torch import nn
|
|
|
|
from mmdet3d.models.builder import VTRANSFORMS
|
|
|
|
from .base import BaseTransform
|
|
|
|
__all__ = ["LSSTransform"]
|
|
|
|
|
|
@VTRANSFORMS.register_module()
|
|
class LSSTransform(BaseTransform):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
image_size: Tuple[int, int],
|
|
feature_size: Tuple[int, int],
|
|
xbound: Tuple[float, float, float],
|
|
ybound: Tuple[float, float, float],
|
|
zbound: Tuple[float, float, float],
|
|
dbound: Tuple[float, float, float],
|
|
downsample: int = 1,
|
|
) -> None:
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
image_size=image_size,
|
|
feature_size=feature_size,
|
|
xbound=xbound,
|
|
ybound=ybound,
|
|
zbound=zbound,
|
|
dbound=dbound,
|
|
)
|
|
self.depthnet = nn.Conv2d(in_channels, self.D + self.C, 1)
|
|
if downsample > 1:
|
|
assert downsample == 2, downsample
|
|
self.downsample = nn.Sequential(
|
|
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(
|
|
out_channels,
|
|
out_channels,
|
|
3,
|
|
stride=downsample,
|
|
padding=1,
|
|
bias=False,
|
|
),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True),
|
|
)
|
|
else:
|
|
self.downsample = nn.Identity()
|
|
|
|
@force_fp32()
|
|
def get_cam_feats(self, x):
|
|
B, N, C, fH, fW = x.shape
|
|
|
|
x = x.view(B * N, C, fH, fW)
|
|
|
|
x = self.depthnet(x)
|
|
depth = x[:, : self.D].softmax(dim=1)
|
|
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
|
|
|
|
x = x.view(B, N, self.C, self.D, fH, fW)
|
|
x = x.permute(0, 1, 3, 4, 5, 2)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
x = super().forward(*args, **kwargs)
|
|
x = self.downsample(x)
|
|
return x
|