bev-project/mmdet3d/models/heads/bbox/centerpoint.py

885 lines
33 KiB
Python
Raw Normal View History

2022-06-03 12:21:18 +08:00
import copy
import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule, force_fp32
from torch import nn
from mmdet3d.core import circle_nms, draw_heatmap_gaussian, gaussian_radius, xywhr2xyxyr
from mmdet3d.models import builder
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu
from mmdet.core import build_bbox_coder, multi_apply
def clip_sigmoid(x: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
return torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
@HEADS.register_module()
class SeparateHead(BaseModule):
"""SeparateHead for CenterHead.
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
head_conv (int): Output channels.
Default: 64.
final_kernal (int): Kernal size for the last conv layer.
Deafult: 1.
init_bias (float): Initial bias. Default: -2.19.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
"""
def __init__(
self,
in_channels,
heads,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type="Conv2d"),
norm_cfg=dict(type="BN2d"),
bias="auto",
init_cfg=None,
**kwargs,
):
assert (
init_cfg is None
), "To prevent abnormal initialization behavior, init_cfg is not allowed to be set"
super(SeparateHead, self).__init__(init_cfg=init_cfg)
self.heads = heads
self.init_bias = init_bias
for head in self.heads:
classes, num_conv = self.heads[head]
conv_layers = []
c_in = in_channels
for i in range(num_conv - 1):
conv_layers.append(
ConvModule(
c_in,
head_conv,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=bias,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)
)
c_in = head_conv
conv_layers.append(
build_conv_layer(
conv_cfg,
head_conv,
classes,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=True,
)
)
conv_layers = nn.Sequential(*conv_layers)
self.__setattr__(head, conv_layers)
if init_cfg is None:
self.init_cfg = dict(type="Kaiming", layer="Conv2d")
def init_weights(self):
"""Initialize weights."""
super().init_weights()
for head in self.heads:
if head == "heatmap":
self.__getattr__(head)[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for SepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the \
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the \
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape \
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the \
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the \
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of \
[B, N, H, W].
"""
ret_dict = dict()
for head in self.heads:
ret_dict[head] = self.__getattr__(head)(x)
return ret_dict
@HEADS.register_module()
class DCNSeparateHead(BaseModule):
r"""DCNSeparateHead for CenterHead.
.. code-block:: none
/-----> DCN for heatmap task -----> heatmap task.
feature
\-----> DCN for regression tasks -----> regression tasks
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
dcn_config (dict): Config of dcn layer.
num_cls (int): Output channels.
Default: 64.
final_kernal (int): Kernal size for the last conv layer.
Deafult: 1.
init_bias (float): Initial bias. Default: -2.19.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
""" # noqa: W605
def __init__(
self,
in_channels,
num_cls,
heads,
dcn_config,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type="Conv2d"),
norm_cfg=dict(type="BN2d"),
bias="auto",
init_cfg=None,
**kwargs,
):
assert init_cfg is None, (
"To prevent abnormal initialization "
"behavior, init_cfg is not allowed to be set"
)
super(DCNSeparateHead, self).__init__(init_cfg=init_cfg)
if "heatmap" in heads:
heads.pop("heatmap")
# feature adaptation with dcn
# use separate features for classification / regression
self.feature_adapt_cls = build_conv_layer(dcn_config)
self.feature_adapt_reg = build_conv_layer(dcn_config)
# heatmap prediction head
cls_head = [
ConvModule(
in_channels,
head_conv,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
bias=bias,
norm_cfg=norm_cfg,
),
build_conv_layer(
conv_cfg,
head_conv,
num_cls,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
]
self.cls_head = nn.Sequential(*cls_head)
self.init_bias = init_bias
# other regression target
self.task_head = SeparateHead(
in_channels,
heads,
head_conv=head_conv,
final_kernel=final_kernel,
bias=bias,
)
if init_cfg is None:
self.init_cfg = dict(type="Kaiming", layer="Conv2d")
def init_weights(self):
"""Initialize weights."""
super().init_weights()
self.cls_head[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for DCNSepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the \
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the \
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape \
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the \
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the \
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of \
[B, N, H, W].
"""
center_feat = self.feature_adapt_cls(x)
reg_feat = self.feature_adapt_reg(x)
cls_score = self.cls_head(center_feat)
ret = self.task_head(reg_feat)
ret["heatmap"] = cls_score
return ret
@HEADS.register_module()
class CenterHead(BaseModule):
"""CenterHead for CenterPoint.
Args:
mode (str): Mode of the head. Default: '3d'.
in_channels (list[int] | int): Channels of the input feature map.
Default: [128].
tasks (list[dict]): Task information including class number
and class names. Default: None.
dataset (str): Name of the dataset. Default: 'nuscenes'.
weight (float): Weight for location loss. Default: 0.25.
code_weights (list[int]): Code weights for location loss. Default: [].
common_heads (dict): Conv information for common heads.
Default: dict().
loss_cls (dict): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int): Output channels for share_conv_layer.
Default: 64.
num_heatmap_convs (int): Number of conv layers for heatmap conv layer.
Default: 2.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
"""
def __init__(
self,
in_channels=[128],
tasks=None,
train_cfg=None,
test_cfg=None,
bbox_coder=None,
common_heads=dict(),
loss_cls=dict(type="GaussianFocalLoss", reduction="mean"),
loss_bbox=dict(type="L1Loss", reduction="none", loss_weight=0.25),
separate_head=dict(type="SeparateHead", init_bias=-2.19, final_kernel=3),
share_conv_channel=64,
num_heatmap_convs=2,
conv_cfg=dict(type="Conv2d"),
norm_cfg=dict(type="BN2d"),
bias="auto",
norm_bbox=True,
init_cfg=None,
):
assert init_cfg is None, (
"To prevent abnormal initialization "
"behavior, init_cfg is not allowed to be set"
)
super(CenterHead, self).__init__(init_cfg=init_cfg)
num_classes = [len(t) for t in tasks]
self.class_names = [t for t in tasks]
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.in_channels = in_channels
self.num_classes = num_classes
self.norm_bbox = norm_bbox
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution
self.shared_conv = ConvModule(
in_channels,
share_conv_channel,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias,
)
self.task_heads = nn.ModuleList()
for num_cls in num_classes:
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(num_cls, num_heatmap_convs)))
separate_head.update(
in_channels=share_conv_channel, heads=heads, num_cls=num_cls
)
self.task_heads.append(builder.build_head(separate_head))
def forward_single(self, x):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x)
for task in self.task_heads:
ret_dicts.append(task(x))
return ret_dicts
def forward(self, feats, metas):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
tuple(list[dict]): Output results for tasks.
"""
if isinstance(feats, torch.Tensor):
feats = [feats]
return multi_apply(self.forward_single, feats)
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor): Mask of the feature map with the shape
of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes.
gt_labels_3d (list[torch.Tensor]): Labels of boxes.
Returns:
Returns:
tuple[list[torch.Tensor]]: Tuple of target including \
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the \
position of the valid boxes.
- list[torch.Tensor]: Masks indicating which \
boxes are valid.
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d
)
# Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds]
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks]
return heatmaps, anno_boxes, inds, masks
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
gt_labels_3d (torch.Tensor): Labels of boxes.
Returns:
tuple[list[torch.Tensor]]: Tuple of target including \
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the position \
of the valid boxes.
- list[torch.Tensor]: Masks indicating which boxes \
are valid.
"""
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]), dim=1
).to(device)
max_objs = self.train_cfg["max_objs"] * self.train_cfg["dense_reg"]
grid_size = torch.tensor(self.train_cfg["grid_size"])
pc_range = torch.tensor(self.train_cfg["point_cloud_range"])
voxel_size = torch.tensor(self.train_cfg["voxel_size"])
feature_map_size = torch.div(
grid_size[:2],
self.train_cfg["out_size_factor"],
rounding_mode="trunc",
)
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
task_masks.append(
[
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
]
)
flag += len(class_name)
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1], feature_map_size[0])
)
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10), dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
num_objs = min(task_boxes[idx].shape[0], max_objs)
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1
width = task_boxes[idx][k][3]
length = task_boxes[idx][k][4]
width = width / voxel_size[0] / self.train_cfg["out_size_factor"]
length = length / voxel_size[1] / self.train_cfg["out_size_factor"]
if width > 0 and length > 0:
radius = gaussian_radius(
(length, width), min_overlap=self.train_cfg["gaussian_overlap"]
)
radius = max(self.train_cfg["min_radius"], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = (
task_boxes[idx][k][0],
task_boxes[idx][k][1],
task_boxes[idx][k][2],
)
coor_x = (
(x - pc_range[0])
/ voxel_size[0]
/ self.train_cfg["out_size_factor"]
)
coor_y = (
(y - pc_range[1])
/ voxel_size[1]
/ self.train_cfg["out_size_factor"]
)
center = torch.tensor(
[coor_x, coor_y], dtype=torch.float32, device=device
)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (
0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]
):
continue
draw_gaussian(heatmap[cls_id], center_int[[1, 0]], radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (
x * feature_map_size[1] + y
< feature_map_size[0] * feature_map_size[1]
)
ind[new_idx] = x * feature_map_size[1] + y
mask[new_idx] = 1
# TODO: support other outdoor dataset
vx, vy = task_boxes[idx][k][7:]
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat(
[
center - torch.tensor([x, y], device=device),
z.unsqueeze(0),
box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0),
]
)
heatmaps.append(heatmap)
anno_boxes.append(anno_box)
masks.append(mask)
inds.append(ind)
return heatmaps, anno_boxes, inds, masks
@force_fp32(apply_to=("preds_dicts"))
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes.
gt_labels_3d (list[torch.Tensor]): Labels of boxes.
preds_dicts (dict): Output of forward function.
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps, anno_boxes, inds, masks = self.get_targets(gt_bboxes_3d, gt_labels_3d)
loss_dict = dict()
for task_id, preds_dict in enumerate(preds_dicts):
# heatmap focal loss
preds_dict[0]["heatmap"] = clip_sigmoid(preds_dict[0]["heatmap"])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
loss_heatmap = self.loss_cls(
preds_dict[0]["heatmap"], heatmaps[task_id], avg_factor=max(num_pos, 1)
)
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]["anno_box"] = torch.cat(
(
preds_dict[0]["reg"],
preds_dict[0]["height"],
preds_dict[0]["dim"],
preds_dict[0]["rot"],
preds_dict[0]["vel"],
),
dim=1,
)
# Regression loss for dimension, offset, height, rotation
ind = inds[task_id]
num = masks[task_id].float().sum()
pred = preds_dict[0]["anno_box"].permute(0, 2, 3, 1).contiguous()
pred = pred.view(pred.size(0), -1, pred.size(3))
pred = self._gather_feat(pred, ind)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan
code_weights = self.train_cfg.get("code_weights", None)
bbox_weights = mask * mask.new_tensor(code_weights)
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=(num + 1e-4)
)
loss_dict[f"heatmap/task{task_id}"] = loss_heatmap
loss_dict[f"bbox/task{task_id}"] = loss_bbox
return loss_dict
@force_fp32(apply_to=("preds_dicts"))
def get_bboxes(self, preds_dicts, metas, img=None, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
if not isinstance(self.test_cfg["nms_type"], list):
nms_types = [self.test_cfg["nms_type"] for _ in range(len(preds_dicts))]
else:
nms_types = self.test_cfg["nms_type"]
if "nms_scale" in self.test_cfg:
if not isinstance(self.test_cfg["nms_scale"], list):
nms_scales = [
[
self.test_cfg["nms_scale"]
for _ in range(self.num_classes[task_id])
]
for task_id in range(len(preds_dicts))
]
else:
nms_scales = self.test_cfg["nms_scale"]
else:
nms_scales = [
[1.0 for _ in range(self.num_classes[task_id])]
for task_id in range(len(preds_dicts))
]
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
num_class_with_bg = self.num_classes[task_id]
batch_size = preds_dict[0]["heatmap"].shape[0]
batch_heatmap = preds_dict[0]["heatmap"].sigmoid()
batch_reg = preds_dict[0]["reg"]
batch_hei = preds_dict[0]["height"]
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]["dim"])
else:
batch_dim = preds_dict[0]["dim"]
batch_rots = preds_dict[0]["rot"][:, 0].unsqueeze(1)
batch_rotc = preds_dict[0]["rot"][:, 1].unsqueeze(1)
if "vel" in preds_dict[0]:
batch_vel = preds_dict[0]["vel"]
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id,
)
batch_reg_preds = [box["bboxes"] for box in temp]
batch_cls_preds = [box["scores"] for box in temp]
batch_cls_labels = [box["labels"] for box in temp]
if nms_types[task_id] == "circle":
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]["bboxes"]
scores = temp[i]["scores"]
labels = temp[i]["labels"]
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg["min_radius"][task_id],
post_max_size=self.test_cfg["post_max_size"],
),
dtype=torch.long,
device=boxes.device,
)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(
num_class_with_bg,
batch_cls_preds,
batch_reg_preds,
batch_cls_labels,
metas,
nms_scales[task_id],
)
)
# Merge branches results
num_samples = len(rets[0])
ret_list = []
for i in range(num_samples):
for k in rets[0][i].keys():
if k == "bboxes":
bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = metas[i]["box_type_3d"](bboxes, self.bbox_coder.code_size)
elif k == "scores":
scores = torch.cat([ret[i][k] for ret in rets])
elif k == "labels":
flag = 0
for j, num_class in enumerate(self.num_classes):
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels])
return ret_list
def get_task_detections(
self,
num_class_with_bg,
batch_cls_preds,
batch_reg_preds,
batch_cls_labels,
metas,
nms_scale=1.0,
):
"""Rotate nms for each task.
Args:
num_class_with_bg (int): Number of classes for the current task.
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N].
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9].
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N].
metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the \
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the \
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the \
shape of [N].
"""
predictions_dicts = []
post_center_range = self.test_cfg["post_center_limit_range"]
if len(post_center_range) > 0:
post_center_range = torch.tensor(
post_center_range,
dtype=batch_reg_preds[0].dtype,
device=batch_reg_preds[0].device,
)
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)
):
# Apply NMS in birdeye view
# get highest score per prediction, than apply nms
# to remove overlapped box.
if num_class_with_bg == 1:
top_scores = cls_preds.squeeze(-1)
top_labels = torch.zeros(
cls_preds.shape[0], device=cls_preds.device, dtype=torch.long
)
else:
top_labels = cls_labels.long()
top_scores = cls_preds.squeeze(-1)
if self.test_cfg["score_threshold"] > 0.0:
thresh = torch.tensor(
[self.test_cfg["score_threshold"]], device=cls_preds.device
).type_as(cls_preds)
top_scores_keep = top_scores >= thresh
top_scores = top_scores.masked_select(top_scores_keep)
if top_scores.shape[0] != 0:
if self.test_cfg["score_threshold"] > 0.0:
box_preds = box_preds[top_scores_keep]
top_labels = top_labels[top_scores_keep]
bev_box = metas[i]["box_type_3d"](
box_preds[:, :], self.bbox_coder.code_size
).bev
for cls, scale in enumerate(nms_scale):
cur_bev_box = bev_box[top_labels == cls]
cur_bev_box[:, [2, 3]] *= scale
bev_box[top_labels == cls] = cur_bev_box
boxes_for_nms = xywhr2xyxyr(bev_box)
# the nms in 3d detection just remove overlap boxes.
selected = nms_gpu(
boxes_for_nms,
top_scores,
thresh=self.test_cfg["nms_thr"],
pre_maxsize=self.test_cfg["pre_max_size"],
post_max_size=self.test_cfg["post_max_size"],
)
else:
selected = []
# if selected is not None:
selected_boxes = box_preds[selected]
selected_labels = top_labels[selected]
selected_scores = top_scores[selected]
# finally generate predictions.
if selected_boxes.shape[0] != 0:
box_preds = selected_boxes
scores = selected_scores
label_preds = selected_labels
final_box_preds = box_preds
final_scores = scores
final_labels = label_preds
if post_center_range is not None:
mask = (final_box_preds[:, :3] >= post_center_range[:3]).all(1)
mask &= (final_box_preds[:, :3] <= post_center_range[3:]).all(1)
predictions_dict = dict(
bboxes=final_box_preds[mask],
scores=final_scores[mask],
labels=final_labels[mask],
)
else:
predictions_dict = dict(
bboxes=final_box_preds, scores=final_scores, labels=final_labels
)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros(
[0, self.bbox_coder.code_size], dtype=dtype, device=device
),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0], dtype=top_labels.dtype, device=device),
)
predictions_dicts.append(predictions_dict)
return predictions_dicts