385 lines
16 KiB
Python
385 lines
16 KiB
Python
import warnings
|
|
|
|
import numba
|
|
import numpy as np
|
|
from numba import errors
|
|
|
|
from mmdet3d.core.bbox import box_np_ops
|
|
|
|
warnings.filterwarnings("ignore", category=errors.NumbaPerformanceWarning)
|
|
|
|
|
|
@numba.njit
|
|
def _rotation_box2d_jit_(corners, angle, rot_mat_T):
|
|
"""Rotate 2D boxes.
|
|
|
|
Args:
|
|
corners (np.ndarray): Corners of boxes.
|
|
angle (float): Rotation angle.
|
|
rot_mat_T (np.ndarray): Transposed rotation matrix.
|
|
"""
|
|
rot_sin = np.sin(angle)
|
|
rot_cos = np.cos(angle)
|
|
rot_mat_T[0, 0] = rot_cos
|
|
rot_mat_T[0, 1] = -rot_sin
|
|
rot_mat_T[1, 0] = rot_sin
|
|
rot_mat_T[1, 1] = rot_cos
|
|
corners[:] = corners @ rot_mat_T
|
|
|
|
|
|
@numba.jit(nopython=True)
|
|
def box_collision_test(boxes, qboxes, clockwise=True):
|
|
"""Box collision test.
|
|
|
|
Args:
|
|
boxes (np.ndarray): Corners of current boxes.
|
|
qboxes (np.ndarray): Boxes to be avoid colliding.
|
|
clockwise (bool): Whether the corners are in clockwise order.
|
|
Default: True.
|
|
"""
|
|
N = boxes.shape[0]
|
|
K = qboxes.shape[0]
|
|
ret = np.zeros((N, K), dtype=np.bool_)
|
|
slices = np.array([1, 2, 3, 0])
|
|
lines_boxes = np.stack((boxes, boxes[:, slices, :]), axis=2) # [N, 4, 2(line), 2(xy)]
|
|
lines_qboxes = np.stack((qboxes, qboxes[:, slices, :]), axis=2)
|
|
# vec = np.zeros((2,), dtype=boxes.dtype)
|
|
boxes_standup = box_np_ops.corner_to_standup_nd_jit(boxes)
|
|
qboxes_standup = box_np_ops.corner_to_standup_nd_jit(qboxes)
|
|
for i in range(N):
|
|
for j in range(K):
|
|
# calculate standup first
|
|
iw = min(boxes_standup[i, 2], qboxes_standup[j, 2]) - max(
|
|
boxes_standup[i, 0], qboxes_standup[j, 0]
|
|
)
|
|
if iw > 0:
|
|
ih = min(boxes_standup[i, 3], qboxes_standup[j, 3]) - max(
|
|
boxes_standup[i, 1], qboxes_standup[j, 1]
|
|
)
|
|
if ih > 0:
|
|
for k in range(4):
|
|
for box_l in range(4):
|
|
A = lines_boxes[i, k, 0]
|
|
B = lines_boxes[i, k, 1]
|
|
C = lines_qboxes[j, box_l, 0]
|
|
D = lines_qboxes[j, box_l, 1]
|
|
acd = (D[1] - A[1]) * (C[0] - A[0]) > (C[1] - A[1]) * (D[0] - A[0])
|
|
bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0])
|
|
if acd != bcd:
|
|
abc = (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])
|
|
abd = (D[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (D[0] - A[0])
|
|
if abc != abd:
|
|
ret[i, j] = True # collision.
|
|
break
|
|
if ret[i, j] is True:
|
|
break
|
|
if ret[i, j] is False:
|
|
# now check complete overlap.
|
|
# box overlap qbox:
|
|
box_overlap_qbox = True
|
|
for box_l in range(4): # point l in qboxes
|
|
for k in range(4): # corner k in boxes
|
|
vec = boxes[i, k] - boxes[i, (k + 1) % 4]
|
|
if clockwise:
|
|
vec = -vec
|
|
cross = vec[1] * (boxes[i, k, 0] - qboxes[j, box_l, 0])
|
|
cross -= vec[0] * (boxes[i, k, 1] - qboxes[j, box_l, 1])
|
|
if cross >= 0:
|
|
box_overlap_qbox = False
|
|
break
|
|
if box_overlap_qbox is False:
|
|
break
|
|
|
|
if box_overlap_qbox is False:
|
|
qbox_overlap_box = True
|
|
for box_l in range(4): # point box_l in boxes
|
|
for k in range(4): # corner k in qboxes
|
|
vec = qboxes[j, k] - qboxes[j, (k + 1) % 4]
|
|
if clockwise:
|
|
vec = -vec
|
|
cross = vec[1] * (qboxes[j, k, 0] - boxes[i, box_l, 0])
|
|
cross -= vec[0] * (qboxes[j, k, 1] - boxes[i, box_l, 1])
|
|
if cross >= 0: #
|
|
qbox_overlap_box = False
|
|
break
|
|
if qbox_overlap_box is False:
|
|
break
|
|
if qbox_overlap_box:
|
|
ret[i, j] = True # collision.
|
|
else:
|
|
ret[i, j] = True # collision.
|
|
return ret
|
|
|
|
|
|
@numba.njit
|
|
def noise_per_box(boxes, valid_mask, loc_noises, rot_noises):
|
|
"""Add noise to every box (only on the horizontal plane).
|
|
|
|
Args:
|
|
boxes (np.ndarray): Input boxes with shape (N, 5).
|
|
valid_mask (np.ndarray): Mask to indicate which boxes are valid
|
|
with shape (N).
|
|
loc_noises (np.ndarray): Location noises with shape (N, M, 3).
|
|
rot_noises (np.ndarray): Rotation noises with shape (N, M).
|
|
|
|
Returns:
|
|
np.ndarray: Mask to indicate whether the noise is
|
|
added successfully (pass the collision test).
|
|
"""
|
|
num_boxes = boxes.shape[0]
|
|
num_tests = loc_noises.shape[1]
|
|
box_corners = box_np_ops.box2d_to_corner_jit(boxes)
|
|
current_corners = np.zeros((4, 2), dtype=boxes.dtype)
|
|
rot_mat_T = np.zeros((2, 2), dtype=boxes.dtype)
|
|
success_mask = -np.ones((num_boxes,), dtype=np.int64)
|
|
# print(valid_mask)
|
|
for i in range(num_boxes):
|
|
if valid_mask[i]:
|
|
for j in range(num_tests):
|
|
current_corners[:] = box_corners[i]
|
|
current_corners -= boxes[i, :2]
|
|
_rotation_box2d_jit_(current_corners, rot_noises[i, j], rot_mat_T)
|
|
current_corners += boxes[i, :2] + loc_noises[i, j, :2]
|
|
coll_mat = box_collision_test(current_corners.reshape(1, 4, 2), box_corners)
|
|
coll_mat[0, i] = False
|
|
# print(coll_mat)
|
|
if not coll_mat.any():
|
|
success_mask[i] = j
|
|
box_corners[i] = current_corners
|
|
break
|
|
return success_mask
|
|
|
|
|
|
@numba.njit
|
|
def noise_per_box_v2_(boxes, valid_mask, loc_noises, rot_noises, global_rot_noises):
|
|
"""Add noise to every box (only on the horizontal plane). Version 2 used
|
|
when enable global rotations.
|
|
|
|
Args:
|
|
boxes (np.ndarray): Input boxes with shape (N, 5).
|
|
valid_mask (np.ndarray): Mask to indicate which boxes are valid
|
|
with shape (N).
|
|
loc_noises (np.ndarray): Location noises with shape (N, M, 3).
|
|
rot_noises (np.ndarray): Rotation noises with shape (N, M).
|
|
|
|
Returns:
|
|
np.ndarray: Mask to indicate whether the noise is
|
|
added successfully (pass the collision test).
|
|
"""
|
|
num_boxes = boxes.shape[0]
|
|
num_tests = loc_noises.shape[1]
|
|
box_corners = box_np_ops.box2d_to_corner_jit(boxes)
|
|
current_corners = np.zeros((4, 2), dtype=boxes.dtype)
|
|
current_box = np.zeros((1, 5), dtype=boxes.dtype)
|
|
rot_mat_T = np.zeros((2, 2), dtype=boxes.dtype)
|
|
dst_pos = np.zeros((2,), dtype=boxes.dtype)
|
|
success_mask = -np.ones((num_boxes,), dtype=np.int64)
|
|
corners_norm = np.zeros((4, 2), dtype=boxes.dtype)
|
|
corners_norm[1, 1] = 1.0
|
|
corners_norm[2] = 1.0
|
|
corners_norm[3, 0] = 1.0
|
|
corners_norm -= np.array([0.5, 0.5], dtype=boxes.dtype)
|
|
corners_norm = corners_norm.reshape(4, 2)
|
|
for i in range(num_boxes):
|
|
if valid_mask[i]:
|
|
for j in range(num_tests):
|
|
current_box[0, :] = boxes[i]
|
|
current_radius = np.sqrt(boxes[i, 0] ** 2 + boxes[i, 1] ** 2)
|
|
current_grot = np.arctan2(boxes[i, 0], boxes[i, 1])
|
|
dst_grot = current_grot + global_rot_noises[i, j]
|
|
dst_pos[0] = current_radius * np.sin(dst_grot)
|
|
dst_pos[1] = current_radius * np.cos(dst_grot)
|
|
current_box[0, :2] = dst_pos
|
|
current_box[0, -1] += dst_grot - current_grot
|
|
|
|
rot_sin = np.sin(current_box[0, -1])
|
|
rot_cos = np.cos(current_box[0, -1])
|
|
rot_mat_T[0, 0] = rot_cos
|
|
rot_mat_T[0, 1] = -rot_sin
|
|
rot_mat_T[1, 0] = rot_sin
|
|
rot_mat_T[1, 1] = rot_cos
|
|
current_corners[:] = (
|
|
current_box[0, 2:4] * corners_norm @ rot_mat_T + current_box[0, :2]
|
|
)
|
|
current_corners -= current_box[0, :2]
|
|
_rotation_box2d_jit_(current_corners, rot_noises[i, j], rot_mat_T)
|
|
current_corners += current_box[0, :2] + loc_noises[i, j, :2]
|
|
coll_mat = box_collision_test(current_corners.reshape(1, 4, 2), box_corners)
|
|
coll_mat[0, i] = False
|
|
if not coll_mat.any():
|
|
success_mask[i] = j
|
|
box_corners[i] = current_corners
|
|
loc_noises[i, j, :2] += dst_pos - boxes[i, :2]
|
|
rot_noises[i, j] += dst_grot - current_grot
|
|
break
|
|
return success_mask
|
|
|
|
|
|
def _select_transform(transform, indices):
|
|
"""Select transform.
|
|
|
|
Args:
|
|
transform (np.ndarray): Transforms to select from.
|
|
indices (np.ndarray): Mask to indicate which transform to select.
|
|
|
|
Returns:
|
|
np.ndarray: Selected transforms.
|
|
"""
|
|
result = np.zeros((transform.shape[0], *transform.shape[2:]), dtype=transform.dtype)
|
|
for i in range(transform.shape[0]):
|
|
if indices[i] != -1:
|
|
result[i] = transform[i, indices[i]]
|
|
return result
|
|
|
|
|
|
@numba.njit
|
|
def _rotation_matrix_3d_(rot_mat_T, angle, axis):
|
|
"""Get the 3D rotation matrix.
|
|
|
|
Args:
|
|
rot_mat_T (np.ndarray): Transposed rotation matrix.
|
|
angle (float): Rotation angle.
|
|
axis (int): Rotation axis.
|
|
"""
|
|
rot_sin = np.sin(angle)
|
|
rot_cos = np.cos(angle)
|
|
rot_mat_T[:] = np.eye(3)
|
|
if axis == 1:
|
|
rot_mat_T[0, 0] = rot_cos
|
|
rot_mat_T[0, 2] = -rot_sin
|
|
rot_mat_T[2, 0] = rot_sin
|
|
rot_mat_T[2, 2] = rot_cos
|
|
elif axis == 2 or axis == -1:
|
|
rot_mat_T[0, 0] = rot_cos
|
|
rot_mat_T[0, 1] = -rot_sin
|
|
rot_mat_T[1, 0] = rot_sin
|
|
rot_mat_T[1, 1] = rot_cos
|
|
elif axis == 0:
|
|
rot_mat_T[1, 1] = rot_cos
|
|
rot_mat_T[1, 2] = -rot_sin
|
|
rot_mat_T[2, 1] = rot_sin
|
|
rot_mat_T[2, 2] = rot_cos
|
|
|
|
|
|
@numba.njit
|
|
def points_transform_(points, centers, point_masks, loc_transform, rot_transform, valid_mask):
|
|
"""Apply transforms to points and box centers.
|
|
|
|
Args:
|
|
points (np.ndarray): Input points.
|
|
centers (np.ndarray): Input box centers.
|
|
point_masks (np.ndarray): Mask to indicate which points need
|
|
to be transformed.
|
|
loc_transform (np.ndarray): Location transform to be applied.
|
|
rot_transform (np.ndarray): Rotation transform to be applied.
|
|
valid_mask (np.ndarray): Mask to indicate which boxes are valid.
|
|
"""
|
|
num_box = centers.shape[0]
|
|
num_points = points.shape[0]
|
|
rot_mat_T = np.zeros((num_box, 3, 3), dtype=points.dtype)
|
|
for i in range(num_box):
|
|
_rotation_matrix_3d_(rot_mat_T[i], rot_transform[i], 2)
|
|
for i in range(num_points):
|
|
for j in range(num_box):
|
|
if valid_mask[j]:
|
|
if point_masks[i, j] == 1:
|
|
points[i, :3] -= centers[j, :3]
|
|
points[i : i + 1, :3] = points[i : i + 1, :3] @ rot_mat_T[j]
|
|
points[i, :3] += centers[j, :3]
|
|
points[i, :3] += loc_transform[j]
|
|
break # only apply first box's transform
|
|
|
|
|
|
@numba.njit
|
|
def box3d_transform_(boxes, loc_transform, rot_transform, valid_mask):
|
|
"""Transform 3D boxes.
|
|
|
|
Args:
|
|
boxes (np.ndarray): 3D boxes to be transformed.
|
|
loc_transform (np.ndarray): Location transform to be applied.
|
|
rot_transform (np.ndarray): Rotation transform to be applied.
|
|
valid_mask (np.ndarray | None): Mask to indicate which boxes are valid.
|
|
"""
|
|
num_box = boxes.shape[0]
|
|
for i in range(num_box):
|
|
if valid_mask[i]:
|
|
boxes[i, :3] += loc_transform[i]
|
|
boxes[i, 6] += rot_transform[i]
|
|
|
|
|
|
def noise_per_object_v3_(
|
|
gt_boxes,
|
|
points=None,
|
|
valid_mask=None,
|
|
rotation_perturb=np.pi / 4,
|
|
center_noise_std=1.0,
|
|
global_random_rot_range=np.pi / 4,
|
|
num_try=100,
|
|
):
|
|
"""Random rotate or remove each groundtruth independently. use kitti viewer
|
|
to test this function points_transform_
|
|
|
|
Args:
|
|
gt_boxes (np.ndarray): Ground truth boxes with shape (N, 7).
|
|
points (np.ndarray | None): Input point cloud with shape (M, 4).
|
|
Default: None.
|
|
valid_mask (np.ndarray | None): Mask to indicate which boxes are valid.
|
|
Default: None.
|
|
rotation_perturb (float): Rotation perturbation. Default: pi / 4.
|
|
center_noise_std (float): Center noise standard deviation.
|
|
Default: 1.0.
|
|
global_random_rot_range (float): Global random rotation range.
|
|
Default: pi/4.
|
|
num_try (int): Number of try. Default: 100.
|
|
"""
|
|
num_boxes = gt_boxes.shape[0]
|
|
if not isinstance(rotation_perturb, (list, tuple, np.ndarray)):
|
|
rotation_perturb = [-rotation_perturb, rotation_perturb]
|
|
if not isinstance(global_random_rot_range, (list, tuple, np.ndarray)):
|
|
global_random_rot_range = [-global_random_rot_range, global_random_rot_range]
|
|
enable_grot = np.abs(global_random_rot_range[0] - global_random_rot_range[1]) >= 1e-3
|
|
|
|
if not isinstance(center_noise_std, (list, tuple, np.ndarray)):
|
|
center_noise_std = [center_noise_std, center_noise_std, center_noise_std]
|
|
if valid_mask is None:
|
|
valid_mask = np.ones((num_boxes,), dtype=np.bool_)
|
|
center_noise_std = np.array(center_noise_std, dtype=gt_boxes.dtype)
|
|
|
|
loc_noises = np.random.normal(scale=center_noise_std, size=[num_boxes, num_try, 3])
|
|
rot_noises = np.random.uniform(
|
|
rotation_perturb[0], rotation_perturb[1], size=[num_boxes, num_try]
|
|
)
|
|
gt_grots = np.arctan2(gt_boxes[:, 0], gt_boxes[:, 1])
|
|
grot_lowers = global_random_rot_range[0] - gt_grots
|
|
grot_uppers = global_random_rot_range[1] - gt_grots
|
|
global_rot_noises = np.random.uniform(
|
|
grot_lowers[..., np.newaxis], grot_uppers[..., np.newaxis], size=[num_boxes, num_try]
|
|
)
|
|
|
|
origin = (0.5, 0.5, 0)
|
|
gt_box_corners = box_np_ops.center_to_corner_box3d(
|
|
gt_boxes[:, :3], gt_boxes[:, 3:6], gt_boxes[:, 6], origin=origin, axis=2
|
|
)
|
|
|
|
# TODO: rewrite this noise box function?
|
|
if not enable_grot:
|
|
selected_noise = noise_per_box(
|
|
gt_boxes[:, [0, 1, 3, 4, 6]], valid_mask, loc_noises, rot_noises
|
|
)
|
|
else:
|
|
selected_noise = noise_per_box_v2_(
|
|
gt_boxes[:, [0, 1, 3, 4, 6]], valid_mask, loc_noises, rot_noises, global_rot_noises
|
|
)
|
|
|
|
loc_transforms = _select_transform(loc_noises, selected_noise)
|
|
rot_transforms = _select_transform(rot_noises, selected_noise)
|
|
surfaces = box_np_ops.corner_to_surfaces_3d_jit(gt_box_corners)
|
|
if points is not None:
|
|
# TODO: replace this points_in_convex function by my tools?
|
|
point_masks = box_np_ops.points_in_convex_polygon_3d_jit(points[:, :3], surfaces)
|
|
points_transform_(
|
|
points, gt_boxes[:, :3], point_masks, loc_transforms, rot_transforms, valid_mask
|
|
)
|
|
|
|
box3d_transform_(gt_boxes, loc_transforms, rot_transforms, valid_mask)
|