bev-project/mmdet3d/models/vtransforms/aware_bevdepth.py

698 lines
24 KiB
Python
Raw Normal View History

2023-07-08 10:53:36 +08:00
from typing import Tuple
from mmcv.cnn import build_conv_layer
from mmcv.runner import force_fp32
from torch import nn
import torch.nn.functional as F
from torch.cuda.amp.autocast_mode import autocast
from mmdet3d.models.builder import VTRANSFORMS
from mmdet.models.backbones.resnet import BasicBlock
from .base import BaseTransform, BaseDepthTransform
import torch
__all__ = ["AwareBEVDepth"]
class DepthRefinement(nn.Module):
"""
pixel cloud feature extraction
"""
def __init__(self, in_channels, mid_channels, out_channels):
super(DepthRefinement, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.conv = nn.Sequential(
nn.Conv2d(mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.out_conv = nn.Sequential(
nn.Conv2d(mid_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True),
# nn.BatchNorm3d(out_channels),
# nn.ReLU(inplace=True),
)
@autocast(False)
def forward(self, x):
x = self.reduce_conv(x)
x = self.conv(x) + x
x = self.out_conv(x)
return x
class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation,
BatchNorm):
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False)
self.bn = BatchNorm(planes)
self.relu = nn.ReLU()
self._init_weight()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class ASPP(nn.Module):
def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d):
super(ASPP, self).__init__()
dilations = [1, 6, 12, 18]
self.aspp1 = _ASPPModule(inplanes,
mid_channels,
1,
padding=0,
dilation=dilations[0],
BatchNorm=BatchNorm)
self.aspp2 = _ASPPModule(inplanes,
mid_channels,
3,
padding=dilations[1],
dilation=dilations[1],
BatchNorm=BatchNorm)
self.aspp3 = _ASPPModule(inplanes,
mid_channels,
3,
padding=dilations[2],
dilation=dilations[2],
BatchNorm=BatchNorm)
self.aspp4 = _ASPPModule(inplanes,
mid_channels,
3,
padding=dilations[3],
dilation=dilations[3],
BatchNorm=BatchNorm)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(inplanes, mid_channels, 1, stride=1, bias=False),
BatchNorm(mid_channels),
nn.ReLU(),
)
self.conv1 = nn.Conv2d(int(mid_channels * 5),
mid_channels,
1,
bias=False)
self.bn1 = BatchNorm(mid_channels)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self._init_weight()
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5,
size=x4.size()[2:],
mode='bilinear',
align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SELayer(nn.Module):
def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
super().__init__()
self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True)
self.act1 = act_layer()
self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True)
self.gate = gate_layer()
def forward(self, x, x_se):
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
return x * self.gate(x_se)
class DepthNet(nn.Module):
def __init__(self, in_channels, mid_channels, context_channels,
depth_channels):
super(DepthNet, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.context_conv = nn.Conv2d(mid_channels,
context_channels,
kernel_size=1,
stride=1,
padding=0)
self.bn = nn.BatchNorm1d(27)
self.depth_mlp = Mlp(27, mid_channels, mid_channels)
self.depth_se = SELayer(mid_channels) # NOTE: add camera-aware
self.context_mlp = Mlp(27, mid_channels, mid_channels)
self.context_se = SELayer(mid_channels) # NOTE: add camera-aware
self.depth_conv_1 = nn.Sequential(
BasicBlock(mid_channels, mid_channels),
BasicBlock(mid_channels, mid_channels),
BasicBlock(mid_channels, mid_channels),
)
self.depth_conv_2 = nn.Sequential(
ASPP(mid_channels, mid_channels),
build_conv_layer(cfg=dict(
type='Conv2d',
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
padding=1,
)),
nn.BatchNorm2d(mid_channels),
)
self.depth_conv_3 = nn.Sequential(
nn.Conv2d(mid_channels,
depth_channels,
kernel_size=1,
stride=1,
padding=0),
nn.BatchNorm2d(depth_channels),
)
self.export = False
def export_mode(self):
self.export = True
@force_fp32()
def forward(self, x, mats_dict):
intrins = mats_dict['intrin_mats'][:, ..., :3, :3]
batch_size = intrins.shape[0]
num_cams = intrins.shape[1]
ida = mats_dict['ida_mats'][:, ...]
sensor2ego = mats_dict['sensor2ego_mats'][:, ..., :3, :]
bda = mats_dict['bda_mat'].view(batch_size, 1, 4, 4).repeat(1, num_cams, 1, 1)
# If exporting, cache the MLP input, since it's based on
# intrinsics and data augmentation, which are constant at inference time.
if not hasattr(self, 'mlp_input') or not self.export:
mlp_input = torch.cat(
[
torch.stack(
[
intrins[:, ..., 0, 0],
intrins[:, ..., 1, 1],
intrins[:, ..., 0, 2],
intrins[:, ..., 1, 2],
ida[:, ..., 0, 0],
ida[:, ..., 0, 1],
ida[:, ..., 0, 3],
ida[:, ..., 1, 0],
ida[:, ..., 1, 1],
ida[:, ..., 1, 3],
bda[:, ..., 0, 0],
bda[:, ..., 0, 1],
bda[:, ..., 1, 0],
bda[:, ..., 1, 1],
bda[:, ..., 2, 2],
],
dim=-1,
),
sensor2ego.view(batch_size, num_cams, -1),
],
-1,
)
self.mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1]))
x = self.reduce_conv(x)
context_se = self.context_mlp(self.mlp_input)[..., None, None]
context = self.context_se(x, context_se)
context = self.context_conv(context)
depth_se = self.depth_mlp(self.mlp_input)[..., None, None]
depth = self.depth_se(x, depth_se)
depth = self.depth_conv_1(depth)
depth = self.depth_conv_2(depth)
depth = self.depth_conv_3(depth)
return torch.cat([depth, context], dim=1)
@VTRANSFORMS.register_module()
class AwareBEVDepth(BaseTransform):
def __init__(
self,
in_channels: int,
out_channels: int,
image_size: Tuple[int, int],
feature_size: Tuple[int, int],
xbound: Tuple[float, float, float],
ybound: Tuple[float, float, float],
zbound: Tuple[float, float, float],
dbound: Tuple[float, float, float],
use_points = 'lidar',
downsample: int = 1,
bevdepth_downsample: int = 16,
bevdepth_refine: bool = True,
depth_loss_factor: float = 3.0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
image_size=image_size,
feature_size=feature_size,
xbound=xbound,
ybound=ybound,
zbound=zbound,
dbound=dbound,
use_points=use_points,
)
self.depth_loss_factor = depth_loss_factor
self.downsample_factor = bevdepth_downsample
self.bevdepth_refine = bevdepth_refine
if self.bevdepth_refine:
self.refinement = DepthRefinement(self.C, self.C, self.C)
self.depth_channels = self.frustum.shape[0]
mid_channels = in_channels
self.depthnet = DepthNet(
in_channels,
mid_channels,
self.C,
self.D
)
if downsample > 1:
assert downsample == 2, downsample
self.downsample = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(
out_channels,
out_channels,
3,
stride=downsample,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
self.downsample = nn.Identity()
def export_mode(self):
super().export_mode()
self.depthnet.export_mode()
@force_fp32()
def get_cam_feats(self, x, mats_dict):
B, N, C, fH, fW = x.shape
x = x.view(B * N, C, fH, fW)
x = self.depthnet(x, mats_dict)
depth = x[:, : self.D].softmax(dim=1)
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
if self.bevdepth_refine:
x = x.permute(0, 3, 1, 4, 2).contiguous() # [n, c, d, h, w] -> [n, h, c, w, d]
n, h, c, w, d = x.shape
x = x.view(-1, c, w, d)
x = self.refinement(x)
x = x.view(n, h, c, w, d).permute(0, 2, 4, 1, 3).contiguous().float()
x = x.view(B, N, self.C, self.D, fH, fW)
x = x.permute(0, 1, 3, 4, 5, 2)
return x, depth
def get_depth_loss(self, depth_labels, depth_preds):
if len(depth_labels.shape) == 5:
# only key-frame will calculate depth loss
depth_labels = depth_labels[:, 0, ...]
depth_labels = self.get_downsampled_gt_depth(depth_labels)
depth_preds = depth_preds.permute(0, 2, 3, 1).contiguous().view(
-1, self.depth_channels)
fg_mask = torch.max(depth_labels, dim=1).values > 0.0
with autocast(enabled=False):
depth_loss = (F.binary_cross_entropy(
depth_preds[fg_mask],
depth_labels[fg_mask],
reduction='none',
).sum() / max(1.0, fg_mask.sum()))
return self.depth_loss_factor * depth_loss
def get_downsampled_gt_depth(self, gt_depths):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B, N, H, W = gt_depths.shape
gt_depths = gt_depths.view(
B * N,
H // self.downsample_factor,
self.downsample_factor,
W // self.downsample_factor,
self.downsample_factor,
1,
)
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
gt_depths = gt_depths.view(
-1, self.downsample_factor * self.downsample_factor)
gt_depths_tmp = torch.where(gt_depths == 0.0,
1e5 * torch.ones_like(gt_depths),
gt_depths)
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
gt_depths = gt_depths.view(B * N, H // self.downsample_factor,
W // self.downsample_factor)
gt_depths = (gt_depths -
(self.dbound[0] - self.dbound[2])) / self.dbound[2]
gt_depths = torch.where(
(gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0),
gt_depths, torch.zeros_like(gt_depths))
gt_depths = F.one_hot(gt_depths.long(),
num_classes=self.depth_channels + 1).view(
-1, self.depth_channels + 1)[:, 1:]
return gt_depths.float()
def forward(self, *args, **kwargs):
x = super().forward(*args, **kwargs)
x, depth_pred = x[0], x[-1]
x = self.downsample(x)
if kwargs.get('depth_loss', False):
# print(kwargs['gt_depths'])
depth_loss = self.get_depth_loss(kwargs['gt_depths'], depth_pred)
return x, depth_loss
else:
return x
@VTRANSFORMS.register_module()
class AwareDBEVDepth(BaseDepthTransform):
def __init__(
self,
in_channels: int,
out_channels: int,
image_size: Tuple[int, int],
feature_size: Tuple[int, int],
xbound: Tuple[float, float, float],
ybound: Tuple[float, float, float],
zbound: Tuple[float, float, float],
dbound: Tuple[float, float, float],
use_points = 'lidar',
depth_input = 'scalar',
height_expand = False,
downsample: int = 1,
bevdepth_downsample: int = 16,
bevdepth_refine: bool = True,
depth_loss_factor: float = 3.0,
add_depth_features = False,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
image_size=image_size,
feature_size=feature_size,
xbound=xbound,
ybound=ybound,
zbound=zbound,
dbound=dbound,
use_points=use_points,
depth_input=depth_input,
height_expand=height_expand,
add_depth_features=add_depth_features,
)
self.depth_loss_factor = depth_loss_factor
self.downsample_factor = bevdepth_downsample
self.bevdepth_refine = bevdepth_refine
if self.bevdepth_refine:
self.refinement = DepthRefinement(self.C, self.C, self.C)
self.depth_channels = self.frustum.shape[0]
mid_channels = in_channels
self.depthnet = DepthNet(
in_channels+64,
mid_channels,
self.C,
self.D
)
dtransform_in_channels = 1 if depth_input=='scalar' else self.D
if self.add_depth_features:
dtransform_in_channels += 45
if depth_input == 'scalar':
self.dtransform = nn.Sequential(
nn.Conv2d(dtransform_in_channels, 8, 1),
nn.BatchNorm2d(8),
nn.ReLU(True),
nn.Conv2d(8, 32, 5, stride=4, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
)
else:
self.dtransform = nn.Sequential(
nn.Conv2d(dtransform_in_channels, 32, 1),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 32, 5, stride=4, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
)
if downsample > 1:
assert downsample == 2, downsample
self.downsample = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(
out_channels,
out_channels,
3,
stride=downsample,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
self.downsample = nn.Identity()
@force_fp32()
def get_cam_feats(self, x, d, mats_dict):
B, N, C, fH, fW = x.shape
d = d.view(B * N, *d.shape[2:])
x = x.view(B * N, C, fH, fW)
d = self.dtransform(d)
x = torch.cat([d, x], dim=1)
x = self.depthnet(x, mats_dict)
depth = x[:, : self.D].softmax(dim=1)
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
if self.bevdepth_refine:
x = x.permute(0, 3, 1, 4, 2).contiguous() # [n, c, d, h, w] -> [n, h, c, w, d]
n, h, c, w, d = x.shape
x = x.view(-1, c, w, d)
x = self.refinement(x)
x = x.view(n, h, c, w, d).permute(0, 2, 4, 1, 3).contiguous().float()
# Here, x.shape is [num_cams, num_channels, depth_bins, downsampled_height, downsampled_width]
x = x.view(B, N, self.C, self.D, fH, fW)
x = x.permute(0, 1, 3, 4, 5, 2)
return x, depth
def export_mode(self):
super().export_mode()
self.depthnet.export_mode()
def get_depth_loss(self, depth_labels, depth_preds):
# if len(depth_labels.shape) == 5:
# # only key-frame will calculate depth loss
# depth_labels = depth_labels[:, 0, ...]
depth_labels = self.get_downsampled_gt_depth(depth_labels)
depth_preds = depth_preds.permute(0, 2, 3, 1).contiguous().view(
-1, self.depth_channels)
fg_mask = torch.max(depth_labels, dim=1).values > 0.0
with autocast(enabled=False):
depth_loss = (F.binary_cross_entropy(
depth_preds[fg_mask],
depth_labels[fg_mask],
reduction='none',
).sum() / max(1.0, fg_mask.sum()))
return self.depth_loss_factor * depth_loss
def get_downsampled_gt_depth(self, gt_depths):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B, N, H, W = gt_depths.shape
gt_depths = gt_depths.view(
B * N,
H // self.downsample_factor,
self.downsample_factor,
W // self.downsample_factor,
self.downsample_factor,
1,
)
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
gt_depths = gt_depths.view(
-1, self.downsample_factor * self.downsample_factor)
gt_depths_tmp = torch.where(gt_depths == 0.0,
1e5 * torch.ones_like(gt_depths),
gt_depths)
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
gt_depths = gt_depths.view(B * N, H // self.downsample_factor,
W // self.downsample_factor)
gt_depths = (gt_depths -
(self.dbound[0] - self.dbound[2])) / self.dbound[2]
gt_depths = torch.where(
(gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0),
gt_depths, torch.zeros_like(gt_depths))
gt_depths = F.one_hot(gt_depths.long(),
num_classes=self.depth_channels + 1).view(
-1, self.depth_channels + 1)[:, 1:]
return gt_depths.float()
def forward(self, *args, **kwargs):
x = super().forward(*args, **kwargs)
x, depth_pred = x[0], x[-1]
x = self.downsample(x)
if kwargs.get('depth_loss', False):
depth_loss = self.get_depth_loss(kwargs['gt_depths'], depth_pred)
return x, depth_loss
else:
return x