Add BEVFusion-R (#443)

This commit is contained in:
Kevin Shao 2023-07-07 21:53:36 -05:00 committed by GitHub
parent db75150717
commit d0152cf97c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 2484 additions and 322 deletions

5
.gitignore vendored
View File

@ -127,3 +127,8 @@ dmypy.json
# Pyre type checker
.pyre/
# Models and data
models/*
data/*
runs/*

View File

@ -183,6 +183,9 @@ train_pipeline:
- lidar2image
- img_aug_matrix
- lidar_aug_matrix
-
type: GTDepth
keyframe_only: true
test_pipeline:
-
@ -256,6 +259,9 @@ test_pipeline:
- lidar2image
- img_aug_matrix
- lidar_aug_matrix
-
type: GTDepth
keyframe_only: true
data:
samples_per_gpu: 4

View File

@ -1,3 +1,5 @@
gt_paste_stop_epoch: 15
model:
heads:
object:

View File

@ -0,0 +1,72 @@
data:
train:
dataset:
ann_file: nuscenes_radar/nuscenes_radar_infos_train_radar.pkl
val:
ann_file: nuscenes_radar/nuscenes_radar_infos_val_radar.pkl
test:
ann_file: nuscenes_radar/nuscenes_radar_infos_val_radar.pkl
augment2d:
resize: [[0.38, 0.55], [0.48, 0.48]]
augment3d:
scale: [0.9, 1.1]
rotate: [0, 0]
translate: 0.5
model:
encoders:
lidar: null
camera:
vtransform:
type: LSSTransform
image_size: ${image_size}
xbound: [-51.2, 51.2, 0.8]
ybound: [-51.2, 51.2, 0.8]
zbound: [-10.0, 10.0, 20.0]
dbound: [1.0, 60.0, 1.0]
radar:
voxelize_reduce: false
voxelize:
max_num_points: 20
point_cloud_range: ${point_cloud_range}
voxel_size: ${radar_voxel_size}
max_voxels: [30000, 60000]
backbone:
type: RadarEncoder
pts_voxel_encoder:
type: RadarFeatureNet
in_channels: 45
feat_channels: [128, 128, 128, 64]
with_distance: false
point_cloud_range: ${point_cloud_range}
voxel_size: ${radar_voxel_size}
norm_cfg:
type: BN1d
eps: 1.0e-3
momentum: 0.01
pts_middle_encoder:
type: PointPillarsScatter
in_channels: 64
output_shape: [128, 128]
pts_bev_encoder: null
heads:
object:
test_cfg:
nms_type:
- circle
- rotate
- rotate
- circle
- rotate
- rotate
nms_scale:
- [1.0]
- [1.0, 1.0]
- [1.0, 1.0]
- [1.0]
- [1.0, 1.0]
- [2.5, 4.0]

View File

@ -0,0 +1,83 @@
image_size: [256, 704]
model:
encoders:
camera:
backbone:
type: ResNet
depth: 50
num_stages: 4
out_indices: [0, 1, 2, 3]
norm_cfg:
type: BN2d
requires_grad: true
norm_eval: false
init_cfg:
type: Pretrained
checkpoint: torchvision://resnet50
neck:
type: SECONDFPN
in_channels: [256, 512, 1024, 2048]
out_channels: [128, 128, 128, 128]
upsample_strides: [0.25, 0.5, 1, 2]
vtransform:
type: LSSTransform
in_channels: 512
out_channels: 64
image_size: ${image_size}
feature_size: ${[image_size[0] // 16, image_size[1] // 16]}
xbound: [-51.2, 51.2, 0.8]
ybound: [-51.2, 51.2, 0.8]
zbound: [-10.0, 10.0, 20.0]
dbound: [1.0, 60.0, 1.0]
downsample: 1
decoder:
backbone:
type: GeneralizedResNet
in_channels: 64
blocks:
- [2, 128, 2]
- [2, 256, 2]
- [2, 512, 1]
neck:
type: LSSFPN
in_indices: [-1, 0]
in_channels: [512, 128]
out_channels: 256
scale_factor: 2
heads:
object:
train_cfg:
code_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
fuser:
type: ConvFuser
in_channels: [64, 64]
out_channels: 64
optimizer:
paramwise_cfg:
custom_keys:
absolute_pos_embed:
decay_mult: 0
relative_position_bias_table:
decay_mult: 0
# encoders.camera.backbone:
# lr_mult: 0.1
# lr_config:
# policy: cyclic
# target_ratio: 5.0
# cyclic_times: 1
# step_ratio_up: 0.4
# momentum_config:
# policy: cyclic
# cyclic_times: 1
# step_ratio_up: 0.4
data:
samples_per_gpu: 4

View File

@ -0,0 +1,13 @@
model:
encoders:
camera:
vtransform:
type: AwareDBEVDepth
bevdepth_downsample: 16
bevdepth_refine: false
depth_loss_factor: 3.0
use_points: radar
depth_input: one-hot
height_expand: true
add_depth_features: true

View File

@ -0,0 +1,16 @@
model:
encoders:
camera:
vtransform:
type: AwareBEVDepth
bevdepth_downsample: 16
bevdepth_refine: false
depth_loss_factor: 3.0
in_channels: 512
out_channels: 64
feature_size: ${[image_size[0] // 16, image_size[1] // 16]}
xbound: [-51.2, 51.2, 0.8]
ybound: [-51.2, 51.2, 0.8]
zbound: [-10.0, 10.0, 20.0]
dbound: [1.0, 60.0, 1.0]
downsample: 1

View File

@ -0,0 +1,45 @@
model:
encoders:
camera:
backbone:
type: ResNet
depth: 50
num_stages: 4
out_indices: [0, 1, 2, 3]
norm_cfg:
type: BN2d
requires_grad: true
norm_eval: false
init_cfg:
type: Pretrained
checkpoint: torchvision://resnet50
neck:
type: SECONDFPN
in_channels: [256, 512, 1024, 2048]
out_channels: [128, 128, 128, 128]
upsample_strides: [0.25, 0.5, 1, 2]
vtransform:
type: LSSTransform
in_channels: 512
out_channels: 64
image_size: ${image_size}
feature_size: ${[image_size[0] // 16, image_size[1] // 16]}
xbound: [-51.2, 51.2, 0.8]
ybound: [-51.2, 51.2, 0.8]
zbound: [-10.0, 10.0, 20.0]
dbound: [1.0, 60.0, 1.0]
downsample: 1
decoder:
backbone:
type: GeneralizedResNet
in_channels: 64
blocks:
- [2, 128, 2]
- [2, 256, 2]
- [2, 512, 1]
neck:
type: LSSFPN
in_indices: [-1, 0]
in_channels: [512, 128]
out_channels: 256
scale_factor: 2

View File

@ -82,4 +82,4 @@ momentum_config:
step_ratio_up: 0.4
data:
samples_per_gpu: 6
samples_per_gpu: 4

View File

@ -1,3 +1,4 @@
augment3d:
scale: [0.95, 1.05]
rotate: [-0.3925, 0.3925]
@ -31,5 +32,3 @@ model:
- [1.0]
- [1.0, 1.0]
- [2.5, 4.0]
lr_config: null

View File

@ -1,3 +1,13 @@
# data:
# train:
# dataset:
# ann_file: nuscenes_radar/nuscenes_radar_infos_train_radar.pkl
# val:
# ann_file: nuscenes_radar/nuscenes_radar_infos_val_radar.pkl
# test:
# ann_file: nuscenes_radar/nuscenes_radar_infos_val_radar.pkl
model:
decoder:
backbone:
@ -21,6 +31,12 @@ optimizer:
type: AdamW
lr: 2.0e-4
weight_decay: 0.01
paramwise_cfg:
custom_keys:
absolute_pos_embed:
decay_mult: 0
relative_position_bias_table:
decay_mult: 0
optimizer_config:
grad_clip:
@ -32,4 +48,4 @@ lr_config:
warmup: linear
warmup_iters: 500
warmup_ratio: 0.33333333
min_lr_ratio: 1.0e-3
min_lr_ratio: 1.0e-3

View File

@ -1,6 +1,236 @@
radar_sweeps: 6
radar_max_points: 2500
radar_use_dims: [0, 1, 2, 5, 8, 9, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56]
radar_compensate_velocity: true
radar_filtering: none
radar_voxel_size: [0.8, 0.8, 8]
image_size: [256, 704]
radar_jitter: 0
radar_normalize: false
model:
type: BEVFusion
encoders: null
fuser: null
heads:
map: null
train_pipeline:
-
type: LoadMultiViewImageFromFiles
to_float32: true
-
type: LoadPointsFromFile
coord_type: LIDAR
load_dim: ${load_dim}
use_dim: ${use_dim}
reduce_beams: ${reduce_beams}
load_augmented: ${load_augmented}
-
type: LoadPointsFromMultiSweeps
sweeps_num: 0
load_dim: ${load_dim}
use_dim: ${use_dim}
reduce_beams: ${reduce_beams}
pad_empty_sweeps: true
remove_close: true
load_augmented: ${load_augmented}
-
type: LoadRadarPointsMultiSweeps
load_dim: 18
sweeps_num: ${radar_sweeps}
use_dim: ${radar_use_dims}
max_num: ${radar_max_points}
compensate_velocity: ${radar_compensate_velocity}
filtering: ${radar_filtering}
normalize: ${radar_normalize}
-
type: LoadAnnotations3D
with_bbox_3d: true
with_label_3d: true
with_attr_label: False
-
type: ObjectPaste
stop_epoch: ${gt_paste_stop_epoch}
db_sampler:
dataset_root: ${dataset_root}
info_path: ${'data/nuscenes/' + "nuscenes_dbinfos_train.pkl"}
rate: 1.0
prepare:
filter_by_difficulty: [-1]
filter_by_min_points:
car: 5
truck: 5
bus: 5
trailer: 5
construction_vehicle: 5
traffic_cone: 5
barrier: 5
motorcycle: 5
bicycle: 5
pedestrian: 5
classes: ${object_classes}
sample_groups:
car: 2
truck: 3
construction_vehicle: 7
bus: 4
trailer: 6
barrier: 2
motorcycle: 6
bicycle: 6
pedestrian: 2
traffic_cone: 2
points_loader:
type: LoadPointsFromFile
coord_type: LIDAR
load_dim: ${load_dim}
use_dim: ${use_dim}
reduce_beams: ${reduce_beams}
-
type: ImageAug3D
final_dim: ${image_size}
resize_lim: ${augment2d.resize[0]}
bot_pct_lim: [0.0, 0.0]
rot_lim: ${augment2d.rotate}
rand_flip: true
is_train: true
-
type: GlobalRotScaleTrans
resize_lim: ${augment3d.scale}
rot_lim: ${augment3d.rotate}
trans_lim: ${augment3d.translate}
is_train: true
-
type: RandomFlip3D
-
type: PointsRangeFilter
point_cloud_range: ${point_cloud_range}
-
type: ObjectRangeFilter
point_cloud_range: ${point_cloud_range}
-
type: ObjectNameFilter
classes: ${object_classes}
-
type: ImageNormalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
-
type: GridMask
use_h: true
use_w: true
max_epoch: ${max_epochs}
rotate: 1
offset: false
ratio: 0.5
mode: 1
prob: ${augment2d.gridmask.prob}
fixed_prob: ${augment2d.gridmask.fixed_prob}
-
type: PointShuffle
-
type: DefaultFormatBundle3D
classes: ${object_classes}
-
type: Collect3D
keys:
- img
- points
- radar
- gt_bboxes_3d
- gt_labels_3d
meta_keys:
- camera_intrinsics
- camera2ego
- lidar2ego
- lidar2camera
- lidar2image
- camera2lidar
- img_aug_matrix
- lidar_aug_matrix
-
type: GTDepth
keyframe_only: true
test_pipeline:
-
type: LoadMultiViewImageFromFiles
to_float32: true
-
type: LoadPointsFromFile
coord_type: LIDAR
load_dim: ${load_dim}
use_dim: ${use_dim}
reduce_beams: ${reduce_beams}
load_augmented: ${load_augmented}
-
type: LoadPointsFromMultiSweeps
sweeps_num: 9
load_dim: ${load_dim}
use_dim: ${use_dim}
reduce_beams: ${reduce_beams}
pad_empty_sweeps: true
remove_close: true
load_augmented: ${load_augmented}
-
type: LoadRadarPointsMultiSweeps
load_dim: 18
sweeps_num: ${radar_sweeps}
use_dim: ${radar_use_dims}
max_num: ${radar_max_points}
compensate_velocity: ${radar_compensate_velocity}
filtering: ${radar_filtering}
normalize: ${radar_normalize}
-
type: LoadAnnotations3D
with_bbox_3d: true
with_label_3d: true
with_attr_label: False
-
type: ImageAug3D
final_dim: ${image_size}
resize_lim: ${augment2d.resize[1]}
bot_pct_lim: [0.0, 0.0]
rot_lim: [0.0, 0.0]
rand_flip: false
is_train: false
-
type: GlobalRotScaleTrans
resize_lim: [1.0, 1.0]
rot_lim: [0.0, 0.0]
trans_lim: 0.0
is_train: false
-
type: PointsRangeFilter
point_cloud_range: ${point_cloud_range}
-
type: ImageNormalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
-
type: DefaultFormatBundle3D
classes: ${object_classes}
-
type: Collect3D
keys:
- img
- points
- radar
- gt_bboxes_3d
- gt_labels_3d
#- gt_masks_bev
meta_keys:
- camera_intrinsics
- camera2ego
- lidar2ego
- lidar2camera
- lidar2image
- camera2lidar
- img_aug_matrix
- lidar_aug_matrix
-
type: GTDepth
keyframe_only: true

View File

@ -2,8 +2,9 @@ from .base_points import BasePoints
from .cam_points import CameraPoints
from .depth_points import DepthPoints
from .lidar_points import LiDARPoints
from .radar_points import RadarPoints
__all__ = ["BasePoints", "CameraPoints", "DepthPoints", "LiDARPoints"]
__all__ = ["BasePoints", "CameraPoints", "DepthPoints", "LiDARPoints", "RadarPoints"]
def get_points_type(points_type):

View File

@ -0,0 +1,124 @@
from .base_points import BasePoints
import torch
class RadarPoints(BasePoints):
"""Points of instances in LIDAR coordinates.
Args:
tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix.
points_dim (int): Number of the dimension of a point.
Each row is (x, y, z). Default to 3.
attribute_dims (dict): Dictionary to indicate the meaning of extra
dimension. Default to None.
Attributes:
tensor (torch.Tensor): Float matrix of N x points_dim.
points_dim (int): Integer indicating the dimension of a point.
Each row is (x, y, z, ...).
attribute_dims (bool): Dictionary to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
def __init__(self, tensor, points_dim=3, attribute_dims=None):
super(RadarPoints, self).__init__(
tensor, points_dim=points_dim, attribute_dims=attribute_dims
)
self.rotation_axis = 2
def flip(self, bev_direction="horizontal"):
"""Flip the boxes in BEV along given BEV direction."""
if bev_direction == "horizontal":
self.tensor[:, 1] = -self.tensor[:, 1]
self.tensor[:, 4] = -self.tensor[:, 4]
elif bev_direction == "vertical":
self.tensor[:, 0] = -self.tensor[:, 0]
self.tensor[:, 3] = -self.tensor[:, 3]
def jitter(self, amount):
jitter_noise = torch.randn(self.tensor.shape[0], 3)
jitter_noise *= amount
self.tensor[:, :3] += jitter_noise
def scale(self, scale_factor):
"""Scale the points with horizontal and vertical scaling factors.
Args:
scale_factors (float): Scale factors to scale the points.
"""
self.tensor[:, :3] *= scale_factor
self.tensor[:, 3:5] *= scale_factor
def rotate(self, rotation, axis=None):
"""Rotate points with the given rotation matrix or angle.
Args:
rotation (float, np.ndarray, torch.Tensor): Rotation matrix
or angle.
axis (int): Axis to rotate at. Defaults to None.
"""
if not isinstance(rotation, torch.Tensor):
rotation = self.tensor.new_tensor(rotation)
assert (
rotation.shape == torch.Size([3, 3]) or rotation.numel() == 1
), f"invalid rotation shape {rotation.shape}"
if axis is None:
axis = self.rotation_axis
if rotation.numel() == 1:
rot_sin = torch.sin(rotation)
rot_cos = torch.cos(rotation)
if axis == 1:
rot_mat_T = rotation.new_tensor(
[[rot_cos, 0, -rot_sin], [0, 1, 0], [rot_sin, 0, rot_cos]]
)
elif axis == 2 or axis == -1:
rot_mat_T = rotation.new_tensor(
[[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]
)
elif axis == 0:
rot_mat_T = rotation.new_tensor(
[[0, rot_cos, -rot_sin], [0, rot_sin, rot_cos], [1, 0, 0]]
)
else:
raise ValueError("axis should in range")
rot_mat_T = rot_mat_T.T
elif rotation.numel() == 9:
rot_mat_T = rotation
else:
raise NotImplementedError
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
self.tensor[:, 3:5] = self.tensor[:, 3:5] @ rot_mat_T[:2, :2]
return rot_mat_T
def in_range_bev(self, point_range):
"""Check whether the points are in the given range.
Args:
point_range (list | torch.Tensor): The range of point
in order of (x_min, y_min, x_max, y_max).
Returns:
torch.Tensor: Indicating whether each point is inside \
the reference range.
"""
in_range_flags = (
(self.tensor[:, 0] > point_range[0])
& (self.tensor[:, 1] > point_range[1])
& (self.tensor[:, 0] < point_range[2])
& (self.tensor[:, 1] < point_range[3])
)
return in_range_flags
def convert_to(self, dst, rt_mat=None):
"""Convert self to ``dst`` mode.
Args:
dst (:obj:`CoordMode`): The target Point mode.
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates
usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix.
Returns:
:obj:`BasePoints`: The converted point of the same type \
in the `dst` mode.
"""
from mmdet3d.core.bbox import Coord3DMode
return Coord3DMode.convert_point(point=self, src=Coord3DMode.LIDAR, dst=dst, rt_mat=rt_mat)

View File

@ -215,9 +215,15 @@ class NuScenesDataset(Custom3DDataset):
lidar_path=info["lidar_path"],
sweeps=info["sweeps"],
timestamp=info["timestamp"],
location=info["location"],
location=info.get('location', None),
radar=info.get('radars', None),
)
if data['location'] is None:
data.pop('location')
if data['radar'] is None:
data.pop('radar')
# ego to global transform
ego2global = np.eye(4).astype(np.float32)
ego2global[:3, :3] = Quaternion(info["ego2global_rotation"]).rotation_matrix
@ -253,7 +259,7 @@ class NuScenesDataset(Custom3DDataset):
# camera intrinsics
camera_intrinsics = np.eye(4).astype(np.float32)
camera_intrinsics[:3, :3] = camera_info["camera_intrinsics"]
camera_intrinsics[:3, :3] = camera_info["cam_intrinsic"]
data["camera_intrinsics"].append(camera_intrinsics)
# lidar to image transform

View File

@ -52,6 +52,9 @@ class DefaultFormatBundle3D:
assert isinstance(results["points"], BasePoints)
results["points"] = DC(results["points"].tensor)
if "radar" in results:
results["radar"] = DC(results["radar"].tensor)
for key in ["voxels", "coors", "voxel_centers", "num_points"]:
if key not in results:
continue

View File

@ -3,15 +3,18 @@ from typing import Any, Dict, Tuple
import mmcv
import numpy as np
from mmdet3d.core.points import RadarPoints
from nuscenes.utils.data_classes import RadarPointCloud
from nuscenes.map_expansion.map_api import NuScenesMap
from nuscenes.map_expansion.map_api import locations as LOCATIONS
from PIL import Image
from mmdet3d.core.points import BasePoints, get_points_type
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadAnnotations
import torch
from .loading_utils import load_augmented_point_cloud, reduce_LiDAR_beams
@ -180,6 +183,7 @@ class LoadPointsFromMultiSweeps:
cloud arrays.
"""
points = results["points"]
points = points[:, self.use_dim]
points.tensor[:, 4] = 0
sweep_points_list = [points]
ts = results["timestamp"] / 1e6
@ -216,6 +220,7 @@ class LoadPointsFromMultiSweeps:
if self.remove_close:
points_sweep = self._remove_close(points_sweep)
points_sweep = points_sweep[:, self.use_dim]
sweep_ts = sweep["timestamp"] / 1e6
points_sweep[:, :3] = (
points_sweep[:, :3] @ sweep["sensor2lidar_rotation"].T
@ -226,7 +231,7 @@ class LoadPointsFromMultiSweeps:
sweep_points_list.append(points_sweep)
points = points.cat(sweep_points_list)
points = points[:, self.use_dim]
results["points"] = points
return results
@ -556,3 +561,234 @@ class LoadAnnotations3D(LoadAnnotations):
results = self._load_attr_labels(results)
return results
@PIPELINES.register_module()
class NormalizePointFeatures:
def __call__(self, results):
points = results["points"]
points.tensor[:, 3] = torch.tanh(points.tensor[:, 3])
results["points"] = points
return results
@PIPELINES.register_module()
class LoadRadarPointsMultiSweeps(object):
"""Load radar points from multiple sweeps.
This is usually used for nuScenes dataset to utilize previous sweeps.
Args:
sweeps_num (int): Number of sweeps. Defaults to 10.
load_dim (int): Dimension number of the loaded points. Defaults to 5.
use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4].
file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
pad_empty_sweeps (bool): Whether to repeat keyframe when
sweeps is empty. Defaults to False.
remove_close (bool): Whether to remove close points.
Defaults to False.
test_mode (bool): If test_model=True used for testing, it will not
randomly sample sweeps but select the nearest N frames.
Defaults to False.
"""
def __init__(self,
load_dim=18,
use_dim=[0, 1, 2, 3, 4],
sweeps_num=3,
file_client_args=dict(backend='disk'),
max_num=300,
pc_range=[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0],
compensate_velocity=False,
normalize_dims=[(3, 0, 50), (4, -100, 100), (5, -100, 100)],
filtering='default',
normalize=False,
test_mode=False):
self.load_dim = load_dim
self.use_dim = use_dim
self.sweeps_num = sweeps_num
self.file_client_args = file_client_args.copy()
self.file_client = None
self.max_num = max_num
self.test_mode = test_mode
self.pc_range = pc_range
self.compensate_velocity = compensate_velocity
self.normalize_dims = normalize_dims
self.filtering = filtering
self.normalize = normalize
self.encoding = [
(3, 'one-hot', 8), # dynprop
(11, 'one-hot', 5), # ambig_state
(14, 'one-hot', 18), # invalid_state
(15, 'ordinal', 7), # pdh
(0, 'nusc-filter', 1) # binary feature: 1 if nusc would have filtered it out
]
def perform_encodings(self, points, encoding):
for idx, encoding_type, encoding_dims in self.encoding:
assert encoding_type in ['one-hot', 'ordinal', 'nusc-filter']
feat = points[:, idx]
if encoding_type == 'one-hot':
encoding = np.zeros((points.shape[0], encoding_dims))
encoding[np.arange(feat.shape[0]), np.rint(feat).astype(int)] = 1
if encoding_type == 'ordinal':
encoding = np.zeros((points.shape[0], encoding_dims))
for i in range(encoding_dims):
encoding[:, i] = (np.rint(feat) > i).astype(int)
if encoding_type == 'nusc-filter':
encoding = np.zeros((points.shape[0], encoding_dims))
mask1 = (points[:, 14] == 0)
mask2 = (points[:, 3] < 7)
mask3 = (points[:, 11] == 3)
encoding[mask1 & mask2 & mask3, 0] = 1
points = np.concatenate([points, encoding], axis=1)
return points
def _load_points(self, pts_filename):
"""Private function to load point clouds data.
Args:
pts_filename (str): Filename of point clouds data.
Returns:
np.ndarray: An array containing point clouds data.
[N, 18]
"""
invalid_states, dynprop_states, ambig_states = {
'default': ([0], range(7), [3]),
'none': (range(18), range(8), range(5)),
}[self.filtering]
radar_obj = RadarPointCloud.from_file(
pts_filename,
invalid_states, dynprop_states, ambig_states
)
#[18, N]
points = radar_obj.points
return points.transpose().astype(np.float32)
def _pad_or_drop(self, points):
'''
points: [N, 18]
'''
num_points = points.shape[0]
if num_points == self.max_num:
masks = np.ones((num_points, 1),
dtype=points.dtype)
return points, masks
if num_points > self.max_num:
points = np.random.permutation(points)[:self.max_num, :]
masks = np.ones((self.max_num, 1),
dtype=points.dtype)
return points, masks
if num_points < self.max_num:
zeros = np.zeros((self.max_num - num_points, points.shape[1]),
dtype=points.dtype)
masks = np.ones((num_points, 1),
dtype=points.dtype)
points = np.concatenate((points, zeros), axis=0)
masks = np.concatenate((masks, zeros.copy()[:, [0]]), axis=0)
return points, masks
def normalize_feats(self, points, normalize_dims):
for dim, min, max in normalize_dims:
points[:, dim] -= min
points[:, dim] /= (max-min)
return points
def __call__(self, results):
"""Call function to load multi-sweep point clouds from files.
Args:
results (dict): Result dict containing multi-sweep point cloud \
filenames.
Returns:
dict: The result dict containing the multi-sweep points data. \
Added key and value are described below.
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point \
cloud arrays.
"""
radars_dict = results['radar']
points_sweep_list = []
for key, sweeps in radars_dict.items():
if len(sweeps) < self.sweeps_num:
idxes = list(range(len(sweeps)))
else:
idxes = list(range(self.sweeps_num))
ts = sweeps[0]['timestamp'] * 1e-6
for idx in idxes:
sweep = sweeps[idx]
points_sweep = self._load_points(sweep['data_path'])
points_sweep = np.copy(points_sweep).reshape(-1, self.load_dim)
timestamp = sweep['timestamp'] * 1e-6
time_diff = ts - timestamp
time_diff = np.ones((points_sweep.shape[0], 1)) * time_diff
# velocity compensated by the ego motion in sensor frame
velo_comp = points_sweep[:, 8:10]
velo_comp = np.concatenate(
(velo_comp, np.zeros((velo_comp.shape[0], 1))), 1)
velo_comp = velo_comp @ sweep['sensor2lidar_rotation'].T
velo_comp = velo_comp[:, :2]
# velocity in sensor frame
velo = points_sweep[:, 6:8]
velo = np.concatenate(
(velo, np.zeros((velo.shape[0], 1))), 1)
velo = velo @ sweep['sensor2lidar_rotation'].T
velo = velo[:, :2]
points_sweep[:, :3] = points_sweep[:, :3] @ sweep[
'sensor2lidar_rotation'].T
points_sweep[:, :3] += sweep['sensor2lidar_translation']
if self.compensate_velocity:
points_sweep[:, :2] += velo_comp * time_diff
points_sweep_ = np.concatenate(
[points_sweep[:, :6], velo,
velo_comp, points_sweep[:, 10:],
time_diff], axis=1)
# current format is x y z dyn_prop id rcs vx vy vx_comp vy_comp is_quality_valid ambig_state x_rms y_rms invalid_state pdh0 vx_rms vy_rms timestamp
points_sweep_list.append(points_sweep_)
points = np.concatenate(points_sweep_list, axis=0)
points = self.perform_encodings(points, self.encoding)
points = points[:, self.use_dim]
if self.normalize:
points = self.normalize_feats(points, self.normalize_dims)
points = RadarPoints(
points, points_dim=points.shape[-1], attribute_dims=None
)
results['radar'] = points
return results
def __repr__(self):
"""str: Return a string that describes the module."""
return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})'

View File

@ -22,6 +22,78 @@ from ..builder import OBJECTSAMPLERS
from .utils import noise_per_object_v3_
@PIPELINES.register_module()
class GTDepth:
def __init__(self, keyframe_only=False):
self.keyframe_only = keyframe_only
def __call__(self, data):
sensor2ego = data['camera2ego'].data
cam_intrinsic = data['camera_intrinsics'].data
img_aug_matrix = data['img_aug_matrix'].data
bev_aug_matrix = data['lidar_aug_matrix'].data
lidar2ego = data['lidar2ego'].data
camera2lidar = data['camera2lidar'].data
lidar2image = data['lidar2image'].data
rots = sensor2ego[..., :3, :3]
trans = sensor2ego[..., :3, 3]
intrins = cam_intrinsic[..., :3, :3]
post_rots = img_aug_matrix[..., :3, :3]
post_trans = img_aug_matrix[..., :3, 3]
lidar2ego_rots = lidar2ego[..., :3, :3]
lidar2ego_trans = lidar2ego[..., :3, 3]
camera2lidar_rots = camera2lidar[..., :3, :3]
camera2lidar_trans = camera2lidar[..., :3, 3]
points = data['points'].data
img = data['img'].data
if self.keyframe_only:
points = points[points[:, 4] == 0]
batch_size = len(points)
depth = torch.zeros(img.shape[0], *img.shape[-2:]) #.to(points[0].device)
# for b in range(batch_size):
cur_coords = points[:, :3]
# inverse aug
cur_coords -= bev_aug_matrix[:3, 3]
cur_coords = torch.inverse(bev_aug_matrix[:3, :3]).matmul(
cur_coords.transpose(1, 0)
)
# lidar2image
cur_coords = lidar2image[:, :3, :3].matmul(cur_coords)
cur_coords += lidar2image[:, :3, 3].reshape(-1, 3, 1)
# get 2d coords
dist = cur_coords[:, 2, :]
cur_coords[:, 2, :] = torch.clamp(cur_coords[:, 2, :], 1e-5, 1e5)
cur_coords[:, :2, :] /= cur_coords[:, 2:3, :]
# imgaug
cur_coords = img_aug_matrix[:, :3, :3].matmul(cur_coords)
cur_coords += img_aug_matrix[:, :3, 3].reshape(-1, 3, 1)
cur_coords = cur_coords[:, :2, :].transpose(1, 2)
# normalize coords for grid sample
cur_coords = cur_coords[..., [1, 0]]
on_img = (
(cur_coords[..., 0] < img.shape[2])
& (cur_coords[..., 0] >= 0)
& (cur_coords[..., 1] < img.shape[3])
& (cur_coords[..., 1] >= 0)
)
for c in range(on_img.shape[0]):
masked_coords = cur_coords[c, on_img[c]].long()
masked_dist = dist[c, on_img[c]]
depth[c, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist
data['depths'] = depth
return data
@PIPELINES.register_module()
class ImageAug3D:
def __init__(
@ -142,6 +214,11 @@ class GlobalRotScaleTrans:
data["points"].translate(translation)
data["points"].scale(scale)
if "radar" in data:
data["radar"].rotate(-theta)
data["radar"].translate(translation)
data["radar"].scale(scale)
gt_boxes = data["gt_bboxes_3d"]
rotation = rotation @ gt_boxes.rotate(theta).numpy()
gt_boxes.translate(translation)
@ -254,6 +331,8 @@ class RandomFlip3D:
rotation = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) @ rotation
if "points" in data:
data["points"].flip("horizontal")
if "radar" in data:
data["radar"].flip("horizontal")
if "gt_bboxes_3d" in data:
data["gt_bboxes_3d"].flip("horizontal")
if "gt_masks_bev" in data:
@ -263,6 +342,8 @@ class RandomFlip3D:
rotation = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) @ rotation
if "points" in data:
data["points"].flip("vertical")
if "radar" in data:
data["radar"].flip("vertical")
if "gt_bboxes_3d" in data:
data["gt_bboxes_3d"].flip("vertical")
if "gt_masks_bev" in data:
@ -522,6 +603,14 @@ class PointsRangeFilter:
points_mask = points.in_range_3d(self.pcd_range)
clean_points = points[points_mask]
data["points"] = clean_points
if "radar" in data:
radar = data["radar"]
# radar_mask = radar.in_range_3d(self.pcd_range)
radar_mask = radar.in_range_bev([-55.0, -55.0, 55.0, 55.0])
clean_radar = radar[radar_mask]
data["radar"] = clean_radar
return data

View File

@ -5,4 +5,5 @@ from .second import *
from .sparse_encoder import *
from .pillar_encoder import *
from .vovnet import *
from .dla import *
from .dla import *
from .radar_encoder import *

View File

@ -0,0 +1,230 @@
from torch import nn
from typing import Any, Dict
from functools import cached_property
import torch
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.cnn.resnet import make_res_layer, BasicBlock
from torch import nn
from torch.nn import functional as F
from mmdet3d.models.builder import build_backbone
from mmdet.models import BACKBONES
from torchvision.utils import save_image
from mmdet3d.ops import feature_decorator
from mmcv.cnn.bricks.non_local import NonLocal2d
from flash_attn.flash_attention import FlashMHA
__all__ = ["RadarFeatureNet", "RadarEncoder"]
def get_paddings_indicator(actual_num, max_num, axis=0):
"""Create boolean mask by actually number of a padded tensor.
Args:
actual_num ([type]): [description]
max_num ([type]): [description]
Returns:
[type]: [description]
"""
actual_num = torch.unsqueeze(actual_num, axis + 1)
# tiled_actual_num: [N, M, 1]
max_num_shape = [1] * len(actual_num.shape)
max_num_shape[axis + 1] = -1
max_num = torch.arange(max_num, dtype=torch.int, device=actual_num.device).view(
max_num_shape
)
# tiled_actual_num: [[3,3,3,3,3], [4,4,4,4,4], [2,2,2,2,2]]
# tiled_max_num: [[0,1,2,3,4], [0,1,2,3,4], [0,1,2,3,4]]
paddings_indicator = actual_num.int() > max_num
# paddings_indicator shape: [batch_size, max_num]
return paddings_indicator
class RFNLayer(nn.Module):
def __init__(self, in_channels, out_channels, norm_cfg=None, last_layer=False):
"""
Pillar Feature Net Layer.
The Pillar Feature Net could be composed of a series of these layers, but the PointPillars paper results only
used a single PFNLayer. This layer performs a similar role as second.pytorch.voxelnet.VFELayer.
:param in_channels: <int>. Number of input channels.
:param out_channels: <int>. Number of output channels.
:param last_layer: <bool>. If last_layer, there is no concatenation of features.
"""
super().__init__()
self.name = "RFNLayer"
self.last_vfe = last_layer
self.units = out_channels
if norm_cfg is None:
norm_cfg = dict(type="BN1d", eps=1e-3, momentum=0.01)
self.norm_cfg = norm_cfg
self.linear = nn.Linear(in_channels, self.units, bias=False)
self.norm = build_norm_layer(self.norm_cfg, self.units)[1]
def forward(self, inputs):
x = self.linear(inputs)
torch.backends.cudnn.enabled = False
x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
torch.backends.cudnn.enabled = True
x = F.relu(x)
if self.last_vfe:
x_max = torch.max(x, dim=1, keepdim=True)[0]
return x_max
else:
return x
@BACKBONES.register_module()
class RadarFeatureNet(nn.Module):
def __init__(
self,
in_channels=4,
feat_channels=(64,),
with_distance=False,
voxel_size=(0.2, 0.2, 4),
point_cloud_range=(0, -40, -3, 70.4, 40, 1),
norm_cfg=None,
):
"""
Pillar Feature Net.
The network prepares the pillar features and performs forward pass through PFNLayers. This net performs a
similar role to SECOND's second.pytorch.voxelnet.VoxelFeatureExtractor.
:param num_input_features: <int>. Number of input features, either x, y, z or x, y, z, r.
:param num_filters: (<int>: N). Number of features in each of the N PFNLayers.
:param with_distance: <bool>. Whether to include Euclidean distance to points.
:param voxel_size: (<float>: 3). Size of voxels, only utilize x and y size.
:param pc_range: (<float>: 6). Point cloud range, only utilize x and y min.
"""
super().__init__()
self.name = "RadarFeatureNet"
assert len(feat_channels) > 0
self.in_channels = in_channels
in_channels += 2
# in_channels += 5
self._with_distance = with_distance
self.export_onnx = False
# Create PillarFeatureNet layers
feat_channels = [in_channels] + list(feat_channels)
rfn_layers = []
for i in range(len(feat_channels) - 1):
in_filters = feat_channels[i]
out_filters = feat_channels[i + 1]
if i < len(feat_channels) - 2:
last_layer = False
else:
last_layer = True
rfn_layers.append(
RFNLayer(
in_filters, out_filters, norm_cfg=norm_cfg, last_layer=last_layer
)
)
self.rfn_layers = nn.ModuleList(rfn_layers)
# Need pillar (voxel) size and x/y offset in order to calculate pillar offset
self.vx = voxel_size[0]
self.vy = voxel_size[1]
self.x_offset = self.vx / 2 + point_cloud_range[0]
self.y_offset = self.vy / 2 + point_cloud_range[1]
self.pc_range = point_cloud_range
def forward(self, features, num_voxels, coors):
if not self.export_onnx:
dtype = features.dtype
# Find distance of x, y, and z from cluster center
points_mean = features[:, :, :3].sum(dim=1, keepdim=True) / num_voxels.type_as(
features
).view(-1, 1, 1)
f_cluster = features[:, :, :3] - points_mean
f_center = torch.zeros_like(features[:, :, :2])
f_center[:, :, 0] = features[:, :, 0] - (
coors[:, 1].to(dtype).unsqueeze(1) * self.vx + self.x_offset
)
f_center[:, :, 1] = features[:, :, 1] - (
coors[:, 2].to(dtype).unsqueeze(1) * self.vy + self.y_offset
)
# print(self.pc_range) [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
# normalize x,y,z to [0, 1]
features[:, :, 0:1] = (features[:, :, 0:1] - self.pc_range[0]) / (self.pc_range[3] - self.pc_range[0])
features[:, :, 1:2] = (features[:, :, 1:2] - self.pc_range[1]) / (self.pc_range[4] - self.pc_range[1])
features[:, :, 2:3] = (features[:, :, 2:3] - self.pc_range[2]) / (self.pc_range[5] - self.pc_range[2])
# Combine together feature decorations
features_ls = [features, f_center]
features = torch.cat(features_ls, dim=-1)
# The feature decorations were calculated without regard to whether pillar was empty. Need to ensure that
# empty pillars remain set to zeros.
voxel_count = features.shape[1]
mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
mask = torch.unsqueeze(mask, -1).type_as(features)
features *= mask
features = torch.nan_to_num(features)
else:
features = feature_decorator(features, num_voxels, coors, self.vx, self.vy, self.x_offset, self.y_offset, True, False, True)
# Forward pass through PFNLayers
for rfn in self.rfn_layers:
features = rfn(features)
return features.squeeze()
@BACKBONES.register_module()
class RadarEncoder(nn.Module):
def __init__(
self,
pts_voxel_encoder: Dict[str, Any],
pts_middle_encoder: Dict[str, Any],
pts_transformer_encoder=None,
pts_bev_encoder=None,
post_scatter=None,
**kwargs,
):
super().__init__()
self.pts_voxel_encoder = build_backbone(pts_voxel_encoder)
self.pts_middle_encoder = build_backbone(pts_middle_encoder)
self.pts_transformer_encoder = build_backbone(pts_transformer_encoder) if pts_transformer_encoder is not None else None
self.pts_bev_encoder = build_backbone(pts_bev_encoder) if pts_bev_encoder is not None else None
self.post_scatter = build_backbone(post_scatter) if post_scatter is not None else None
def forward(self, feats, coords, batch_size, sizes, img_features=None):
x = self.pts_voxel_encoder(feats, sizes, coords)
if self.pts_transformer_encoder is not None:
x = self.pts_transformer_encoder(x, sizes, coords, batch_size)
x = self.pts_middle_encoder(x, coords, batch_size)
if self.post_scatter is not None:
x = self.post_scatter(x, img_features)
if self.pts_bev_encoder is not None:
x = self.pts_bev_encoder(x)
return x
def visualize_pillars(self, feats, coords, sizes):
nx, ny = 128, 128
canvas = torch.zeros(
nx*ny, dtype=sizes.dtype, device=sizes.device
)
indices = coords[:, 1] * ny + coords[:, 2]
indices = indices.type(torch.long)
canvas[indices] = sizes
torch.save(canvas, 'sample_canvas')

View File

@ -55,6 +55,19 @@ class BEVFusion(Base3DFusionModel):
)
self.voxelize_reduce = encoders["lidar"].get("voxelize_reduce", True)
if encoders.get("radar") is not None:
if encoders["radar"]["voxelize"].get("max_num_points", -1) > 0:
voxelize_module = Voxelization(**encoders["radar"]["voxelize"])
else:
voxelize_module = DynamicScatter(**encoders["radar"]["voxelize"])
self.encoders["radar"] = nn.ModuleDict(
{
"voxelize": voxelize_module,
"backbone": build_backbone(encoders["radar"]["backbone"]),
}
)
self.voxelize_reduce = encoders["radar"].get("voxelize_reduce", True)
if fuser is not None:
self.fuser = build_fuser(fuser)
else:
@ -79,6 +92,10 @@ class BEVFusion(Base3DFusionModel):
if heads[name] is not None:
self.loss_scale[name] = 1.0
# If the camera's vtransform is a BEVDepth version, then we're using depth loss.
self.use_depth_loss = ((encoders.get('camera', {}) or {}).get('vtransform', {}) or {}).get('type', '') in ['BEVDepth', 'AwareBEVDepth', 'DBEVDepth', 'AwareDBEVDepth']
self.init_weights()
def init_weights(self) -> None:
@ -89,6 +106,7 @@ class BEVFusion(Base3DFusionModel):
self,
x,
points,
radar_points,
camera2ego,
lidar2ego,
lidar2camera,
@ -98,6 +116,7 @@ class BEVFusion(Base3DFusionModel):
img_aug_matrix,
lidar_aug_matrix,
img_metas,
gt_depths=None,
) -> torch.Tensor:
B, N, C, H, W = x.size()
x = x.view(B * N, C, H, W)
@ -114,6 +133,7 @@ class BEVFusion(Base3DFusionModel):
x = self.encoders["camera"]["vtransform"](
x,
points,
radar_points,
camera2ego,
lidar2ego,
lidar2camera,
@ -123,21 +143,35 @@ class BEVFusion(Base3DFusionModel):
img_aug_matrix,
lidar_aug_matrix,
img_metas,
depth_loss=self.use_depth_loss,
gt_depths=gt_depths,
)
return x
def extract_lidar_features(self, x) -> torch.Tensor:
feats, coords, sizes = self.voxelize(x)
def extract_features(self, x, sensor) -> torch.Tensor:
feats, coords, sizes = self.voxelize(x, sensor)
batch_size = coords[-1, 0] + 1
x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes)
x = self.encoders[sensor]["backbone"](feats, coords, batch_size, sizes=sizes)
return x
# def extract_lidar_features(self, x) -> torch.Tensor:
# feats, coords, sizes = self.voxelize(x)
# batch_size = coords[-1, 0] + 1
# x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes)
# return x
# def extract_radar_features(self, x) -> torch.Tensor:
# feats, coords, sizes = self.radar_voxelize(x)
# batch_size = coords[-1, 0] + 1
# x = self.encoders["radar"]["backbone"](feats, coords, batch_size, sizes=sizes)
# return x
@torch.no_grad()
@force_fp32()
def voxelize(self, points):
def voxelize(self, points, sensor):
feats, coords, sizes = [], [], []
for k, res in enumerate(points):
ret = self.encoders["lidar"]["voxelize"](res)
ret = self.encoders[sensor]["voxelize"](res)
if len(ret) == 3:
# hard voxelize
f, c, n = ret
@ -162,6 +196,36 @@ class BEVFusion(Base3DFusionModel):
return feats, coords, sizes
# @torch.no_grad()
# @force_fp32()
# def radar_voxelize(self, points):
# feats, coords, sizes = [], [], []
# for k, res in enumerate(points):
# ret = self.encoders["radar"]["voxelize"](res)
# if len(ret) == 3:
# # hard voxelize
# f, c, n = ret
# else:
# assert len(ret) == 2
# f, c = ret
# n = None
# feats.append(f)
# coords.append(F.pad(c, (1, 0), mode="constant", value=k))
# if n is not None:
# sizes.append(n)
# feats = torch.cat(feats, dim=0)
# coords = torch.cat(coords, dim=0)
# if len(sizes) > 0:
# sizes = torch.cat(sizes, dim=0)
# if self.voxelize_reduce:
# feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
# -1, 1
# )
# feats = feats.contiguous()
# return feats, coords, sizes
@auto_fp16(apply_to=("img", "points"))
def forward(
self,
@ -176,6 +240,8 @@ class BEVFusion(Base3DFusionModel):
img_aug_matrix,
lidar_aug_matrix,
metas,
depths,
radar=None,
gt_masks_bev=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
@ -196,6 +262,8 @@ class BEVFusion(Base3DFusionModel):
img_aug_matrix,
lidar_aug_matrix,
metas,
depths,
radar,
gt_masks_bev,
gt_bboxes_3d,
gt_labels_3d,
@ -217,12 +285,15 @@ class BEVFusion(Base3DFusionModel):
img_aug_matrix,
lidar_aug_matrix,
metas,
depths=None,
radar=None,
gt_masks_bev=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
**kwargs,
):
features = []
auxiliary_losses = {}
for sensor in (
self.encoders if self.training else list(self.encoders.keys())[::-1]
):
@ -230,6 +301,7 @@ class BEVFusion(Base3DFusionModel):
feature = self.extract_camera_features(
img,
points,
radar,
camera2ego,
lidar2ego,
lidar2camera,
@ -239,11 +311,17 @@ class BEVFusion(Base3DFusionModel):
img_aug_matrix,
lidar_aug_matrix,
metas,
gt_depths=depths,
)
if self.use_depth_loss:
feature, auxiliary_losses['depth'] = feature[0], feature[-1]
elif sensor == "lidar":
feature = self.extract_lidar_features(points)
feature = self.extract_features(points, sensor)
elif sensor == "radar":
feature = self.extract_features(radar, sensor)
else:
raise ValueError(f"unsupported sensor: {sensor}")
features.append(feature)
if not self.training:
@ -276,6 +354,11 @@ class BEVFusion(Base3DFusionModel):
outputs[f"loss/{type}/{name}"] = val * self.loss_scale[type]
else:
outputs[f"stats/{type}/{name}"] = val
if self.use_depth_loss:
if 'depth' in auxiliary_losses:
outputs["loss/depth"] = auxiliary_losses['depth']
else:
raise ValueError('Use depth loss is true, but depth loss not found')
return outputs
else:
outputs = [{} for _ in range(batch_size)]
@ -303,3 +386,4 @@ class BEVFusion(Base3DFusionModel):
else:
raise ValueError(f"unsupported head: {type}")
return outputs

View File

@ -1,2 +1,3 @@
from .lss import *
from .depth_lss import *
from .aware_bevdepth import *

View File

@ -0,0 +1,698 @@
from typing import Tuple
from mmcv.cnn import build_conv_layer
from mmcv.runner import force_fp32
from torch import nn
import torch.nn.functional as F
from torch.cuda.amp.autocast_mode import autocast
from mmdet3d.models.builder import VTRANSFORMS
from mmdet.models.backbones.resnet import BasicBlock
from .base import BaseTransform, BaseDepthTransform
import torch
__all__ = ["AwareBEVDepth"]
class DepthRefinement(nn.Module):
"""
pixel cloud feature extraction
"""
def __init__(self, in_channels, mid_channels, out_channels):
super(DepthRefinement, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.conv = nn.Sequential(
nn.Conv2d(mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.out_conv = nn.Sequential(
nn.Conv2d(mid_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True),
# nn.BatchNorm3d(out_channels),
# nn.ReLU(inplace=True),
)
@autocast(False)
def forward(self, x):
x = self.reduce_conv(x)
x = self.conv(x) + x
x = self.out_conv(x)
return x
class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation,
BatchNorm):
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False)
self.bn = BatchNorm(planes)
self.relu = nn.ReLU()
self._init_weight()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class ASPP(nn.Module):
def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d):
super(ASPP, self).__init__()
dilations = [1, 6, 12, 18]
self.aspp1 = _ASPPModule(inplanes,
mid_channels,
1,
padding=0,
dilation=dilations[0],
BatchNorm=BatchNorm)
self.aspp2 = _ASPPModule(inplanes,
mid_channels,
3,
padding=dilations[1],
dilation=dilations[1],
BatchNorm=BatchNorm)
self.aspp3 = _ASPPModule(inplanes,
mid_channels,
3,
padding=dilations[2],
dilation=dilations[2],
BatchNorm=BatchNorm)
self.aspp4 = _ASPPModule(inplanes,
mid_channels,
3,
padding=dilations[3],
dilation=dilations[3],
BatchNorm=BatchNorm)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(inplanes, mid_channels, 1, stride=1, bias=False),
BatchNorm(mid_channels),
nn.ReLU(),
)
self.conv1 = nn.Conv2d(int(mid_channels * 5),
mid_channels,
1,
bias=False)
self.bn1 = BatchNorm(mid_channels)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self._init_weight()
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5,
size=x4.size()[2:],
mode='bilinear',
align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SELayer(nn.Module):
def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
super().__init__()
self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True)
self.act1 = act_layer()
self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True)
self.gate = gate_layer()
def forward(self, x, x_se):
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
return x * self.gate(x_se)
class DepthNet(nn.Module):
def __init__(self, in_channels, mid_channels, context_channels,
depth_channels):
super(DepthNet, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.context_conv = nn.Conv2d(mid_channels,
context_channels,
kernel_size=1,
stride=1,
padding=0)
self.bn = nn.BatchNorm1d(27)
self.depth_mlp = Mlp(27, mid_channels, mid_channels)
self.depth_se = SELayer(mid_channels) # NOTE: add camera-aware
self.context_mlp = Mlp(27, mid_channels, mid_channels)
self.context_se = SELayer(mid_channels) # NOTE: add camera-aware
self.depth_conv_1 = nn.Sequential(
BasicBlock(mid_channels, mid_channels),
BasicBlock(mid_channels, mid_channels),
BasicBlock(mid_channels, mid_channels),
)
self.depth_conv_2 = nn.Sequential(
ASPP(mid_channels, mid_channels),
build_conv_layer(cfg=dict(
type='Conv2d',
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
padding=1,
)),
nn.BatchNorm2d(mid_channels),
)
self.depth_conv_3 = nn.Sequential(
nn.Conv2d(mid_channels,
depth_channels,
kernel_size=1,
stride=1,
padding=0),
nn.BatchNorm2d(depth_channels),
)
self.export = False
def export_mode(self):
self.export = True
@force_fp32()
def forward(self, x, mats_dict):
intrins = mats_dict['intrin_mats'][:, ..., :3, :3]
batch_size = intrins.shape[0]
num_cams = intrins.shape[1]
ida = mats_dict['ida_mats'][:, ...]
sensor2ego = mats_dict['sensor2ego_mats'][:, ..., :3, :]
bda = mats_dict['bda_mat'].view(batch_size, 1, 4, 4).repeat(1, num_cams, 1, 1)
# If exporting, cache the MLP input, since it's based on
# intrinsics and data augmentation, which are constant at inference time.
if not hasattr(self, 'mlp_input') or not self.export:
mlp_input = torch.cat(
[
torch.stack(
[
intrins[:, ..., 0, 0],
intrins[:, ..., 1, 1],
intrins[:, ..., 0, 2],
intrins[:, ..., 1, 2],
ida[:, ..., 0, 0],
ida[:, ..., 0, 1],
ida[:, ..., 0, 3],
ida[:, ..., 1, 0],
ida[:, ..., 1, 1],
ida[:, ..., 1, 3],
bda[:, ..., 0, 0],
bda[:, ..., 0, 1],
bda[:, ..., 1, 0],
bda[:, ..., 1, 1],
bda[:, ..., 2, 2],
],
dim=-1,
),
sensor2ego.view(batch_size, num_cams, -1),
],
-1,
)
self.mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1]))
x = self.reduce_conv(x)
context_se = self.context_mlp(self.mlp_input)[..., None, None]
context = self.context_se(x, context_se)
context = self.context_conv(context)
depth_se = self.depth_mlp(self.mlp_input)[..., None, None]
depth = self.depth_se(x, depth_se)
depth = self.depth_conv_1(depth)
depth = self.depth_conv_2(depth)
depth = self.depth_conv_3(depth)
return torch.cat([depth, context], dim=1)
@VTRANSFORMS.register_module()
class AwareBEVDepth(BaseTransform):
def __init__(
self,
in_channels: int,
out_channels: int,
image_size: Tuple[int, int],
feature_size: Tuple[int, int],
xbound: Tuple[float, float, float],
ybound: Tuple[float, float, float],
zbound: Tuple[float, float, float],
dbound: Tuple[float, float, float],
use_points = 'lidar',
downsample: int = 1,
bevdepth_downsample: int = 16,
bevdepth_refine: bool = True,
depth_loss_factor: float = 3.0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
image_size=image_size,
feature_size=feature_size,
xbound=xbound,
ybound=ybound,
zbound=zbound,
dbound=dbound,
use_points=use_points,
)
self.depth_loss_factor = depth_loss_factor
self.downsample_factor = bevdepth_downsample
self.bevdepth_refine = bevdepth_refine
if self.bevdepth_refine:
self.refinement = DepthRefinement(self.C, self.C, self.C)
self.depth_channels = self.frustum.shape[0]
mid_channels = in_channels
self.depthnet = DepthNet(
in_channels,
mid_channels,
self.C,
self.D
)
if downsample > 1:
assert downsample == 2, downsample
self.downsample = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(
out_channels,
out_channels,
3,
stride=downsample,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
self.downsample = nn.Identity()
def export_mode(self):
super().export_mode()
self.depthnet.export_mode()
@force_fp32()
def get_cam_feats(self, x, mats_dict):
B, N, C, fH, fW = x.shape
x = x.view(B * N, C, fH, fW)
x = self.depthnet(x, mats_dict)
depth = x[:, : self.D].softmax(dim=1)
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
if self.bevdepth_refine:
x = x.permute(0, 3, 1, 4, 2).contiguous() # [n, c, d, h, w] -> [n, h, c, w, d]
n, h, c, w, d = x.shape
x = x.view(-1, c, w, d)
x = self.refinement(x)
x = x.view(n, h, c, w, d).permute(0, 2, 4, 1, 3).contiguous().float()
x = x.view(B, N, self.C, self.D, fH, fW)
x = x.permute(0, 1, 3, 4, 5, 2)
return x, depth
def get_depth_loss(self, depth_labels, depth_preds):
if len(depth_labels.shape) == 5:
# only key-frame will calculate depth loss
depth_labels = depth_labels[:, 0, ...]
depth_labels = self.get_downsampled_gt_depth(depth_labels)
depth_preds = depth_preds.permute(0, 2, 3, 1).contiguous().view(
-1, self.depth_channels)
fg_mask = torch.max(depth_labels, dim=1).values > 0.0
with autocast(enabled=False):
depth_loss = (F.binary_cross_entropy(
depth_preds[fg_mask],
depth_labels[fg_mask],
reduction='none',
).sum() / max(1.0, fg_mask.sum()))
return self.depth_loss_factor * depth_loss
def get_downsampled_gt_depth(self, gt_depths):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B, N, H, W = gt_depths.shape
gt_depths = gt_depths.view(
B * N,
H // self.downsample_factor,
self.downsample_factor,
W // self.downsample_factor,
self.downsample_factor,
1,
)
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
gt_depths = gt_depths.view(
-1, self.downsample_factor * self.downsample_factor)
gt_depths_tmp = torch.where(gt_depths == 0.0,
1e5 * torch.ones_like(gt_depths),
gt_depths)
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
gt_depths = gt_depths.view(B * N, H // self.downsample_factor,
W // self.downsample_factor)
gt_depths = (gt_depths -
(self.dbound[0] - self.dbound[2])) / self.dbound[2]
gt_depths = torch.where(
(gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0),
gt_depths, torch.zeros_like(gt_depths))
gt_depths = F.one_hot(gt_depths.long(),
num_classes=self.depth_channels + 1).view(
-1, self.depth_channels + 1)[:, 1:]
return gt_depths.float()
def forward(self, *args, **kwargs):
x = super().forward(*args, **kwargs)
x, depth_pred = x[0], x[-1]
x = self.downsample(x)
if kwargs.get('depth_loss', False):
# print(kwargs['gt_depths'])
depth_loss = self.get_depth_loss(kwargs['gt_depths'], depth_pred)
return x, depth_loss
else:
return x
@VTRANSFORMS.register_module()
class AwareDBEVDepth(BaseDepthTransform):
def __init__(
self,
in_channels: int,
out_channels: int,
image_size: Tuple[int, int],
feature_size: Tuple[int, int],
xbound: Tuple[float, float, float],
ybound: Tuple[float, float, float],
zbound: Tuple[float, float, float],
dbound: Tuple[float, float, float],
use_points = 'lidar',
depth_input = 'scalar',
height_expand = False,
downsample: int = 1,
bevdepth_downsample: int = 16,
bevdepth_refine: bool = True,
depth_loss_factor: float = 3.0,
add_depth_features = False,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
image_size=image_size,
feature_size=feature_size,
xbound=xbound,
ybound=ybound,
zbound=zbound,
dbound=dbound,
use_points=use_points,
depth_input=depth_input,
height_expand=height_expand,
add_depth_features=add_depth_features,
)
self.depth_loss_factor = depth_loss_factor
self.downsample_factor = bevdepth_downsample
self.bevdepth_refine = bevdepth_refine
if self.bevdepth_refine:
self.refinement = DepthRefinement(self.C, self.C, self.C)
self.depth_channels = self.frustum.shape[0]
mid_channels = in_channels
self.depthnet = DepthNet(
in_channels+64,
mid_channels,
self.C,
self.D
)
dtransform_in_channels = 1 if depth_input=='scalar' else self.D
if self.add_depth_features:
dtransform_in_channels += 45
if depth_input == 'scalar':
self.dtransform = nn.Sequential(
nn.Conv2d(dtransform_in_channels, 8, 1),
nn.BatchNorm2d(8),
nn.ReLU(True),
nn.Conv2d(8, 32, 5, stride=4, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
)
else:
self.dtransform = nn.Sequential(
nn.Conv2d(dtransform_in_channels, 32, 1),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 32, 5, stride=4, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
)
if downsample > 1:
assert downsample == 2, downsample
self.downsample = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(
out_channels,
out_channels,
3,
stride=downsample,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
self.downsample = nn.Identity()
@force_fp32()
def get_cam_feats(self, x, d, mats_dict):
B, N, C, fH, fW = x.shape
d = d.view(B * N, *d.shape[2:])
x = x.view(B * N, C, fH, fW)
d = self.dtransform(d)
x = torch.cat([d, x], dim=1)
x = self.depthnet(x, mats_dict)
depth = x[:, : self.D].softmax(dim=1)
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
if self.bevdepth_refine:
x = x.permute(0, 3, 1, 4, 2).contiguous() # [n, c, d, h, w] -> [n, h, c, w, d]
n, h, c, w, d = x.shape
x = x.view(-1, c, w, d)
x = self.refinement(x)
x = x.view(n, h, c, w, d).permute(0, 2, 4, 1, 3).contiguous().float()
# Here, x.shape is [num_cams, num_channels, depth_bins, downsampled_height, downsampled_width]
x = x.view(B, N, self.C, self.D, fH, fW)
x = x.permute(0, 1, 3, 4, 5, 2)
return x, depth
def export_mode(self):
super().export_mode()
self.depthnet.export_mode()
def get_depth_loss(self, depth_labels, depth_preds):
# if len(depth_labels.shape) == 5:
# # only key-frame will calculate depth loss
# depth_labels = depth_labels[:, 0, ...]
depth_labels = self.get_downsampled_gt_depth(depth_labels)
depth_preds = depth_preds.permute(0, 2, 3, 1).contiguous().view(
-1, self.depth_channels)
fg_mask = torch.max(depth_labels, dim=1).values > 0.0
with autocast(enabled=False):
depth_loss = (F.binary_cross_entropy(
depth_preds[fg_mask],
depth_labels[fg_mask],
reduction='none',
).sum() / max(1.0, fg_mask.sum()))
return self.depth_loss_factor * depth_loss
def get_downsampled_gt_depth(self, gt_depths):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B, N, H, W = gt_depths.shape
gt_depths = gt_depths.view(
B * N,
H // self.downsample_factor,
self.downsample_factor,
W // self.downsample_factor,
self.downsample_factor,
1,
)
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
gt_depths = gt_depths.view(
-1, self.downsample_factor * self.downsample_factor)
gt_depths_tmp = torch.where(gt_depths == 0.0,
1e5 * torch.ones_like(gt_depths),
gt_depths)
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
gt_depths = gt_depths.view(B * N, H // self.downsample_factor,
W // self.downsample_factor)
gt_depths = (gt_depths -
(self.dbound[0] - self.dbound[2])) / self.dbound[2]
gt_depths = torch.where(
(gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0),
gt_depths, torch.zeros_like(gt_depths))
gt_depths = F.one_hot(gt_depths.long(),
num_classes=self.depth_channels + 1).view(
-1, self.depth_channels + 1)[:, 1:]
return gt_depths.float()
def forward(self, *args, **kwargs):
x = super().forward(*args, **kwargs)
x, depth_pred = x[0], x[-1]
x = self.downsample(x)
if kwargs.get('depth_loss', False):
depth_loss = self.get_depth_loss(kwargs['gt_depths'], depth_pred)
return x, depth_loss
else:
return x

View File

@ -8,6 +8,9 @@ from mmdet3d.ops import bev_pool
__all__ = ["BaseTransform", "BaseDepthTransform"]
def boolmask2idx(mask):
# A utility function, workaround for ONNX not supporting 'nonzero'
return torch.nonzero(mask).squeeze(1).tolist()
def gen_dx_bx(xbound, ybound, zbound):
dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])
@ -29,6 +32,10 @@ class BaseTransform(nn.Module):
ybound: Tuple[float, float, float],
zbound: Tuple[float, float, float],
dbound: Tuple[float, float, float],
use_points='lidar',
depth_input='scalar',
height_expand=True,
add_depth_features=True,
) -> None:
super().__init__()
self.in_channels = in_channels
@ -38,6 +45,12 @@ class BaseTransform(nn.Module):
self.ybound = ybound
self.zbound = zbound
self.dbound = dbound
self.use_points = use_points
assert use_points in ['radar', 'lidar']
self.depth_input=depth_input
assert depth_input in ['scalar', 'one-hot']
self.height_expand = height_expand
self.add_depth_features = add_depth_features
dx, bx, nx = gen_dx_bx(self.xbound, self.ybound, self.zbound)
self.dx = nn.Parameter(dx, requires_grad=False)
@ -167,6 +180,7 @@ class BaseTransform(nn.Module):
self,
img,
points,
radar,
camera2ego,
lidar2ego,
lidar2camera,
@ -199,10 +213,26 @@ class BaseTransform(nn.Module):
extra_rots=extra_rots,
extra_trans=extra_trans,
)
mats_dict = {
'intrin_mats': camera_intrinsics,
'ida_mats': img_aug_matrix,
'bda_mat': lidar_aug_matrix,
'sensor2ego_mats': camera2ego,
}
x = self.get_cam_feats(img, mats_dict)
x = self.get_cam_feats(img)
use_depth = False
if type(x) == tuple:
x, depth = x
use_depth = True
x = self.bev_pool(geom, x)
return x
if use_depth:
return x, depth
else:
return x
class BaseDepthTransform(BaseTransform):
@ -211,6 +241,7 @@ class BaseDepthTransform(BaseTransform):
self,
img,
points,
radar,
sensor2ego,
lidar2ego,
lidar2camera,
@ -232,12 +263,22 @@ class BaseDepthTransform(BaseTransform):
camera2lidar_rots = camera2lidar[..., :3, :3]
camera2lidar_trans = camera2lidar[..., :3, 3]
# print(img.shape, self.image_size, self.feature_size)
if self.use_points == 'radar':
points = radar
if self.height_expand:
for b in range(len(points)):
points_repeated = points[b].repeat_interleave(8, dim=0)
points_repeated[:, 2] = torch.arange(0.25, 2.25, 0.25).repeat(points[b].shape[0])
points[b] = points_repeated
batch_size = len(points)
depth = torch.zeros(batch_size, img.shape[1], 1, *self.image_size).to(
points[0].device
)
depth_in_channels = 1 if self.depth_input=='scalar' else self.D
if self.add_depth_features:
depth_in_channels += points[0].shape[1]
depth = torch.zeros(batch_size, img.shape[1], depth_in_channels, *self.image_size, device=points[0].device)
for b in range(batch_size):
cur_coords = points[b][:, :3]
@ -275,7 +316,17 @@ class BaseDepthTransform(BaseTransform):
for c in range(on_img.shape[0]):
masked_coords = cur_coords[c, on_img[c]].long()
masked_dist = dist[c, on_img[c]]
depth[b, c, 0, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist
if self.depth_input == 'scalar':
depth[b, c, 0, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist
elif self.depth_input == 'one-hot':
# Clamp depths that are too big to D
# These can arise when the point range filter is different from the dbound.
masked_dist = torch.clamp(masked_dist, max=self.D-1)
depth[b, c, masked_dist.long(), masked_coords[:, 0], masked_coords[:, 1]] = 1.0
if self.add_depth_features:
depth[b, c, -points[b].shape[-1]:, masked_coords[:, 0], masked_coords[:, 1]] = points[b][boolmask2idx(on_img[c])].transpose(0,1)
extra_rots = lidar_aug_matrix[..., :3, :3]
extra_trans = lidar_aug_matrix[..., :3, 3]
@ -289,6 +340,23 @@ class BaseDepthTransform(BaseTransform):
extra_trans=extra_trans,
)
x = self.get_cam_feats(img, depth)
mats_dict = {
'intrin_mats': intrins,
'ida_mats': img_aug_matrix,
'bda_mat': lidar_aug_matrix,
'sensor2ego_mats': sensor2ego,
}
x = self.get_cam_feats(img, depth, mats_dict)
use_depth = False
if type(x) == tuple:
x, depth = x
use_depth = True
x = self.bev_pool(geom, x)
return x
if use_depth:
return x, depth
else:
return x

View File

@ -9,6 +9,7 @@ from mmcv.ops import (
)
from .ball_query import ball_query
from .feature_decorator import feature_decorator
from .furthest_point_sample import (
Points_Sampler,
furthest_point_sample,
@ -89,4 +90,5 @@ __all__ = [
"PAConvCUDASAModule",
"PAConvCUDASAModuleMSG",
"bev_pool",
"feature_decorator",
]

View File

@ -0,0 +1 @@
from .feature_decorator import feature_decorator

View File

@ -0,0 +1,21 @@
import torch
from mmdet3d.ops.feature_decorator import feature_decorator_ext
__all__ = ["feature_decorator"]
def feature_decorator(features, num_voxels, coords, vx, vy, x_offset, y_offset, normalize_coords, use_cluster, use_center):
result = torch.ops.feature_decorator_ext.feature_decorator_forward(features, coords, num_voxels, vx, vy, x_offset, y_offset, normalize_coords, use_cluster, use_center)
return result
if __name__ == '__main__':
A = torch.ones((2, 20, 5), dtype=torch.float32).cuda()
B = torch.ones(2, dtype=torch.int32).cuda()
C = torch.ones((2, 4), dtype=torch.int32).cuda()
D = feature_decorator(A, B, C)
D = feature_decorator_ext.feature_decorator_forward(A, B, C)
print(D.shape)

View File

@ -0,0 +1,45 @@
#include <torch/torch.h>
// CUDA function declarations
// void feature_decorator(int b, int d, int h, int w, int n, int c, int n_intervals, const float* x,
// const int* geom_feats, const int* interval_starts, const int* interval_lengths, float* out);
void feature_decorator(float* out);
at::Tensor feature_decorator_forward(
const at::Tensor _x,
const at::Tensor _y,
const at::Tensor _z,
const double vx, const double vy, const double x_offset, const double y_offset,
int normalize_coords, int use_cluster, int use_center
) {
int n = _x.size(0);
int c = _x.size(1);
int a = _x.size(2);
auto options = torch::TensorOptions().dtype(_x.dtype()).device(_x.device());
int decorate_dims = 0;
if (use_cluster > 0) {
decorate_dims += 3;
}
if (use_center > 0) {
decorate_dims += 2;
}
at::Tensor _out = torch::zeros({n, c, a+decorate_dims}, options);
float* out = _out.data_ptr<float>();
const float* x = _x.data_ptr<float>();
const int* y = _y.data_ptr<int>();
const int* z = _z.data_ptr<int>();
feature_decorator(out);
return _out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("feature_decorator_forward", &feature_decorator_forward,
"feature_decorator_forward");
}
static auto registry =
torch::RegisterOperators("feature_decorator_ext::feature_decorator_forward", &feature_decorator_forward);

View File

@ -0,0 +1,49 @@
#include <stdio.h>
#include <stdlib.h>
// __global__ void feature_decorator_kernel(int b, int d, int h, int w, int n, int c, int n_intervals,
// const float *__restrict__ x,
// const int *__restrict__ geom_feats,
// const int *__restrict__ interval_starts,
// const int *__restrict__ interval_lengths,
// float* __restrict__ out) {
// int idx = blockIdx.x * blockDim.x + threadIdx.x;
// int index = idx / c;
// int cur_c = idx % c;
// if (index >= n_intervals) return;
// int interval_start = interval_starts[index];
// int interval_length = interval_lengths[index];
// const int* cur_geom_feats = geom_feats + interval_start * 4;
// const float* cur_x = x + interval_start * c + cur_c;
// float* cur_out = out + cur_geom_feats[3] * d * h * w * c +
// cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c +
// cur_geom_feats[1] * c + cur_c;
// float psum = 0;
// for(int i = 0; i < interval_length; i++){
// psum += cur_x[i * c];
// }
// *cur_out = psum;
// }
__global__ void feature_decorator_kernel(float* __restrict__ out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// if (idx == 0) {
// out = 5.0;
// }
out[0] = 5.0;
out[15] = 6.0;
}
// void feature_decorator(int b, int d, int h, int w, int n, int c, int n_intervals, const float* x,
// const int* geom_feats, const int* interval_starts, const int* interval_lengths, float* out) {
// feature_decorator_kernel<<<(int)ceil(((double)n_intervals * c / 256)), 256>>>(
// b, d, h, w, n, c, n_intervals, x, geom_feats, interval_starts, interval_lengths, out
// );
// }
void feature_decorator(float* out) {
feature_decorator_kernel<<<1, 1>>>(out);
}

View File

@ -1,51 +1,34 @@
import os
from collections import OrderedDict
from os import path as osp
from typing import List, Tuple, Union
import mmcv
import numpy as np
import os
from collections import OrderedDict
from nuscenes.nuscenes import NuScenes
from nuscenes.utils.geometry_utils import view_points
from os import path as osp
from pyquaternion import Quaternion
from shapely.geometry import MultiPoint, box
from typing import List, Tuple, Union
from mmdet3d.core.bbox.box_np_ops import points_cam2img
from mmdet3d.datasets import NuScenesDataset
nus_categories = (
"car",
"truck",
"trailer",
"bus",
"construction_vehicle",
"bicycle",
"motorcycle",
"pedestrian",
"traffic_cone",
"barrier",
)
nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
'barrier')
nus_attributes = (
"cycle.with_rider",
"cycle.without_rider",
"pedestrian.moving",
"pedestrian.standing",
"pedestrian.sitting_lying_down",
"vehicle.moving",
"vehicle.parked",
"vehicle.stopped",
"None",
)
nus_attributes = ('cycle.with_rider', 'cycle.without_rider',
'pedestrian.moving', 'pedestrian.standing',
'pedestrian.sitting_lying_down', 'vehicle.moving',
'vehicle.parked', 'vehicle.stopped', 'None')
def create_nuscenes_infos(
root_path, info_prefix, version="v1.0-trainval", max_sweeps=10
):
def create_nuscenes_infos(root_path,
info_prefix,
version='v1.0-trainval',
max_sweeps=10,
max_radar_sweeps=10):
"""Create info file of nuscene dataset.
Given the raw data, generate its related info file in pkl format.
Args:
root_path (str): Path of the data root.
info_prefix (str): Prefix of the info file to be generated.
@ -53,100 +36,96 @@ def create_nuscenes_infos(
Default: 'v1.0-trainval'
max_sweeps (int): Max number of sweeps.
Default: 10
max_radar_sweeps (int): Max number of radar sweeps.
Default: 10
"""
from nuscenes.nuscenes import NuScenes
nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
from nuscenes.utils import splits
available_vers = ["v1.0-trainval", "v1.0-test", "v1.0-mini"]
available_vers = ['v1.0-trainval', 'v1.0-test', 'v1.0-mini']
assert version in available_vers
if version == "v1.0-trainval":
if version == 'v1.0-trainval':
train_scenes = splits.train
val_scenes = splits.val
elif version == "v1.0-test":
elif version == 'v1.0-test':
train_scenes = splits.test
val_scenes = []
elif version == "v1.0-mini":
elif version == 'v1.0-mini':
train_scenes = splits.mini_train
val_scenes = splits.mini_val
else:
raise ValueError("unknown")
raise ValueError('unknown')
# filter existing scenes.
available_scenes = get_available_scenes(nusc)
available_scene_names = [s["name"] for s in available_scenes]
train_scenes = list(filter(lambda x: x in available_scene_names, train_scenes))
available_scene_names = [s['name'] for s in available_scenes]
train_scenes = list(
filter(lambda x: x in available_scene_names, train_scenes))
val_scenes = list(filter(lambda x: x in available_scene_names, val_scenes))
train_scenes = set(
[
available_scenes[available_scene_names.index(s)]["token"]
for s in train_scenes
]
)
val_scenes = set(
[available_scenes[available_scene_names.index(s)]["token"] for s in val_scenes]
)
train_scenes = set([
available_scenes[available_scene_names.index(s)]['token']
for s in train_scenes
])
val_scenes = set([
available_scenes[available_scene_names.index(s)]['token']
for s in val_scenes
])
test = "test" in version
test = 'test' in version
if test:
print("test scene: {}".format(len(train_scenes)))
print('test scene: {}'.format(len(train_scenes)))
else:
print(
"train scene: {}, val scene: {}".format(len(train_scenes), len(val_scenes))
)
print('train scene: {}, val scene: {}'.format(
len(train_scenes), len(val_scenes)))
train_nusc_infos, val_nusc_infos = _fill_trainval_infos(
nusc, train_scenes, val_scenes, test, max_sweeps=max_sweeps
)
nusc, train_scenes, val_scenes, test, max_sweeps=max_sweeps, max_radar_sweeps=max_radar_sweeps)
metadata = dict(version=version)
if test:
print("test sample: {}".format(len(train_nusc_infos)))
print('test sample: {}'.format(len(train_nusc_infos)))
data = dict(infos=train_nusc_infos, metadata=metadata)
info_path = osp.join(root_path, "{}_infos_test.pkl".format(info_prefix))
info_path = osp.join(root_path,
'{}_infos_test_radar.pkl'.format(info_prefix))
mmcv.dump(data, info_path)
else:
print(
"train sample: {}, val sample: {}".format(
len(train_nusc_infos), len(val_nusc_infos)
)
)
print(info_prefix)
print('train sample: {}, val sample: {}'.format(
len(train_nusc_infos), len(val_nusc_infos)))
data = dict(infos=train_nusc_infos, metadata=metadata)
info_path = osp.join(root_path, "{}_infos_train.pkl".format(info_prefix))
info_path = osp.join(info_prefix,
'{}_infos_train_radar.pkl'.format(info_prefix))
mmcv.dump(data, info_path)
data["infos"] = val_nusc_infos
info_val_path = osp.join(root_path, "{}_infos_val.pkl".format(info_prefix))
data['infos'] = val_nusc_infos
info_val_path = osp.join(info_prefix,
'{}_infos_val_radar.pkl'.format(info_prefix))
mmcv.dump(data, info_val_path)
def get_available_scenes(nusc):
"""Get available scenes from the input nuscenes class.
Given the raw data, get the information of available scenes for
further info generation.
Args:
nusc (class): Dataset class in the nuScenes dataset.
Returns:
available_scenes (list[dict]): List of basic information for the
available scenes.
"""
available_scenes = []
print("total scene num: {}".format(len(nusc.scene)))
print('total scene num: {}'.format(len(nusc.scene)))
for scene in nusc.scene:
scene_token = scene["token"]
scene_rec = nusc.get("scene", scene_token)
sample_rec = nusc.get("sample", scene_rec["first_sample_token"])
sd_rec = nusc.get("sample_data", sample_rec["data"]["LIDAR_TOP"])
scene_token = scene['token']
scene_rec = nusc.get('scene', scene_token)
sample_rec = nusc.get('sample', scene_rec['first_sample_token'])
sd_rec = nusc.get('sample_data', sample_rec['data']['LIDAR_TOP'])
has_more_frames = True
scene_not_exist = False
while has_more_frames:
lidar_path, boxes, _ = nusc.get_sample_data(sd_rec["token"])
lidar_path, boxes, _ = nusc.get_sample_data(sd_rec['token'])
lidar_path = str(lidar_path)
if os.getcwd() in lidar_path:
# path from lyftdataset is absolute path
lidar_path = lidar_path.split(f"{os.getcwd()}/")[-1]
lidar_path = lidar_path.split(f'{os.getcwd()}/')[-1]
# relative path
if not mmcv.is_filepath(lidar_path):
scene_not_exist = True
@ -156,13 +135,17 @@ def get_available_scenes(nusc):
if scene_not_exist:
continue
available_scenes.append(scene)
print("exist scene num: {}".format(len(available_scenes)))
print('exist scene num: {}'.format(len(available_scenes)))
return available_scenes
def _fill_trainval_infos(nusc, train_scenes, val_scenes, test=False, max_sweeps=10):
def _fill_trainval_infos(nusc,
train_scenes,
val_scenes,
test=False,
max_sweeps=10,
max_radar_sweeps=10):
"""Generate the train/val infos from the raw data.
Args:
nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset.
train_scenes (list[str]): Basic information of training scenes.
@ -170,101 +153,126 @@ def _fill_trainval_infos(nusc, train_scenes, val_scenes, test=False, max_sweeps=
test (bool): Whether use the test mode. In the test mode, no
annotations can be accessed. Default: False.
max_sweeps (int): Max number of sweeps. Default: 10.
max_radar_sweeps (int): Max number of radar sweeps. Default: 10.
Returns:
tuple[list[dict]]: Information of training set and validation set
that will be saved to the info file.
"""
train_nusc_infos = []
val_nusc_infos = []
token2idx = {}
i_ = 0
for sample in mmcv.track_iter_progress(nusc.sample):
lidar_token = sample["data"]["LIDAR_TOP"]
sd_rec = nusc.get("sample_data", sample["data"]["LIDAR_TOP"])
cs_record = nusc.get("calibrated_sensor", sd_rec["calibrated_sensor_token"])
pose_record = nusc.get("ego_pose", sd_rec["ego_pose_token"])
location = nusc.get(
"log", nusc.get("scene", sample["scene_token"])["log_token"]
)["location"]
# i_ += 1
# if i_ > 6:
# break
lidar_token = sample['data']['LIDAR_TOP']
sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
cs_record = nusc.get('calibrated_sensor',
sd_rec['calibrated_sensor_token'])
pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
lidar_path, boxes, _ = nusc.get_sample_data(lidar_token)
mmcv.check_file_exist(lidar_path)
info = {
"lidar_path": lidar_path,
"token": sample["token"],
"sweeps": [],
"cams": dict(),
"lidar2ego_translation": cs_record["translation"],
"lidar2ego_rotation": cs_record["rotation"],
"ego2global_translation": pose_record["translation"],
"ego2global_rotation": pose_record["rotation"],
"timestamp": sample["timestamp"],
"location": location,
'lidar_path': lidar_path,
'token': sample['token'],
'sweeps': [],
'cams': dict(),
'radars': dict(),
'lidar2ego_translation': cs_record['translation'],
'lidar2ego_rotation': cs_record['rotation'],
'ego2global_translation': pose_record['translation'],
'ego2global_rotation': pose_record['rotation'],
'timestamp': sample['timestamp'],
'prev_token': sample['prev']
}
l2e_r = info["lidar2ego_rotation"]
l2e_t = info["lidar2ego_translation"]
e2g_r = info["ego2global_rotation"]
e2g_t = info["ego2global_translation"]
l2e_r = info['lidar2ego_rotation']
l2e_t = info['lidar2ego_translation']
e2g_r = info['ego2global_rotation']
e2g_t = info['ego2global_translation']
l2e_r_mat = Quaternion(l2e_r).rotation_matrix
e2g_r_mat = Quaternion(e2g_r).rotation_matrix
# obtain 6 image's information per frame
camera_types = [
"CAM_FRONT",
"CAM_FRONT_RIGHT",
"CAM_FRONT_LEFT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_BACK_RIGHT",
'CAM_FRONT',
'CAM_FRONT_RIGHT',
'CAM_FRONT_LEFT',
'CAM_BACK',
'CAM_BACK_LEFT',
'CAM_BACK_RIGHT',
]
for cam in camera_types:
cam_token = sample["data"][cam]
cam_path, _, camera_intrinsics = nusc.get_sample_data(cam_token)
cam_info = obtain_sensor2top(
nusc, cam_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, cam
)
cam_info.update(camera_intrinsics=camera_intrinsics)
info["cams"].update({cam: cam_info})
cam_token = sample['data'][cam]
cam_path, _, cam_intrinsic = nusc.get_sample_data(cam_token)
cam_info = obtain_sensor2top(nusc, cam_token, l2e_t, l2e_r_mat,
e2g_t, e2g_r_mat, cam)
cam_info.update(cam_intrinsic=cam_intrinsic)
info['cams'].update({cam: cam_info})
radar_names = ['RADAR_FRONT', 'RADAR_FRONT_LEFT', 'RADAR_FRONT_RIGHT', 'RADAR_BACK_LEFT', 'RADAR_BACK_RIGHT']
for radar_name in radar_names:
radar_token = sample['data'][radar_name]
radar_rec = nusc.get('sample_data', radar_token)
sweeps = []
while len(sweeps) < max_radar_sweeps:
if not radar_rec['prev'] == '':
radar_path, _, radar_intrin = nusc.get_sample_data(radar_token)
radar_info = obtain_sensor2top(nusc, radar_token, l2e_t, l2e_r_mat,
e2g_t, e2g_r_mat, radar_name)
sweeps.append(radar_info)
radar_token = radar_rec['prev']
radar_rec = nusc.get('sample_data', radar_token)
else:
radar_path, _, radar_intrin = nusc.get_sample_data(radar_token)
radar_info = obtain_sensor2top(nusc, radar_token, l2e_t, l2e_r_mat,
e2g_t, e2g_r_mat, radar_name)
sweeps.append(radar_info)
info['radars'].update({radar_name: sweeps})
# obtain sweeps for a single key-frame
sd_rec = nusc.get("sample_data", sample["data"]["LIDAR_TOP"])
sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
sweeps = []
while len(sweeps) < max_sweeps:
if not sd_rec["prev"] == "":
sweep = obtain_sensor2top(
nusc, sd_rec["prev"], l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, "lidar"
)
if not sd_rec['prev'] == '':
sweep = obtain_sensor2top(nusc, sd_rec['prev'], l2e_t,
l2e_r_mat, e2g_t, e2g_r_mat, 'lidar')
sweeps.append(sweep)
sd_rec = nusc.get("sample_data", sd_rec["prev"])
sd_rec = nusc.get('sample_data', sd_rec['prev'])
else:
break
info["sweeps"] = sweeps
info['sweeps'] = sweeps
# obtain annotation
if not test:
annotations = [
nusc.get("sample_annotation", token) for token in sample["anns"]
nusc.get('sample_annotation', token)
for token in sample['anns']
]
locs = np.array([b.center for b in boxes]).reshape(-1, 3)
dims = np.array([b.wlh for b in boxes]).reshape(-1, 3)
rots = np.array([b.orientation.yaw_pitch_roll[0] for b in boxes]).reshape(
-1, 1
)
rots = np.array([b.orientation.yaw_pitch_roll[0]
for b in boxes]).reshape(-1, 1)
velocity = np.array(
[nusc.box_velocity(token)[:2] for token in sample["anns"]]
)
[nusc.box_velocity(token)[:2] for token in sample['anns']])
valid_flag = np.array(
[
(anno["num_lidar_pts"] + anno["num_radar_pts"]) > 0
for anno in annotations
],
dtype=bool,
).reshape(-1)
[(anno['num_lidar_pts'] + anno['num_radar_pts']) > 0
for anno in annotations],
dtype=bool).reshape(-1)
# convert velo from global to lidar
for i in range(len(boxes)):
velo = np.array([*velocity[i], 0.0])
velo = velo @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
velo = velo @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(
l2e_r_mat).T
velocity[i] = velo[:2]
names = [b.name for b in boxes]
@ -275,28 +283,52 @@ def _fill_trainval_infos(nusc, train_scenes, val_scenes, test=False, max_sweeps=
# we need to convert rot to SECOND format.
gt_boxes = np.concatenate([locs, dims, -rots - np.pi / 2], axis=1)
assert len(gt_boxes) == len(
annotations
), f"{len(gt_boxes)}, {len(annotations)}"
info["gt_boxes"] = gt_boxes
info["gt_names"] = names
info["gt_velocity"] = velocity.reshape(-1, 2)
info["num_lidar_pts"] = np.array([a["num_lidar_pts"] for a in annotations])
info["num_radar_pts"] = np.array([a["num_radar_pts"] for a in annotations])
info["valid_flag"] = valid_flag
annotations), f'{len(gt_boxes)}, {len(annotations)}'
info['gt_boxes'] = gt_boxes
info['gt_names'] = names
info['gt_velocity'] = velocity.reshape(-1, 2)
info['num_lidar_pts'] = np.array(
[a['num_lidar_pts'] for a in annotations])
info['num_radar_pts'] = np.array(
[a['num_radar_pts'] for a in annotations])
info['valid_flag'] = valid_flag
if sample["scene_token"] in train_scenes:
if sample['scene_token'] in train_scenes:
train_nusc_infos.append(info)
token2idx[info['token']] = ('train', len(train_nusc_infos) - 1)
else:
val_nusc_infos.append(info)
token2idx[info['token']] = ('val', len(val_nusc_infos) - 1)
for info in train_nusc_infos:
prev_token = info['prev_token']
if prev_token == '':
info['prev'] = -1
else:
prev_set, prev_idx = token2idx[prev_token]
assert prev_set == 'train'
info['prev'] = prev_idx
for info in val_nusc_infos:
prev_token = info['prev_token']
if prev_token == '':
info['prev'] = -1
else:
prev_set, prev_idx = token2idx[prev_token]
assert prev_set == 'val'
info['prev'] = prev_idx
return train_nusc_infos, val_nusc_infos
def obtain_sensor2top(
nusc, sensor_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, sensor_type="lidar"
):
def obtain_sensor2top(nusc,
sensor_token,
l2e_t,
l2e_r_mat,
e2g_t,
e2g_r_mat,
sensor_type='lidar'):
"""Obtain the info with RT matric from general sensor to Top LiDAR.
Args:
nusc (class): Dataset class in the nuScenes dataset.
sensor_token (str): Sample data token corresponding to the
@ -308,53 +340,48 @@ def obtain_sensor2top(
e2g_r_mat (np.ndarray): Rotation matrix from ego to global
in shape (3, 3).
sensor_type (str): Sensor to calibrate. Default: 'lidar'.
Returns:
sweep (dict): Sweep information after transformation.
"""
sd_rec = nusc.get("sample_data", sensor_token)
cs_record = nusc.get("calibrated_sensor", sd_rec["calibrated_sensor_token"])
pose_record = nusc.get("ego_pose", sd_rec["ego_pose_token"])
data_path = str(nusc.get_sample_data_path(sd_rec["token"]))
sd_rec = nusc.get('sample_data', sensor_token)
cs_record = nusc.get('calibrated_sensor',
sd_rec['calibrated_sensor_token'])
pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
data_path = str(nusc.get_sample_data_path(sd_rec['token']))
if os.getcwd() in data_path: # path from lyftdataset is absolute path
data_path = data_path.split(f"{os.getcwd()}/")[-1] # relative path
data_path = data_path.split(f'{os.getcwd()}/')[-1] # relative path
sweep = {
"data_path": data_path,
"type": sensor_type,
"sample_data_token": sd_rec["token"],
"sensor2ego_translation": cs_record["translation"],
"sensor2ego_rotation": cs_record["rotation"],
"ego2global_translation": pose_record["translation"],
"ego2global_rotation": pose_record["rotation"],
"timestamp": sd_rec["timestamp"],
'data_path': data_path,
'type': sensor_type,
'sample_data_token': sd_rec['token'],
'sensor2ego_translation': cs_record['translation'],
'sensor2ego_rotation': cs_record['rotation'],
'ego2global_translation': pose_record['translation'],
'ego2global_rotation': pose_record['rotation'],
'timestamp': sd_rec['timestamp']
}
l2e_r_s = sweep["sensor2ego_rotation"]
l2e_t_s = sweep["sensor2ego_translation"]
e2g_r_s = sweep["ego2global_rotation"]
e2g_t_s = sweep["ego2global_translation"]
l2e_r_s = sweep['sensor2ego_rotation']
l2e_t_s = sweep['sensor2ego_translation']
e2g_r_s = sweep['ego2global_rotation']
e2g_t_s = sweep['ego2global_translation']
# obtain the RT from sensor to Top LiDAR
# sweep->ego->global->ego'->lidar
l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix
e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix
R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ (
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
)
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ (
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
)
T -= (
e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
+ l2e_t @ np.linalg.inv(l2e_r_mat).T
)
sweep["sensor2lidar_rotation"] = R.T # points @ R.T + T
sweep["sensor2lidar_translation"] = T
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
) + l2e_t @ np.linalg.inv(l2e_r_mat).T
sweep['sensor2lidar_rotation'] = R.T # points @ R.T + T
sweep['sensor2lidar_translation'] = T
return sweep
def export_2d_annotation(root_path, info_path, version, mono3d=True):
"""Export 2d annotation from the info file and raw data.
Args:
root_path (str): Root path of the raw data.
info_path (str): Path of the info file.
@ -363,14 +390,14 @@ def export_2d_annotation(root_path, info_path, version, mono3d=True):
"""
# get bbox annotations for camera
camera_types = [
"CAM_FRONT",
"CAM_FRONT_RIGHT",
"CAM_FRONT_LEFT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_BACK_RIGHT",
'CAM_FRONT',
'CAM_FRONT_RIGHT',
'CAM_FRONT_LEFT',
'CAM_BACK',
'CAM_BACK_LEFT',
'CAM_BACK_RIGHT',
]
nusc_infos = mmcv.load(info_path)["infos"]
nusc_infos = mmcv.load(info_path)['infos']
nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
# info_2d_list = []
cat2Ids = [
@ -381,97 +408,100 @@ def export_2d_annotation(root_path, info_path, version, mono3d=True):
coco_2d_dict = dict(annotations=[], images=[], categories=cat2Ids)
for info in mmcv.track_iter_progress(nusc_infos):
for cam in camera_types:
cam_info = info["cams"][cam]
cam_info = info['cams'][cam]
coco_infos = get_2d_boxes(
nusc,
cam_info["sample_data_token"],
visibilities=["", "1", "2", "3", "4"],
mono3d=mono3d,
)
(height, width, _) = mmcv.imread(cam_info["data_path"]).shape
coco_2d_dict["images"].append(
cam_info['sample_data_token'],
visibilities=['', '1', '2', '3', '4'],
mono3d=mono3d)
(height, width, _) = mmcv.imread(cam_info['data_path']).shape
coco_2d_dict['images'].append(
dict(
file_name=cam_info["data_path"].split("data/nuscenes/")[-1],
id=cam_info["sample_data_token"],
token=info["token"],
cam2ego_rotation=cam_info["sensor2ego_rotation"],
cam2ego_translation=cam_info["sensor2ego_translation"],
ego2global_rotation=info["ego2global_rotation"],
ego2global_translation=info["ego2global_translation"],
camera_intrinsics=cam_info["camera_intrinsics"],
file_name=cam_info['data_path'].split('data/nuscenes/')
[-1],
id=cam_info['sample_data_token'],
token=info['token'],
cam2ego_rotation=cam_info['sensor2ego_rotation'],
cam2ego_translation=cam_info['sensor2ego_translation'],
ego2global_rotation=info['ego2global_rotation'],
ego2global_translation=info['ego2global_translation'],
cam_intrinsic=cam_info['cam_intrinsic'],
width=width,
height=height,
)
)
height=height))
for coco_info in coco_infos:
if coco_info is None:
continue
# add an empty key for coco format
coco_info["segmentation"] = []
coco_info["id"] = coco_ann_id
coco_2d_dict["annotations"].append(coco_info)
coco_info['segmentation'] = []
coco_info['id'] = coco_ann_id
coco_2d_dict['annotations'].append(coco_info)
coco_ann_id += 1
if mono3d:
json_prefix = f"{info_path[:-4]}_mono3d"
json_prefix = f'{info_path[:-4]}_mono3d'
else:
json_prefix = f"{info_path[:-4]}"
mmcv.dump(coco_2d_dict, f"{json_prefix}.coco.json")
json_prefix = f'{info_path[:-4]}'
mmcv.dump(coco_2d_dict, f'{json_prefix}.coco.json')
def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str], mono3d=True):
def get_2d_boxes(nusc,
sample_data_token: str,
visibilities: List[str],
mono3d=True):
"""Get the 2D annotation records for a given `sample_data_token`.
Args:
sample_data_token (str): Sample data token belonging to a camera \
keyframe.
visibilities (list[str]): Visibility filter.
mono3d (bool): Whether to get boxes with mono3d annotation.
Return:
list[dict]: List of 2D annotation record that belongs to the input
`sample_data_token`.
"""
# Get the sample data and the sample corresponding to that sample data.
sd_rec = nusc.get("sample_data", sample_data_token)
sd_rec = nusc.get('sample_data', sample_data_token)
assert sd_rec["sensor_modality"] == "camera", (
"Error: get_2d_boxes only works" " for camera sample_data!"
)
if not sd_rec["is_key_frame"]:
raise ValueError("The 2D re-projections are available only for keyframes.")
assert sd_rec[
'sensor_modality'] == 'camera', 'Error: get_2d_boxes only works' \
' for camera sample_data!'
if not sd_rec['is_key_frame']:
raise ValueError(
'The 2D re-projections are available only for keyframes.')
s_rec = nusc.get("sample", sd_rec["sample_token"])
s_rec = nusc.get('sample', sd_rec['sample_token'])
# Get the calibrated sensor and ego pose
# record to get the transformation matrices.
cs_rec = nusc.get("calibrated_sensor", sd_rec["calibrated_sensor_token"])
pose_rec = nusc.get("ego_pose", sd_rec["ego_pose_token"])
camera_intrinsic = np.array(cs_rec["camera_intrinsic"])
cs_rec = nusc.get('calibrated_sensor', sd_rec['calibrated_sensor_token'])
pose_rec = nusc.get('ego_pose', sd_rec['ego_pose_token'])
camera_intrinsic = np.array(cs_rec['camera_intrinsic'])
# Get all the annotation with the specified visibilties.
ann_recs = [nusc.get("sample_annotation", token) for token in s_rec["anns"]]
ann_recs = [
ann_rec for ann_rec in ann_recs if (ann_rec["visibility_token"] in visibilities)
nusc.get('sample_annotation', token) for token in s_rec['anns']
]
ann_recs = [
ann_rec for ann_rec in ann_recs
if (ann_rec['visibility_token'] in visibilities)
]
repro_recs = []
for ann_rec in ann_recs:
# Augment sample_annotation with token information.
ann_rec["sample_annotation_token"] = ann_rec["token"]
ann_rec["sample_data_token"] = sample_data_token
ann_rec['sample_annotation_token'] = ann_rec['token']
ann_rec['sample_data_token'] = sample_data_token
# Get the box in global coordinates.
box = nusc.get_box(ann_rec["token"])
box = nusc.get_box(ann_rec['token'])
# Move them to the ego-pose frame.
box.translate(-np.array(pose_rec["translation"]))
box.rotate(Quaternion(pose_rec["rotation"]).inverse)
box.translate(-np.array(pose_rec['translation']))
box.rotate(Quaternion(pose_rec['rotation']).inverse)
# Move them to the calibrated sensor frame.
box.translate(-np.array(cs_rec["translation"]))
box.rotate(Quaternion(cs_rec["rotation"]).inverse)
box.translate(-np.array(cs_rec['translation']))
box.rotate(Quaternion(cs_rec['rotation']).inverse)
# Filter out the corners that are not in front of the calibrated
# sensor.
@ -480,9 +510,8 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str], mono3d=T
corners_3d = corners_3d[:, in_front]
# Project 3d box to 2d.
corner_coords = (
view_points(corners_3d, camera_intrinsic, True).T[:, :2].tolist()
)
corner_coords = view_points(corners_3d, camera_intrinsic,
True).T[:, :2].tolist()
# Keep only corners that fall within the image.
final_coords = post_process_coords(corner_coords)
@ -495,49 +524,44 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str], mono3d=T
min_x, min_y, max_x, max_y = final_coords
# Generate dictionary record to be included in the .json file.
repro_rec = generate_record(
ann_rec, min_x, min_y, max_x, max_y, sample_data_token, sd_rec["filename"]
)
repro_rec = generate_record(ann_rec, min_x, min_y, max_x, max_y,
sample_data_token, sd_rec['filename'])
# If mono3d=True, add 3D annotations in camera coordinates
if mono3d and (repro_rec is not None):
loc = box.center.tolist()
dim = box.wlh
dim[[0, 1, 2]] = dim[[1, 2, 0]] # convert wlh to our lhw
dim = dim.tolist()
rot = box.orientation.yaw_pitch_roll[0]
rot = [-rot] # convert the rot to our cam coordinate
dim = box.wlh.tolist()
rot = [box.orientation.yaw_pitch_roll[0]]
global_velo2d = nusc.box_velocity(box.token)[:2]
global_velo3d = np.array([*global_velo2d, 0.0])
e2g_r_mat = Quaternion(pose_rec["rotation"]).rotation_matrix
c2e_r_mat = Quaternion(cs_rec["rotation"]).rotation_matrix
cam_velo3d = (
global_velo3d @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(c2e_r_mat).T
)
e2g_r_mat = Quaternion(pose_rec['rotation']).rotation_matrix
c2e_r_mat = Quaternion(cs_rec['rotation']).rotation_matrix
cam_velo3d = global_velo3d @ np.linalg.inv(
e2g_r_mat).T @ np.linalg.inv(c2e_r_mat).T
velo = cam_velo3d[0::2].tolist()
repro_rec["bbox_cam3d"] = loc + dim + rot
repro_rec["velo_cam3d"] = velo
repro_rec['bbox_cam3d'] = loc + dim + rot
repro_rec['velo_cam3d'] = velo
center3d = np.array(loc).reshape([1, 3])
center2d = points_cam2img(center3d, camera_intrinsic, with_depth=True)
repro_rec["center2d"] = center2d.squeeze().tolist()
center2d = points_cam2img(
center3d, camera_intrinsic, with_depth=True)
repro_rec['center2d'] = center2d.squeeze().tolist()
# normalized center2D + depth
# if samples with depth < 0 will be removed
if repro_rec["center2d"][2] <= 0:
if repro_rec['center2d'][2] <= 0:
continue
ann_token = nusc.get("sample_annotation", box.token)["attribute_tokens"]
ann_token = nusc.get('sample_annotation',
box.token)['attribute_tokens']
if len(ann_token) == 0:
attr_name = "None"
attr_name = 'None'
else:
attr_name = nusc.get("attribute", ann_token[0])["name"]
attr_name = nusc.get('attribute', ann_token[0])['name']
attr_id = nus_attributes.index(attr_name)
repro_rec["attribute_name"] = attr_name
repro_rec["attribute_id"] = attr_id
repro_rec['attribute_name'] = attr_name
repro_rec['attribute_id'] = attr_id
repro_recs.append(repro_rec)
@ -549,12 +573,10 @@ def post_process_coords(
) -> Union[Tuple[float, float, float, float], None]:
"""Get the intersection of the convex hull of the reprojected bbox corners
and the image canvas, return None if no intersection.
Args:
corner_coords (list[int]): Corner coordinates of reprojected
bounding box.
imsize (tuple[int]): Size of the image canvas.
Return:
tuple [float]: Intersection of the convex hull of the 2D box
corners and the image canvas.
@ -565,8 +587,7 @@ def post_process_coords(
if polygon_from_2d_box.intersects(img_canvas):
img_intersection = polygon_from_2d_box.intersection(img_canvas)
intersection_coords = np.array(
[coord for coord in img_intersection.exterior.coords]
)
[coord for coord in img_intersection.exterior.coords])
min_x = min(intersection_coords[:, 0])
min_y = min(intersection_coords[:, 1])
@ -578,18 +599,10 @@ def post_process_coords(
return None
def generate_record(
ann_rec: dict,
x1: float,
y1: float,
x2: float,
y2: float,
sample_data_token: str,
filename: str,
) -> OrderedDict:
def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
sample_data_token: str, filename: str) -> OrderedDict:
"""Generate one 2D annotation record given various informations on top of
the 2D bounding box coordinates.
Args:
ann_rec (dict): Original 3d annotation record.
x1 (float): Minimum value of the x coordinate.
@ -599,7 +612,6 @@ def generate_record(
sample_data_token (str): Sample data token.
filename (str):The corresponding image file where the annotation
is present.
Returns:
dict: A sample 2D annotation record.
- file_name (str): flie name
@ -611,39 +623,43 @@ def generate_record(
- iscrowd (int): whether the area is crowd
"""
repro_rec = OrderedDict()
repro_rec["sample_data_token"] = sample_data_token
repro_rec['sample_data_token'] = sample_data_token
coco_rec = dict()
relevant_keys = [
"attribute_tokens",
"category_name",
"instance_token",
"next",
"num_lidar_pts",
"num_radar_pts",
"prev",
"sample_annotation_token",
"sample_data_token",
"visibility_token",
'attribute_tokens',
'category_name',
'instance_token',
'next',
'num_lidar_pts',
'num_radar_pts',
'prev',
'sample_annotation_token',
'sample_data_token',
'visibility_token',
]
for key, value in ann_rec.items():
if key in relevant_keys:
repro_rec[key] = value
repro_rec["bbox_corners"] = [x1, y1, x2, y2]
repro_rec["filename"] = filename
repro_rec['bbox_corners'] = [x1, y1, x2, y2]
repro_rec['filename'] = filename
coco_rec["file_name"] = filename
coco_rec["image_id"] = sample_data_token
coco_rec["area"] = (y2 - y1) * (x2 - x1)
coco_rec['file_name'] = filename
coco_rec['image_id'] = sample_data_token
coco_rec['area'] = (y2 - y1) * (x2 - x1)
if repro_rec["category_name"] not in NuScenesDataset.NameMapping:
if repro_rec['category_name'] not in NuScenesDataset.NameMapping:
return None
cat_name = NuScenesDataset.NameMapping[repro_rec["category_name"]]
coco_rec["category_name"] = cat_name
coco_rec["category_id"] = nus_categories.index(cat_name)
coco_rec["bbox"] = [x1, y1, x2 - x1, y2 - y1]
coco_rec["iscrowd"] = 0
cat_name = NuScenesDataset.NameMapping[repro_rec['category_name']]
coco_rec['category_name'] = cat_name
coco_rec['category_id'] = nus_categories.index(cat_name)
coco_rec['bbox'] = [x1, y1, x2 - x1, y2 - y1]
coco_rec['iscrowd'] = 0
return coco_rec
if __name__ == '__main__':
create_nuscenes_infos('data/nuscenes/', 'radar_nuscenes_5sweeps')