from typing import Any, Dict import torch from mmcv.runner import auto_fp16, force_fp32 from torch import nn from torch.nn import functional as F from mmdet3d.models.builder import ( build_backbone, build_fuser, build_head, build_neck, build_vtransform, ) from mmdet3d.ops import Voxelization, DynamicScatter from mmdet3d.models import FUSIONMODELS from .base import Base3DFusionModel __all__ = ["BEVFusion"] @FUSIONMODELS.register_module() class BEVFusion(Base3DFusionModel): def __init__( self, encoders: Dict[str, Any], fuser: Dict[str, Any], decoder: Dict[str, Any], heads: Dict[str, Any], **kwargs, ) -> None: super().__init__() self.encoders = nn.ModuleDict() if encoders.get("camera") is not None: self.encoders["camera"] = nn.ModuleDict( { "backbone": build_backbone(encoders["camera"]["backbone"]), "neck": build_neck(encoders["camera"]["neck"]), "vtransform": build_vtransform(encoders["camera"]["vtransform"]), } ) if encoders.get("lidar") is not None: if encoders["lidar"]["voxelize"].get("max_num_points", -1) > 0: voxelize_module = Voxelization(**encoders["lidar"]["voxelize"]) else: voxelize_module = DynamicScatter(**encoders["lidar"]["voxelize"]) self.encoders["lidar"] = nn.ModuleDict( { "voxelize": voxelize_module, "backbone": build_backbone(encoders["lidar"]["backbone"]), } ) self.voxelize_reduce = encoders["lidar"].get("voxelize_reduce", True) if encoders.get("radar") is not None: if encoders["radar"]["voxelize"].get("max_num_points", -1) > 0: voxelize_module = Voxelization(**encoders["radar"]["voxelize"]) else: voxelize_module = DynamicScatter(**encoders["radar"]["voxelize"]) self.encoders["radar"] = nn.ModuleDict( { "voxelize": voxelize_module, "backbone": build_backbone(encoders["radar"]["backbone"]), } ) self.voxelize_reduce = encoders["radar"].get("voxelize_reduce", True) if fuser is not None: self.fuser = build_fuser(fuser) else: self.fuser = None self.decoder = nn.ModuleDict( { "backbone": build_backbone(decoder["backbone"]), "neck": build_neck(decoder["neck"]), } ) self.heads = nn.ModuleDict() for name in heads: if heads[name] is not None: self.heads[name] = build_head(heads[name]) if "loss_scale" in kwargs: self.loss_scale = kwargs["loss_scale"] else: self.loss_scale = dict() for name in heads: if heads[name] is not None: self.loss_scale[name] = 1.0 # If the camera's vtransform is a BEVDepth version, then we're using depth loss. self.use_depth_loss = ((encoders.get('camera', {}) or {}).get('vtransform', {}) or {}).get('type', '') in ['BEVDepth', 'AwareBEVDepth', 'DBEVDepth', 'AwareDBEVDepth'] self.init_weights() def init_weights(self) -> None: if "camera" in self.encoders: self.encoders["camera"]["backbone"].init_weights() def extract_camera_features( self, x, points, radar_points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, img_metas, gt_depths=None, ) -> torch.Tensor: B, N, C, H, W = x.size() x = x.view(B * N, C, H, W) x = self.encoders["camera"]["backbone"](x) x = self.encoders["camera"]["neck"](x) if not isinstance(x, torch.Tensor): x = x[0] BN, C, H, W = x.size() x = x.view(B, int(BN / B), C, H, W) x = self.encoders["camera"]["vtransform"]( x, points, radar_points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, img_metas, depth_loss=self.use_depth_loss, gt_depths=gt_depths, ) return x def extract_features(self, x, sensor) -> torch.Tensor: feats, coords, sizes = self.voxelize(x, sensor) batch_size = coords[-1, 0] + 1 x = self.encoders[sensor]["backbone"](feats, coords, batch_size, sizes=sizes) return x # def extract_lidar_features(self, x) -> torch.Tensor: # feats, coords, sizes = self.voxelize(x) # batch_size = coords[-1, 0] + 1 # x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes) # return x # def extract_radar_features(self, x) -> torch.Tensor: # feats, coords, sizes = self.radar_voxelize(x) # batch_size = coords[-1, 0] + 1 # x = self.encoders["radar"]["backbone"](feats, coords, batch_size, sizes=sizes) # return x @torch.no_grad() @force_fp32() def voxelize(self, points, sensor): feats, coords, sizes = [], [], [] for k, res in enumerate(points): ret = self.encoders[sensor]["voxelize"](res) if len(ret) == 3: # hard voxelize f, c, n = ret else: assert len(ret) == 2 f, c = ret n = None feats.append(f) coords.append(F.pad(c, (1, 0), mode="constant", value=k)) if n is not None: sizes.append(n) feats = torch.cat(feats, dim=0) coords = torch.cat(coords, dim=0) if len(sizes) > 0: sizes = torch.cat(sizes, dim=0) if self.voxelize_reduce: feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view( -1, 1 ) feats = feats.contiguous() return feats, coords, sizes # @torch.no_grad() # @force_fp32() # def radar_voxelize(self, points): # feats, coords, sizes = [], [], [] # for k, res in enumerate(points): # ret = self.encoders["radar"]["voxelize"](res) # if len(ret) == 3: # # hard voxelize # f, c, n = ret # else: # assert len(ret) == 2 # f, c = ret # n = None # feats.append(f) # coords.append(F.pad(c, (1, 0), mode="constant", value=k)) # if n is not None: # sizes.append(n) # feats = torch.cat(feats, dim=0) # coords = torch.cat(coords, dim=0) # if len(sizes) > 0: # sizes = torch.cat(sizes, dim=0) # if self.voxelize_reduce: # feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view( # -1, 1 # ) # feats = feats.contiguous() # return feats, coords, sizes @auto_fp16(apply_to=("img", "points")) def forward( self, img, points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, metas, depths, radar=None, gt_masks_bev=None, gt_bboxes_3d=None, gt_labels_3d=None, **kwargs, ): if isinstance(img, list): raise NotImplementedError else: outputs = self.forward_single( img, points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, metas, depths, radar, gt_masks_bev, gt_bboxes_3d, gt_labels_3d, **kwargs, ) return outputs @auto_fp16(apply_to=("img", "points")) def forward_single( self, img, points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, metas, depths=None, radar=None, gt_masks_bev=None, gt_bboxes_3d=None, gt_labels_3d=None, **kwargs, ): features = [] auxiliary_losses = {} for sensor in ( self.encoders if self.training else list(self.encoders.keys())[::-1] ): if sensor == "camera": feature = self.extract_camera_features( img, points, radar, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, metas, gt_depths=depths, ) if self.use_depth_loss: feature, auxiliary_losses['depth'] = feature[0], feature[-1] elif sensor == "lidar": feature = self.extract_features(points, sensor) elif sensor == "radar": feature = self.extract_features(radar, sensor) else: raise ValueError(f"unsupported sensor: {sensor}") features.append(feature) if not self.training: # avoid OOM features = features[::-1] if self.fuser is not None: x = self.fuser(features) else: assert len(features) == 1, features x = features[0] batch_size = x.shape[0] x = self.decoder["backbone"](x) x = self.decoder["neck"](x) if self.training: outputs = {} for type, head in self.heads.items(): if type == "object": pred_dict = head(x, metas) losses = head.loss(gt_bboxes_3d, gt_labels_3d, pred_dict) elif type == "map": losses = head(x, gt_masks_bev) else: raise ValueError(f"unsupported head: {type}") for name, val in losses.items(): if val.requires_grad: outputs[f"loss/{type}/{name}"] = val * self.loss_scale[type] else: outputs[f"stats/{type}/{name}"] = val if self.use_depth_loss: if 'depth' in auxiliary_losses: outputs["loss/depth"] = auxiliary_losses['depth'] else: raise ValueError('Use depth loss is true, but depth loss not found') return outputs else: outputs = [{} for _ in range(batch_size)] for type, head in self.heads.items(): if type == "object": pred_dict = head(x, metas) bboxes = head.get_bboxes(pred_dict, metas) for k, (boxes, scores, labels) in enumerate(bboxes): outputs[k].update( { "boxes_3d": boxes.to("cpu"), "scores_3d": scores.cpu(), "labels_3d": labels.cpu(), } ) elif type == "map": logits = head(x) for k in range(batch_size): outputs[k].update( { "masks_bev": logits[k].cpu(), "gt_masks_bev": gt_masks_bev[k].cpu(), } ) else: raise ValueError(f"unsupported head: {type}") return outputs