from mmdet.core.bbox.builder import BBOX_ASSIGNERS from mmdet.core.bbox.assigners import AssignResult, BaseAssigner from mmdet.core.bbox.match_costs import build_match_cost from mmdet.core.bbox.match_costs.builder import MATCH_COST from mmdet.core.bbox.iou_calculators import build_iou_calculator import torch try: from scipy.optimize import linear_sum_assignment except ImportError: linear_sum_assignment = None @MATCH_COST.register_module() class BBoxBEVL1Cost(object): def __init__(self, weight): self.weight = weight def __call__(self, bboxes, gt_bboxes, train_cfg): pc_start = bboxes.new(train_cfg['point_cloud_range'][0:2]) pc_range = bboxes.new(train_cfg['point_cloud_range'][3:5]) - bboxes.new(train_cfg['point_cloud_range'][0:2]) # normalize the box center to [0, 1] normalized_bboxes_xy = (bboxes[:, :2] - pc_start) / pc_range normalized_gt_bboxes_xy = (gt_bboxes[:, :2] - pc_start) / pc_range reg_cost = torch.cdist(normalized_bboxes_xy, normalized_gt_bboxes_xy, p=1) return reg_cost * self.weight @MATCH_COST.register_module() class IoU3DCost(object): def __init__(self, weight): self.weight = weight def __call__(self, iou): iou_cost = - iou return iou_cost * self.weight @BBOX_ASSIGNERS.register_module() class HeuristicAssigner3D(BaseAssigner): def __init__(self, dist_thre=100, iou_calculator=dict(type='BboxOverlaps3D') ): self.dist_thre = dist_thre # distance in meter self.iou_calculator = build_iou_calculator(iou_calculator) def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None, query_labels=None): dist_thre = self.dist_thre num_gts, num_bboxes = len(gt_bboxes), len(bboxes) bev_dist = torch.norm(bboxes[:, 0:2][None, :, :] - gt_bboxes[:, 0:2][:, None, :], dim=-1) # [num_gts, num_bboxes] if query_labels is not None: # only match the gt box and query with same category not_same_class = (query_labels[None] != gt_labels[:, None]) bev_dist += not_same_class * dist_thre # for each gt box, assign it to the nearest pred box nearest_values, nearest_indices = bev_dist.min(1) # [num_gts] assigned_gt_inds = torch.ones([num_bboxes, ]).to(bboxes) * 0 assigned_gt_vals = torch.ones([num_bboxes, ]).to(bboxes) * 10000 assigned_gt_labels = torch.ones([num_bboxes, ]).to(bboxes) * -1 for idx_gts in range(num_gts): # for idx_pred in torch.where(bev_dist[idx_gts] < dist_thre)[0]: # each gt match to all the pred box within some radius idx_pred = nearest_indices[idx_gts] # each gt only match to the nearest pred box if bev_dist[idx_gts, idx_pred] <= dist_thre: if bev_dist[idx_gts, idx_pred] < assigned_gt_vals[idx_pred]: # if this pred box is assigned, then compare assigned_gt_vals[idx_pred] = bev_dist[idx_gts, idx_pred] assigned_gt_inds[idx_pred] = idx_gts + 1 # for AssignResult, 0 is negative, -1 is ignore, 1-based indices are positive assigned_gt_labels[idx_pred] = gt_labels[idx_gts] max_overlaps = torch.zeros([num_bboxes, ]).to(bboxes) matched_indices = torch.where(assigned_gt_inds > 0) matched_iou = self.iou_calculator(gt_bboxes[assigned_gt_inds[matched_indices].long() - 1], bboxes[matched_indices]).diag() max_overlaps[matched_indices] = matched_iou return AssignResult( num_gts, assigned_gt_inds.long(), max_overlaps, labels=assigned_gt_labels ) @BBOX_ASSIGNERS.register_module() class HungarianAssigner3D(BaseAssigner): def __init__(self, cls_cost=dict(type='ClassificationCost', weight=1.), reg_cost=dict(type='BBoxBEVL1Cost', weight=1.0), iou_cost=dict(type='IoU3DCost', weight=1.0), iou_calculator=dict(type='BboxOverlaps3D') ): self.cls_cost = build_match_cost(cls_cost) self.reg_cost = build_match_cost(reg_cost) self.iou_cost = build_match_cost(iou_cost) self.iou_calculator = build_iou_calculator(iou_calculator) def assign(self, bboxes, gt_bboxes, gt_labels, cls_pred, train_cfg): num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0) # 1. assign -1 by default assigned_gt_inds = bboxes.new_full((num_bboxes,), -1, dtype=torch.long) assigned_labels = bboxes.new_full((num_bboxes,), -1, dtype=torch.long) if num_gts == 0 or num_bboxes == 0: # No ground truth or boxes, return empty assignment if num_gts == 0: # No ground truth, assign all to background assigned_gt_inds[:] = 0 return AssignResult( num_gts, assigned_gt_inds, None, labels=assigned_labels) # 2. compute the weighted costs # see mmdetection/mmdet/core/bbox/match_costs/match_cost.py cls_cost = self.cls_cost(cls_pred[0].T, gt_labels) reg_cost = self.reg_cost(bboxes, gt_bboxes) iou = self.iou_calculator(bboxes, gt_bboxes) iou_cost = self.iou_cost(iou) # weighted sum of above three costs cost = cls_cost + reg_cost + iou_cost # 3. do Hungarian matching on CPU using linear_sum_assignment cost = cost.detach().cpu() if linear_sum_assignment is None: raise ImportError('Please run "pip install scipy" ' 'to install scipy first.') # 最终安全检查 if torch.isnan(cost).any() or torch.isinf(cost).any(): print(f"Critical: Cost matrix still contains invalid values after processing!") print(f"Final cost stats - min: {cost.min():.4f}, max: {cost.max():.4f}") # 如果仍然有问题,使用简化的贪婪分配 print("Fallback: Using greedy assignment instead of Hungarian matching") matched_row_inds = torch.arange(min(num_gts, num_bboxes), device=bboxes.device) matched_col_inds = torch.arange(min(num_gts, num_bboxes), device=bboxes.device) else: matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds = torch.from_numpy(matched_row_inds).to(bboxes.device) matched_col_inds = torch.from_numpy(matched_col_inds).to(bboxes.device) # 4. assign backgrounds and foregrounds # assign all indices to backgrounds first assigned_gt_inds[:] = 0 # assign foregrounds based on matching results assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] max_overlaps = torch.zeros_like(iou.max(1).values) max_overlaps[matched_row_inds] = iou[matched_row_inds, matched_col_inds] # max_overlaps = iou.max(1).values return AssignResult( num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)