bev-project/mmdet3d/models/heads/bbox/centerpoint.py

885 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule, force_fp32
from torch import nn
from mmdet3d.core import circle_nms, draw_heatmap_gaussian, gaussian_radius, xywhr2xyxyr
from mmdet3d.models import builder
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu
from mmdet.core import build_bbox_coder, multi_apply
def clip_sigmoid(x: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
return torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
@HEADS.register_module()
class SeparateHead(BaseModule):
"""SeparateHead for CenterHead.
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
head_conv (int): Output channels.
Default: 64.
final_kernal (int): Kernal size for the last conv layer.
Deafult: 1.
init_bias (float): Initial bias. Default: -2.19.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
"""
def __init__(
self,
in_channels,
heads,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type="Conv2d"),
norm_cfg=dict(type="BN2d"),
bias="auto",
init_cfg=None,
**kwargs,
):
assert (
init_cfg is None
), "To prevent abnormal initialization behavior, init_cfg is not allowed to be set"
super(SeparateHead, self).__init__(init_cfg=init_cfg)
self.heads = heads
self.init_bias = init_bias
for head in self.heads:
classes, num_conv = self.heads[head]
conv_layers = []
c_in = in_channels
for i in range(num_conv - 1):
conv_layers.append(
ConvModule(
c_in,
head_conv,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=bias,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)
)
c_in = head_conv
conv_layers.append(
build_conv_layer(
conv_cfg,
head_conv,
classes,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=True,
)
)
conv_layers = nn.Sequential(*conv_layers)
self.__setattr__(head, conv_layers)
if init_cfg is None:
self.init_cfg = dict(type="Kaiming", layer="Conv2d")
def init_weights(self):
"""Initialize weights."""
super().init_weights()
for head in self.heads:
if head == "heatmap":
self.__getattr__(head)[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for SepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the \
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the \
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape \
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the \
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the \
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of \
[B, N, H, W].
"""
ret_dict = dict()
for head in self.heads:
ret_dict[head] = self.__getattr__(head)(x)
return ret_dict
@HEADS.register_module()
class DCNSeparateHead(BaseModule):
r"""DCNSeparateHead for CenterHead.
.. code-block:: none
/-----> DCN for heatmap task -----> heatmap task.
feature
\-----> DCN for regression tasks -----> regression tasks
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
dcn_config (dict): Config of dcn layer.
num_cls (int): Output channels.
Default: 64.
final_kernal (int): Kernal size for the last conv layer.
Deafult: 1.
init_bias (float): Initial bias. Default: -2.19.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
""" # noqa: W605
def __init__(
self,
in_channels,
num_cls,
heads,
dcn_config,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type="Conv2d"),
norm_cfg=dict(type="BN2d"),
bias="auto",
init_cfg=None,
**kwargs,
):
assert init_cfg is None, (
"To prevent abnormal initialization "
"behavior, init_cfg is not allowed to be set"
)
super(DCNSeparateHead, self).__init__(init_cfg=init_cfg)
if "heatmap" in heads:
heads.pop("heatmap")
# feature adaptation with dcn
# use separate features for classification / regression
self.feature_adapt_cls = build_conv_layer(dcn_config)
self.feature_adapt_reg = build_conv_layer(dcn_config)
# heatmap prediction head
cls_head = [
ConvModule(
in_channels,
head_conv,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
bias=bias,
norm_cfg=norm_cfg,
),
build_conv_layer(
conv_cfg,
head_conv,
num_cls,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
]
self.cls_head = nn.Sequential(*cls_head)
self.init_bias = init_bias
# other regression target
self.task_head = SeparateHead(
in_channels,
heads,
head_conv=head_conv,
final_kernel=final_kernel,
bias=bias,
)
if init_cfg is None:
self.init_cfg = dict(type="Kaiming", layer="Conv2d")
def init_weights(self):
"""Initialize weights."""
super().init_weights()
self.cls_head[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for DCNSepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the \
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the \
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape \
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the \
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the \
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of \
[B, N, H, W].
"""
center_feat = self.feature_adapt_cls(x)
reg_feat = self.feature_adapt_reg(x)
cls_score = self.cls_head(center_feat)
ret = self.task_head(reg_feat)
ret["heatmap"] = cls_score
return ret
@HEADS.register_module()
class CenterHead(BaseModule):
"""CenterHead for CenterPoint.
Args:
mode (str): Mode of the head. Default: '3d'.
in_channels (list[int] | int): Channels of the input feature map.
Default: [128].
tasks (list[dict]): Task information including class number
and class names. Default: None.
dataset (str): Name of the dataset. Default: 'nuscenes'.
weight (float): Weight for location loss. Default: 0.25.
code_weights (list[int]): Code weights for location loss. Default: [].
common_heads (dict): Conv information for common heads.
Default: dict().
loss_cls (dict): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int): Output channels for share_conv_layer.
Default: 64.
num_heatmap_convs (int): Number of conv layers for heatmap conv layer.
Default: 2.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
"""
def __init__(
self,
in_channels=[128],
tasks=None,
train_cfg=None,
test_cfg=None,
bbox_coder=None,
common_heads=dict(),
loss_cls=dict(type="GaussianFocalLoss", reduction="mean"),
loss_bbox=dict(type="L1Loss", reduction="none", loss_weight=0.25),
separate_head=dict(type="SeparateHead", init_bias=-2.19, final_kernel=3),
share_conv_channel=64,
num_heatmap_convs=2,
conv_cfg=dict(type="Conv2d"),
norm_cfg=dict(type="BN2d"),
bias="auto",
norm_bbox=True,
init_cfg=None,
):
assert init_cfg is None, (
"To prevent abnormal initialization "
"behavior, init_cfg is not allowed to be set"
)
super(CenterHead, self).__init__(init_cfg=init_cfg)
num_classes = [len(t) for t in tasks]
self.class_names = [t for t in tasks]
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.in_channels = in_channels
self.num_classes = num_classes
self.norm_bbox = norm_bbox
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution
self.shared_conv = ConvModule(
in_channels,
share_conv_channel,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias,
)
self.task_heads = nn.ModuleList()
for num_cls in num_classes:
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(num_cls, num_heatmap_convs)))
separate_head.update(
in_channels=share_conv_channel, heads=heads, num_cls=num_cls
)
self.task_heads.append(builder.build_head(separate_head))
def forward_single(self, x):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x)
for task in self.task_heads:
ret_dicts.append(task(x))
return ret_dicts
def forward(self, feats, metas):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
tuple(list[dict]): Output results for tasks.
"""
if isinstance(feats, torch.Tensor):
feats = [feats]
return multi_apply(self.forward_single, feats)
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor): Mask of the feature map with the shape
of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes.
gt_labels_3d (list[torch.Tensor]): Labels of boxes.
Returns:
Returns:
tuple[list[torch.Tensor]]: Tuple of target including \
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the \
position of the valid boxes.
- list[torch.Tensor]: Masks indicating which \
boxes are valid.
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d
)
# Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds]
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks]
return heatmaps, anno_boxes, inds, masks
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
gt_labels_3d (torch.Tensor): Labels of boxes.
Returns:
tuple[list[torch.Tensor]]: Tuple of target including \
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the position \
of the valid boxes.
- list[torch.Tensor]: Masks indicating which boxes \
are valid.
"""
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]), dim=1
).to(device)
max_objs = self.train_cfg["max_objs"] * self.train_cfg["dense_reg"]
grid_size = torch.tensor(self.train_cfg["grid_size"])
pc_range = torch.tensor(self.train_cfg["point_cloud_range"])
voxel_size = torch.tensor(self.train_cfg["voxel_size"])
feature_map_size = torch.div(
grid_size[:2],
self.train_cfg["out_size_factor"],
rounding_mode="trunc",
)
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
task_masks.append(
[
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
]
)
flag += len(class_name)
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1], feature_map_size[0])
)
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10), dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
num_objs = min(task_boxes[idx].shape[0], max_objs)
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1
width = task_boxes[idx][k][3]
length = task_boxes[idx][k][4]
width = width / voxel_size[0] / self.train_cfg["out_size_factor"]
length = length / voxel_size[1] / self.train_cfg["out_size_factor"]
if width > 0 and length > 0:
radius = gaussian_radius(
(length, width), min_overlap=self.train_cfg["gaussian_overlap"]
)
radius = max(self.train_cfg["min_radius"], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = (
task_boxes[idx][k][0],
task_boxes[idx][k][1],
task_boxes[idx][k][2],
)
coor_x = (
(x - pc_range[0])
/ voxel_size[0]
/ self.train_cfg["out_size_factor"]
)
coor_y = (
(y - pc_range[1])
/ voxel_size[1]
/ self.train_cfg["out_size_factor"]
)
center = torch.tensor(
[coor_x, coor_y], dtype=torch.float32, device=device
)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (
0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]
):
continue
draw_gaussian(heatmap[cls_id], center_int[[1, 0]], radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (
x * feature_map_size[1] + y
< feature_map_size[0] * feature_map_size[1]
)
ind[new_idx] = x * feature_map_size[1] + y
mask[new_idx] = 1
# TODO: support other outdoor dataset
vx, vy = task_boxes[idx][k][7:]
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat(
[
center - torch.tensor([x, y], device=device),
z.unsqueeze(0),
box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0),
]
)
heatmaps.append(heatmap)
anno_boxes.append(anno_box)
masks.append(mask)
inds.append(ind)
return heatmaps, anno_boxes, inds, masks
@force_fp32(apply_to=("preds_dicts"))
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes.
gt_labels_3d (list[torch.Tensor]): Labels of boxes.
preds_dicts (dict): Output of forward function.
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps, anno_boxes, inds, masks = self.get_targets(gt_bboxes_3d, gt_labels_3d)
loss_dict = dict()
for task_id, preds_dict in enumerate(preds_dicts):
# heatmap focal loss
preds_dict[0]["heatmap"] = clip_sigmoid(preds_dict[0]["heatmap"])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
loss_heatmap = self.loss_cls(
preds_dict[0]["heatmap"], heatmaps[task_id], avg_factor=max(num_pos, 1)
)
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]["anno_box"] = torch.cat(
(
preds_dict[0]["reg"],
preds_dict[0]["height"],
preds_dict[0]["dim"],
preds_dict[0]["rot"],
preds_dict[0]["vel"],
),
dim=1,
)
# Regression loss for dimension, offset, height, rotation
ind = inds[task_id]
num = masks[task_id].float().sum()
pred = preds_dict[0]["anno_box"].permute(0, 2, 3, 1).contiguous()
pred = pred.view(pred.size(0), -1, pred.size(3))
pred = self._gather_feat(pred, ind)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan
code_weights = self.train_cfg.get("code_weights", None)
bbox_weights = mask * mask.new_tensor(code_weights)
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=(num + 1e-4)
)
loss_dict[f"heatmap/task{task_id}"] = loss_heatmap
loss_dict[f"bbox/task{task_id}"] = loss_bbox
return loss_dict
@force_fp32(apply_to=("preds_dicts"))
def get_bboxes(self, preds_dicts, metas, img=None, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
if not isinstance(self.test_cfg["nms_type"], list):
nms_types = [self.test_cfg["nms_type"] for _ in range(len(preds_dicts))]
else:
nms_types = self.test_cfg["nms_type"]
if "nms_scale" in self.test_cfg:
if not isinstance(self.test_cfg["nms_scale"], list):
nms_scales = [
[
self.test_cfg["nms_scale"]
for _ in range(self.num_classes[task_id])
]
for task_id in range(len(preds_dicts))
]
else:
nms_scales = self.test_cfg["nms_scale"]
else:
nms_scales = [
[1.0 for _ in range(self.num_classes[task_id])]
for task_id in range(len(preds_dicts))
]
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
num_class_with_bg = self.num_classes[task_id]
batch_size = preds_dict[0]["heatmap"].shape[0]
batch_heatmap = preds_dict[0]["heatmap"].sigmoid()
batch_reg = preds_dict[0]["reg"]
batch_hei = preds_dict[0]["height"]
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]["dim"])
else:
batch_dim = preds_dict[0]["dim"]
batch_rots = preds_dict[0]["rot"][:, 0].unsqueeze(1)
batch_rotc = preds_dict[0]["rot"][:, 1].unsqueeze(1)
if "vel" in preds_dict[0]:
batch_vel = preds_dict[0]["vel"]
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id,
)
batch_reg_preds = [box["bboxes"] for box in temp]
batch_cls_preds = [box["scores"] for box in temp]
batch_cls_labels = [box["labels"] for box in temp]
if nms_types[task_id] == "circle":
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]["bboxes"]
scores = temp[i]["scores"]
labels = temp[i]["labels"]
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg["min_radius"][task_id],
post_max_size=self.test_cfg["post_max_size"],
),
dtype=torch.long,
device=boxes.device,
)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(
num_class_with_bg,
batch_cls_preds,
batch_reg_preds,
batch_cls_labels,
metas,
nms_scales[task_id],
)
)
# Merge branches results
num_samples = len(rets[0])
ret_list = []
for i in range(num_samples):
for k in rets[0][i].keys():
if k == "bboxes":
bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = metas[i]["box_type_3d"](bboxes, self.bbox_coder.code_size)
elif k == "scores":
scores = torch.cat([ret[i][k] for ret in rets])
elif k == "labels":
flag = 0
for j, num_class in enumerate(self.num_classes):
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels])
return ret_list
def get_task_detections(
self,
num_class_with_bg,
batch_cls_preds,
batch_reg_preds,
batch_cls_labels,
metas,
nms_scale=1.0,
):
"""Rotate nms for each task.
Args:
num_class_with_bg (int): Number of classes for the current task.
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N].
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9].
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N].
metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the \
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the \
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the \
shape of [N].
"""
predictions_dicts = []
post_center_range = self.test_cfg["post_center_limit_range"]
if len(post_center_range) > 0:
post_center_range = torch.tensor(
post_center_range,
dtype=batch_reg_preds[0].dtype,
device=batch_reg_preds[0].device,
)
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)
):
# Apply NMS in birdeye view
# get highest score per prediction, than apply nms
# to remove overlapped box.
if num_class_with_bg == 1:
top_scores = cls_preds.squeeze(-1)
top_labels = torch.zeros(
cls_preds.shape[0], device=cls_preds.device, dtype=torch.long
)
else:
top_labels = cls_labels.long()
top_scores = cls_preds.squeeze(-1)
if self.test_cfg["score_threshold"] > 0.0:
thresh = torch.tensor(
[self.test_cfg["score_threshold"]], device=cls_preds.device
).type_as(cls_preds)
top_scores_keep = top_scores >= thresh
top_scores = top_scores.masked_select(top_scores_keep)
if top_scores.shape[0] != 0:
if self.test_cfg["score_threshold"] > 0.0:
box_preds = box_preds[top_scores_keep]
top_labels = top_labels[top_scores_keep]
bev_box = metas[i]["box_type_3d"](
box_preds[:, :], self.bbox_coder.code_size
).bev
for cls, scale in enumerate(nms_scale):
cur_bev_box = bev_box[top_labels == cls]
cur_bev_box[:, [2, 3]] *= scale
bev_box[top_labels == cls] = cur_bev_box
boxes_for_nms = xywhr2xyxyr(bev_box)
# the nms in 3d detection just remove overlap boxes.
selected = nms_gpu(
boxes_for_nms,
top_scores,
thresh=self.test_cfg["nms_thr"],
pre_maxsize=self.test_cfg["pre_max_size"],
post_max_size=self.test_cfg["post_max_size"],
)
else:
selected = []
# if selected is not None:
selected_boxes = box_preds[selected]
selected_labels = top_labels[selected]
selected_scores = top_scores[selected]
# finally generate predictions.
if selected_boxes.shape[0] != 0:
box_preds = selected_boxes
scores = selected_scores
label_preds = selected_labels
final_box_preds = box_preds
final_scores = scores
final_labels = label_preds
if post_center_range is not None:
mask = (final_box_preds[:, :3] >= post_center_range[:3]).all(1)
mask &= (final_box_preds[:, :3] <= post_center_range[3:]).all(1)
predictions_dict = dict(
bboxes=final_box_preds[mask],
scores=final_scores[mask],
labels=final_labels[mask],
)
else:
predictions_dict = dict(
bboxes=final_box_preds, scores=final_scores, labels=final_labels
)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros(
[0, self.bbox_coder.code_size], dtype=dtype, device=device
),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0], dtype=top_labels.dtype, device=device),
)
predictions_dicts.append(predictions_dict)
return predictions_dicts