bev-project/mmdet3d/datasets/pipelines/transforms_3d.py

1092 lines
38 KiB
Python

from typing import Any, Dict
import mmcv
import numpy as np
import torch
import torchvision
from mmcv import is_tuple_of
from mmcv.utils import build_from_cfg
from numpy import random
from PIL import Image
from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import (
CameraInstance3DBoxes,
DepthInstance3DBoxes,
LiDARInstance3DBoxes,
box_np_ops,
)
from mmdet.datasets.builder import PIPELINES
from ..builder import OBJECTSAMPLERS
from .utils import noise_per_object_v3_
@PIPELINES.register_module()
class GTDepth:
def __init__(self, keyframe_only=False):
self.keyframe_only = keyframe_only
def __call__(self, data):
sensor2ego = data['camera2ego'].data
cam_intrinsic = data['camera_intrinsics'].data
img_aug_matrix = data['img_aug_matrix'].data
bev_aug_matrix = data['lidar_aug_matrix'].data
lidar2ego = data['lidar2ego'].data
camera2lidar = data['camera2lidar'].data
lidar2image = data['lidar2image'].data
rots = sensor2ego[..., :3, :3]
trans = sensor2ego[..., :3, 3]
intrins = cam_intrinsic[..., :3, :3]
post_rots = img_aug_matrix[..., :3, :3]
post_trans = img_aug_matrix[..., :3, 3]
lidar2ego_rots = lidar2ego[..., :3, :3]
lidar2ego_trans = lidar2ego[..., :3, 3]
camera2lidar_rots = camera2lidar[..., :3, :3]
camera2lidar_trans = camera2lidar[..., :3, 3]
points = data['points'].data
img = data['img'].data
if self.keyframe_only:
points = points[points[:, 4] == 0]
batch_size = len(points)
depth = torch.zeros(img.shape[0], *img.shape[-2:]) #.to(points[0].device)
# for b in range(batch_size):
cur_coords = points[:, :3]
# inverse aug
cur_coords -= bev_aug_matrix[:3, 3]
cur_coords = torch.inverse(bev_aug_matrix[:3, :3]).matmul(
cur_coords.transpose(1, 0)
)
# lidar2image
cur_coords = lidar2image[:, :3, :3].matmul(cur_coords)
cur_coords += lidar2image[:, :3, 3].reshape(-1, 3, 1)
# get 2d coords
dist = cur_coords[:, 2, :]
cur_coords[:, 2, :] = torch.clamp(cur_coords[:, 2, :], 1e-5, 1e5)
cur_coords[:, :2, :] /= cur_coords[:, 2:3, :]
# imgaug
cur_coords = img_aug_matrix[:, :3, :3].matmul(cur_coords)
cur_coords += img_aug_matrix[:, :3, 3].reshape(-1, 3, 1)
cur_coords = cur_coords[:, :2, :].transpose(1, 2)
# normalize coords for grid sample
cur_coords = cur_coords[..., [1, 0]]
on_img = (
(cur_coords[..., 0] < img.shape[2])
& (cur_coords[..., 0] >= 0)
& (cur_coords[..., 1] < img.shape[3])
& (cur_coords[..., 1] >= 0)
)
for c in range(on_img.shape[0]):
masked_coords = cur_coords[c, on_img[c]].long()
masked_dist = dist[c, on_img[c]]
depth[c, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist
data['depths'] = depth
return data
@PIPELINES.register_module()
class ImageAug3D:
def __init__(
self, final_dim, resize_lim, bot_pct_lim, rot_lim, rand_flip, is_train
):
self.final_dim = final_dim
self.resize_lim = resize_lim
self.bot_pct_lim = bot_pct_lim
self.rand_flip = rand_flip
self.rot_lim = rot_lim
self.is_train = is_train
def sample_augmentation(self, results):
W, H = results["ori_shape"]
fH, fW = self.final_dim
if self.is_train:
resize = np.random.uniform(*self.resize_lim)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.random.uniform(*self.bot_pct_lim)) * newH) - fH
crop_w = int(np.random.uniform(0, max(0, newW - fW)))
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
if self.rand_flip and np.random.choice([0, 1]):
flip = True
rotate = np.random.uniform(*self.rot_lim)
else:
resize = np.mean(self.resize_lim)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.mean(self.bot_pct_lim)) * newH) - fH
crop_w = int(max(0, newW - fW) / 2)
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
rotate = 0
return resize, resize_dims, crop, flip, rotate
def img_transform(
self, img, rotation, translation, resize, resize_dims, crop, flip, rotate
):
# adjust image
img = img.resize(resize_dims)
img = img.crop(crop)
if flip:
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
img = img.rotate(rotate)
# post-homography transformation
rotation *= resize
translation -= torch.Tensor(crop[:2])
if flip:
A = torch.Tensor([[-1, 0], [0, 1]])
b = torch.Tensor([crop[2] - crop[0], 0])
rotation = A.matmul(rotation)
translation = A.matmul(translation) + b
theta = rotate / 180 * np.pi
A = torch.Tensor(
[
[np.cos(theta), np.sin(theta)],
[-np.sin(theta), np.cos(theta)],
]
)
b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
b = A.matmul(-b) + b
rotation = A.matmul(rotation)
translation = A.matmul(translation) + b
return img, rotation, translation
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
imgs = data["img"]
new_imgs = []
transforms = []
for img in imgs:
resize, resize_dims, crop, flip, rotate = self.sample_augmentation(data)
post_rot = torch.eye(2)
post_tran = torch.zeros(2)
new_img, rotation, translation = self.img_transform(
img,
post_rot,
post_tran,
resize=resize,
resize_dims=resize_dims,
crop=crop,
flip=flip,
rotate=rotate,
)
transform = torch.eye(4)
transform[:2, :2] = rotation
transform[:2, 3] = translation
new_imgs.append(new_img)
transforms.append(transform.numpy())
data["img"] = new_imgs
# update the calibration matrices
data["img_aug_matrix"] = transforms
return data
@PIPELINES.register_module()
class GlobalRotScaleTrans:
def __init__(self, resize_lim, rot_lim, trans_lim, is_train):
self.resize_lim = resize_lim
self.rot_lim = rot_lim
self.trans_lim = trans_lim
self.is_train = is_train
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
transform = np.eye(4).astype(np.float32)
if self.is_train:
scale = random.uniform(*self.resize_lim)
theta = random.uniform(*self.rot_lim)
translation = np.array([random.normal(0, self.trans_lim) for i in range(3)])
rotation = np.eye(3)
if "points" in data:
data["points"].rotate(-theta)
data["points"].translate(translation)
data["points"].scale(scale)
if "radar" in data:
data["radar"].rotate(-theta)
data["radar"].translate(translation)
data["radar"].scale(scale)
gt_boxes = data["gt_bboxes_3d"]
rotation = rotation @ gt_boxes.rotate(theta).numpy()
gt_boxes.translate(translation)
gt_boxes.scale(scale)
data["gt_bboxes_3d"] = gt_boxes
transform[:3, :3] = rotation.T * scale
transform[:3, 3] = translation * scale
data["lidar_aug_matrix"] = transform
return data
@PIPELINES.register_module()
class GridMask:
def __init__(
self,
use_h,
use_w,
max_epoch,
rotate=1,
offset=False,
ratio=0.5,
mode=0,
prob=1.0,
fixed_prob=False,
):
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.st_prob = prob
self.prob = prob
self.epoch = None
self.max_epoch = max_epoch
self.fixed_prob = fixed_prob
def set_epoch(self, epoch):
self.epoch = epoch
if not self.fixed_prob:
self.set_prob(self.epoch, self.max_epoch)
def set_prob(self, epoch, max_epoch):
self.prob = self.st_prob * self.epoch / self.max_epoch
def __call__(self, results):
if np.random.rand() > self.prob:
return results
imgs = results["img"]
h = imgs[0].shape[0]
w = imgs[0].shape[1]
self.d1 = 2
self.d2 = min(h, w)
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(self.d1, self.d2)
if self.ratio == 1:
self.l = np.random.randint(1, d)
else:
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + self.l, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + self.l, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[
(hh - h) // 2 : (hh - h) // 2 + h, (ww - w) // 2 : (ww - w) // 2 + w
]
mask = mask.astype(np.float32)
mask = mask[:, :, None]
if self.mode == 1:
mask = 1 - mask
# mask = mask.expand_as(imgs[0])
if self.offset:
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float()
offset = (1 - mask) * offset
imgs = [x * mask + offset for x in imgs]
else:
imgs = [x * mask for x in imgs]
results.update(img=imgs)
return results
@PIPELINES.register_module()
class RandomFlip3D:
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
flip_horizontal = random.choice([0, 1])
flip_vertical = random.choice([0, 1])
rotation = np.eye(3)
if flip_horizontal:
rotation = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) @ rotation
if "points" in data:
data["points"].flip("horizontal")
if "radar" in data:
data["radar"].flip("horizontal")
if "gt_bboxes_3d" in data:
data["gt_bboxes_3d"].flip("horizontal")
if "gt_masks_bev" in data:
data["gt_masks_bev"] = data["gt_masks_bev"][:, :, ::-1].copy()
if flip_vertical:
rotation = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) @ rotation
if "points" in data:
data["points"].flip("vertical")
if "radar" in data:
data["radar"].flip("vertical")
if "gt_bboxes_3d" in data:
data["gt_bboxes_3d"].flip("vertical")
if "gt_masks_bev" in data:
data["gt_masks_bev"] = data["gt_masks_bev"][:, ::-1, :].copy()
data["lidar_aug_matrix"][:3, :] = rotation @ data["lidar_aug_matrix"][:3, :]
return data
@PIPELINES.register_module()
class ObjectPaste:
"""Sample GT objects to the data.
Args:
db_sampler (dict): Config dict of the database sampler.
sample_2d (bool): Whether to also paste 2D image patch to the images
This should be true when applying multi-modality cut-and-paste.
Defaults to False.
"""
def __init__(self, db_sampler, sample_2d=False, stop_epoch=None):
self.sampler_cfg = db_sampler
self.sample_2d = sample_2d
if "type" not in db_sampler.keys():
db_sampler["type"] = "DataBaseSampler"
self.db_sampler = build_from_cfg(db_sampler, OBJECTSAMPLERS)
self.epoch = -1
self.stop_epoch = stop_epoch
def set_epoch(self, epoch):
self.epoch = epoch
@staticmethod
def remove_points_in_boxes(points, boxes):
"""Remove the points in the sampled bounding boxes.
Args:
points (:obj:`BasePoints`): Input point cloud array.
boxes (np.ndarray): Sampled ground truth boxes.
Returns:
np.ndarray: Points with those in the boxes removed.
"""
masks = box_np_ops.points_in_rbbox(points.coord.numpy(), boxes)
points = points[np.logical_not(masks.any(-1))]
return points
def __call__(self, data):
"""Call function to sample ground truth objects to the data.
Args:
data (dict): Result dict from loading pipeline.
Returns:
dict: Results after object sampling augmentation, \
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated \
in the result dict.
"""
if self.stop_epoch is not None and self.epoch >= self.stop_epoch:
return data
gt_bboxes_3d = data["gt_bboxes_3d"]
gt_labels_3d = data["gt_labels_3d"]
# change to float for blending operation
points = data["points"]
if self.sample_2d:
img = data["img"]
gt_bboxes_2d = data["gt_bboxes"]
# Assume for now 3D & 2D bboxes are the same
sampled_dict = self.db_sampler.sample_all(
gt_bboxes_3d.tensor.numpy(),
gt_labels_3d,
gt_bboxes_2d=gt_bboxes_2d,
img=img,
)
else:
sampled_dict = self.db_sampler.sample_all(
gt_bboxes_3d.tensor.numpy(), gt_labels_3d, img=None
)
if sampled_dict is not None:
sampled_gt_bboxes_3d = sampled_dict["gt_bboxes_3d"]
sampled_points = sampled_dict["points"]
sampled_gt_labels = sampled_dict["gt_labels_3d"]
gt_labels_3d = np.concatenate([gt_labels_3d, sampled_gt_labels], axis=0)
gt_bboxes_3d = gt_bboxes_3d.new_box(
np.concatenate([gt_bboxes_3d.tensor.numpy(), sampled_gt_bboxes_3d])
)
points = self.remove_points_in_boxes(points, sampled_gt_bboxes_3d)
# check the points dimension
points = points.cat([sampled_points, points])
if self.sample_2d:
sampled_gt_bboxes_2d = sampled_dict["gt_bboxes_2d"]
gt_bboxes_2d = np.concatenate(
[gt_bboxes_2d, sampled_gt_bboxes_2d]
).astype(np.float32)
data["gt_bboxes"] = gt_bboxes_2d
data["img"] = sampled_dict["img"]
data["gt_bboxes_3d"] = gt_bboxes_3d
data["gt_labels_3d"] = gt_labels_3d.astype(np.long)
data["points"] = points
return data
@PIPELINES.register_module()
class ObjectNoise:
"""Apply noise to each GT objects in the scene.
Args:
translation_std (list[float], optional): Standard deviation of the
distribution where translation noise are sampled from.
Defaults to [0.25, 0.25, 0.25].
global_rot_range (list[float], optional): Global rotation to the scene.
Defaults to [0.0, 0.0].
rot_range (list[float], optional): Object rotation range.
Defaults to [-0.15707963267, 0.15707963267].
num_try (int, optional): Number of times to try if the noise applied is
invalid. Defaults to 100.
"""
def __init__(
self,
translation_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0],
rot_range=[-0.15707963267, 0.15707963267],
num_try=100,
):
self.translation_std = translation_std
self.global_rot_range = global_rot_range
self.rot_range = rot_range
self.num_try = num_try
def __call__(self, data):
"""Call function to apply noise to each ground truth in the scene.
Args:
data (dict): Result dict from loading pipeline.
Returns:
dict: Results after adding noise to each object, \
'points', 'gt_bboxes_3d' keys are updated in the result dict.
"""
gt_bboxes_3d = data["gt_bboxes_3d"]
points = data["points"]
# TODO: check this inplace function
numpy_box = gt_bboxes_3d.tensor.numpy()
numpy_points = points.tensor.numpy()
noise_per_object_v3_(
numpy_box,
numpy_points,
rotation_perturb=self.rot_range,
center_noise_std=self.translation_std,
global_random_rot_range=self.global_rot_range,
num_try=self.num_try,
)
data["gt_bboxes_3d"] = gt_bboxes_3d.new_box(numpy_box)
data["points"] = points.new_point(numpy_points)
return data
@PIPELINES.register_module()
class FrameDropout:
def __init__(self, prob: float = 0.5, time_dim: int = -1) -> None:
self.prob = prob
self.time_dim = time_dim
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
offsets = []
for offset in torch.unique(data["points"].tensor[:, self.time_dim]):
if offset == 0 or random.random() > self.prob:
offsets.append(offset)
offsets = torch.tensor(offsets)
points = data["points"].tensor
indices = torch.isin(points[:, self.time_dim], offsets)
data["points"].tensor = points[indices]
return data
@PIPELINES.register_module()
class PointShuffle:
def __call__(self, data):
data["points"].shuffle()
return data
@PIPELINES.register_module()
class ObjectRangeFilter:
"""Filter objects by the range.
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def __call__(self, data):
"""Call function to filter objects by the range.
Args:
data (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \
keys are updated in the result dict.
"""
# Check points instance type and initialise bev_range
if isinstance(
data["gt_bboxes_3d"], (LiDARInstance3DBoxes, DepthInstance3DBoxes)
):
bev_range = self.pcd_range[[0, 1, 3, 4]]
elif isinstance(data["gt_bboxes_3d"], CameraInstance3DBoxes):
bev_range = self.pcd_range[[0, 2, 3, 5]]
gt_bboxes_3d = data["gt_bboxes_3d"]
gt_labels_3d = data["gt_labels_3d"]
mask = gt_bboxes_3d.in_range_bev(bev_range)
gt_bboxes_3d = gt_bboxes_3d[mask]
# mask is a torch tensor but gt_labels_3d is still numpy array
# using mask to index gt_labels_3d will cause bug when
# len(gt_labels_3d) == 1, where mask=1 will be interpreted
# as gt_labels_3d[1] and cause out of index error
gt_labels_3d = gt_labels_3d[mask.numpy().astype(np.bool)]
# limit rad to [-pi, pi]
gt_bboxes_3d.limit_yaw(offset=0.5, period=2 * np.pi)
data["gt_bboxes_3d"] = gt_bboxes_3d
data["gt_labels_3d"] = gt_labels_3d
return data
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(point_cloud_range={self.pcd_range.tolist()})"
return repr_str
@PIPELINES.register_module()
class PointsRangeFilter:
"""Filter points by the range.
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def __call__(self, data):
"""Call function to filter points by the range.
Args:
data (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = data["points"]
points_mask = points.in_range_3d(self.pcd_range)
clean_points = points[points_mask]
data["points"] = clean_points
if "radar" in data:
radar = data["radar"]
# radar_mask = radar.in_range_3d(self.pcd_range)
radar_mask = radar.in_range_bev([-55.0, -55.0, 55.0, 55.0])
clean_radar = radar[radar_mask]
data["radar"] = clean_radar
return data
@PIPELINES.register_module()
class ObjectNameFilter:
"""Filter GT objects by their names.
Args:
classes (list[str]): List of class names to be kept for training.
"""
def __init__(self, classes):
self.classes = classes
self.labels = list(range(len(self.classes)))
def __call__(self, data):
gt_labels_3d = data["gt_labels_3d"]
gt_bboxes_mask = np.array(
[n in self.labels for n in gt_labels_3d], dtype=np.bool_
)
data["gt_bboxes_3d"] = data["gt_bboxes_3d"][gt_bboxes_mask]
data["gt_labels_3d"] = data["gt_labels_3d"][gt_bboxes_mask]
return data
@PIPELINES.register_module()
class PointSample:
"""Point sample.
Sampling data to a certain number.
Args:
num_points (int): Number of points to be sampled.
sample_range (float, optional): The range where to sample points.
If not None, the points with depth larger than `sample_range` are
prior to be sampled. Defaults to None.
replace (bool, optional): Whether the sampling is with or without
replacement. Defaults to False.
"""
def __init__(self, num_points, sample_range=None, replace=False):
self.num_points = num_points
self.sample_range = sample_range
self.replace = replace
def _points_random_sampling(
self,
points,
num_samples,
sample_range=None,
replace=False,
return_choices=False,
):
"""Points random sampling.
Sample points to a certain number.
Args:
points (np.ndarray | :obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled.
sample_range (float, optional): Indicating the range where the
points will be sampled. Defaults to None.
replace (bool, optional): Sampling with or without replacement.
Defaults to None.
return_choices (bool, optional): Whether return choice.
Defaults to False.
Returns:
tuple[np.ndarray] | np.ndarray:
- points (np.ndarray | :obj:`BasePoints`): 3D Points.
- choices (np.ndarray, optional): The generated random samples.
"""
if not replace:
replace = points.shape[0] < num_samples
point_range = range(len(points))
if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples
depth = np.linalg.norm(points.tensor, axis=1)
far_inds = np.where(depth > sample_range)[0]
near_inds = np.where(depth <= sample_range)[0]
# in case there are too many far points
if len(far_inds) > num_samples:
far_inds = np.random.choice(far_inds, num_samples, replace=False)
point_range = near_inds
num_samples -= len(far_inds)
choices = np.random.choice(point_range, num_samples, replace=replace)
if sample_range is not None and not replace:
choices = np.concatenate((far_inds, choices))
# Shuffle points after sampling
np.random.shuffle(choices)
if return_choices:
return points[choices], choices
else:
return points[choices]
def __call__(self, data):
"""Call function to sample points to in indoor scenes.
Args:
data (dict): Result dict from loading pipeline.
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = data["points"]
# Points in Camera coord can provide the depth information.
# TODO: Need to suport distance-based sampling for other coord system.
if self.sample_range is not None:
from mmdet3d.core.points import CameraPoints
assert isinstance(
points, CameraPoints
), "Sampling based on distance is only appliable for CAMERA coord"
points, choices = self._points_random_sampling(
points,
self.num_points,
self.sample_range,
self.replace,
return_choices=True,
)
data["points"] = points
return data
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(num_points={self.num_points},"
repr_str += f" sample_range={self.sample_range},"
repr_str += f" replace={self.replace})"
return repr_str
@PIPELINES.register_module()
class BackgroundPointsFilter:
"""Filter background points near the bounding box.
Args:
bbox_enlarge_range (tuple[float], float): Bbox enlarge range.
"""
def __init__(self, bbox_enlarge_range):
assert (
is_tuple_of(bbox_enlarge_range, float) and len(bbox_enlarge_range) == 3
) or isinstance(
bbox_enlarge_range, float
), f"Invalid arguments bbox_enlarge_range {bbox_enlarge_range}"
if isinstance(bbox_enlarge_range, float):
bbox_enlarge_range = [bbox_enlarge_range] * 3
self.bbox_enlarge_range = np.array(bbox_enlarge_range, dtype=np.float32)[
np.newaxis, :
]
def __call__(self, data):
"""Call function to filter points by the range.
Args:
data (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = data["points"]
gt_bboxes_3d = data["gt_bboxes_3d"]
# avoid groundtruth being modified
gt_bboxes_3d_np = gt_bboxes_3d.tensor.clone().numpy()
gt_bboxes_3d_np[:, :3] = gt_bboxes_3d.gravity_center.clone().numpy()
enlarged_gt_bboxes_3d = gt_bboxes_3d_np.copy()
enlarged_gt_bboxes_3d[:, 3:6] += self.bbox_enlarge_range
points_numpy = points.tensor.clone().numpy()
foreground_masks = box_np_ops.points_in_rbbox(
points_numpy, gt_bboxes_3d_np, origin=(0.5, 0.5, 0.5)
)
enlarge_foreground_masks = box_np_ops.points_in_rbbox(
points_numpy, enlarged_gt_bboxes_3d, origin=(0.5, 0.5, 0.5)
)
foreground_masks = foreground_masks.max(1)
enlarge_foreground_masks = enlarge_foreground_masks.max(1)
valid_masks = ~np.logical_and(~foreground_masks, enlarge_foreground_masks)
data["points"] = points[valid_masks]
return data
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(bbox_enlarge_range={self.bbox_enlarge_range.tolist()})"
return repr_str
@PIPELINES.register_module()
class VoxelBasedPointSampler:
"""Voxel based point sampler.
Apply voxel sampling to multiple sweep points.
Args:
cur_sweep_cfg (dict): Config for sampling current points.
prev_sweep_cfg (dict): Config for sampling previous points.
time_dim (int): Index that indicate the time dimention
for input points.
"""
def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3):
self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg)
self.cur_voxel_num = self.cur_voxel_generator._max_voxels
self.time_dim = time_dim
if prev_sweep_cfg is not None:
assert prev_sweep_cfg["max_num_points"] == cur_sweep_cfg["max_num_points"]
self.prev_voxel_generator = VoxelGenerator(**prev_sweep_cfg)
self.prev_voxel_num = self.prev_voxel_generator._max_voxels
else:
self.prev_voxel_generator = None
self.prev_voxel_num = 0
def _sample_points(self, points, sampler, point_dim):
"""Sample points for each points subset.
Args:
points (np.ndarray): Points subset to be sampled.
sampler (VoxelGenerator): Voxel based sampler for
each points subset.
point_dim (int): The dimention of each points
Returns:
np.ndarray: Sampled points.
"""
voxels, coors, num_points_per_voxel = sampler.generate(points)
if voxels.shape[0] < sampler._max_voxels:
padding_points = np.zeros(
[
sampler._max_voxels - voxels.shape[0],
sampler._max_num_points,
point_dim,
],
dtype=points.dtype,
)
padding_points[:] = voxels[0]
sample_points = np.concatenate([voxels, padding_points], axis=0)
else:
sample_points = voxels
return sample_points
def __call__(self, results):
"""Call function to sample points from multiple sweeps.
Args:
data (dict): Result dict from loading pipeline.
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = results["points"]
original_dim = points.shape[1]
# TODO: process instance and semantic mask while _max_num_points
# is larger than 1
# Extend points with seg and mask fields
map_fields2dim = []
start_dim = original_dim
points_numpy = points.tensor.numpy()
extra_channel = [points_numpy]
for idx, key in enumerate(results["pts_mask_fields"]):
map_fields2dim.append((key, idx + start_dim))
extra_channel.append(results[key][..., None])
start_dim += len(results["pts_mask_fields"])
for idx, key in enumerate(results["pts_seg_fields"]):
map_fields2dim.append((key, idx + start_dim))
extra_channel.append(results[key][..., None])
points_numpy = np.concatenate(extra_channel, axis=-1)
# Split points into two part, current sweep points and
# previous sweeps points.
# TODO: support different sampling methods for next sweeps points
# and previous sweeps points.
cur_points_flag = points_numpy[:, self.time_dim] == 0
cur_sweep_points = points_numpy[cur_points_flag]
prev_sweeps_points = points_numpy[~cur_points_flag]
if prev_sweeps_points.shape[0] == 0:
prev_sweeps_points = cur_sweep_points
# Shuffle points before sampling
np.random.shuffle(cur_sweep_points)
np.random.shuffle(prev_sweeps_points)
cur_sweep_points = self._sample_points(
cur_sweep_points, self.cur_voxel_generator, points_numpy.shape[1]
)
if self.prev_voxel_generator is not None:
prev_sweeps_points = self._sample_points(
prev_sweeps_points, self.prev_voxel_generator, points_numpy.shape[1]
)
points_numpy = np.concatenate([cur_sweep_points, prev_sweeps_points], 0)
else:
points_numpy = cur_sweep_points
if self.cur_voxel_generator._max_num_points == 1:
points_numpy = points_numpy.squeeze(1)
results["points"] = points.new_point(points_numpy[..., :original_dim])
# Restore the correspoinding seg and mask fields
for key, dim_index in map_fields2dim:
results[key] = points_numpy[..., dim_index]
return results
def __repr__(self):
"""str: Return a string that describes the module."""
def _auto_indent(repr_str, indent):
repr_str = repr_str.split("\n")
repr_str = [" " * indent + t + "\n" for t in repr_str]
repr_str = "".join(repr_str)[:-1]
return repr_str
repr_str = self.__class__.__name__
indent = 4
repr_str += "(\n"
repr_str += " " * indent + f"num_cur_sweep={self.cur_voxel_num},\n"
repr_str += " " * indent + f"num_prev_sweep={self.prev_voxel_num},\n"
repr_str += " " * indent + f"time_dim={self.time_dim},\n"
repr_str += " " * indent + "cur_voxel_generator=\n"
repr_str += f"{_auto_indent(repr(self.cur_voxel_generator), 8)},\n"
repr_str += " " * indent + "prev_voxel_generator=\n"
repr_str += f"{_auto_indent(repr(self.prev_voxel_generator), 8)})"
return repr_str
@PIPELINES.register_module()
class ImagePad:
"""Pad the multi-view image.
There are two padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number.
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
Args:
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value, 0 by default.
"""
def __init__(self, size=None, size_divisor=None, pad_val=0):
self.size = size
self.size_divisor = size_divisor
self.pad_val = pad_val
# only one of size and size_divisor should be valid
assert size is not None or size_divisor is not None
assert size is None or size_divisor is None
def _pad_img(self, results):
"""Pad images according to ``self.size``."""
if self.size is not None:
padded_img = [
mmcv.impad(img, shape=self.size, pad_val=self.pad_val)
for img in results["img"]
]
elif self.size_divisor is not None:
padded_img = [
mmcv.impad_to_multiple(img, self.size_divisor, pad_val=self.pad_val)
for img in results["img"]
]
results["img"] = padded_img
results["img_shape"] = [img.shape for img in padded_img]
results["pad_shape"] = [img.shape for img in padded_img]
results["pad_fixed_size"] = self.size
results["pad_size_divisor"] = self.size_divisor
def __call__(self, results):
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self._pad_img(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(size={self.size}, "
repr_str += f"size_divisor={self.size_divisor}, "
repr_str += f"pad_val={self.pad_val})"
return repr_str
@PIPELINES.register_module()
class ImageNormalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
self.compose = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=mean, std=std),
]
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
data["img"] = [self.compose(img) for img in data["img"]]
data["img_norm_cfg"] = dict(mean=self.mean, std=self.std)
return data
@PIPELINES.register_module()
class ImageDistort:
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(
self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18,
):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
imgs = data["img"]
new_imgs = []
for img in imgs:
assert img.dtype == np.float32, (
"PhotoMetricDistortion needs the input image of dtype np.float32,"
' please set "to_float32=True" in "LoadImageFromFile" pipeline'
)
# random brightness
if random.randint(2):
delta = random.uniform(-self.brightness_delta, self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower, self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(
self.saturation_lower, self.saturation_upper
)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower, self.contrast_upper)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
new_imgs.append(img)
data["img"] = new_imgs
return data