307 lines
10 KiB
Python
307 lines
10 KiB
Python
import tempfile
|
|
from os import path as osp
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmdet.datasets import DATASETS
|
|
|
|
from ..core.bbox import get_box_type
|
|
from .pipelines import Compose
|
|
from .utils import extract_result_dict
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class Custom3DDataset(Dataset):
|
|
"""Customized 3D dataset.
|
|
|
|
This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI
|
|
dataset.
|
|
|
|
Args:
|
|
dataset_root (str): Path of dataset root.
|
|
ann_file (str): Path of annotation file.
|
|
pipeline (list[dict], optional): Pipeline used for data processing.
|
|
Defaults to None.
|
|
classes (tuple[str], optional): Classes used in the dataset.
|
|
Defaults to None.
|
|
modality (dict, optional): Modality to specify the sensor data used
|
|
as input. Defaults to None.
|
|
box_type_3d (str, optional): Type of 3D box of this dataset.
|
|
Based on the `box_type_3d`, the dataset will encapsulate the box
|
|
to its original format then converted them to `box_type_3d`.
|
|
Defaults to 'LiDAR'. Available options includes
|
|
|
|
- 'LiDAR': Box in LiDAR coordinates.
|
|
- 'Depth': Box in depth coordinates, usually for indoor dataset.
|
|
- 'Camera': Box in camera coordinates.
|
|
filter_empty_gt (bool, optional): Whether to filter empty GT.
|
|
Defaults to True.
|
|
test_mode (bool, optional): Whether the dataset is in test mode.
|
|
Defaults to False.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_root,
|
|
ann_file,
|
|
pipeline=None,
|
|
classes=None,
|
|
modality=None,
|
|
box_type_3d="LiDAR",
|
|
filter_empty_gt=True,
|
|
test_mode=False,
|
|
):
|
|
super().__init__()
|
|
self.dataset_root = dataset_root
|
|
self.ann_file = ann_file
|
|
self.test_mode = test_mode
|
|
self.modality = modality
|
|
self.filter_empty_gt = filter_empty_gt
|
|
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
|
|
|
|
self.CLASSES = self.get_classes(classes)
|
|
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
|
|
self.data_infos = self.load_annotations(self.ann_file)
|
|
|
|
if pipeline is not None:
|
|
self.pipeline = Compose(pipeline)
|
|
|
|
# set group flag for the sampler
|
|
if not self.test_mode:
|
|
self._set_group_flag()
|
|
|
|
self.epoch = -1
|
|
|
|
def set_epoch(self, epoch):
|
|
self.epoch = epoch
|
|
if hasattr(self, "pipeline"):
|
|
for transform in self.pipeline.transforms:
|
|
if hasattr(transform, "set_epoch"):
|
|
transform.set_epoch(epoch)
|
|
|
|
def load_annotations(self, ann_file):
|
|
"""Load annotations from ann_file.
|
|
|
|
Args:
|
|
ann_file (str): Path of the annotation file.
|
|
|
|
Returns:
|
|
list[dict]: List of annotations.
|
|
"""
|
|
return mmcv.load(ann_file)
|
|
|
|
def get_data_info(self, index):
|
|
"""Get data info according to the given index.
|
|
|
|
Args:
|
|
index (int): Index of the sample data to get.
|
|
|
|
Returns:
|
|
dict: Data information that will be passed to the data \
|
|
preprocessing pipelines. It includes the following keys:
|
|
|
|
- sample_idx (str): Sample index.
|
|
- lidar_path (str): Filename of point clouds.
|
|
- file_name (str): Filename of point clouds.
|
|
- ann_info (dict): Annotation info.
|
|
"""
|
|
info = self.data_infos[index]
|
|
sample_idx = info["point_cloud"]["lidar_idx"]
|
|
lidar_path = osp.join(self.dataset_root, info["pts_path"])
|
|
|
|
input_dict = dict(
|
|
lidar_path=lidar_path, sample_idx=sample_idx, file_name=lidar_path
|
|
)
|
|
|
|
if not self.test_mode:
|
|
annos = self.get_ann_info(index)
|
|
input_dict["ann_info"] = annos
|
|
if self.filter_empty_gt and ~(annos["gt_labels_3d"] != -1).any():
|
|
return None
|
|
return input_dict
|
|
|
|
def pre_pipeline(self, results):
|
|
"""Initialization before data preparation.
|
|
|
|
Args:
|
|
results (dict): Dict before data preprocessing.
|
|
|
|
- img_fields (list): Image fields.
|
|
- bbox3d_fields (list): 3D bounding boxes fields.
|
|
- pts_mask_fields (list): Mask fields of points.
|
|
- pts_seg_fields (list): Mask fields of point segments.
|
|
- bbox_fields (list): Fields of bounding boxes.
|
|
- mask_fields (list): Fields of masks.
|
|
- seg_fields (list): Segment fields.
|
|
- box_type_3d (str): 3D box type.
|
|
- box_mode_3d (str): 3D box mode.
|
|
"""
|
|
results["img_fields"] = []
|
|
results["bbox3d_fields"] = []
|
|
results["pts_mask_fields"] = []
|
|
results["pts_seg_fields"] = []
|
|
results["bbox_fields"] = []
|
|
results["mask_fields"] = []
|
|
results["seg_fields"] = []
|
|
results["box_type_3d"] = self.box_type_3d
|
|
results["box_mode_3d"] = self.box_mode_3d
|
|
|
|
def prepare_train_data(self, index):
|
|
"""Training data preparation.
|
|
|
|
Args:
|
|
index (int): Index for accessing the target data.
|
|
|
|
Returns:
|
|
dict: Training data dict of the corresponding index.
|
|
"""
|
|
input_dict = self.get_data_info(index)
|
|
if input_dict is None:
|
|
return None
|
|
self.pre_pipeline(input_dict)
|
|
example = self.pipeline(input_dict)
|
|
if self.filter_empty_gt and (
|
|
example is None or ~(example["gt_labels_3d"]._data != -1).any()
|
|
):
|
|
return None
|
|
return example
|
|
|
|
def prepare_test_data(self, index):
|
|
"""Prepare data for testing.
|
|
|
|
Args:
|
|
index (int): Index for accessing the target data.
|
|
|
|
Returns:
|
|
dict: Testing data dict of the corresponding index.
|
|
"""
|
|
input_dict = self.get_data_info(index)
|
|
self.pre_pipeline(input_dict)
|
|
example = self.pipeline(input_dict)
|
|
return example
|
|
|
|
@classmethod
|
|
def get_classes(cls, classes=None):
|
|
"""Get class names of current dataset.
|
|
|
|
Args:
|
|
classes (Sequence[str] | str | None): If classes is None, use
|
|
default CLASSES defined by builtin dataset. If classes is a
|
|
string, take it as a file name. The file contains the name of
|
|
classes where each line contains one class name. If classes is
|
|
a tuple or list, override the CLASSES defined by the dataset.
|
|
|
|
Return:
|
|
list[str]: A list of class names.
|
|
"""
|
|
if classes is None:
|
|
return cls.CLASSES
|
|
|
|
if isinstance(classes, str):
|
|
# take it as a file path
|
|
class_names = mmcv.list_from_file(classes)
|
|
elif isinstance(classes, (tuple, list)):
|
|
class_names = classes
|
|
else:
|
|
raise ValueError(f"Unsupported type {type(classes)} of classes.")
|
|
|
|
return class_names
|
|
|
|
def format_results(self, outputs, pklfile_prefix=None, submission_prefix=None):
|
|
"""Format the results to pkl file.
|
|
|
|
Args:
|
|
outputs (list[dict]): Testing results of the dataset.
|
|
pklfile_prefix (str | None): The prefix of pkl files. It includes
|
|
the file path and the prefix of filename, e.g., "a/b/prefix".
|
|
If not specified, a temp file will be created. Default: None.
|
|
|
|
Returns:
|
|
tuple: (outputs, tmp_dir), outputs is the detection results, \
|
|
tmp_dir is the temporal directory created for saving json \
|
|
files when ``jsonfile_prefix`` is not specified.
|
|
"""
|
|
if pklfile_prefix is None:
|
|
tmp_dir = tempfile.TemporaryDirectory()
|
|
pklfile_prefix = osp.join(tmp_dir.name, "results")
|
|
out = f"{pklfile_prefix}.pkl"
|
|
mmcv.dump(outputs, out)
|
|
return outputs, tmp_dir
|
|
|
|
def _extract_data(self, index, pipeline, key, load_annos=False):
|
|
"""Load data using input pipeline and extract data according to key.
|
|
|
|
Args:
|
|
index (int): Index for accessing the target data.
|
|
pipeline (:obj:`Compose`): Composed data loading pipeline.
|
|
key (str | list[str]): One single or a list of data key.
|
|
load_annos (bool): Whether to load data annotations.
|
|
If True, need to set self.test_mode as False before loading.
|
|
|
|
Returns:
|
|
np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor]:
|
|
A single or a list of loaded data.
|
|
"""
|
|
assert pipeline is not None, "data loading pipeline is not provided"
|
|
# when we want to load ground-truth via pipeline (e.g. bbox, seg mask)
|
|
# we need to set self.test_mode as False so that we have 'annos'
|
|
if load_annos:
|
|
original_test_mode = self.test_mode
|
|
self.test_mode = False
|
|
input_dict = self.get_data_info(index)
|
|
self.pre_pipeline(input_dict)
|
|
example = pipeline(input_dict)
|
|
|
|
# extract data items according to keys
|
|
if isinstance(key, str):
|
|
data = extract_result_dict(example, key)
|
|
else:
|
|
data = [extract_result_dict(example, k) for k in key]
|
|
if load_annos:
|
|
self.test_mode = original_test_mode
|
|
|
|
return data
|
|
|
|
def __len__(self):
|
|
"""Return the length of data infos.
|
|
|
|
Returns:
|
|
int: Length of data infos.
|
|
"""
|
|
return len(self.data_infos)
|
|
|
|
def _rand_another(self, idx):
|
|
"""Randomly get another item with the same flag.
|
|
|
|
Returns:
|
|
int: Another index of item with the same flag.
|
|
"""
|
|
pool = np.where(self.flag == self.flag[idx])[0]
|
|
return np.random.choice(pool)
|
|
|
|
def __getitem__(self, idx):
|
|
"""Get item from infos according to the given index.
|
|
|
|
Returns:
|
|
dict: Data dictionary of the corresponding index.
|
|
"""
|
|
if self.test_mode:
|
|
return self.prepare_test_data(idx)
|
|
while True:
|
|
data = self.prepare_train_data(idx)
|
|
if data is None:
|
|
idx = self._rand_another(idx)
|
|
continue
|
|
return data
|
|
|
|
def _set_group_flag(self):
|
|
"""Set flag according to image aspect ratio.
|
|
|
|
Images with aspect ratio greater than 1 will be set as group 1,
|
|
otherwise group 0. In 3D datasets, they are all the same, thus are all
|
|
zeros.
|
|
"""
|
|
self.flag = np.zeros(len(self), dtype=np.uint8)
|