390 lines
12 KiB
Python
390 lines
12 KiB
Python
from typing import Any, Dict
|
|
|
|
import torch
|
|
from mmcv.runner import auto_fp16, force_fp32
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from mmdet3d.models.builder import (
|
|
build_backbone,
|
|
build_fuser,
|
|
build_head,
|
|
build_neck,
|
|
build_vtransform,
|
|
)
|
|
from mmdet3d.ops import Voxelization, DynamicScatter
|
|
from mmdet3d.models import FUSIONMODELS
|
|
|
|
|
|
from .base import Base3DFusionModel
|
|
|
|
__all__ = ["BEVFusion"]
|
|
|
|
|
|
@FUSIONMODELS.register_module()
|
|
class BEVFusion(Base3DFusionModel):
|
|
def __init__(
|
|
self,
|
|
encoders: Dict[str, Any],
|
|
fuser: Dict[str, Any],
|
|
decoder: Dict[str, Any],
|
|
heads: Dict[str, Any],
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.encoders = nn.ModuleDict()
|
|
if encoders.get("camera") is not None:
|
|
self.encoders["camera"] = nn.ModuleDict(
|
|
{
|
|
"backbone": build_backbone(encoders["camera"]["backbone"]),
|
|
"neck": build_neck(encoders["camera"]["neck"]),
|
|
"vtransform": build_vtransform(encoders["camera"]["vtransform"]),
|
|
}
|
|
)
|
|
if encoders.get("lidar") is not None:
|
|
if encoders["lidar"]["voxelize"].get("max_num_points", -1) > 0:
|
|
voxelize_module = Voxelization(**encoders["lidar"]["voxelize"])
|
|
else:
|
|
voxelize_module = DynamicScatter(**encoders["lidar"]["voxelize"])
|
|
self.encoders["lidar"] = nn.ModuleDict(
|
|
{
|
|
"voxelize": voxelize_module,
|
|
"backbone": build_backbone(encoders["lidar"]["backbone"]),
|
|
}
|
|
)
|
|
self.voxelize_reduce = encoders["lidar"].get("voxelize_reduce", True)
|
|
|
|
if encoders.get("radar") is not None:
|
|
if encoders["radar"]["voxelize"].get("max_num_points", -1) > 0:
|
|
voxelize_module = Voxelization(**encoders["radar"]["voxelize"])
|
|
else:
|
|
voxelize_module = DynamicScatter(**encoders["radar"]["voxelize"])
|
|
self.encoders["radar"] = nn.ModuleDict(
|
|
{
|
|
"voxelize": voxelize_module,
|
|
"backbone": build_backbone(encoders["radar"]["backbone"]),
|
|
}
|
|
)
|
|
self.voxelize_reduce = encoders["radar"].get("voxelize_reduce", True)
|
|
|
|
if fuser is not None:
|
|
self.fuser = build_fuser(fuser)
|
|
else:
|
|
self.fuser = None
|
|
|
|
self.decoder = nn.ModuleDict(
|
|
{
|
|
"backbone": build_backbone(decoder["backbone"]),
|
|
"neck": build_neck(decoder["neck"]),
|
|
}
|
|
)
|
|
self.heads = nn.ModuleDict()
|
|
for name in heads:
|
|
if heads[name] is not None:
|
|
self.heads[name] = build_head(heads[name])
|
|
|
|
if "loss_scale" in kwargs:
|
|
self.loss_scale = kwargs["loss_scale"]
|
|
else:
|
|
self.loss_scale = dict()
|
|
for name in heads:
|
|
if heads[name] is not None:
|
|
self.loss_scale[name] = 1.0
|
|
|
|
# If the camera's vtransform is a BEVDepth version, then we're using depth loss.
|
|
self.use_depth_loss = ((encoders.get('camera', {}) or {}).get('vtransform', {}) or {}).get('type', '') in ['BEVDepth', 'AwareBEVDepth', 'DBEVDepth', 'AwareDBEVDepth']
|
|
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self) -> None:
|
|
if "camera" in self.encoders:
|
|
self.encoders["camera"]["backbone"].init_weights()
|
|
|
|
def extract_camera_features(
|
|
self,
|
|
x,
|
|
points,
|
|
radar_points,
|
|
camera2ego,
|
|
lidar2ego,
|
|
lidar2camera,
|
|
lidar2image,
|
|
camera_intrinsics,
|
|
camera2lidar,
|
|
img_aug_matrix,
|
|
lidar_aug_matrix,
|
|
img_metas,
|
|
gt_depths=None,
|
|
) -> torch.Tensor:
|
|
B, N, C, H, W = x.size()
|
|
x = x.view(B * N, C, H, W)
|
|
|
|
x = self.encoders["camera"]["backbone"](x)
|
|
x = self.encoders["camera"]["neck"](x)
|
|
|
|
if not isinstance(x, torch.Tensor):
|
|
x = x[0]
|
|
|
|
BN, C, H, W = x.size()
|
|
x = x.view(B, int(BN / B), C, H, W)
|
|
|
|
x = self.encoders["camera"]["vtransform"](
|
|
x,
|
|
points,
|
|
radar_points,
|
|
camera2ego,
|
|
lidar2ego,
|
|
lidar2camera,
|
|
lidar2image,
|
|
camera_intrinsics,
|
|
camera2lidar,
|
|
img_aug_matrix,
|
|
lidar_aug_matrix,
|
|
img_metas,
|
|
depth_loss=self.use_depth_loss,
|
|
gt_depths=gt_depths,
|
|
)
|
|
return x
|
|
|
|
def extract_features(self, x, sensor) -> torch.Tensor:
|
|
feats, coords, sizes = self.voxelize(x, sensor)
|
|
batch_size = coords[-1, 0] + 1
|
|
x = self.encoders[sensor]["backbone"](feats, coords, batch_size, sizes=sizes)
|
|
return x
|
|
|
|
# def extract_lidar_features(self, x) -> torch.Tensor:
|
|
# feats, coords, sizes = self.voxelize(x)
|
|
# batch_size = coords[-1, 0] + 1
|
|
# x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes)
|
|
# return x
|
|
|
|
# def extract_radar_features(self, x) -> torch.Tensor:
|
|
# feats, coords, sizes = self.radar_voxelize(x)
|
|
# batch_size = coords[-1, 0] + 1
|
|
# x = self.encoders["radar"]["backbone"](feats, coords, batch_size, sizes=sizes)
|
|
# return x
|
|
|
|
@torch.no_grad()
|
|
@force_fp32()
|
|
def voxelize(self, points, sensor):
|
|
feats, coords, sizes = [], [], []
|
|
for k, res in enumerate(points):
|
|
ret = self.encoders[sensor]["voxelize"](res)
|
|
if len(ret) == 3:
|
|
# hard voxelize
|
|
f, c, n = ret
|
|
else:
|
|
assert len(ret) == 2
|
|
f, c = ret
|
|
n = None
|
|
feats.append(f)
|
|
coords.append(F.pad(c, (1, 0), mode="constant", value=k))
|
|
if n is not None:
|
|
sizes.append(n)
|
|
|
|
feats = torch.cat(feats, dim=0)
|
|
coords = torch.cat(coords, dim=0)
|
|
if len(sizes) > 0:
|
|
sizes = torch.cat(sizes, dim=0)
|
|
if self.voxelize_reduce:
|
|
feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
|
|
-1, 1
|
|
)
|
|
feats = feats.contiguous()
|
|
|
|
return feats, coords, sizes
|
|
|
|
# @torch.no_grad()
|
|
# @force_fp32()
|
|
# def radar_voxelize(self, points):
|
|
# feats, coords, sizes = [], [], []
|
|
# for k, res in enumerate(points):
|
|
# ret = self.encoders["radar"]["voxelize"](res)
|
|
# if len(ret) == 3:
|
|
# # hard voxelize
|
|
# f, c, n = ret
|
|
# else:
|
|
# assert len(ret) == 2
|
|
# f, c = ret
|
|
# n = None
|
|
# feats.append(f)
|
|
# coords.append(F.pad(c, (1, 0), mode="constant", value=k))
|
|
# if n is not None:
|
|
# sizes.append(n)
|
|
|
|
# feats = torch.cat(feats, dim=0)
|
|
# coords = torch.cat(coords, dim=0)
|
|
# if len(sizes) > 0:
|
|
# sizes = torch.cat(sizes, dim=0)
|
|
# if self.voxelize_reduce:
|
|
# feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
|
|
# -1, 1
|
|
# )
|
|
# feats = feats.contiguous()
|
|
|
|
# return feats, coords, sizes
|
|
|
|
@auto_fp16(apply_to=("img", "points"))
|
|
def forward(
|
|
self,
|
|
img,
|
|
points,
|
|
camera2ego,
|
|
lidar2ego,
|
|
lidar2camera,
|
|
lidar2image,
|
|
camera_intrinsics,
|
|
camera2lidar,
|
|
img_aug_matrix,
|
|
lidar_aug_matrix,
|
|
metas,
|
|
depths,
|
|
radar=None,
|
|
gt_masks_bev=None,
|
|
gt_bboxes_3d=None,
|
|
gt_labels_3d=None,
|
|
**kwargs,
|
|
):
|
|
if isinstance(img, list):
|
|
raise NotImplementedError
|
|
else:
|
|
outputs = self.forward_single(
|
|
img,
|
|
points,
|
|
camera2ego,
|
|
lidar2ego,
|
|
lidar2camera,
|
|
lidar2image,
|
|
camera_intrinsics,
|
|
camera2lidar,
|
|
img_aug_matrix,
|
|
lidar_aug_matrix,
|
|
metas,
|
|
depths,
|
|
radar,
|
|
gt_masks_bev,
|
|
gt_bboxes_3d,
|
|
gt_labels_3d,
|
|
**kwargs,
|
|
)
|
|
return outputs
|
|
|
|
@auto_fp16(apply_to=("img", "points"))
|
|
def forward_single(
|
|
self,
|
|
img,
|
|
points,
|
|
camera2ego,
|
|
lidar2ego,
|
|
lidar2camera,
|
|
lidar2image,
|
|
camera_intrinsics,
|
|
camera2lidar,
|
|
img_aug_matrix,
|
|
lidar_aug_matrix,
|
|
metas,
|
|
depths=None,
|
|
radar=None,
|
|
gt_masks_bev=None,
|
|
gt_bboxes_3d=None,
|
|
gt_labels_3d=None,
|
|
**kwargs,
|
|
):
|
|
features = []
|
|
auxiliary_losses = {}
|
|
for sensor in (
|
|
self.encoders if self.training else list(self.encoders.keys())[::-1]
|
|
):
|
|
if sensor == "camera":
|
|
feature = self.extract_camera_features(
|
|
img,
|
|
points,
|
|
radar,
|
|
camera2ego,
|
|
lidar2ego,
|
|
lidar2camera,
|
|
lidar2image,
|
|
camera_intrinsics,
|
|
camera2lidar,
|
|
img_aug_matrix,
|
|
lidar_aug_matrix,
|
|
metas,
|
|
gt_depths=depths,
|
|
)
|
|
if self.use_depth_loss:
|
|
feature, auxiliary_losses['depth'] = feature[0], feature[-1]
|
|
elif sensor == "lidar":
|
|
feature = self.extract_features(points, sensor)
|
|
elif sensor == "radar":
|
|
feature = self.extract_features(radar, sensor)
|
|
else:
|
|
raise ValueError(f"unsupported sensor: {sensor}")
|
|
|
|
features.append(feature)
|
|
|
|
if not self.training:
|
|
# avoid OOM
|
|
features = features[::-1]
|
|
|
|
if self.fuser is not None:
|
|
x = self.fuser(features)
|
|
else:
|
|
assert len(features) == 1, features
|
|
x = features[0]
|
|
|
|
batch_size = x.shape[0]
|
|
|
|
x = self.decoder["backbone"](x)
|
|
x = self.decoder["neck"](x)
|
|
|
|
if self.training:
|
|
outputs = {}
|
|
for type, head in self.heads.items():
|
|
if type == "object":
|
|
pred_dict = head(x, metas)
|
|
losses = head.loss(gt_bboxes_3d, gt_labels_3d, pred_dict)
|
|
elif type == "map":
|
|
losses = head(x, gt_masks_bev)
|
|
else:
|
|
raise ValueError(f"unsupported head: {type}")
|
|
for name, val in losses.items():
|
|
if val.requires_grad:
|
|
outputs[f"loss/{type}/{name}"] = val * self.loss_scale[type]
|
|
else:
|
|
outputs[f"stats/{type}/{name}"] = val
|
|
if self.use_depth_loss:
|
|
if 'depth' in auxiliary_losses:
|
|
outputs["loss/depth"] = auxiliary_losses['depth']
|
|
else:
|
|
raise ValueError('Use depth loss is true, but depth loss not found')
|
|
return outputs
|
|
else:
|
|
outputs = [{} for _ in range(batch_size)]
|
|
for type, head in self.heads.items():
|
|
if type == "object":
|
|
pred_dict = head(x, metas)
|
|
bboxes = head.get_bboxes(pred_dict, metas)
|
|
for k, (boxes, scores, labels) in enumerate(bboxes):
|
|
outputs[k].update(
|
|
{
|
|
"boxes_3d": boxes.to("cpu"),
|
|
"scores_3d": scores.cpu(),
|
|
"labels_3d": labels.cpu(),
|
|
}
|
|
)
|
|
elif type == "map":
|
|
logits = head(x)
|
|
for k in range(batch_size):
|
|
outputs[k].update(
|
|
{
|
|
"masks_bev": logits[k].cpu(),
|
|
"gt_masks_bev": gt_masks_bev[k].cpu(),
|
|
}
|
|
)
|
|
else:
|
|
raise ValueError(f"unsupported head: {type}")
|
|
return outputs
|
|
|