import copy import torch from mmcv.cnn import ConvModule, build_conv_layer from mmcv.runner import BaseModule, force_fp32 from torch import nn from mmdet3d.core import circle_nms, draw_heatmap_gaussian, gaussian_radius, xywhr2xyxyr from mmdet3d.models import builder from mmdet3d.models.builder import HEADS, build_loss from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu from mmdet.core import build_bbox_coder, multi_apply def clip_sigmoid(x: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: return torch.clamp(x.sigmoid_(), min=eps, max=1 - eps) @HEADS.register_module() class SeparateHead(BaseModule): """SeparateHead for CenterHead. Args: in_channels (int): Input channels for conv_layer. heads (dict): Conv information. head_conv (int): Output channels. Default: 64. final_kernal (int): Kernal size for the last conv layer. Deafult: 1. init_bias (float): Initial bias. Default: -2.19. conv_cfg (dict): Config of conv layer. Default: dict(type='Conv2d') norm_cfg (dict): Config of norm layer. Default: dict(type='BN2d'). bias (str): Type of bias. Default: 'auto'. """ def __init__( self, in_channels, heads, head_conv=64, final_kernel=1, init_bias=-2.19, conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"), bias="auto", init_cfg=None, **kwargs, ): assert ( init_cfg is None ), "To prevent abnormal initialization behavior, init_cfg is not allowed to be set" super(SeparateHead, self).__init__(init_cfg=init_cfg) self.heads = heads self.init_bias = init_bias for head in self.heads: classes, num_conv = self.heads[head] conv_layers = [] c_in = in_channels for i in range(num_conv - 1): conv_layers.append( ConvModule( c_in, head_conv, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=bias, conv_cfg=conv_cfg, norm_cfg=norm_cfg, ) ) c_in = head_conv conv_layers.append( build_conv_layer( conv_cfg, head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True, ) ) conv_layers = nn.Sequential(*conv_layers) self.__setattr__(head, conv_layers) if init_cfg is None: self.init_cfg = dict(type="Kaiming", layer="Conv2d") def init_weights(self): """Initialize weights.""" super().init_weights() for head in self.heads: if head == "heatmap": self.__getattr__(head)[-1].bias.data.fill_(self.init_bias) def forward(self, x): """Forward function for SepHead. Args: x (torch.Tensor): Input feature map with the shape of [B, 512, 128, 128]. Returns: dict[str: torch.Tensor]: contains the following keys: -reg (torch.Tensor): 2D regression value with the \ shape of [B, 2, H, W]. -height (torch.Tensor): Height value with the \ shape of [B, 1, H, W]. -dim (torch.Tensor): Size value with the shape \ of [B, 3, H, W]. -rot (torch.Tensor): Rotation value with the \ shape of [B, 2, H, W]. -vel (torch.Tensor): Velocity value with the \ shape of [B, 2, H, W]. -heatmap (torch.Tensor): Heatmap with the shape of \ [B, N, H, W]. """ ret_dict = dict() for head in self.heads: ret_dict[head] = self.__getattr__(head)(x) return ret_dict @HEADS.register_module() class DCNSeparateHead(BaseModule): r"""DCNSeparateHead for CenterHead. .. code-block:: none /-----> DCN for heatmap task -----> heatmap task. feature \-----> DCN for regression tasks -----> regression tasks Args: in_channels (int): Input channels for conv_layer. heads (dict): Conv information. dcn_config (dict): Config of dcn layer. num_cls (int): Output channels. Default: 64. final_kernal (int): Kernal size for the last conv layer. Deafult: 1. init_bias (float): Initial bias. Default: -2.19. conv_cfg (dict): Config of conv layer. Default: dict(type='Conv2d') norm_cfg (dict): Config of norm layer. Default: dict(type='BN2d'). bias (str): Type of bias. Default: 'auto'. """ # noqa: W605 def __init__( self, in_channels, num_cls, heads, dcn_config, head_conv=64, final_kernel=1, init_bias=-2.19, conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"), bias="auto", init_cfg=None, **kwargs, ): assert init_cfg is None, ( "To prevent abnormal initialization " "behavior, init_cfg is not allowed to be set" ) super(DCNSeparateHead, self).__init__(init_cfg=init_cfg) if "heatmap" in heads: heads.pop("heatmap") # feature adaptation with dcn # use separate features for classification / regression self.feature_adapt_cls = build_conv_layer(dcn_config) self.feature_adapt_reg = build_conv_layer(dcn_config) # heatmap prediction head cls_head = [ ConvModule( in_channels, head_conv, kernel_size=3, padding=1, conv_cfg=conv_cfg, bias=bias, norm_cfg=norm_cfg, ), build_conv_layer( conv_cfg, head_conv, num_cls, kernel_size=3, stride=1, padding=1, bias=bias, ), ] self.cls_head = nn.Sequential(*cls_head) self.init_bias = init_bias # other regression target self.task_head = SeparateHead( in_channels, heads, head_conv=head_conv, final_kernel=final_kernel, bias=bias, ) if init_cfg is None: self.init_cfg = dict(type="Kaiming", layer="Conv2d") def init_weights(self): """Initialize weights.""" super().init_weights() self.cls_head[-1].bias.data.fill_(self.init_bias) def forward(self, x): """Forward function for DCNSepHead. Args: x (torch.Tensor): Input feature map with the shape of [B, 512, 128, 128]. Returns: dict[str: torch.Tensor]: contains the following keys: -reg (torch.Tensor): 2D regression value with the \ shape of [B, 2, H, W]. -height (torch.Tensor): Height value with the \ shape of [B, 1, H, W]. -dim (torch.Tensor): Size value with the shape \ of [B, 3, H, W]. -rot (torch.Tensor): Rotation value with the \ shape of [B, 2, H, W]. -vel (torch.Tensor): Velocity value with the \ shape of [B, 2, H, W]. -heatmap (torch.Tensor): Heatmap with the shape of \ [B, N, H, W]. """ center_feat = self.feature_adapt_cls(x) reg_feat = self.feature_adapt_reg(x) cls_score = self.cls_head(center_feat) ret = self.task_head(reg_feat) ret["heatmap"] = cls_score return ret @HEADS.register_module() class CenterHead(BaseModule): """CenterHead for CenterPoint. Args: mode (str): Mode of the head. Default: '3d'. in_channels (list[int] | int): Channels of the input feature map. Default: [128]. tasks (list[dict]): Task information including class number and class names. Default: None. dataset (str): Name of the dataset. Default: 'nuscenes'. weight (float): Weight for location loss. Default: 0.25. code_weights (list[int]): Code weights for location loss. Default: []. common_heads (dict): Conv information for common heads. Default: dict(). loss_cls (dict): Config of classification loss function. Default: dict(type='GaussianFocalLoss', reduction='mean'). loss_bbox (dict): Config of regression loss function. Default: dict(type='L1Loss', reduction='none'). separate_head (dict): Config of separate head. Default: dict( type='SeparateHead', init_bias=-2.19, final_kernel=3) share_conv_channel (int): Output channels for share_conv_layer. Default: 64. num_heatmap_convs (int): Number of conv layers for heatmap conv layer. Default: 2. conv_cfg (dict): Config of conv layer. Default: dict(type='Conv2d') norm_cfg (dict): Config of norm layer. Default: dict(type='BN2d'). bias (str): Type of bias. Default: 'auto'. """ def __init__( self, in_channels=[128], tasks=None, train_cfg=None, test_cfg=None, bbox_coder=None, common_heads=dict(), loss_cls=dict(type="GaussianFocalLoss", reduction="mean"), loss_bbox=dict(type="L1Loss", reduction="none", loss_weight=0.25), separate_head=dict(type="SeparateHead", init_bias=-2.19, final_kernel=3), share_conv_channel=64, num_heatmap_convs=2, conv_cfg=dict(type="Conv2d"), norm_cfg=dict(type="BN2d"), bias="auto", norm_bbox=True, init_cfg=None, ): assert init_cfg is None, ( "To prevent abnormal initialization " "behavior, init_cfg is not allowed to be set" ) super(CenterHead, self).__init__(init_cfg=init_cfg) num_classes = [len(t) for t in tasks] self.class_names = [t for t in tasks] self.train_cfg = train_cfg self.test_cfg = test_cfg self.in_channels = in_channels self.num_classes = num_classes self.norm_bbox = norm_bbox self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) self.bbox_coder = build_bbox_coder(bbox_coder) self.num_anchor_per_locs = [n for n in num_classes] self.fp16_enabled = False # a shared convolution self.shared_conv = ConvModule( in_channels, share_conv_channel, kernel_size=3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=bias, ) self.task_heads = nn.ModuleList() for num_cls in num_classes: heads = copy.deepcopy(common_heads) heads.update(dict(heatmap=(num_cls, num_heatmap_convs))) separate_head.update( in_channels=share_conv_channel, heads=heads, num_cls=num_cls ) self.task_heads.append(builder.build_head(separate_head)) def forward_single(self, x): """Forward function for CenterPoint. Args: x (torch.Tensor): Input feature map with the shape of [B, 512, 128, 128]. Returns: list[dict]: Output results for tasks. """ ret_dicts = [] x = self.shared_conv(x) for task in self.task_heads: ret_dicts.append(task(x)) return ret_dicts def forward(self, feats, metas): """Forward pass. Args: feats (list[torch.Tensor]): Multi-level features, e.g., features produced by FPN. Returns: tuple(list[dict]): Output results for tasks. """ if isinstance(feats, torch.Tensor): feats = [feats] return multi_apply(self.forward_single, feats) def _gather_feat(self, feat, ind, mask=None): """Gather feature map. Given feature map and index, return indexed feature map. Args: feat (torch.tensor): Feature map with the shape of [B, H*W, 10]. ind (torch.Tensor): Index of the ground truth boxes with the shape of [B, max_obj]. mask (torch.Tensor): Mask of the feature map with the shape of [B, max_obj]. Default: None. Returns: torch.Tensor: Feature map after gathering with the shape of [B, max_obj, 10]. """ dim = feat.size(2) ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) if mask is not None: mask = mask.unsqueeze(2).expand_as(feat) feat = feat[mask] feat = feat.view(-1, dim) return feat def get_targets(self, gt_bboxes_3d, gt_labels_3d): """Generate targets. How each output is transformed: Each nested list is transposed so that all same-index elements in each sub-list (1, ..., N) become the new sub-lists. [ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ] ==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ] The new transposed nested list is converted into a list of N tensors generated by concatenating tensors in the new sub-lists. [ tensor0, tensor1, tensor2, ... ] Args: gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground truth gt boxes. gt_labels_3d (list[torch.Tensor]): Labels of boxes. Returns: Returns: tuple[list[torch.Tensor]]: Tuple of target including \ the following results in order. - list[torch.Tensor]: Heatmap scores. - list[torch.Tensor]: Ground truth boxes. - list[torch.Tensor]: Indexes indicating the \ position of the valid boxes. - list[torch.Tensor]: Masks indicating which \ boxes are valid. """ heatmaps, anno_boxes, inds, masks = multi_apply( self.get_targets_single, gt_bboxes_3d, gt_labels_3d ) # Transpose heatmaps heatmaps = list(map(list, zip(*heatmaps))) heatmaps = [torch.stack(hms_) for hms_ in heatmaps] # Transpose anno_boxes anno_boxes = list(map(list, zip(*anno_boxes))) anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] # Transpose inds inds = list(map(list, zip(*inds))) inds = [torch.stack(inds_) for inds_ in inds] # Transpose inds masks = list(map(list, zip(*masks))) masks = [torch.stack(masks_) for masks_ in masks] return heatmaps, anno_boxes, inds, masks def get_targets_single(self, gt_bboxes_3d, gt_labels_3d): """Generate training targets for a single sample. Args: gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. gt_labels_3d (torch.Tensor): Labels of boxes. Returns: tuple[list[torch.Tensor]]: Tuple of target including \ the following results in order. - list[torch.Tensor]: Heatmap scores. - list[torch.Tensor]: Ground truth boxes. - list[torch.Tensor]: Indexes indicating the position \ of the valid boxes. - list[torch.Tensor]: Masks indicating which boxes \ are valid. """ device = gt_labels_3d.device gt_bboxes_3d = torch.cat( (gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]), dim=1 ).to(device) max_objs = self.train_cfg["max_objs"] * self.train_cfg["dense_reg"] grid_size = torch.tensor(self.train_cfg["grid_size"]) pc_range = torch.tensor(self.train_cfg["point_cloud_range"]) voxel_size = torch.tensor(self.train_cfg["voxel_size"]) feature_map_size = torch.div( grid_size[:2], self.train_cfg["out_size_factor"], rounding_mode="trunc", ) # reorganize the gt_dict by tasks task_masks = [] flag = 0 for class_name in self.class_names: task_masks.append( [ torch.where(gt_labels_3d == class_name.index(i) + flag) for i in class_name ] ) flag += len(class_name) task_boxes = [] task_classes = [] flag2 = 0 for idx, mask in enumerate(task_masks): task_box = [] task_class = [] for m in mask: task_box.append(gt_bboxes_3d[m]) # 0 is background for each task, so we need to add 1 here. task_class.append(gt_labels_3d[m] + 1 - flag2) task_boxes.append(torch.cat(task_box, axis=0).to(device)) task_classes.append(torch.cat(task_class).long().to(device)) flag2 += len(mask) draw_gaussian = draw_heatmap_gaussian heatmaps, anno_boxes, inds, masks = [], [], [], [] for idx, task_head in enumerate(self.task_heads): heatmap = gt_bboxes_3d.new_zeros( (len(self.class_names[idx]), feature_map_size[1], feature_map_size[0]) ) anno_box = gt_bboxes_3d.new_zeros((max_objs, 10), dtype=torch.float32) ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64) mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8) num_objs = min(task_boxes[idx].shape[0], max_objs) for k in range(num_objs): cls_id = task_classes[idx][k] - 1 width = task_boxes[idx][k][3] length = task_boxes[idx][k][4] width = width / voxel_size[0] / self.train_cfg["out_size_factor"] length = length / voxel_size[1] / self.train_cfg["out_size_factor"] if width > 0 and length > 0: radius = gaussian_radius( (length, width), min_overlap=self.train_cfg["gaussian_overlap"] ) radius = max(self.train_cfg["min_radius"], int(radius)) # be really careful for the coordinate system of # your box annotation. x, y, z = ( task_boxes[idx][k][0], task_boxes[idx][k][1], task_boxes[idx][k][2], ) coor_x = ( (x - pc_range[0]) / voxel_size[0] / self.train_cfg["out_size_factor"] ) coor_y = ( (y - pc_range[1]) / voxel_size[1] / self.train_cfg["out_size_factor"] ) center = torch.tensor( [coor_x, coor_y], dtype=torch.float32, device=device ) center_int = center.to(torch.int32) # throw out not in range objects to avoid out of array # area when creating the heatmap if not ( 0 <= center_int[0] < feature_map_size[0] and 0 <= center_int[1] < feature_map_size[1] ): continue draw_gaussian(heatmap[cls_id], center_int[[1, 0]], radius) new_idx = k x, y = center_int[0], center_int[1] assert ( x * feature_map_size[1] + y < feature_map_size[0] * feature_map_size[1] ) ind[new_idx] = x * feature_map_size[1] + y mask[new_idx] = 1 # TODO: support other outdoor dataset vx, vy = task_boxes[idx][k][7:] rot = task_boxes[idx][k][6] box_dim = task_boxes[idx][k][3:6] if self.norm_bbox: box_dim = box_dim.log() anno_box[new_idx] = torch.cat( [ center - torch.tensor([x, y], device=device), z.unsqueeze(0), box_dim, torch.sin(rot).unsqueeze(0), torch.cos(rot).unsqueeze(0), vx.unsqueeze(0), vy.unsqueeze(0), ] ) heatmaps.append(heatmap) anno_boxes.append(anno_box) masks.append(mask) inds.append(ind) return heatmaps, anno_boxes, inds, masks @force_fp32(apply_to=("preds_dicts")) def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs): """Loss function for CenterHead. Args: gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground truth gt boxes. gt_labels_3d (list[torch.Tensor]): Labels of boxes. preds_dicts (dict): Output of forward function. Returns: dict[str:torch.Tensor]: Loss of heatmap and bbox of each task. """ heatmaps, anno_boxes, inds, masks = self.get_targets(gt_bboxes_3d, gt_labels_3d) loss_dict = dict() for task_id, preds_dict in enumerate(preds_dicts): # heatmap focal loss preds_dict[0]["heatmap"] = clip_sigmoid(preds_dict[0]["heatmap"]) num_pos = heatmaps[task_id].eq(1).float().sum().item() loss_heatmap = self.loss_cls( preds_dict[0]["heatmap"], heatmaps[task_id], avg_factor=max(num_pos, 1) ) target_box = anno_boxes[task_id] # reconstruct the anno_box from multiple reg heads preds_dict[0]["anno_box"] = torch.cat( ( preds_dict[0]["reg"], preds_dict[0]["height"], preds_dict[0]["dim"], preds_dict[0]["rot"], preds_dict[0]["vel"], ), dim=1, ) # Regression loss for dimension, offset, height, rotation ind = inds[task_id] num = masks[task_id].float().sum() pred = preds_dict[0]["anno_box"].permute(0, 2, 3, 1).contiguous() pred = pred.view(pred.size(0), -1, pred.size(3)) pred = self._gather_feat(pred, ind) mask = masks[task_id].unsqueeze(2).expand_as(target_box).float() isnotnan = (~torch.isnan(target_box)).float() mask *= isnotnan code_weights = self.train_cfg.get("code_weights", None) bbox_weights = mask * mask.new_tensor(code_weights) loss_bbox = self.loss_bbox( pred, target_box, bbox_weights, avg_factor=(num + 1e-4) ) loss_dict[f"heatmap/task{task_id}"] = loss_heatmap loss_dict[f"bbox/task{task_id}"] = loss_bbox return loss_dict @force_fp32(apply_to=("preds_dicts")) def get_bboxes(self, preds_dicts, metas, img=None, rescale=False): """Generate bboxes from bbox head predictions. Args: preds_dicts (tuple[list[dict]]): Prediction results. metas (list[dict]): Point cloud and image's meta info. Returns: list[dict]: Decoded bbox, scores and labels after nms. """ if not isinstance(self.test_cfg["nms_type"], list): nms_types = [self.test_cfg["nms_type"] for _ in range(len(preds_dicts))] else: nms_types = self.test_cfg["nms_type"] if "nms_scale" in self.test_cfg: if not isinstance(self.test_cfg["nms_scale"], list): nms_scales = [ [ self.test_cfg["nms_scale"] for _ in range(self.num_classes[task_id]) ] for task_id in range(len(preds_dicts)) ] else: nms_scales = self.test_cfg["nms_scale"] else: nms_scales = [ [1.0 for _ in range(self.num_classes[task_id])] for task_id in range(len(preds_dicts)) ] rets = [] for task_id, preds_dict in enumerate(preds_dicts): num_class_with_bg = self.num_classes[task_id] batch_size = preds_dict[0]["heatmap"].shape[0] batch_heatmap = preds_dict[0]["heatmap"].sigmoid() batch_reg = preds_dict[0]["reg"] batch_hei = preds_dict[0]["height"] if self.norm_bbox: batch_dim = torch.exp(preds_dict[0]["dim"]) else: batch_dim = preds_dict[0]["dim"] batch_rots = preds_dict[0]["rot"][:, 0].unsqueeze(1) batch_rotc = preds_dict[0]["rot"][:, 1].unsqueeze(1) if "vel" in preds_dict[0]: batch_vel = preds_dict[0]["vel"] else: batch_vel = None temp = self.bbox_coder.decode( batch_heatmap, batch_rots, batch_rotc, batch_hei, batch_dim, batch_vel, reg=batch_reg, task_id=task_id, ) batch_reg_preds = [box["bboxes"] for box in temp] batch_cls_preds = [box["scores"] for box in temp] batch_cls_labels = [box["labels"] for box in temp] if nms_types[task_id] == "circle": ret_task = [] for i in range(batch_size): boxes3d = temp[i]["bboxes"] scores = temp[i]["scores"] labels = temp[i]["labels"] centers = boxes3d[:, [0, 1]] boxes = torch.cat([centers, scores.view(-1, 1)], dim=1) keep = torch.tensor( circle_nms( boxes.detach().cpu().numpy(), self.test_cfg["min_radius"][task_id], post_max_size=self.test_cfg["post_max_size"], ), dtype=torch.long, device=boxes.device, ) boxes3d = boxes3d[keep] scores = scores[keep] labels = labels[keep] ret = dict(bboxes=boxes3d, scores=scores, labels=labels) ret_task.append(ret) rets.append(ret_task) else: rets.append( self.get_task_detections( num_class_with_bg, batch_cls_preds, batch_reg_preds, batch_cls_labels, metas, nms_scales[task_id], ) ) # Merge branches results num_samples = len(rets[0]) ret_list = [] for i in range(num_samples): for k in rets[0][i].keys(): if k == "bboxes": bboxes = torch.cat([ret[i][k] for ret in rets]) bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 bboxes = metas[i]["box_type_3d"](bboxes, self.bbox_coder.code_size) elif k == "scores": scores = torch.cat([ret[i][k] for ret in rets]) elif k == "labels": flag = 0 for j, num_class in enumerate(self.num_classes): rets[j][i][k] += flag flag += num_class labels = torch.cat([ret[i][k].int() for ret in rets]) ret_list.append([bboxes, scores, labels]) return ret_list def get_task_detections( self, num_class_with_bg, batch_cls_preds, batch_reg_preds, batch_cls_labels, metas, nms_scale=1.0, ): """Rotate nms for each task. Args: num_class_with_bg (int): Number of classes for the current task. batch_cls_preds (list[torch.Tensor]): Prediction score with the shape of [N]. batch_reg_preds (list[torch.Tensor]): Prediction bbox with the shape of [N, 9]. batch_cls_labels (list[torch.Tensor]): Prediction label with the shape of [N]. metas (list[dict]): Meta information of each sample. Returns: list[dict[str: torch.Tensor]]: contains the following keys: -bboxes (torch.Tensor): Prediction bboxes after nms with the \ shape of [N, 9]. -scores (torch.Tensor): Prediction scores after nms with the \ shape of [N]. -labels (torch.Tensor): Prediction labels after nms with the \ shape of [N]. """ predictions_dicts = [] post_center_range = self.test_cfg["post_center_limit_range"] if len(post_center_range) > 0: post_center_range = torch.tensor( post_center_range, dtype=batch_reg_preds[0].dtype, device=batch_reg_preds[0].device, ) for i, (box_preds, cls_preds, cls_labels) in enumerate( zip(batch_reg_preds, batch_cls_preds, batch_cls_labels) ): # Apply NMS in birdeye view # get highest score per prediction, than apply nms # to remove overlapped box. if num_class_with_bg == 1: top_scores = cls_preds.squeeze(-1) top_labels = torch.zeros( cls_preds.shape[0], device=cls_preds.device, dtype=torch.long ) else: top_labels = cls_labels.long() top_scores = cls_preds.squeeze(-1) if self.test_cfg["score_threshold"] > 0.0: thresh = torch.tensor( [self.test_cfg["score_threshold"]], device=cls_preds.device ).type_as(cls_preds) top_scores_keep = top_scores >= thresh top_scores = top_scores.masked_select(top_scores_keep) if top_scores.shape[0] != 0: if self.test_cfg["score_threshold"] > 0.0: box_preds = box_preds[top_scores_keep] top_labels = top_labels[top_scores_keep] bev_box = metas[i]["box_type_3d"]( box_preds[:, :], self.bbox_coder.code_size ).bev for cls, scale in enumerate(nms_scale): cur_bev_box = bev_box[top_labels == cls] cur_bev_box[:, [2, 3]] *= scale bev_box[top_labels == cls] = cur_bev_box boxes_for_nms = xywhr2xyxyr(bev_box) # the nms in 3d detection just remove overlap boxes. selected = nms_gpu( boxes_for_nms, top_scores, thresh=self.test_cfg["nms_thr"], pre_maxsize=self.test_cfg["pre_max_size"], post_max_size=self.test_cfg["post_max_size"], ) else: selected = [] # if selected is not None: selected_boxes = box_preds[selected] selected_labels = top_labels[selected] selected_scores = top_scores[selected] # finally generate predictions. if selected_boxes.shape[0] != 0: box_preds = selected_boxes scores = selected_scores label_preds = selected_labels final_box_preds = box_preds final_scores = scores final_labels = label_preds if post_center_range is not None: mask = (final_box_preds[:, :3] >= post_center_range[:3]).all(1) mask &= (final_box_preds[:, :3] <= post_center_range[3:]).all(1) predictions_dict = dict( bboxes=final_box_preds[mask], scores=final_scores[mask], labels=final_labels[mask], ) else: predictions_dict = dict( bboxes=final_box_preds, scores=final_scores, labels=final_labels ) else: dtype = batch_reg_preds[0].dtype device = batch_reg_preds[0].device predictions_dict = dict( bboxes=torch.zeros( [0, self.bbox_coder.code_size], dtype=dtype, device=device ), scores=torch.zeros([0], dtype=dtype, device=device), labels=torch.zeros([0], dtype=top_labels.dtype, device=device), ) predictions_dicts.append(predictions_dict) return predictions_dicts