from typing import Any, Dict import mmcv import numpy as np import torch import torchvision from mmcv import is_tuple_of from mmcv.utils import build_from_cfg from numpy import random from PIL import Image from mmdet3d.core import VoxelGenerator from mmdet3d.core.bbox import ( CameraInstance3DBoxes, DepthInstance3DBoxes, LiDARInstance3DBoxes, box_np_ops, ) from mmdet.datasets.builder import PIPELINES from ..builder import OBJECTSAMPLERS from .utils import noise_per_object_v3_ @PIPELINES.register_module() class GTDepth: def __init__(self, keyframe_only=False): self.keyframe_only = keyframe_only def __call__(self, data): sensor2ego = data['camera2ego'].data cam_intrinsic = data['camera_intrinsics'].data img_aug_matrix = data['img_aug_matrix'].data bev_aug_matrix = data['lidar_aug_matrix'].data lidar2ego = data['lidar2ego'].data camera2lidar = data['camera2lidar'].data lidar2image = data['lidar2image'].data rots = sensor2ego[..., :3, :3] trans = sensor2ego[..., :3, 3] intrins = cam_intrinsic[..., :3, :3] post_rots = img_aug_matrix[..., :3, :3] post_trans = img_aug_matrix[..., :3, 3] lidar2ego_rots = lidar2ego[..., :3, :3] lidar2ego_trans = lidar2ego[..., :3, 3] camera2lidar_rots = camera2lidar[..., :3, :3] camera2lidar_trans = camera2lidar[..., :3, 3] points = data['points'].data img = data['img'].data if self.keyframe_only: points = points[points[:, 4] == 0] batch_size = len(points) depth = torch.zeros(img.shape[0], *img.shape[-2:]) #.to(points[0].device) # for b in range(batch_size): cur_coords = points[:, :3] # inverse aug cur_coords -= bev_aug_matrix[:3, 3] cur_coords = torch.inverse(bev_aug_matrix[:3, :3]).matmul( cur_coords.transpose(1, 0) ) # lidar2image cur_coords = lidar2image[:, :3, :3].matmul(cur_coords) cur_coords += lidar2image[:, :3, 3].reshape(-1, 3, 1) # get 2d coords dist = cur_coords[:, 2, :] cur_coords[:, 2, :] = torch.clamp(cur_coords[:, 2, :], 1e-5, 1e5) cur_coords[:, :2, :] /= cur_coords[:, 2:3, :] # imgaug cur_coords = img_aug_matrix[:, :3, :3].matmul(cur_coords) cur_coords += img_aug_matrix[:, :3, 3].reshape(-1, 3, 1) cur_coords = cur_coords[:, :2, :].transpose(1, 2) # normalize coords for grid sample cur_coords = cur_coords[..., [1, 0]] on_img = ( (cur_coords[..., 0] < img.shape[2]) & (cur_coords[..., 0] >= 0) & (cur_coords[..., 1] < img.shape[3]) & (cur_coords[..., 1] >= 0) ) for c in range(on_img.shape[0]): masked_coords = cur_coords[c, on_img[c]].long() masked_dist = dist[c, on_img[c]] depth[c, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist data['depths'] = depth return data @PIPELINES.register_module() class ImageAug3D: def __init__( self, final_dim, resize_lim, bot_pct_lim, rot_lim, rand_flip, is_train ): self.final_dim = final_dim self.resize_lim = resize_lim self.bot_pct_lim = bot_pct_lim self.rand_flip = rand_flip self.rot_lim = rot_lim self.is_train = is_train def sample_augmentation(self, results): W, H = results["ori_shape"] fH, fW = self.final_dim if self.is_train: resize = np.random.uniform(*self.resize_lim) resize_dims = (int(W * resize), int(H * resize)) newW, newH = resize_dims crop_h = int((1 - np.random.uniform(*self.bot_pct_lim)) * newH) - fH crop_w = int(np.random.uniform(0, max(0, newW - fW))) crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) flip = False if self.rand_flip and np.random.choice([0, 1]): flip = True rotate = np.random.uniform(*self.rot_lim) else: resize = np.mean(self.resize_lim) resize_dims = (int(W * resize), int(H * resize)) newW, newH = resize_dims crop_h = int((1 - np.mean(self.bot_pct_lim)) * newH) - fH crop_w = int(max(0, newW - fW) / 2) crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) flip = False rotate = 0 return resize, resize_dims, crop, flip, rotate def img_transform( self, img, rotation, translation, resize, resize_dims, crop, flip, rotate ): # adjust image img = img.resize(resize_dims) img = img.crop(crop) if flip: img = img.transpose(method=Image.FLIP_LEFT_RIGHT) img = img.rotate(rotate) # post-homography transformation rotation *= resize translation -= torch.Tensor(crop[:2]) if flip: A = torch.Tensor([[-1, 0], [0, 1]]) b = torch.Tensor([crop[2] - crop[0], 0]) rotation = A.matmul(rotation) translation = A.matmul(translation) + b theta = rotate / 180 * np.pi A = torch.Tensor( [ [np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)], ] ) b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2 b = A.matmul(-b) + b rotation = A.matmul(rotation) translation = A.matmul(translation) + b return img, rotation, translation def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: imgs = data["img"] new_imgs = [] transforms = [] for img in imgs: resize, resize_dims, crop, flip, rotate = self.sample_augmentation(data) post_rot = torch.eye(2) post_tran = torch.zeros(2) new_img, rotation, translation = self.img_transform( img, post_rot, post_tran, resize=resize, resize_dims=resize_dims, crop=crop, flip=flip, rotate=rotate, ) transform = torch.eye(4) transform[:2, :2] = rotation transform[:2, 3] = translation new_imgs.append(new_img) transforms.append(transform.numpy()) data["img"] = new_imgs # update the calibration matrices data["img_aug_matrix"] = transforms return data @PIPELINES.register_module() class GlobalRotScaleTrans: def __init__(self, resize_lim, rot_lim, trans_lim, is_train): self.resize_lim = resize_lim self.rot_lim = rot_lim self.trans_lim = trans_lim self.is_train = is_train def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: transform = np.eye(4).astype(np.float32) if self.is_train: scale = random.uniform(*self.resize_lim) theta = random.uniform(*self.rot_lim) translation = np.array([random.normal(0, self.trans_lim) for i in range(3)]) rotation = np.eye(3) if "points" in data: data["points"].rotate(-theta) data["points"].translate(translation) data["points"].scale(scale) if "radar" in data and hasattr(data["radar"], 'rotate'): data["radar"].rotate(-theta) data["radar"].translate(translation) data["radar"].scale(scale) gt_boxes = data["gt_bboxes_3d"] rotation = rotation @ gt_boxes.rotate(theta).numpy() gt_boxes.translate(translation) gt_boxes.scale(scale) data["gt_bboxes_3d"] = gt_boxes transform[:3, :3] = rotation.T * scale transform[:3, 3] = translation * scale data["lidar_aug_matrix"] = transform return data @PIPELINES.register_module() class GridMask: def __init__( self, use_h, use_w, max_epoch, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0, fixed_prob=False, ): self.use_h = use_h self.use_w = use_w self.rotate = rotate self.offset = offset self.ratio = ratio self.mode = mode self.st_prob = prob self.prob = prob self.epoch = None self.max_epoch = max_epoch self.fixed_prob = fixed_prob def set_epoch(self, epoch): self.epoch = epoch if not self.fixed_prob: self.set_prob(self.epoch, self.max_epoch) def set_prob(self, epoch, max_epoch): self.prob = self.st_prob * self.epoch / self.max_epoch def __call__(self, results): if np.random.rand() > self.prob: return results imgs = results["img"] h = imgs[0].shape[0] w = imgs[0].shape[1] self.d1 = 2 self.d2 = min(h, w) hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(self.d1, self.d2) if self.ratio == 1: self.l = np.random.randint(1, d) else: self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) mask = np.ones((hh, ww), np.float32) st_h = np.random.randint(d) st_w = np.random.randint(d) if self.use_h: for i in range(hh // d): s = d * i + st_h t = min(s + self.l, hh) mask[s:t, :] *= 0 if self.use_w: for i in range(ww // d): s = d * i + st_w t = min(s + self.l, ww) mask[:, s:t] *= 0 r = np.random.randint(self.rotate) mask = Image.fromarray(np.uint8(mask)) mask = mask.rotate(r) mask = np.asarray(mask) mask = mask[ (hh - h) // 2 : (hh - h) // 2 + h, (ww - w) // 2 : (ww - w) // 2 + w ] mask = mask.astype(np.float32) mask = mask[:, :, None] if self.mode == 1: mask = 1 - mask # mask = mask.expand_as(imgs[0]) if self.offset: offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float() offset = (1 - mask) * offset imgs = [x * mask + offset for x in imgs] else: imgs = [x * mask for x in imgs] results.update(img=imgs) return results @PIPELINES.register_module() class RandomFlip3D: def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: flip_horizontal = random.choice([0, 1]) flip_vertical = random.choice([0, 1]) rotation = np.eye(3) if flip_horizontal: rotation = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) @ rotation if "points" in data: data["points"].flip("horizontal") if "radar" in data and hasattr(data["radar"], "flip"): data["radar"].flip("horizontal") if "gt_bboxes_3d" in data: data["gt_bboxes_3d"].flip("horizontal") if "gt_masks_bev" in data: data["gt_masks_bev"] = data["gt_masks_bev"][:, :, ::-1].copy() if flip_vertical: rotation = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) @ rotation if "points" in data: data["points"].flip("vertical") if "radar" in data and hasattr(data["radar"], "flip"): data["radar"].flip("vertical") if "gt_bboxes_3d" in data: data["gt_bboxes_3d"].flip("vertical") if "gt_masks_bev" in data: data["gt_masks_bev"] = data["gt_masks_bev"][:, ::-1, :].copy() data["lidar_aug_matrix"][:3, :] = rotation @ data["lidar_aug_matrix"][:3, :] return data @PIPELINES.register_module() class ObjectPaste: """Sample GT objects to the data. Args: db_sampler (dict): Config dict of the database sampler. sample_2d (bool): Whether to also paste 2D image patch to the images This should be true when applying multi-modality cut-and-paste. Defaults to False. """ def __init__(self, db_sampler, sample_2d=False, stop_epoch=None): self.sampler_cfg = db_sampler self.sample_2d = sample_2d if "type" not in db_sampler.keys(): db_sampler["type"] = "DataBaseSampler" self.db_sampler = build_from_cfg(db_sampler, OBJECTSAMPLERS) self.epoch = -1 self.stop_epoch = stop_epoch def set_epoch(self, epoch): self.epoch = epoch @staticmethod def remove_points_in_boxes(points, boxes): """Remove the points in the sampled bounding boxes. Args: points (:obj:`BasePoints`): Input point cloud array. boxes (np.ndarray): Sampled ground truth boxes. Returns: np.ndarray: Points with those in the boxes removed. """ masks = box_np_ops.points_in_rbbox(points.coord.numpy(), boxes) points = points[np.logical_not(masks.any(-1))] return points def __call__(self, data): """Call function to sample ground truth objects to the data. Args: data (dict): Result dict from loading pipeline. Returns: dict: Results after object sampling augmentation, \ 'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated \ in the result dict. """ if self.stop_epoch is not None and self.epoch >= self.stop_epoch: return data gt_bboxes_3d = data["gt_bboxes_3d"] gt_labels_3d = data["gt_labels_3d"] # change to float for blending operation points = data["points"] if self.sample_2d: img = data["img"] gt_bboxes_2d = data["gt_bboxes"] # Assume for now 3D & 2D bboxes are the same sampled_dict = self.db_sampler.sample_all( gt_bboxes_3d.tensor.numpy(), gt_labels_3d, gt_bboxes_2d=gt_bboxes_2d, img=img, ) else: sampled_dict = self.db_sampler.sample_all( gt_bboxes_3d.tensor.numpy(), gt_labels_3d, img=None ) if sampled_dict is not None: sampled_gt_bboxes_3d = sampled_dict["gt_bboxes_3d"] sampled_points = sampled_dict["points"] sampled_gt_labels = sampled_dict["gt_labels_3d"] gt_labels_3d = np.concatenate([gt_labels_3d, sampled_gt_labels], axis=0) gt_bboxes_3d = gt_bboxes_3d.new_box( np.concatenate([gt_bboxes_3d.tensor.numpy(), sampled_gt_bboxes_3d]) ) points = self.remove_points_in_boxes(points, sampled_gt_bboxes_3d) # check the points dimension points = points.cat([sampled_points, points]) if self.sample_2d: sampled_gt_bboxes_2d = sampled_dict["gt_bboxes_2d"] gt_bboxes_2d = np.concatenate( [gt_bboxes_2d, sampled_gt_bboxes_2d] ).astype(np.float32) data["gt_bboxes"] = gt_bboxes_2d data["img"] = sampled_dict["img"] data["gt_bboxes_3d"] = gt_bboxes_3d data["gt_labels_3d"] = gt_labels_3d.astype(np.long) data["points"] = points return data @PIPELINES.register_module() class ObjectNoise: """Apply noise to each GT objects in the scene. Args: translation_std (list[float], optional): Standard deviation of the distribution where translation noise are sampled from. Defaults to [0.25, 0.25, 0.25]. global_rot_range (list[float], optional): Global rotation to the scene. Defaults to [0.0, 0.0]. rot_range (list[float], optional): Object rotation range. Defaults to [-0.15707963267, 0.15707963267]. num_try (int, optional): Number of times to try if the noise applied is invalid. Defaults to 100. """ def __init__( self, translation_std=[0.25, 0.25, 0.25], global_rot_range=[0.0, 0.0], rot_range=[-0.15707963267, 0.15707963267], num_try=100, ): self.translation_std = translation_std self.global_rot_range = global_rot_range self.rot_range = rot_range self.num_try = num_try def __call__(self, data): """Call function to apply noise to each ground truth in the scene. Args: data (dict): Result dict from loading pipeline. Returns: dict: Results after adding noise to each object, \ 'points', 'gt_bboxes_3d' keys are updated in the result dict. """ gt_bboxes_3d = data["gt_bboxes_3d"] points = data["points"] # TODO: check this inplace function numpy_box = gt_bboxes_3d.tensor.numpy() numpy_points = points.tensor.numpy() noise_per_object_v3_( numpy_box, numpy_points, rotation_perturb=self.rot_range, center_noise_std=self.translation_std, global_random_rot_range=self.global_rot_range, num_try=self.num_try, ) data["gt_bboxes_3d"] = gt_bboxes_3d.new_box(numpy_box) data["points"] = points.new_point(numpy_points) return data @PIPELINES.register_module() class FrameDropout: def __init__(self, prob: float = 0.5, time_dim: int = -1) -> None: self.prob = prob self.time_dim = time_dim def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: offsets = [] for offset in torch.unique(data["points"].tensor[:, self.time_dim]): if offset == 0 or random.random() > self.prob: offsets.append(offset) offsets = torch.tensor(offsets) points = data["points"].tensor indices = torch.isin(points[:, self.time_dim], offsets) data["points"].tensor = points[indices] return data @PIPELINES.register_module() class PointShuffle: def __call__(self, data): data["points"].shuffle() return data @PIPELINES.register_module() class ObjectRangeFilter: """Filter objects by the range. Args: point_cloud_range (list[float]): Point cloud range. """ def __init__(self, point_cloud_range): self.pcd_range = np.array(point_cloud_range, dtype=np.float32) def __call__(self, data): """Call function to filter objects by the range. Args: data (dict): Result dict from loading pipeline. Returns: dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \ keys are updated in the result dict. """ # Check points instance type and initialise bev_range if isinstance( data["gt_bboxes_3d"], (LiDARInstance3DBoxes, DepthInstance3DBoxes) ): bev_range = self.pcd_range[[0, 1, 3, 4]] elif isinstance(data["gt_bboxes_3d"], CameraInstance3DBoxes): bev_range = self.pcd_range[[0, 2, 3, 5]] gt_bboxes_3d = data["gt_bboxes_3d"] gt_labels_3d = data["gt_labels_3d"] mask = gt_bboxes_3d.in_range_bev(bev_range) gt_bboxes_3d = gt_bboxes_3d[mask] # mask is a torch tensor but gt_labels_3d is still numpy array # using mask to index gt_labels_3d will cause bug when # len(gt_labels_3d) == 1, where mask=1 will be interpreted # as gt_labels_3d[1] and cause out of index error gt_labels_3d = gt_labels_3d[mask.numpy().astype(np.bool)] # limit rad to [-pi, pi] gt_bboxes_3d.limit_yaw(offset=0.5, period=2 * np.pi) data["gt_bboxes_3d"] = gt_bboxes_3d data["gt_labels_3d"] = gt_labels_3d return data def __repr__(self): """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f"(point_cloud_range={self.pcd_range.tolist()})" return repr_str @PIPELINES.register_module() class PointsRangeFilter: """Filter points by the range. Args: point_cloud_range (list[float]): Point cloud range. """ def __init__(self, point_cloud_range): self.pcd_range = np.array(point_cloud_range, dtype=np.float32) def __call__(self, data): """Call function to filter points by the range. Args: data (dict): Result dict from loading pipeline. Returns: dict: Results after filtering, 'points', 'pts_instance_mask' \ and 'pts_semantic_mask' keys are updated in the result dict. """ points = data["points"] points_mask = points.in_range_3d(self.pcd_range) clean_points = points[points_mask] data["points"] = clean_points if "radar" in data and hasattr(data["radar"], "in_range_bev"): radar = data["radar"] # radar_mask = radar.in_range_3d(self.pcd_range) radar_mask = radar.in_range_bev([-55.0, -55.0, 55.0, 55.0]) clean_radar = radar[radar_mask] data["radar"] = clean_radar return data @PIPELINES.register_module() class ObjectNameFilter: """Filter GT objects by their names. Args: classes (list[str]): List of class names to be kept for training. """ def __init__(self, classes): self.classes = classes self.labels = list(range(len(self.classes))) def __call__(self, data): gt_labels_3d = data["gt_labels_3d"] gt_bboxes_mask = np.array( [n in self.labels for n in gt_labels_3d], dtype=np.bool_ ) data["gt_bboxes_3d"] = data["gt_bboxes_3d"][gt_bboxes_mask] data["gt_labels_3d"] = data["gt_labels_3d"][gt_bboxes_mask] return data @PIPELINES.register_module() class PointSample: """Point sample. Sampling data to a certain number. Args: num_points (int): Number of points to be sampled. sample_range (float, optional): The range where to sample points. If not None, the points with depth larger than `sample_range` are prior to be sampled. Defaults to None. replace (bool, optional): Whether the sampling is with or without replacement. Defaults to False. """ def __init__(self, num_points, sample_range=None, replace=False): self.num_points = num_points self.sample_range = sample_range self.replace = replace def _points_random_sampling( self, points, num_samples, sample_range=None, replace=False, return_choices=False, ): """Points random sampling. Sample points to a certain number. Args: points (np.ndarray | :obj:`BasePoints`): 3D Points. num_samples (int): Number of samples to be sampled. sample_range (float, optional): Indicating the range where the points will be sampled. Defaults to None. replace (bool, optional): Sampling with or without replacement. Defaults to None. return_choices (bool, optional): Whether return choice. Defaults to False. Returns: tuple[np.ndarray] | np.ndarray: - points (np.ndarray | :obj:`BasePoints`): 3D Points. - choices (np.ndarray, optional): The generated random samples. """ if not replace: replace = points.shape[0] < num_samples point_range = range(len(points)) if sample_range is not None and not replace: # Only sampling the near points when len(points) >= num_samples depth = np.linalg.norm(points.tensor, axis=1) far_inds = np.where(depth > sample_range)[0] near_inds = np.where(depth <= sample_range)[0] # in case there are too many far points if len(far_inds) > num_samples: far_inds = np.random.choice(far_inds, num_samples, replace=False) point_range = near_inds num_samples -= len(far_inds) choices = np.random.choice(point_range, num_samples, replace=replace) if sample_range is not None and not replace: choices = np.concatenate((far_inds, choices)) # Shuffle points after sampling np.random.shuffle(choices) if return_choices: return points[choices], choices else: return points[choices] def __call__(self, data): """Call function to sample points to in indoor scenes. Args: data (dict): Result dict from loading pipeline. Returns: dict: Results after sampling, 'points', 'pts_instance_mask' \ and 'pts_semantic_mask' keys are updated in the result dict. """ points = data["points"] # Points in Camera coord can provide the depth information. # TODO: Need to suport distance-based sampling for other coord system. if self.sample_range is not None: from mmdet3d.core.points import CameraPoints assert isinstance( points, CameraPoints ), "Sampling based on distance is only appliable for CAMERA coord" points, choices = self._points_random_sampling( points, self.num_points, self.sample_range, self.replace, return_choices=True, ) data["points"] = points return data def __repr__(self): """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f"(num_points={self.num_points}," repr_str += f" sample_range={self.sample_range}," repr_str += f" replace={self.replace})" return repr_str @PIPELINES.register_module() class BackgroundPointsFilter: """Filter background points near the bounding box. Args: bbox_enlarge_range (tuple[float], float): Bbox enlarge range. """ def __init__(self, bbox_enlarge_range): assert ( is_tuple_of(bbox_enlarge_range, float) and len(bbox_enlarge_range) == 3 ) or isinstance( bbox_enlarge_range, float ), f"Invalid arguments bbox_enlarge_range {bbox_enlarge_range}" if isinstance(bbox_enlarge_range, float): bbox_enlarge_range = [bbox_enlarge_range] * 3 self.bbox_enlarge_range = np.array(bbox_enlarge_range, dtype=np.float32)[ np.newaxis, : ] def __call__(self, data): """Call function to filter points by the range. Args: data (dict): Result dict from loading pipeline. Returns: dict: Results after filtering, 'points', 'pts_instance_mask' \ and 'pts_semantic_mask' keys are updated in the result dict. """ points = data["points"] gt_bboxes_3d = data["gt_bboxes_3d"] # avoid groundtruth being modified gt_bboxes_3d_np = gt_bboxes_3d.tensor.clone().numpy() gt_bboxes_3d_np[:, :3] = gt_bboxes_3d.gravity_center.clone().numpy() enlarged_gt_bboxes_3d = gt_bboxes_3d_np.copy() enlarged_gt_bboxes_3d[:, 3:6] += self.bbox_enlarge_range points_numpy = points.tensor.clone().numpy() foreground_masks = box_np_ops.points_in_rbbox( points_numpy, gt_bboxes_3d_np, origin=(0.5, 0.5, 0.5) ) enlarge_foreground_masks = box_np_ops.points_in_rbbox( points_numpy, enlarged_gt_bboxes_3d, origin=(0.5, 0.5, 0.5) ) foreground_masks = foreground_masks.max(1) enlarge_foreground_masks = enlarge_foreground_masks.max(1) valid_masks = ~np.logical_and(~foreground_masks, enlarge_foreground_masks) data["points"] = points[valid_masks] return data def __repr__(self): """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f"(bbox_enlarge_range={self.bbox_enlarge_range.tolist()})" return repr_str @PIPELINES.register_module() class VoxelBasedPointSampler: """Voxel based point sampler. Apply voxel sampling to multiple sweep points. Args: cur_sweep_cfg (dict): Config for sampling current points. prev_sweep_cfg (dict): Config for sampling previous points. time_dim (int): Index that indicate the time dimention for input points. """ def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3): self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg) self.cur_voxel_num = self.cur_voxel_generator._max_voxels self.time_dim = time_dim if prev_sweep_cfg is not None: assert prev_sweep_cfg["max_num_points"] == cur_sweep_cfg["max_num_points"] self.prev_voxel_generator = VoxelGenerator(**prev_sweep_cfg) self.prev_voxel_num = self.prev_voxel_generator._max_voxels else: self.prev_voxel_generator = None self.prev_voxel_num = 0 def _sample_points(self, points, sampler, point_dim): """Sample points for each points subset. Args: points (np.ndarray): Points subset to be sampled. sampler (VoxelGenerator): Voxel based sampler for each points subset. point_dim (int): The dimention of each points Returns: np.ndarray: Sampled points. """ voxels, coors, num_points_per_voxel = sampler.generate(points) if voxels.shape[0] < sampler._max_voxels: padding_points = np.zeros( [ sampler._max_voxels - voxels.shape[0], sampler._max_num_points, point_dim, ], dtype=points.dtype, ) padding_points[:] = voxels[0] sample_points = np.concatenate([voxels, padding_points], axis=0) else: sample_points = voxels return sample_points def __call__(self, results): """Call function to sample points from multiple sweeps. Args: data (dict): Result dict from loading pipeline. Returns: dict: Results after sampling, 'points', 'pts_instance_mask' \ and 'pts_semantic_mask' keys are updated in the result dict. """ points = results["points"] original_dim = points.shape[1] # TODO: process instance and semantic mask while _max_num_points # is larger than 1 # Extend points with seg and mask fields map_fields2dim = [] start_dim = original_dim points_numpy = points.tensor.numpy() extra_channel = [points_numpy] for idx, key in enumerate(results["pts_mask_fields"]): map_fields2dim.append((key, idx + start_dim)) extra_channel.append(results[key][..., None]) start_dim += len(results["pts_mask_fields"]) for idx, key in enumerate(results["pts_seg_fields"]): map_fields2dim.append((key, idx + start_dim)) extra_channel.append(results[key][..., None]) points_numpy = np.concatenate(extra_channel, axis=-1) # Split points into two part, current sweep points and # previous sweeps points. # TODO: support different sampling methods for next sweeps points # and previous sweeps points. cur_points_flag = points_numpy[:, self.time_dim] == 0 cur_sweep_points = points_numpy[cur_points_flag] prev_sweeps_points = points_numpy[~cur_points_flag] if prev_sweeps_points.shape[0] == 0: prev_sweeps_points = cur_sweep_points # Shuffle points before sampling np.random.shuffle(cur_sweep_points) np.random.shuffle(prev_sweeps_points) cur_sweep_points = self._sample_points( cur_sweep_points, self.cur_voxel_generator, points_numpy.shape[1] ) if self.prev_voxel_generator is not None: prev_sweeps_points = self._sample_points( prev_sweeps_points, self.prev_voxel_generator, points_numpy.shape[1] ) points_numpy = np.concatenate([cur_sweep_points, prev_sweeps_points], 0) else: points_numpy = cur_sweep_points if self.cur_voxel_generator._max_num_points == 1: points_numpy = points_numpy.squeeze(1) results["points"] = points.new_point(points_numpy[..., :original_dim]) # Restore the correspoinding seg and mask fields for key, dim_index in map_fields2dim: results[key] = points_numpy[..., dim_index] return results def __repr__(self): """str: Return a string that describes the module.""" def _auto_indent(repr_str, indent): repr_str = repr_str.split("\n") repr_str = [" " * indent + t + "\n" for t in repr_str] repr_str = "".join(repr_str)[:-1] return repr_str repr_str = self.__class__.__name__ indent = 4 repr_str += "(\n" repr_str += " " * indent + f"num_cur_sweep={self.cur_voxel_num},\n" repr_str += " " * indent + f"num_prev_sweep={self.prev_voxel_num},\n" repr_str += " " * indent + f"time_dim={self.time_dim},\n" repr_str += " " * indent + "cur_voxel_generator=\n" repr_str += f"{_auto_indent(repr(self.cur_voxel_generator), 8)},\n" repr_str += " " * indent + "prev_voxel_generator=\n" repr_str += f"{_auto_indent(repr(self.prev_voxel_generator), 8)})" return repr_str @PIPELINES.register_module() class ImagePad: """Pad the multi-view image. There are two padding modes: (1) pad to a fixed size and (2) pad to the minimum size that is divisible by some number. Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", Args: size (tuple, optional): Fixed padding size. size_divisor (int, optional): The divisor of padded size. pad_val (float, optional): Padding value, 0 by default. """ def __init__(self, size=None, size_divisor=None, pad_val=0): self.size = size self.size_divisor = size_divisor self.pad_val = pad_val # only one of size and size_divisor should be valid assert size is not None or size_divisor is not None assert size is None or size_divisor is None def _pad_img(self, results): """Pad images according to ``self.size``.""" if self.size is not None: padded_img = [ mmcv.impad(img, shape=self.size, pad_val=self.pad_val) for img in results["img"] ] elif self.size_divisor is not None: padded_img = [ mmcv.impad_to_multiple(img, self.size_divisor, pad_val=self.pad_val) for img in results["img"] ] results["img"] = padded_img results["img_shape"] = [img.shape for img in padded_img] results["pad_shape"] = [img.shape for img in padded_img] results["pad_fixed_size"] = self.size results["pad_size_divisor"] = self.size_divisor def __call__(self, results): """Call function to pad images, masks, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. Returns: dict: Updated result dict. """ self._pad_img(results) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f"(size={self.size}, " repr_str += f"size_divisor={self.size_divisor}, " repr_str += f"pad_val={self.pad_val})" return repr_str @PIPELINES.register_module() class ImageNormalize: def __init__(self, mean, std): self.mean = mean self.std = std self.compose = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=mean, std=std), ] ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: data["img"] = [self.compose(img) for img in data["img"]] data["img_norm_cfg"] = dict(mean=self.mean, std=self.std) return data @PIPELINES.register_module() class ImageDistort: """Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in second or second to last. 1. random brightness 2. random contrast (mode 0) 3. convert color from BGR to HSV 4. random saturation 5. random hue 6. convert color from HSV to BGR 7. random contrast (mode 1) 8. randomly swap channels Args: brightness_delta (int): delta of brightness. contrast_range (tuple): range of contrast. saturation_range (tuple): range of saturation. hue_delta (int): delta of hue. """ def __init__( self, brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18, ): self.brightness_delta = brightness_delta self.contrast_lower, self.contrast_upper = contrast_range self.saturation_lower, self.saturation_upper = saturation_range self.hue_delta = hue_delta def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: imgs = data["img"] new_imgs = [] for img in imgs: assert img.dtype == np.float32, ( "PhotoMetricDistortion needs the input image of dtype np.float32," ' please set "to_float32=True" in "LoadImageFromFile" pipeline' ) # random brightness if random.randint(2): delta = random.uniform(-self.brightness_delta, self.brightness_delta) img += delta # mode == 0 --> do random contrast first # mode == 1 --> do random contrast last mode = random.randint(2) if mode == 1: if random.randint(2): alpha = random.uniform(self.contrast_lower, self.contrast_upper) img *= alpha # convert color from BGR to HSV img = mmcv.bgr2hsv(img) # random saturation if random.randint(2): img[..., 1] *= random.uniform( self.saturation_lower, self.saturation_upper ) # random hue if random.randint(2): img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) img[..., 0][img[..., 0] > 360] -= 360 img[..., 0][img[..., 0] < 0] += 360 # convert color from HSV to BGR img = mmcv.hsv2bgr(img) # random contrast if mode == 0: if random.randint(2): alpha = random.uniform(self.contrast_lower, self.contrast_upper) img *= alpha # randomly swap channels if random.randint(2): img = img[..., random.permutation(3)] new_imgs.append(img) data["img"] = new_imgs return data