324 lines
11 KiB
Python
324 lines
11 KiB
Python
import copy
|
|
import os
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
|
|
from mmdet3d.core.bbox import box_np_ops
|
|
from mmdet.datasets import PIPELINES
|
|
|
|
from ..builder import OBJECTSAMPLERS
|
|
from .utils import box_collision_test
|
|
|
|
|
|
class BatchSampler:
|
|
"""Class for sampling specific category of ground truths.
|
|
|
|
Args:
|
|
sample_list (list[dict]): List of samples.
|
|
name (str | None): The category of samples. Default: None.
|
|
epoch (int | None): Sampling epoch. Default: None.
|
|
shuffle (bool): Whether to shuffle indices. Default: False.
|
|
drop_reminder (bool): Drop reminder. Default: False.
|
|
"""
|
|
|
|
def __init__(
|
|
self, sampled_list, name=None, epoch=None, shuffle=True, drop_reminder=False
|
|
):
|
|
self._sampled_list = sampled_list
|
|
self._indices = np.arange(len(sampled_list))
|
|
if shuffle:
|
|
np.random.shuffle(self._indices)
|
|
self._idx = 0
|
|
self._example_num = len(sampled_list)
|
|
self._name = name
|
|
self._shuffle = shuffle
|
|
self._epoch = epoch
|
|
self._epoch_counter = 0
|
|
self._drop_reminder = drop_reminder
|
|
|
|
def _sample(self, num):
|
|
"""Sample specific number of ground truths and return indices.
|
|
|
|
Args:
|
|
num (int): Sampled number.
|
|
|
|
Returns:
|
|
list[int]: Indices of sampled ground truths.
|
|
"""
|
|
if self._idx + num >= self._example_num:
|
|
ret = self._indices[self._idx :].copy()
|
|
self._reset()
|
|
else:
|
|
ret = self._indices[self._idx : self._idx + num]
|
|
self._idx += num
|
|
return ret
|
|
|
|
def _reset(self):
|
|
"""Reset the index of batchsampler to zero."""
|
|
assert self._name is not None
|
|
# print("reset", self._name)
|
|
if self._shuffle:
|
|
np.random.shuffle(self._indices)
|
|
self._idx = 0
|
|
|
|
def sample(self, num):
|
|
"""Sample specific number of ground truths.
|
|
|
|
Args:
|
|
num (int): Sampled number.
|
|
|
|
Returns:
|
|
list[dict]: Sampled ground truths.
|
|
"""
|
|
indices = self._sample(num)
|
|
return [self._sampled_list[i] for i in indices]
|
|
|
|
|
|
@OBJECTSAMPLERS.register_module()
|
|
class DataBaseSampler:
|
|
"""Class for sampling data from the ground truth database.
|
|
|
|
Args:
|
|
info_path (str): Path of groundtruth database info.
|
|
dataset_root (str): Path of groundtruth database.
|
|
rate (float): Rate of actual sampled over maximum sampled number.
|
|
prepare (dict): Name of preparation functions and the input value.
|
|
sample_groups (dict): Sampled classes and numbers.
|
|
classes (list[str]): List of classes. Default: None.
|
|
points_loader(dict): Config of points loader. Default: dict(
|
|
type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
info_path,
|
|
dataset_root,
|
|
rate,
|
|
prepare,
|
|
sample_groups,
|
|
classes=None,
|
|
points_loader=dict(
|
|
type="LoadPointsFromFile",
|
|
coord_type="LIDAR",
|
|
load_dim=4,
|
|
use_dim=[0, 1, 2, 3],
|
|
),
|
|
):
|
|
super().__init__()
|
|
self.dataset_root = dataset_root
|
|
self.info_path = info_path
|
|
self.rate = rate
|
|
self.prepare = prepare
|
|
self.classes = classes
|
|
self.cat2label = {name: i for i, name in enumerate(classes)}
|
|
self.label2cat = {i: name for i, name in enumerate(classes)}
|
|
self.points_loader = mmcv.build_from_cfg(points_loader, PIPELINES)
|
|
|
|
db_infos = mmcv.load(info_path)
|
|
|
|
# filter database infos
|
|
from mmdet3d.utils import get_root_logger
|
|
|
|
logger = get_root_logger()
|
|
for k, v in db_infos.items():
|
|
logger.info(f"load {len(v)} {k} database infos")
|
|
for prep_func, val in prepare.items():
|
|
db_infos = getattr(self, prep_func)(db_infos, val)
|
|
logger.info("After filter database:")
|
|
for k, v in db_infos.items():
|
|
logger.info(f"load {len(v)} {k} database infos")
|
|
|
|
self.db_infos = db_infos
|
|
|
|
# load sample groups
|
|
# TODO: more elegant way to load sample groups
|
|
self.sample_groups = []
|
|
for name, num in sample_groups.items():
|
|
self.sample_groups.append({name: int(num)})
|
|
|
|
self.group_db_infos = self.db_infos # just use db_infos
|
|
self.sample_classes = []
|
|
self.sample_max_nums = []
|
|
for group_info in self.sample_groups:
|
|
self.sample_classes += list(group_info.keys())
|
|
self.sample_max_nums += list(group_info.values())
|
|
|
|
self.sampler_dict = {}
|
|
for k, v in self.group_db_infos.items():
|
|
self.sampler_dict[k] = BatchSampler(v, k, shuffle=True)
|
|
# TODO: No group_sampling currently
|
|
|
|
@staticmethod
|
|
def filter_by_difficulty(db_infos, removed_difficulty):
|
|
"""Filter ground truths by difficulties.
|
|
|
|
Args:
|
|
db_infos (dict): Info of groundtruth database.
|
|
removed_difficulty (list): Difficulties that are not qualified.
|
|
|
|
Returns:
|
|
dict: Info of database after filtering.
|
|
"""
|
|
new_db_infos = {}
|
|
for key, dinfos in db_infos.items():
|
|
new_db_infos[key] = [
|
|
info for info in dinfos if info["difficulty"] not in removed_difficulty
|
|
]
|
|
return new_db_infos
|
|
|
|
@staticmethod
|
|
def filter_by_min_points(db_infos, min_gt_points_dict):
|
|
"""Filter ground truths by number of points in the bbox.
|
|
|
|
Args:
|
|
db_infos (dict): Info of groundtruth database.
|
|
min_gt_points_dict (dict): Different number of minimum points
|
|
needed for different categories of ground truths.
|
|
|
|
Returns:
|
|
dict: Info of database after filtering.
|
|
"""
|
|
for name, min_num in min_gt_points_dict.items():
|
|
min_num = int(min_num)
|
|
if min_num > 0:
|
|
filtered_infos = []
|
|
for info in db_infos[name]:
|
|
if info["num_points_in_gt"] >= min_num:
|
|
filtered_infos.append(info)
|
|
db_infos[name] = filtered_infos
|
|
return db_infos
|
|
|
|
def sample_all(self, gt_bboxes, gt_labels, img=None):
|
|
"""Sampling all categories of bboxes.
|
|
|
|
Args:
|
|
gt_bboxes (np.ndarray): Ground truth bounding boxes.
|
|
gt_labels (np.ndarray): Ground truth labels of boxes.
|
|
|
|
Returns:
|
|
dict: Dict of sampled 'pseudo ground truths'.
|
|
|
|
- gt_labels_3d (np.ndarray): ground truths labels \
|
|
of sampled objects.
|
|
- gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): \
|
|
sampled ground truth 3D bounding boxes
|
|
- points (np.ndarray): sampled points
|
|
- group_ids (np.ndarray): ids of sampled ground truths
|
|
"""
|
|
sampled_num_dict = {}
|
|
sample_num_per_class = []
|
|
for class_name, max_sample_num in zip(
|
|
self.sample_classes, self.sample_max_nums
|
|
):
|
|
class_label = self.cat2label[class_name]
|
|
# sampled_num = int(max_sample_num -
|
|
# np.sum([n == class_name for n in gt_names]))
|
|
sampled_num = int(
|
|
max_sample_num - np.sum([n == class_label for n in gt_labels])
|
|
)
|
|
sampled_num = np.round(self.rate * sampled_num).astype(np.int64)
|
|
sampled_num_dict[class_name] = sampled_num
|
|
sample_num_per_class.append(sampled_num)
|
|
|
|
sampled = []
|
|
sampled_gt_bboxes = []
|
|
avoid_coll_boxes = gt_bboxes
|
|
|
|
for class_name, sampled_num in zip(self.sample_classes, sample_num_per_class):
|
|
if sampled_num > 0:
|
|
sampled_cls = self.sample_class_v2(
|
|
class_name, sampled_num, avoid_coll_boxes
|
|
)
|
|
|
|
sampled += sampled_cls
|
|
if len(sampled_cls) > 0:
|
|
if len(sampled_cls) == 1:
|
|
sampled_gt_box = sampled_cls[0]["box3d_lidar"][np.newaxis, ...]
|
|
else:
|
|
sampled_gt_box = np.stack(
|
|
[s["box3d_lidar"] for s in sampled_cls], axis=0
|
|
)
|
|
|
|
sampled_gt_bboxes += [sampled_gt_box]
|
|
avoid_coll_boxes = np.concatenate(
|
|
[avoid_coll_boxes, sampled_gt_box], axis=0
|
|
)
|
|
|
|
ret = None
|
|
if len(sampled) > 0:
|
|
sampled_gt_bboxes = np.concatenate(sampled_gt_bboxes, axis=0)
|
|
# center = sampled_gt_bboxes[:, 0:3]
|
|
|
|
# num_sampled = len(sampled)
|
|
s_points_list = []
|
|
count = 0
|
|
for info in sampled:
|
|
file_path = (
|
|
os.path.join(self.dataset_root, info["path"])
|
|
if self.dataset_root
|
|
else info["path"]
|
|
)
|
|
results = dict(lidar_path=file_path)
|
|
s_points = self.points_loader(results)["points"]
|
|
s_points.translate(info["box3d_lidar"][:3])
|
|
|
|
count += 1
|
|
|
|
s_points_list.append(s_points)
|
|
|
|
gt_labels = np.array(
|
|
[self.cat2label[s["name"]] for s in sampled], dtype=np.long
|
|
)
|
|
ret = {
|
|
"gt_labels_3d": gt_labels,
|
|
"gt_bboxes_3d": sampled_gt_bboxes,
|
|
"points": s_points_list[0].cat(s_points_list),
|
|
"group_ids": np.arange(
|
|
gt_bboxes.shape[0], gt_bboxes.shape[0] + len(sampled)
|
|
),
|
|
}
|
|
|
|
return ret
|
|
|
|
def sample_class_v2(self, name, num, gt_bboxes):
|
|
"""Sampling specific categories of bounding boxes.
|
|
|
|
Args:
|
|
name (str): Class of objects to be sampled.
|
|
num (int): Number of sampled bboxes.
|
|
gt_bboxes (np.ndarray): Ground truth boxes.
|
|
|
|
Returns:
|
|
list[dict]: Valid samples after collision test.
|
|
"""
|
|
sampled = self.sampler_dict[name].sample(num)
|
|
sampled = copy.deepcopy(sampled)
|
|
num_gt = gt_bboxes.shape[0]
|
|
num_sampled = len(sampled)
|
|
gt_bboxes_bv = box_np_ops.center_to_corner_box2d(
|
|
gt_bboxes[:, 0:2], gt_bboxes[:, 3:5], gt_bboxes[:, 6]
|
|
)
|
|
|
|
sp_boxes = np.stack([i["box3d_lidar"] for i in sampled], axis=0)
|
|
boxes = np.concatenate([gt_bboxes, sp_boxes], axis=0).copy()
|
|
|
|
sp_boxes_new = boxes[gt_bboxes.shape[0] :]
|
|
sp_boxes_bv = box_np_ops.center_to_corner_box2d(
|
|
sp_boxes_new[:, 0:2], sp_boxes_new[:, 3:5], sp_boxes_new[:, 6]
|
|
)
|
|
|
|
total_bv = np.concatenate([gt_bboxes_bv, sp_boxes_bv], axis=0)
|
|
coll_mat = box_collision_test(total_bv, total_bv)
|
|
diag = np.arange(total_bv.shape[0])
|
|
coll_mat[diag, diag] = False
|
|
|
|
valid_samples = []
|
|
for i in range(num_gt, num_gt + num_sampled):
|
|
if coll_mat[i].any():
|
|
coll_mat[i] = False
|
|
coll_mat[:, i] = False
|
|
else:
|
|
valid_samples.append(sampled[i - num_gt])
|
|
return valid_samples
|