857 lines
33 KiB
Python
857 lines
33 KiB
Python
import copy
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule, build_conv_layer
|
|
from mmcv.runner import force_fp32
|
|
from torch import nn
|
|
|
|
from mmdet3d.core import (
|
|
PseudoSampler,
|
|
circle_nms,
|
|
draw_heatmap_gaussian,
|
|
gaussian_radius,
|
|
xywhr2xyxyr,
|
|
)
|
|
from mmdet3d.models.builder import HEADS, build_loss
|
|
from mmdet3d.models.utils import FFN, PositionEmbeddingLearned, TransformerDecoderLayer
|
|
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu
|
|
from mmdet.core import (
|
|
AssignResult,
|
|
build_assigner,
|
|
build_bbox_coder,
|
|
build_sampler,
|
|
multi_apply,
|
|
)
|
|
|
|
__all__ = ["TransFusionHead"]
|
|
|
|
|
|
def clip_sigmoid(x, eps=1e-4):
|
|
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
|
|
return y
|
|
|
|
|
|
@HEADS.register_module()
|
|
class TransFusionHead(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_proposals=128,
|
|
auxiliary=True,
|
|
in_channels=128 * 3,
|
|
hidden_channel=128,
|
|
num_classes=4,
|
|
# config for Transformer
|
|
num_decoder_layers=3,
|
|
num_heads=8,
|
|
nms_kernel_size=1,
|
|
ffn_channel=256,
|
|
dropout=0.1,
|
|
bn_momentum=0.1,
|
|
activation="relu",
|
|
# config for FFN
|
|
common_heads=dict(),
|
|
num_heatmap_convs=2,
|
|
conv_cfg=dict(type="Conv1d"),
|
|
norm_cfg=dict(type="BN1d"),
|
|
bias="auto",
|
|
# loss
|
|
loss_cls=dict(type="GaussianFocalLoss", reduction="mean"),
|
|
loss_iou=dict(
|
|
type="VarifocalLoss", use_sigmoid=True, iou_weighted=True, reduction="mean"
|
|
),
|
|
loss_bbox=dict(type="L1Loss", reduction="mean"),
|
|
loss_heatmap=dict(type="GaussianFocalLoss", reduction="mean"),
|
|
# others
|
|
train_cfg=None,
|
|
test_cfg=None,
|
|
bbox_coder=None,
|
|
):
|
|
super(TransFusionHead, self).__init__()
|
|
|
|
self.fp16_enabled = False
|
|
|
|
self.num_classes = num_classes
|
|
self.num_proposals = num_proposals
|
|
self.auxiliary = auxiliary
|
|
self.in_channels = in_channels
|
|
self.num_heads = num_heads
|
|
self.num_decoder_layers = num_decoder_layers
|
|
self.bn_momentum = bn_momentum
|
|
self.nms_kernel_size = nms_kernel_size
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
|
|
self.use_sigmoid_cls = loss_cls.get("use_sigmoid", False)
|
|
if not self.use_sigmoid_cls:
|
|
self.num_classes += 1
|
|
self.loss_cls = build_loss(loss_cls)
|
|
self.loss_bbox = build_loss(loss_bbox)
|
|
self.loss_iou = build_loss(loss_iou)
|
|
self.loss_heatmap = build_loss(loss_heatmap)
|
|
|
|
self.bbox_coder = build_bbox_coder(bbox_coder)
|
|
self.sampling = False
|
|
|
|
# a shared convolution
|
|
self.shared_conv = build_conv_layer(
|
|
dict(type="Conv2d"),
|
|
in_channels,
|
|
hidden_channel,
|
|
kernel_size=3,
|
|
padding=1,
|
|
bias=bias,
|
|
)
|
|
|
|
layers = []
|
|
layers.append(
|
|
ConvModule(
|
|
hidden_channel,
|
|
hidden_channel,
|
|
kernel_size=3,
|
|
padding=1,
|
|
bias=bias,
|
|
conv_cfg=dict(type="Conv2d"),
|
|
norm_cfg=dict(type="BN2d"),
|
|
)
|
|
)
|
|
layers.append(
|
|
build_conv_layer(
|
|
dict(type="Conv2d"),
|
|
hidden_channel,
|
|
num_classes,
|
|
kernel_size=3,
|
|
padding=1,
|
|
bias=bias,
|
|
)
|
|
)
|
|
self.heatmap_head = nn.Sequential(*layers)
|
|
self.class_encoding = nn.Conv1d(num_classes, hidden_channel, 1)
|
|
|
|
# transformer decoder layers for object query with LiDAR feature
|
|
self.decoder = nn.ModuleList()
|
|
for i in range(self.num_decoder_layers):
|
|
self.decoder.append(
|
|
TransformerDecoderLayer(
|
|
hidden_channel,
|
|
num_heads,
|
|
ffn_channel,
|
|
dropout,
|
|
activation,
|
|
self_posembed=PositionEmbeddingLearned(2, hidden_channel),
|
|
cross_posembed=PositionEmbeddingLearned(2, hidden_channel),
|
|
)
|
|
)
|
|
|
|
# Prediction Head
|
|
self.prediction_heads = nn.ModuleList()
|
|
for i in range(self.num_decoder_layers):
|
|
heads = copy.deepcopy(common_heads)
|
|
heads.update(dict(heatmap=(self.num_classes, num_heatmap_convs)))
|
|
self.prediction_heads.append(
|
|
FFN(
|
|
hidden_channel,
|
|
heads,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
bias=bias,
|
|
)
|
|
)
|
|
|
|
self.init_weights()
|
|
self._init_assigner_sampler()
|
|
|
|
# Position Embedding for Cross-Attention, which is re-used during training
|
|
x_size = self.test_cfg["grid_size"][0] // self.test_cfg["out_size_factor"]
|
|
y_size = self.test_cfg["grid_size"][1] // self.test_cfg["out_size_factor"]
|
|
self.bev_pos = self.create_2D_grid(x_size, y_size)
|
|
|
|
self.img_feat_pos = None
|
|
self.img_feat_collapsed_pos = None
|
|
|
|
def create_2D_grid(self, x_size, y_size):
|
|
meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]]
|
|
# NOTE: modified
|
|
batch_x, batch_y = torch.meshgrid(
|
|
*[torch.linspace(it[0], it[1], it[2]) for it in meshgrid]
|
|
)
|
|
batch_x = batch_x + 0.5
|
|
batch_y = batch_y + 0.5
|
|
coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)[None]
|
|
coord_base = coord_base.view(1, 2, -1).permute(0, 2, 1)
|
|
return coord_base
|
|
|
|
def init_weights(self):
|
|
# initialize transformer
|
|
for m in self.decoder.parameters():
|
|
if m.dim() > 1:
|
|
nn.init.xavier_uniform_(m)
|
|
if hasattr(self, "query"):
|
|
nn.init.xavier_normal_(self.query)
|
|
self.init_bn_momentum()
|
|
|
|
def init_bn_momentum(self):
|
|
for m in self.modules():
|
|
if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
|
m.momentum = self.bn_momentum
|
|
|
|
def _init_assigner_sampler(self):
|
|
"""Initialize the target assigner and sampler of the head."""
|
|
if self.train_cfg is None:
|
|
return
|
|
|
|
if self.sampling:
|
|
self.bbox_sampler = build_sampler(self.train_cfg.sampler)
|
|
else:
|
|
self.bbox_sampler = PseudoSampler()
|
|
if isinstance(self.train_cfg.assigner, dict):
|
|
self.bbox_assigner = build_assigner(self.train_cfg.assigner)
|
|
elif isinstance(self.train_cfg.assigner, list):
|
|
self.bbox_assigner = [
|
|
build_assigner(res) for res in self.train_cfg.assigner
|
|
]
|
|
|
|
def forward_single(self, inputs, img_inputs, metas):
|
|
"""Forward function for CenterPoint.
|
|
Args:
|
|
inputs (torch.Tensor): Input feature map with the shape of
|
|
[B, 512, 128(H), 128(W)]. (consistent with L748)
|
|
Returns:
|
|
list[dict]: Output results for tasks.
|
|
"""
|
|
batch_size = inputs.shape[0]
|
|
lidar_feat = self.shared_conv(inputs)
|
|
|
|
#################################
|
|
# image to BEV
|
|
#################################
|
|
lidar_feat_flatten = lidar_feat.view(
|
|
batch_size, lidar_feat.shape[1], -1
|
|
) # [BS, C, H*W]
|
|
h, w = inputs.shape[2], inputs.shape[3]
|
|
expected_num = h * w
|
|
if self.bev_pos.shape[1] != expected_num:
|
|
self.bev_pos = self.create_2D_grid(w, h).to(inputs.device)
|
|
elif self.bev_pos.device != inputs.device:
|
|
self.bev_pos = self.bev_pos.to(inputs.device)
|
|
bev_pos = self.bev_pos.repeat(batch_size, 1, 1)
|
|
|
|
#################################
|
|
# image guided query initialization
|
|
#################################
|
|
dense_heatmap = self.heatmap_head(lidar_feat)
|
|
dense_heatmap_img = None
|
|
heatmap = dense_heatmap.detach().sigmoid()
|
|
padding = self.nms_kernel_size // 2
|
|
local_max = torch.zeros_like(heatmap)
|
|
# equals to nms radius = voxel_size * out_size_factor * kenel_size
|
|
local_max_inner = F.max_pool2d(
|
|
heatmap, kernel_size=self.nms_kernel_size, stride=1, padding=0
|
|
)
|
|
local_max[:, :, padding:(-padding), padding:(-padding)] = local_max_inner
|
|
## for Pedestrian & Traffic_cone in nuScenes
|
|
if self.test_cfg["dataset"] == "nuScenes":
|
|
local_max[
|
|
:,
|
|
8,
|
|
] = F.max_pool2d(heatmap[:, 8], kernel_size=1, stride=1, padding=0)
|
|
local_max[
|
|
:,
|
|
9,
|
|
] = F.max_pool2d(heatmap[:, 9], kernel_size=1, stride=1, padding=0)
|
|
elif self.test_cfg["dataset"] == "Waymo": # for Pedestrian & Cyclist in Waymo
|
|
local_max[
|
|
:,
|
|
1,
|
|
] = F.max_pool2d(heatmap[:, 1], kernel_size=1, stride=1, padding=0)
|
|
local_max[
|
|
:,
|
|
2,
|
|
] = F.max_pool2d(heatmap[:, 2], kernel_size=1, stride=1, padding=0)
|
|
heatmap = heatmap * (heatmap == local_max)
|
|
heatmap = heatmap.view(batch_size, heatmap.shape[1], -1)
|
|
|
|
# top #num_proposals among all classes
|
|
top_proposals = heatmap.view(batch_size, -1).argsort(dim=-1, descending=True)[
|
|
..., : self.num_proposals
|
|
]
|
|
top_proposals_class = top_proposals // heatmap.shape[-1]
|
|
top_proposals_index = top_proposals % heatmap.shape[-1]
|
|
query_feat = lidar_feat_flatten.gather(
|
|
index=top_proposals_index[:, None, :].expand(
|
|
-1, lidar_feat_flatten.shape[1], -1
|
|
),
|
|
dim=-1,
|
|
)
|
|
self.query_labels = top_proposals_class
|
|
|
|
# add category embedding
|
|
one_hot = F.one_hot(top_proposals_class, num_classes=self.num_classes).permute(
|
|
0, 2, 1
|
|
)
|
|
query_cat_encoding = self.class_encoding(one_hot.float())
|
|
query_feat += query_cat_encoding
|
|
|
|
query_pos = bev_pos.gather(
|
|
index=top_proposals_index[:, None, :]
|
|
.permute(0, 2, 1)
|
|
.expand(-1, -1, bev_pos.shape[-1]),
|
|
dim=1,
|
|
)
|
|
|
|
#################################
|
|
# transformer decoder layer (LiDAR feature as K,V)
|
|
#################################
|
|
ret_dicts = []
|
|
for i in range(self.num_decoder_layers):
|
|
prefix = "last_" if (i == self.num_decoder_layers - 1) else f"{i}head_"
|
|
|
|
# Transformer Decoder Layer
|
|
# :param query: B C Pq :param query_pos: B Pq 3/6
|
|
query_feat = self.decoder[i](
|
|
query_feat, lidar_feat_flatten, query_pos, bev_pos
|
|
)
|
|
|
|
# Prediction
|
|
res_layer = self.prediction_heads[i](query_feat)
|
|
res_layer["center"] = res_layer["center"] + query_pos.permute(0, 2, 1)
|
|
first_res_layer = res_layer
|
|
ret_dicts.append(res_layer)
|
|
|
|
# for next level positional embedding
|
|
query_pos = res_layer["center"].detach().clone().permute(0, 2, 1)
|
|
|
|
#################################
|
|
# transformer decoder layer (img feature as K,V)
|
|
#################################
|
|
ret_dicts[0]["query_heatmap_score"] = heatmap.gather(
|
|
index=top_proposals_index[:, None, :].expand(-1, self.num_classes, -1),
|
|
dim=-1,
|
|
) # [bs, num_classes, num_proposals]
|
|
ret_dicts[0]["dense_heatmap"] = dense_heatmap
|
|
|
|
if self.auxiliary is False:
|
|
# only return the results of last decoder layer
|
|
return [ret_dicts[-1]]
|
|
|
|
# return all the layer's results for auxiliary superivison
|
|
new_res = {}
|
|
for key in ret_dicts[0].keys():
|
|
if key not in ["dense_heatmap", "dense_heatmap_old", "query_heatmap_score"]:
|
|
new_res[key] = torch.cat(
|
|
[ret_dict[key] for ret_dict in ret_dicts], dim=-1
|
|
)
|
|
else:
|
|
new_res[key] = ret_dicts[0][key]
|
|
return [new_res]
|
|
|
|
def forward(self, feats, metas):
|
|
"""Forward pass.
|
|
Args:
|
|
feats (list[torch.Tensor]): Multi-level features, e.g.,
|
|
features produced by FPN.
|
|
Returns:
|
|
tuple(list[dict]): Output results. first index by level, second index by layer
|
|
"""
|
|
if isinstance(feats, torch.Tensor):
|
|
feats = [feats]
|
|
res = multi_apply(self.forward_single, feats, [None], [metas])
|
|
assert len(res) == 1, "only support one level features."
|
|
return res
|
|
|
|
def get_targets(self, gt_bboxes_3d, gt_labels_3d, preds_dict):
|
|
"""Generate training targets.
|
|
Args:
|
|
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
|
|
gt_labels_3d (torch.Tensor): Labels of boxes.
|
|
preds_dicts (tuple of dict): first index by layer (default 1)
|
|
Returns:
|
|
tuple[torch.Tensor]: Tuple of target including \
|
|
the following results in order.
|
|
- torch.Tensor: classification target. [BS, num_proposals]
|
|
- torch.Tensor: classification weights (mask) [BS, num_proposals]
|
|
- torch.Tensor: regression target. [BS, num_proposals, 8]
|
|
- torch.Tensor: regression weights. [BS, num_proposals, 8]
|
|
"""
|
|
# change preds_dict into list of dict (index by batch_id)
|
|
# preds_dict[0]['center'].shape [bs, 3, num_proposal]
|
|
list_of_pred_dict = []
|
|
for batch_idx in range(len(gt_bboxes_3d)):
|
|
pred_dict = {}
|
|
for key in preds_dict[0].keys():
|
|
pred_dict[key] = preds_dict[0][key][batch_idx : batch_idx + 1]
|
|
list_of_pred_dict.append(pred_dict)
|
|
|
|
assert len(gt_bboxes_3d) == len(list_of_pred_dict)
|
|
|
|
res_tuple = multi_apply(
|
|
self.get_targets_single,
|
|
gt_bboxes_3d,
|
|
gt_labels_3d,
|
|
list_of_pred_dict,
|
|
np.arange(len(gt_labels_3d)),
|
|
)
|
|
labels = torch.cat(res_tuple[0], dim=0)
|
|
label_weights = torch.cat(res_tuple[1], dim=0)
|
|
bbox_targets = torch.cat(res_tuple[2], dim=0)
|
|
bbox_weights = torch.cat(res_tuple[3], dim=0)
|
|
ious = torch.cat(res_tuple[4], dim=0)
|
|
num_pos = np.sum(res_tuple[5])
|
|
matched_ious = np.mean(res_tuple[6])
|
|
heatmap = torch.cat(res_tuple[7], dim=0)
|
|
return (
|
|
labels,
|
|
label_weights,
|
|
bbox_targets,
|
|
bbox_weights,
|
|
ious,
|
|
num_pos,
|
|
matched_ious,
|
|
heatmap,
|
|
)
|
|
|
|
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d, preds_dict, batch_idx):
|
|
"""Generate training targets for a single sample.
|
|
Args:
|
|
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
|
|
gt_labels_3d (torch.Tensor): Labels of boxes.
|
|
preds_dict (dict): dict of prediction result for a single sample
|
|
Returns:
|
|
tuple[torch.Tensor]: Tuple of target including \
|
|
the following results in order.
|
|
- torch.Tensor: classification target. [1, num_proposals]
|
|
- torch.Tensor: classification weights (mask) [1, num_proposals]
|
|
- torch.Tensor: regression target. [1, num_proposals, 8]
|
|
- torch.Tensor: regression weights. [1, num_proposals, 8]
|
|
- torch.Tensor: iou target. [1, num_proposals]
|
|
- int: number of positive proposals
|
|
"""
|
|
num_proposals = preds_dict["center"].shape[-1]
|
|
|
|
# get pred boxes, carefully ! donot change the network outputs
|
|
score = copy.deepcopy(preds_dict["heatmap"].detach())
|
|
center = copy.deepcopy(preds_dict["center"].detach())
|
|
height = copy.deepcopy(preds_dict["height"].detach())
|
|
dim = copy.deepcopy(preds_dict["dim"].detach())
|
|
rot = copy.deepcopy(preds_dict["rot"].detach())
|
|
if "vel" in preds_dict.keys():
|
|
vel = copy.deepcopy(preds_dict["vel"].detach())
|
|
else:
|
|
vel = None
|
|
|
|
boxes_dict = self.bbox_coder.decode(
|
|
score, rot, dim, center, height, vel
|
|
) # decode the prediction to real world metric bbox
|
|
bboxes_tensor = boxes_dict[0]["bboxes"]
|
|
gt_bboxes_tensor = gt_bboxes_3d.tensor.to(score.device)
|
|
# each layer should do label assign seperately.
|
|
if self.auxiliary:
|
|
num_layer = self.num_decoder_layers
|
|
else:
|
|
num_layer = 1
|
|
|
|
assign_result_list = []
|
|
for idx_layer in range(num_layer):
|
|
bboxes_tensor_layer = bboxes_tensor[
|
|
self.num_proposals * idx_layer : self.num_proposals * (idx_layer + 1), :
|
|
]
|
|
score_layer = score[
|
|
...,
|
|
self.num_proposals * idx_layer : self.num_proposals * (idx_layer + 1),
|
|
]
|
|
|
|
if self.train_cfg.assigner.type == "HungarianAssigner3D":
|
|
assign_result = self.bbox_assigner.assign(
|
|
bboxes_tensor_layer,
|
|
gt_bboxes_tensor,
|
|
gt_labels_3d,
|
|
score_layer,
|
|
self.train_cfg,
|
|
)
|
|
elif self.train_cfg.assigner.type == "HeuristicAssigner":
|
|
assign_result = self.bbox_assigner.assign(
|
|
bboxes_tensor_layer,
|
|
gt_bboxes_tensor,
|
|
None,
|
|
gt_labels_3d,
|
|
self.query_labels[batch_idx],
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
assign_result_list.append(assign_result)
|
|
|
|
# combine assign result of each layer
|
|
assign_result_ensemble = AssignResult(
|
|
num_gts=sum([res.num_gts for res in assign_result_list]),
|
|
gt_inds=torch.cat([res.gt_inds for res in assign_result_list]),
|
|
max_overlaps=torch.cat([res.max_overlaps for res in assign_result_list]),
|
|
labels=torch.cat([res.labels for res in assign_result_list]),
|
|
)
|
|
sampling_result = self.bbox_sampler.sample(
|
|
assign_result_ensemble, bboxes_tensor, gt_bboxes_tensor
|
|
)
|
|
pos_inds = sampling_result.pos_inds
|
|
neg_inds = sampling_result.neg_inds
|
|
assert len(pos_inds) + len(neg_inds) == num_proposals
|
|
|
|
# create target for loss computation
|
|
bbox_targets = torch.zeros([num_proposals, self.bbox_coder.code_size]).to(
|
|
center.device
|
|
)
|
|
bbox_weights = torch.zeros([num_proposals, self.bbox_coder.code_size]).to(
|
|
center.device
|
|
)
|
|
ious = assign_result_ensemble.max_overlaps
|
|
ious = torch.clamp(ious, min=0.0, max=1.0)
|
|
labels = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
|
|
label_weights = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
|
|
|
|
if gt_labels_3d is not None: # default label is -1
|
|
labels += self.num_classes
|
|
|
|
# both pos and neg have classification loss, only pos has regression and iou loss
|
|
if len(pos_inds) > 0:
|
|
pos_bbox_targets = self.bbox_coder.encode(sampling_result.pos_gt_bboxes)
|
|
|
|
bbox_targets[pos_inds, :] = pos_bbox_targets
|
|
bbox_weights[pos_inds, :] = 1.0
|
|
|
|
if gt_labels_3d is None:
|
|
labels[pos_inds] = 1
|
|
else:
|
|
labels[pos_inds] = gt_labels_3d[sampling_result.pos_assigned_gt_inds]
|
|
if self.train_cfg.pos_weight <= 0:
|
|
label_weights[pos_inds] = 1.0
|
|
else:
|
|
label_weights[pos_inds] = self.train_cfg.pos_weight
|
|
|
|
if len(neg_inds) > 0:
|
|
label_weights[neg_inds] = 1.0
|
|
|
|
# # compute dense heatmap targets
|
|
device = labels.device
|
|
gt_bboxes_3d = torch.cat(
|
|
[gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]], dim=1
|
|
).to(device)
|
|
grid_size = torch.tensor(self.train_cfg["grid_size"])
|
|
pc_range = torch.tensor(self.train_cfg["point_cloud_range"])
|
|
voxel_size = torch.tensor(self.train_cfg["voxel_size"])
|
|
feature_map_size = (
|
|
grid_size[:2] // self.train_cfg["out_size_factor"]
|
|
) # [x_len, y_len]
|
|
heatmap = gt_bboxes_3d.new_zeros(
|
|
self.num_classes, feature_map_size[1], feature_map_size[0]
|
|
)
|
|
for idx in range(len(gt_bboxes_3d)):
|
|
width = gt_bboxes_3d[idx][3]
|
|
length = gt_bboxes_3d[idx][4]
|
|
width = width / voxel_size[0] / self.train_cfg["out_size_factor"]
|
|
length = length / voxel_size[1] / self.train_cfg["out_size_factor"]
|
|
if width > 0 and length > 0:
|
|
radius = gaussian_radius(
|
|
(length, width), min_overlap=self.train_cfg["gaussian_overlap"]
|
|
)
|
|
radius = max(self.train_cfg["min_radius"], int(radius))
|
|
x, y = gt_bboxes_3d[idx][0], gt_bboxes_3d[idx][1]
|
|
|
|
coor_x = (
|
|
(x - pc_range[0])
|
|
/ voxel_size[0]
|
|
/ self.train_cfg["out_size_factor"]
|
|
)
|
|
coor_y = (
|
|
(y - pc_range[1])
|
|
/ voxel_size[1]
|
|
/ self.train_cfg["out_size_factor"]
|
|
)
|
|
|
|
center = torch.tensor(
|
|
[coor_x, coor_y], dtype=torch.float32, device=device
|
|
)
|
|
center_int = center.to(torch.int32)
|
|
|
|
# original
|
|
# draw_heatmap_gaussian(heatmap[gt_labels_3d[idx]], center_int, radius)
|
|
# NOTE: fix
|
|
draw_heatmap_gaussian(
|
|
heatmap[gt_labels_3d[idx]], center_int[[1, 0]], radius
|
|
)
|
|
|
|
mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1)
|
|
return (
|
|
labels[None],
|
|
label_weights[None],
|
|
bbox_targets[None],
|
|
bbox_weights[None],
|
|
ious[None],
|
|
int(pos_inds.shape[0]),
|
|
float(mean_iou),
|
|
heatmap[None],
|
|
)
|
|
|
|
@force_fp32(apply_to=("preds_dicts"))
|
|
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
|
|
"""Loss function for CenterHead.
|
|
Args:
|
|
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
|
|
truth gt boxes.
|
|
gt_labels_3d (list[torch.Tensor]): Labels of boxes.
|
|
preds_dicts (list[list[dict]]): Output of forward function.
|
|
Returns:
|
|
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
|
|
"""
|
|
(
|
|
labels,
|
|
label_weights,
|
|
bbox_targets,
|
|
bbox_weights,
|
|
ious,
|
|
num_pos,
|
|
matched_ious,
|
|
heatmap,
|
|
) = self.get_targets(gt_bboxes_3d, gt_labels_3d, preds_dicts[0])
|
|
if hasattr(self, "on_the_image_mask"):
|
|
label_weights = label_weights * self.on_the_image_mask
|
|
bbox_weights = bbox_weights * self.on_the_image_mask[:, :, None]
|
|
num_pos = bbox_weights.max(-1).values.sum()
|
|
preds_dict = preds_dicts[0][0]
|
|
loss_dict = dict()
|
|
|
|
# compute heatmap loss
|
|
loss_heatmap = self.loss_heatmap(
|
|
clip_sigmoid(preds_dict["dense_heatmap"]),
|
|
heatmap,
|
|
avg_factor=max(heatmap.eq(1).float().sum().item(), 1),
|
|
)
|
|
loss_dict["loss_heatmap"] = loss_heatmap
|
|
|
|
# compute loss for each layer
|
|
for idx_layer in range(self.num_decoder_layers if self.auxiliary else 1):
|
|
if idx_layer == self.num_decoder_layers - 1 or (
|
|
idx_layer == 0 and self.auxiliary is False
|
|
):
|
|
prefix = "layer_-1"
|
|
else:
|
|
prefix = f"layer_{idx_layer}"
|
|
|
|
layer_labels = labels[
|
|
...,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
].reshape(-1)
|
|
layer_label_weights = label_weights[
|
|
...,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
].reshape(-1)
|
|
layer_score = preds_dict["heatmap"][
|
|
...,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
]
|
|
layer_cls_score = layer_score.permute(0, 2, 1).reshape(-1, self.num_classes)
|
|
layer_loss_cls = self.loss_cls(
|
|
layer_cls_score,
|
|
layer_labels,
|
|
layer_label_weights,
|
|
avg_factor=max(num_pos, 1),
|
|
)
|
|
|
|
layer_center = preds_dict["center"][
|
|
...,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
]
|
|
layer_height = preds_dict["height"][
|
|
...,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
]
|
|
layer_rot = preds_dict["rot"][
|
|
...,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
]
|
|
layer_dim = preds_dict["dim"][
|
|
...,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
]
|
|
preds = torch.cat(
|
|
[layer_center, layer_height, layer_dim, layer_rot], dim=1
|
|
).permute(
|
|
0, 2, 1
|
|
) # [BS, num_proposals, code_size]
|
|
if "vel" in preds_dict.keys():
|
|
layer_vel = preds_dict["vel"][
|
|
...,
|
|
idx_layer
|
|
* self.num_proposals : (idx_layer + 1)
|
|
* self.num_proposals,
|
|
]
|
|
preds = torch.cat(
|
|
[layer_center, layer_height, layer_dim, layer_rot, layer_vel], dim=1
|
|
).permute(
|
|
0, 2, 1
|
|
) # [BS, num_proposals, code_size]
|
|
code_weights = self.train_cfg.get("code_weights", None)
|
|
layer_bbox_weights = bbox_weights[
|
|
:,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
:,
|
|
]
|
|
layer_reg_weights = layer_bbox_weights * layer_bbox_weights.new_tensor(
|
|
code_weights
|
|
)
|
|
layer_bbox_targets = bbox_targets[
|
|
:,
|
|
idx_layer * self.num_proposals : (idx_layer + 1) * self.num_proposals,
|
|
:,
|
|
]
|
|
layer_loss_bbox = self.loss_bbox(
|
|
preds, layer_bbox_targets, layer_reg_weights, avg_factor=max(num_pos, 1)
|
|
)
|
|
|
|
# layer_iou = preds_dict['iou'][..., idx_layer*self.num_proposals:(idx_layer+1)*self.num_proposals].squeeze(1)
|
|
# layer_iou_target = ious[..., idx_layer*self.num_proposals:(idx_layer+1)*self.num_proposals]
|
|
# layer_loss_iou = self.loss_iou(layer_iou, layer_iou_target, layer_bbox_weights.max(-1).values, avg_factor=max(num_pos, 1))
|
|
|
|
loss_dict[f"{prefix}_loss_cls"] = layer_loss_cls
|
|
loss_dict[f"{prefix}_loss_bbox"] = layer_loss_bbox
|
|
# loss_dict[f'{prefix}_loss_iou'] = layer_loss_iou
|
|
|
|
loss_dict[f"matched_ious"] = layer_loss_cls.new_tensor(matched_ious)
|
|
|
|
return loss_dict
|
|
|
|
def get_bboxes(self, preds_dicts, metas, img=None, rescale=False, for_roi=False):
|
|
"""Generate bboxes from bbox head predictions.
|
|
Args:
|
|
preds_dicts (tuple[list[dict]]): Prediction results.
|
|
Returns:
|
|
list[list[dict]]: Decoded bbox, scores and labels for each layer & each batch
|
|
"""
|
|
rets = []
|
|
for layer_id, preds_dict in enumerate(preds_dicts):
|
|
batch_size = preds_dict[0]["heatmap"].shape[0]
|
|
batch_score = preds_dict[0]["heatmap"][..., -self.num_proposals :].sigmoid()
|
|
# if self.loss_iou.loss_weight != 0:
|
|
# batch_score = torch.sqrt(batch_score * preds_dict[0]['iou'][..., -self.num_proposals:].sigmoid())
|
|
one_hot = F.one_hot(
|
|
self.query_labels, num_classes=self.num_classes
|
|
).permute(0, 2, 1)
|
|
batch_score = batch_score * preds_dict[0]["query_heatmap_score"] * one_hot
|
|
|
|
batch_center = preds_dict[0]["center"][..., -self.num_proposals :]
|
|
batch_height = preds_dict[0]["height"][..., -self.num_proposals :]
|
|
batch_dim = preds_dict[0]["dim"][..., -self.num_proposals :]
|
|
batch_rot = preds_dict[0]["rot"][..., -self.num_proposals :]
|
|
batch_vel = None
|
|
if "vel" in preds_dict[0]:
|
|
batch_vel = preds_dict[0]["vel"][..., -self.num_proposals :]
|
|
|
|
temp = self.bbox_coder.decode(
|
|
batch_score,
|
|
batch_rot,
|
|
batch_dim,
|
|
batch_center,
|
|
batch_height,
|
|
batch_vel,
|
|
filter=True,
|
|
)
|
|
|
|
if self.test_cfg["dataset"] == "nuScenes":
|
|
self.tasks = [
|
|
dict(
|
|
num_class=8,
|
|
class_names=[],
|
|
indices=[0, 1, 2, 3, 4, 5, 6, 7],
|
|
radius=-1,
|
|
),
|
|
dict(
|
|
num_class=1,
|
|
class_names=["pedestrian"],
|
|
indices=[8],
|
|
radius=0.175,
|
|
),
|
|
dict(
|
|
num_class=1,
|
|
class_names=["traffic_cone"],
|
|
indices=[9],
|
|
radius=0.175,
|
|
),
|
|
]
|
|
elif self.test_cfg["dataset"] == "Waymo":
|
|
self.tasks = [
|
|
dict(num_class=1, class_names=["Car"], indices=[0], radius=0.7),
|
|
dict(
|
|
num_class=1, class_names=["Pedestrian"], indices=[1], radius=0.7
|
|
),
|
|
dict(num_class=1, class_names=["Cyclist"], indices=[2], radius=0.7),
|
|
]
|
|
|
|
ret_layer = []
|
|
for i in range(batch_size):
|
|
boxes3d = temp[i]["bboxes"]
|
|
scores = temp[i]["scores"]
|
|
labels = temp[i]["labels"]
|
|
## adopt circle nms for different categories
|
|
if self.test_cfg["nms_type"] != None:
|
|
keep_mask = torch.zeros_like(scores)
|
|
for task in self.tasks:
|
|
task_mask = torch.zeros_like(scores)
|
|
for cls_idx in task["indices"]:
|
|
task_mask += labels == cls_idx
|
|
task_mask = task_mask.bool()
|
|
if task["radius"] > 0:
|
|
if self.test_cfg["nms_type"] == "circle":
|
|
boxes_for_nms = torch.cat(
|
|
[
|
|
boxes3d[task_mask][:, :2],
|
|
scores[:, None][task_mask],
|
|
],
|
|
dim=1,
|
|
)
|
|
task_keep_indices = torch.tensor(
|
|
circle_nms(
|
|
boxes_for_nms.detach().cpu().numpy(),
|
|
task["radius"],
|
|
)
|
|
)
|
|
else:
|
|
boxes_for_nms = xywhr2xyxyr(
|
|
metas[i]["box_type_3d"](
|
|
boxes3d[task_mask][:, :7], 7
|
|
).bev
|
|
)
|
|
top_scores = scores[task_mask]
|
|
task_keep_indices = nms_gpu(
|
|
boxes_for_nms,
|
|
top_scores,
|
|
thresh=task["radius"],
|
|
pre_maxsize=self.test_cfg["pre_maxsize"],
|
|
post_max_size=self.test_cfg["post_maxsize"],
|
|
)
|
|
else:
|
|
task_keep_indices = torch.arange(task_mask.sum())
|
|
if task_keep_indices.shape[0] != 0:
|
|
keep_indices = torch.where(task_mask != 0)[0][
|
|
task_keep_indices
|
|
]
|
|
keep_mask[keep_indices] = 1
|
|
keep_mask = keep_mask.bool()
|
|
ret = dict(
|
|
bboxes=boxes3d[keep_mask],
|
|
scores=scores[keep_mask],
|
|
labels=labels[keep_mask],
|
|
)
|
|
else: # no nms
|
|
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
|
|
ret_layer.append(ret)
|
|
rets.append(ret_layer)
|
|
assert len(rets) == 1
|
|
res = [
|
|
[
|
|
metas[0]["box_type_3d"](
|
|
rets[0][0]["bboxes"], box_dim=rets[0][0]["bboxes"].shape[-1]
|
|
),
|
|
rets[0][0]["scores"],
|
|
rets[0][0]["labels"].int(),
|
|
]
|
|
]
|
|
return res
|