import numpy as np import torch from logging import warning def limit_period(val, offset=0.5, period=np.pi): """Limit the value into a period for periodic function. Args: val (torch.Tensor): The value to be converted. offset (float, optional): Offset to set the value range. \ Defaults to 0.5. period ([type], optional): Period of the value. Defaults to np.pi. Returns: torch.Tensor: Value in the range of \ [-offset * period, (1-offset) * period] """ return val - torch.floor(val / period + offset) * period def rotation_3d_in_axis(points, angles, axis=0): """Rotate points by angles according to axis. Args: points (torch.Tensor): Points of shape (N, M, 3). angles (torch.Tensor): Vector of angles in shape (N,) axis (int, optional): The axis to be rotated. Defaults to 0. Raises: ValueError: when the axis is not in range [0, 1, 2], it will \ raise value error. Returns: torch.Tensor: Rotated points in shape (N, M, 3) """ rot_sin = torch.sin(angles) rot_cos = torch.cos(angles) ones = torch.ones_like(rot_cos) zeros = torch.zeros_like(rot_cos) if axis == 1: rot_mat_T = torch.stack( [ torch.stack([rot_cos, zeros, -rot_sin]), torch.stack([zeros, ones, zeros]), torch.stack([rot_sin, zeros, rot_cos]), ] ) elif axis == 2 or axis == -1: rot_mat_T = torch.stack( [ torch.stack([rot_cos, -rot_sin, zeros]), torch.stack([rot_sin, rot_cos, zeros]), torch.stack([zeros, zeros, ones]), ] ) elif axis == 0: rot_mat_T = torch.stack( [ torch.stack([zeros, rot_cos, -rot_sin]), torch.stack([zeros, rot_sin, rot_cos]), torch.stack([ones, zeros, zeros]), ] ) else: raise ValueError(f"axis should in range [0, 1, 2], got {axis}") return torch.einsum("aij,jka->aik", (points, rot_mat_T)) def xywhr2xyxyr(boxes_xywhr): """Convert a rotated boxes in XYWHR format to XYXYR format. Args: boxes_xywhr (torch.Tensor): Rotated boxes in XYWHR format. Returns: torch.Tensor: Converted boxes in XYXYR format. """ boxes = torch.zeros_like(boxes_xywhr) half_w = boxes_xywhr[:, 2] / 2 half_h = boxes_xywhr[:, 3] / 2 boxes[:, 0] = boxes_xywhr[:, 0] - half_w boxes[:, 1] = boxes_xywhr[:, 1] - half_h boxes[:, 2] = boxes_xywhr[:, 0] + half_w boxes[:, 3] = boxes_xywhr[:, 1] + half_h boxes[:, 4] = boxes_xywhr[:, 4] return boxes def get_box_type(box_type): """Get the type and mode of box structure. Args: box_type (str): The type of box structure. The valid value are "LiDAR", "Camera", or "Depth". Returns: tuple: Box type and box mode. """ from .box_3d_mode import ( Box3DMode, CameraInstance3DBoxes, DepthInstance3DBoxes, LiDARInstance3DBoxes, ) box_type_lower = box_type.lower() if box_type_lower == "lidar": box_type_3d = LiDARInstance3DBoxes box_mode_3d = Box3DMode.LIDAR elif box_type_lower == "camera": box_type_3d = CameraInstance3DBoxes box_mode_3d = Box3DMode.CAM elif box_type_lower == "depth": box_type_3d = DepthInstance3DBoxes box_mode_3d = Box3DMode.DEPTH else: raise ValueError( 'Only "box_type" of "camera", "lidar", "depth"' f" are supported, got {box_type}" ) return box_type_3d, box_mode_3d def points_cam2img(points_3d, proj_mat, with_depth=False): """Project points from camera coordicates to image coordinates. Args: points_3d (torch.Tensor): Points in shape (N, 3). proj_mat (torch.Tensor): Transformation matrix between coordinates. with_depth (bool, optional): Whether to keep depth in the output. Defaults to False. Returns: torch.Tensor: Points in image coordinates with shape [N, 2]. """ points_num = list(points_3d.shape)[:-1] points_shape = np.concatenate([points_num, [1]], axis=0).tolist() assert len(proj_mat.shape) == 2, ( "The dimension of the projection" f" matrix should be 2 instead of {len(proj_mat.shape)}." ) d1, d2 = proj_mat.shape[:2] assert (d1 == 3 and d2 == 3) or (d1 == 3 and d2 == 4) or (d1 == 4 and d2 == 4), ( "The shape of the projection matrix" f" ({d1}*{d2}) is not supported." ) if d1 == 3: proj_mat_expanded = torch.eye(4, device=proj_mat.device, dtype=proj_mat.dtype) proj_mat_expanded[:d1, :d2] = proj_mat proj_mat = proj_mat_expanded # previous implementation use new_zeros, new_one yeilds better results points_4 = torch.cat([points_3d, points_3d.new_ones(*points_shape)], dim=-1) point_2d = torch.matmul(points_4, proj_mat.t()) point_2d_res = point_2d[..., :2] / point_2d[..., 2:3] if with_depth: return torch.cat([point_2d_res, point_2d[..., 2:3]], dim=-1) return point_2d_res def mono_cam_box2vis(cam_box): """This is a post-processing function on the bboxes from Mono-3D task. If we want to perform projection visualization, we need to: 1. rotate the box along x-axis for np.pi / 2 (roll) 2. change orientation from local yaw to global yaw 3. convert yaw by (np.pi / 2 - yaw) After applying this function, we can project and draw it on 2D images. Args: cam_box (:obj:`CameraInstance3DBoxes`): 3D bbox in camera coordinate \ system before conversion. Could be gt bbox loaded from dataset or \ network prediction output. Returns: :obj:`CameraInstance3DBoxes`: Box after conversion. """ warning.warn( "DeprecationWarning: The hack of yaw and dimension in the " "monocular 3D detection on nuScenes has been removed. The " "function mono_cam_box2vis will be deprecated." ) from . import CameraInstance3DBoxes assert isinstance( cam_box, CameraInstance3DBoxes ), "input bbox should be CameraInstance3DBoxes!" loc = cam_box.gravity_center dim = cam_box.dims yaw = cam_box.yaw feats = cam_box.tensor[:, 7:] # rotate along x-axis for np.pi / 2 # see also here: https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/datasets/nuscenes_mono_dataset.py#L557 # noqa dim[:, [1, 2]] = dim[:, [2, 1]] # change local yaw to global yaw for visualization # refer to https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/datasets/nuscenes_mono_dataset.py#L164-L166 # noqa yaw += torch.atan2(loc[:, 0], loc[:, 2]) # convert yaw by (-yaw - np.pi / 2) # this is because mono 3D box class such as `NuScenesBox` has different # definition of rotation with our `CameraInstance3DBoxes` yaw = -yaw - np.pi / 2 cam_box = torch.cat([loc, dim, yaw[:, None], feats], dim=1) cam_box = CameraInstance3DBoxes( cam_box, box_dim=cam_box.shape[-1], origin=(0.5, 0.5, 0.5) ) return cam_box def get_proj_mat_by_coord_type(img_meta, coord_type): """Obtain image features using points. Args: img_meta (dict): Meta info. coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'. Can be case-insensitive. Returns: torch.Tensor: transformation matrix. """ coord_type = coord_type.upper() mapping = {"LIDAR": "lidar2image", "DEPTH": "depth2img", "CAMERA": "cam2img"} assert coord_type in mapping.keys() return img_meta[mapping[coord_type]]