import copy import os import mmcv import numpy as np from mmdet3d.core.bbox import box_np_ops from mmdet.datasets import PIPELINES from ..builder import OBJECTSAMPLERS from .utils import box_collision_test class BatchSampler: """Class for sampling specific category of ground truths. Args: sample_list (list[dict]): List of samples. name (str | None): The category of samples. Default: None. epoch (int | None): Sampling epoch. Default: None. shuffle (bool): Whether to shuffle indices. Default: False. drop_reminder (bool): Drop reminder. Default: False. """ def __init__( self, sampled_list, name=None, epoch=None, shuffle=True, drop_reminder=False ): self._sampled_list = sampled_list self._indices = np.arange(len(sampled_list)) if shuffle: np.random.shuffle(self._indices) self._idx = 0 self._example_num = len(sampled_list) self._name = name self._shuffle = shuffle self._epoch = epoch self._epoch_counter = 0 self._drop_reminder = drop_reminder def _sample(self, num): """Sample specific number of ground truths and return indices. Args: num (int): Sampled number. Returns: list[int]: Indices of sampled ground truths. """ if self._idx + num >= self._example_num: ret = self._indices[self._idx :].copy() self._reset() else: ret = self._indices[self._idx : self._idx + num] self._idx += num return ret def _reset(self): """Reset the index of batchsampler to zero.""" assert self._name is not None # print("reset", self._name) if self._shuffle: np.random.shuffle(self._indices) self._idx = 0 def sample(self, num): """Sample specific number of ground truths. Args: num (int): Sampled number. Returns: list[dict]: Sampled ground truths. """ indices = self._sample(num) return [self._sampled_list[i] for i in indices] @OBJECTSAMPLERS.register_module() class DataBaseSampler: """Class for sampling data from the ground truth database. Args: info_path (str): Path of groundtruth database info. dataset_root (str): Path of groundtruth database. rate (float): Rate of actual sampled over maximum sampled number. prepare (dict): Name of preparation functions and the input value. sample_groups (dict): Sampled classes and numbers. classes (list[str]): List of classes. Default: None. points_loader(dict): Config of points loader. Default: dict( type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3]) """ def __init__( self, info_path, dataset_root, rate, prepare, sample_groups, classes=None, points_loader=dict( type="LoadPointsFromFile", coord_type="LIDAR", load_dim=4, use_dim=[0, 1, 2, 3], ), ): super().__init__() self.dataset_root = dataset_root self.info_path = info_path self.rate = rate self.prepare = prepare self.classes = classes self.cat2label = {name: i for i, name in enumerate(classes)} self.label2cat = {i: name for i, name in enumerate(classes)} self.points_loader = mmcv.build_from_cfg(points_loader, PIPELINES) db_infos = mmcv.load(info_path) # filter database infos from mmdet3d.utils import get_root_logger logger = get_root_logger() for k, v in db_infos.items(): logger.info(f"load {len(v)} {k} database infos") for prep_func, val in prepare.items(): db_infos = getattr(self, prep_func)(db_infos, val) logger.info("After filter database:") for k, v in db_infos.items(): logger.info(f"load {len(v)} {k} database infos") self.db_infos = db_infos # load sample groups # TODO: more elegant way to load sample groups self.sample_groups = [] for name, num in sample_groups.items(): self.sample_groups.append({name: int(num)}) self.group_db_infos = self.db_infos # just use db_infos self.sample_classes = [] self.sample_max_nums = [] for group_info in self.sample_groups: self.sample_classes += list(group_info.keys()) self.sample_max_nums += list(group_info.values()) self.sampler_dict = {} for k, v in self.group_db_infos.items(): self.sampler_dict[k] = BatchSampler(v, k, shuffle=True) # TODO: No group_sampling currently @staticmethod def filter_by_difficulty(db_infos, removed_difficulty): """Filter ground truths by difficulties. Args: db_infos (dict): Info of groundtruth database. removed_difficulty (list): Difficulties that are not qualified. Returns: dict: Info of database after filtering. """ new_db_infos = {} for key, dinfos in db_infos.items(): new_db_infos[key] = [ info for info in dinfos if info["difficulty"] not in removed_difficulty ] return new_db_infos @staticmethod def filter_by_min_points(db_infos, min_gt_points_dict): """Filter ground truths by number of points in the bbox. Args: db_infos (dict): Info of groundtruth database. min_gt_points_dict (dict): Different number of minimum points needed for different categories of ground truths. Returns: dict: Info of database after filtering. """ for name, min_num in min_gt_points_dict.items(): min_num = int(min_num) if min_num > 0: filtered_infos = [] for info in db_infos[name]: if info["num_points_in_gt"] >= min_num: filtered_infos.append(info) db_infos[name] = filtered_infos return db_infos def sample_all(self, gt_bboxes, gt_labels, img=None): """Sampling all categories of bboxes. Args: gt_bboxes (np.ndarray): Ground truth bounding boxes. gt_labels (np.ndarray): Ground truth labels of boxes. Returns: dict: Dict of sampled 'pseudo ground truths'. - gt_labels_3d (np.ndarray): ground truths labels \ of sampled objects. - gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): \ sampled ground truth 3D bounding boxes - points (np.ndarray): sampled points - group_ids (np.ndarray): ids of sampled ground truths """ sampled_num_dict = {} sample_num_per_class = [] for class_name, max_sample_num in zip( self.sample_classes, self.sample_max_nums ): class_label = self.cat2label[class_name] # sampled_num = int(max_sample_num - # np.sum([n == class_name for n in gt_names])) sampled_num = int( max_sample_num - np.sum([n == class_label for n in gt_labels]) ) sampled_num = np.round(self.rate * sampled_num).astype(np.int64) sampled_num_dict[class_name] = sampled_num sample_num_per_class.append(sampled_num) sampled = [] sampled_gt_bboxes = [] avoid_coll_boxes = gt_bboxes for class_name, sampled_num in zip(self.sample_classes, sample_num_per_class): if sampled_num > 0: sampled_cls = self.sample_class_v2( class_name, sampled_num, avoid_coll_boxes ) sampled += sampled_cls if len(sampled_cls) > 0: if len(sampled_cls) == 1: sampled_gt_box = sampled_cls[0]["box3d_lidar"][np.newaxis, ...] else: sampled_gt_box = np.stack( [s["box3d_lidar"] for s in sampled_cls], axis=0 ) sampled_gt_bboxes += [sampled_gt_box] avoid_coll_boxes = np.concatenate( [avoid_coll_boxes, sampled_gt_box], axis=0 ) ret = None if len(sampled) > 0: sampled_gt_bboxes = np.concatenate(sampled_gt_bboxes, axis=0) # center = sampled_gt_bboxes[:, 0:3] # num_sampled = len(sampled) s_points_list = [] count = 0 for info in sampled: file_path = ( os.path.join(self.dataset_root, info["path"]) if self.dataset_root else info["path"] ) results = dict(lidar_path=file_path) s_points = self.points_loader(results)["points"] s_points.translate(info["box3d_lidar"][:3]) count += 1 s_points_list.append(s_points) gt_labels = np.array( [self.cat2label[s["name"]] for s in sampled], dtype=np.long ) ret = { "gt_labels_3d": gt_labels, "gt_bboxes_3d": sampled_gt_bboxes, "points": s_points_list[0].cat(s_points_list), "group_ids": np.arange( gt_bboxes.shape[0], gt_bboxes.shape[0] + len(sampled) ), } return ret def sample_class_v2(self, name, num, gt_bboxes): """Sampling specific categories of bounding boxes. Args: name (str): Class of objects to be sampled. num (int): Number of sampled bboxes. gt_bboxes (np.ndarray): Ground truth boxes. Returns: list[dict]: Valid samples after collision test. """ sampled = self.sampler_dict[name].sample(num) sampled = copy.deepcopy(sampled) num_gt = gt_bboxes.shape[0] num_sampled = len(sampled) gt_bboxes_bv = box_np_ops.center_to_corner_box2d( gt_bboxes[:, 0:2], gt_bboxes[:, 3:5], gt_bboxes[:, 6] ) sp_boxes = np.stack([i["box3d_lidar"] for i in sampled], axis=0) boxes = np.concatenate([gt_bboxes, sp_boxes], axis=0).copy() sp_boxes_new = boxes[gt_bboxes.shape[0] :] sp_boxes_bv = box_np_ops.center_to_corner_box2d( sp_boxes_new[:, 0:2], sp_boxes_new[:, 3:5], sp_boxes_new[:, 6] ) total_bv = np.concatenate([gt_bboxes_bv, sp_boxes_bv], axis=0) coll_mat = box_collision_test(total_bv, total_bv) diag = np.arange(total_bv.shape[0]) coll_mat[diag, diag] = False valid_samples = [] for i in range(num_gt, num_gt + num_sampled): if coll_mat[i].any(): coll_mat[i] = False coll_mat[:, i] = False else: valid_samples.append(sampled[i - num_gt]) return valid_samples