Add BEVFusion-R (#443)
This commit is contained in:
parent
db75150717
commit
d0152cf97c
|
|
@ -127,3 +127,8 @@ dmypy.json
|
|||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Models and data
|
||||
models/*
|
||||
data/*
|
||||
runs/*
|
||||
|
|
|
|||
|
|
@ -183,6 +183,9 @@ train_pipeline:
|
|||
- lidar2image
|
||||
- img_aug_matrix
|
||||
- lidar_aug_matrix
|
||||
-
|
||||
type: GTDepth
|
||||
keyframe_only: true
|
||||
|
||||
test_pipeline:
|
||||
-
|
||||
|
|
@ -256,6 +259,9 @@ test_pipeline:
|
|||
- lidar2image
|
||||
- img_aug_matrix
|
||||
- lidar_aug_matrix
|
||||
-
|
||||
type: GTDepth
|
||||
keyframe_only: true
|
||||
|
||||
data:
|
||||
samples_per_gpu: 4
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
gt_paste_stop_epoch: 15
|
||||
|
||||
model:
|
||||
heads:
|
||||
object:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
|
||||
data:
|
||||
train:
|
||||
dataset:
|
||||
ann_file: nuscenes_radar/nuscenes_radar_infos_train_radar.pkl
|
||||
val:
|
||||
ann_file: nuscenes_radar/nuscenes_radar_infos_val_radar.pkl
|
||||
test:
|
||||
ann_file: nuscenes_radar/nuscenes_radar_infos_val_radar.pkl
|
||||
|
||||
augment2d:
|
||||
resize: [[0.38, 0.55], [0.48, 0.48]]
|
||||
|
||||
augment3d:
|
||||
scale: [0.9, 1.1]
|
||||
rotate: [0, 0]
|
||||
translate: 0.5
|
||||
|
||||
model:
|
||||
encoders:
|
||||
lidar: null
|
||||
camera:
|
||||
vtransform:
|
||||
type: LSSTransform
|
||||
image_size: ${image_size}
|
||||
xbound: [-51.2, 51.2, 0.8]
|
||||
ybound: [-51.2, 51.2, 0.8]
|
||||
zbound: [-10.0, 10.0, 20.0]
|
||||
dbound: [1.0, 60.0, 1.0]
|
||||
radar:
|
||||
voxelize_reduce: false
|
||||
voxelize:
|
||||
max_num_points: 20
|
||||
point_cloud_range: ${point_cloud_range}
|
||||
voxel_size: ${radar_voxel_size}
|
||||
max_voxels: [30000, 60000]
|
||||
backbone:
|
||||
type: RadarEncoder
|
||||
pts_voxel_encoder:
|
||||
type: RadarFeatureNet
|
||||
in_channels: 45
|
||||
feat_channels: [128, 128, 128, 64]
|
||||
with_distance: false
|
||||
point_cloud_range: ${point_cloud_range}
|
||||
voxel_size: ${radar_voxel_size}
|
||||
norm_cfg:
|
||||
type: BN1d
|
||||
eps: 1.0e-3
|
||||
momentum: 0.01
|
||||
pts_middle_encoder:
|
||||
type: PointPillarsScatter
|
||||
in_channels: 64
|
||||
output_shape: [128, 128]
|
||||
pts_bev_encoder: null
|
||||
heads:
|
||||
object:
|
||||
test_cfg:
|
||||
nms_type:
|
||||
- circle
|
||||
- rotate
|
||||
- rotate
|
||||
- circle
|
||||
- rotate
|
||||
- rotate
|
||||
nms_scale:
|
||||
- [1.0]
|
||||
- [1.0, 1.0]
|
||||
- [1.0, 1.0]
|
||||
- [1.0]
|
||||
- [1.0, 1.0]
|
||||
- [2.5, 4.0]
|
||||
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
|
||||
|
||||
image_size: [256, 704]
|
||||
|
||||
model:
|
||||
encoders:
|
||||
camera:
|
||||
backbone:
|
||||
type: ResNet
|
||||
depth: 50
|
||||
num_stages: 4
|
||||
out_indices: [0, 1, 2, 3]
|
||||
norm_cfg:
|
||||
type: BN2d
|
||||
requires_grad: true
|
||||
norm_eval: false
|
||||
init_cfg:
|
||||
type: Pretrained
|
||||
checkpoint: torchvision://resnet50
|
||||
neck:
|
||||
type: SECONDFPN
|
||||
in_channels: [256, 512, 1024, 2048]
|
||||
out_channels: [128, 128, 128, 128]
|
||||
upsample_strides: [0.25, 0.5, 1, 2]
|
||||
vtransform:
|
||||
type: LSSTransform
|
||||
in_channels: 512
|
||||
out_channels: 64
|
||||
image_size: ${image_size}
|
||||
feature_size: ${[image_size[0] // 16, image_size[1] // 16]}
|
||||
xbound: [-51.2, 51.2, 0.8]
|
||||
ybound: [-51.2, 51.2, 0.8]
|
||||
zbound: [-10.0, 10.0, 20.0]
|
||||
dbound: [1.0, 60.0, 1.0]
|
||||
downsample: 1
|
||||
decoder:
|
||||
backbone:
|
||||
type: GeneralizedResNet
|
||||
in_channels: 64
|
||||
blocks:
|
||||
- [2, 128, 2]
|
||||
- [2, 256, 2]
|
||||
- [2, 512, 1]
|
||||
neck:
|
||||
type: LSSFPN
|
||||
in_indices: [-1, 0]
|
||||
in_channels: [512, 128]
|
||||
out_channels: 256
|
||||
scale_factor: 2
|
||||
heads:
|
||||
object:
|
||||
train_cfg:
|
||||
code_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
fuser:
|
||||
type: ConvFuser
|
||||
in_channels: [64, 64]
|
||||
out_channels: 64
|
||||
|
||||
|
||||
optimizer:
|
||||
paramwise_cfg:
|
||||
custom_keys:
|
||||
absolute_pos_embed:
|
||||
decay_mult: 0
|
||||
relative_position_bias_table:
|
||||
decay_mult: 0
|
||||
# encoders.camera.backbone:
|
||||
# lr_mult: 0.1
|
||||
|
||||
|
||||
# lr_config:
|
||||
# policy: cyclic
|
||||
# target_ratio: 5.0
|
||||
# cyclic_times: 1
|
||||
# step_ratio_up: 0.4
|
||||
|
||||
# momentum_config:
|
||||
# policy: cyclic
|
||||
# cyclic_times: 1
|
||||
# step_ratio_up: 0.4
|
||||
|
||||
data:
|
||||
samples_per_gpu: 4
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
|
||||
model:
|
||||
encoders:
|
||||
camera:
|
||||
vtransform:
|
||||
type: AwareDBEVDepth
|
||||
bevdepth_downsample: 16
|
||||
bevdepth_refine: false
|
||||
depth_loss_factor: 3.0
|
||||
use_points: radar
|
||||
depth_input: one-hot
|
||||
height_expand: true
|
||||
add_depth_features: true
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
model:
|
||||
encoders:
|
||||
camera:
|
||||
vtransform:
|
||||
type: AwareBEVDepth
|
||||
bevdepth_downsample: 16
|
||||
bevdepth_refine: false
|
||||
depth_loss_factor: 3.0
|
||||
in_channels: 512
|
||||
out_channels: 64
|
||||
feature_size: ${[image_size[0] // 16, image_size[1] // 16]}
|
||||
xbound: [-51.2, 51.2, 0.8]
|
||||
ybound: [-51.2, 51.2, 0.8]
|
||||
zbound: [-10.0, 10.0, 20.0]
|
||||
dbound: [1.0, 60.0, 1.0]
|
||||
downsample: 1
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
model:
|
||||
encoders:
|
||||
camera:
|
||||
backbone:
|
||||
type: ResNet
|
||||
depth: 50
|
||||
num_stages: 4
|
||||
out_indices: [0, 1, 2, 3]
|
||||
norm_cfg:
|
||||
type: BN2d
|
||||
requires_grad: true
|
||||
norm_eval: false
|
||||
init_cfg:
|
||||
type: Pretrained
|
||||
checkpoint: torchvision://resnet50
|
||||
neck:
|
||||
type: SECONDFPN
|
||||
in_channels: [256, 512, 1024, 2048]
|
||||
out_channels: [128, 128, 128, 128]
|
||||
upsample_strides: [0.25, 0.5, 1, 2]
|
||||
vtransform:
|
||||
type: LSSTransform
|
||||
in_channels: 512
|
||||
out_channels: 64
|
||||
image_size: ${image_size}
|
||||
feature_size: ${[image_size[0] // 16, image_size[1] // 16]}
|
||||
xbound: [-51.2, 51.2, 0.8]
|
||||
ybound: [-51.2, 51.2, 0.8]
|
||||
zbound: [-10.0, 10.0, 20.0]
|
||||
dbound: [1.0, 60.0, 1.0]
|
||||
downsample: 1
|
||||
decoder:
|
||||
backbone:
|
||||
type: GeneralizedResNet
|
||||
in_channels: 64
|
||||
blocks:
|
||||
- [2, 128, 2]
|
||||
- [2, 256, 2]
|
||||
- [2, 512, 1]
|
||||
neck:
|
||||
type: LSSFPN
|
||||
in_indices: [-1, 0]
|
||||
in_channels: [512, 128]
|
||||
out_channels: 256
|
||||
scale_factor: 2
|
||||
|
|
@ -82,4 +82,4 @@ momentum_config:
|
|||
step_ratio_up: 0.4
|
||||
|
||||
data:
|
||||
samples_per_gpu: 6
|
||||
samples_per_gpu: 4
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
augment3d:
|
||||
scale: [0.95, 1.05]
|
||||
rotate: [-0.3925, 0.3925]
|
||||
|
|
@ -31,5 +32,3 @@ model:
|
|||
- [1.0]
|
||||
- [1.0, 1.0]
|
||||
- [2.5, 4.0]
|
||||
|
||||
lr_config: null
|
||||
|
|
|
|||
|
|
@ -1,3 +1,13 @@
|
|||
|
||||
# data:
|
||||
# train:
|
||||
# dataset:
|
||||
# ann_file: nuscenes_radar/nuscenes_radar_infos_train_radar.pkl
|
||||
# val:
|
||||
# ann_file: nuscenes_radar/nuscenes_radar_infos_val_radar.pkl
|
||||
# test:
|
||||
# ann_file: nuscenes_radar/nuscenes_radar_infos_val_radar.pkl
|
||||
|
||||
model:
|
||||
decoder:
|
||||
backbone:
|
||||
|
|
@ -21,6 +31,12 @@ optimizer:
|
|||
type: AdamW
|
||||
lr: 2.0e-4
|
||||
weight_decay: 0.01
|
||||
paramwise_cfg:
|
||||
custom_keys:
|
||||
absolute_pos_embed:
|
||||
decay_mult: 0
|
||||
relative_position_bias_table:
|
||||
decay_mult: 0
|
||||
|
||||
optimizer_config:
|
||||
grad_clip:
|
||||
|
|
@ -32,4 +48,4 @@ lr_config:
|
|||
warmup: linear
|
||||
warmup_iters: 500
|
||||
warmup_ratio: 0.33333333
|
||||
min_lr_ratio: 1.0e-3
|
||||
min_lr_ratio: 1.0e-3
|
||||
|
|
|
|||
|
|
@ -1,6 +1,236 @@
|
|||
radar_sweeps: 6
|
||||
radar_max_points: 2500
|
||||
radar_use_dims: [0, 1, 2, 5, 8, 9, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56]
|
||||
radar_compensate_velocity: true
|
||||
radar_filtering: none
|
||||
|
||||
radar_voxel_size: [0.8, 0.8, 8]
|
||||
image_size: [256, 704]
|
||||
radar_jitter: 0
|
||||
radar_normalize: false
|
||||
|
||||
model:
|
||||
type: BEVFusion
|
||||
encoders: null
|
||||
fuser: null
|
||||
heads:
|
||||
map: null
|
||||
|
||||
|
||||
train_pipeline:
|
||||
-
|
||||
type: LoadMultiViewImageFromFiles
|
||||
to_float32: true
|
||||
-
|
||||
type: LoadPointsFromFile
|
||||
coord_type: LIDAR
|
||||
load_dim: ${load_dim}
|
||||
use_dim: ${use_dim}
|
||||
reduce_beams: ${reduce_beams}
|
||||
load_augmented: ${load_augmented}
|
||||
-
|
||||
type: LoadPointsFromMultiSweeps
|
||||
sweeps_num: 0
|
||||
load_dim: ${load_dim}
|
||||
use_dim: ${use_dim}
|
||||
reduce_beams: ${reduce_beams}
|
||||
pad_empty_sweeps: true
|
||||
remove_close: true
|
||||
load_augmented: ${load_augmented}
|
||||
-
|
||||
type: LoadRadarPointsMultiSweeps
|
||||
load_dim: 18
|
||||
sweeps_num: ${radar_sweeps}
|
||||
use_dim: ${radar_use_dims}
|
||||
max_num: ${radar_max_points}
|
||||
compensate_velocity: ${radar_compensate_velocity}
|
||||
filtering: ${radar_filtering}
|
||||
normalize: ${radar_normalize}
|
||||
-
|
||||
type: LoadAnnotations3D
|
||||
with_bbox_3d: true
|
||||
with_label_3d: true
|
||||
with_attr_label: False
|
||||
-
|
||||
type: ObjectPaste
|
||||
stop_epoch: ${gt_paste_stop_epoch}
|
||||
db_sampler:
|
||||
dataset_root: ${dataset_root}
|
||||
info_path: ${'data/nuscenes/' + "nuscenes_dbinfos_train.pkl"}
|
||||
rate: 1.0
|
||||
prepare:
|
||||
filter_by_difficulty: [-1]
|
||||
filter_by_min_points:
|
||||
car: 5
|
||||
truck: 5
|
||||
bus: 5
|
||||
trailer: 5
|
||||
construction_vehicle: 5
|
||||
traffic_cone: 5
|
||||
barrier: 5
|
||||
motorcycle: 5
|
||||
bicycle: 5
|
||||
pedestrian: 5
|
||||
classes: ${object_classes}
|
||||
sample_groups:
|
||||
car: 2
|
||||
truck: 3
|
||||
construction_vehicle: 7
|
||||
bus: 4
|
||||
trailer: 6
|
||||
barrier: 2
|
||||
motorcycle: 6
|
||||
bicycle: 6
|
||||
pedestrian: 2
|
||||
traffic_cone: 2
|
||||
points_loader:
|
||||
type: LoadPointsFromFile
|
||||
coord_type: LIDAR
|
||||
load_dim: ${load_dim}
|
||||
use_dim: ${use_dim}
|
||||
reduce_beams: ${reduce_beams}
|
||||
-
|
||||
type: ImageAug3D
|
||||
final_dim: ${image_size}
|
||||
resize_lim: ${augment2d.resize[0]}
|
||||
bot_pct_lim: [0.0, 0.0]
|
||||
rot_lim: ${augment2d.rotate}
|
||||
rand_flip: true
|
||||
is_train: true
|
||||
-
|
||||
type: GlobalRotScaleTrans
|
||||
resize_lim: ${augment3d.scale}
|
||||
rot_lim: ${augment3d.rotate}
|
||||
trans_lim: ${augment3d.translate}
|
||||
is_train: true
|
||||
-
|
||||
type: RandomFlip3D
|
||||
-
|
||||
type: PointsRangeFilter
|
||||
point_cloud_range: ${point_cloud_range}
|
||||
-
|
||||
type: ObjectRangeFilter
|
||||
point_cloud_range: ${point_cloud_range}
|
||||
-
|
||||
type: ObjectNameFilter
|
||||
classes: ${object_classes}
|
||||
-
|
||||
type: ImageNormalize
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
-
|
||||
type: GridMask
|
||||
use_h: true
|
||||
use_w: true
|
||||
max_epoch: ${max_epochs}
|
||||
rotate: 1
|
||||
offset: false
|
||||
ratio: 0.5
|
||||
mode: 1
|
||||
prob: ${augment2d.gridmask.prob}
|
||||
fixed_prob: ${augment2d.gridmask.fixed_prob}
|
||||
-
|
||||
type: PointShuffle
|
||||
-
|
||||
type: DefaultFormatBundle3D
|
||||
classes: ${object_classes}
|
||||
-
|
||||
type: Collect3D
|
||||
keys:
|
||||
- img
|
||||
- points
|
||||
- radar
|
||||
- gt_bboxes_3d
|
||||
- gt_labels_3d
|
||||
meta_keys:
|
||||
- camera_intrinsics
|
||||
- camera2ego
|
||||
- lidar2ego
|
||||
- lidar2camera
|
||||
- lidar2image
|
||||
- camera2lidar
|
||||
- img_aug_matrix
|
||||
- lidar_aug_matrix
|
||||
-
|
||||
type: GTDepth
|
||||
keyframe_only: true
|
||||
|
||||
test_pipeline:
|
||||
-
|
||||
type: LoadMultiViewImageFromFiles
|
||||
to_float32: true
|
||||
-
|
||||
type: LoadPointsFromFile
|
||||
coord_type: LIDAR
|
||||
load_dim: ${load_dim}
|
||||
use_dim: ${use_dim}
|
||||
reduce_beams: ${reduce_beams}
|
||||
load_augmented: ${load_augmented}
|
||||
-
|
||||
type: LoadPointsFromMultiSweeps
|
||||
sweeps_num: 9
|
||||
load_dim: ${load_dim}
|
||||
use_dim: ${use_dim}
|
||||
reduce_beams: ${reduce_beams}
|
||||
pad_empty_sweeps: true
|
||||
remove_close: true
|
||||
load_augmented: ${load_augmented}
|
||||
-
|
||||
type: LoadRadarPointsMultiSweeps
|
||||
load_dim: 18
|
||||
sweeps_num: ${radar_sweeps}
|
||||
use_dim: ${radar_use_dims}
|
||||
max_num: ${radar_max_points}
|
||||
compensate_velocity: ${radar_compensate_velocity}
|
||||
filtering: ${radar_filtering}
|
||||
normalize: ${radar_normalize}
|
||||
-
|
||||
type: LoadAnnotations3D
|
||||
with_bbox_3d: true
|
||||
with_label_3d: true
|
||||
with_attr_label: False
|
||||
-
|
||||
type: ImageAug3D
|
||||
final_dim: ${image_size}
|
||||
resize_lim: ${augment2d.resize[1]}
|
||||
bot_pct_lim: [0.0, 0.0]
|
||||
rot_lim: [0.0, 0.0]
|
||||
rand_flip: false
|
||||
is_train: false
|
||||
-
|
||||
type: GlobalRotScaleTrans
|
||||
resize_lim: [1.0, 1.0]
|
||||
rot_lim: [0.0, 0.0]
|
||||
trans_lim: 0.0
|
||||
is_train: false
|
||||
-
|
||||
type: PointsRangeFilter
|
||||
point_cloud_range: ${point_cloud_range}
|
||||
-
|
||||
type: ImageNormalize
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
-
|
||||
type: DefaultFormatBundle3D
|
||||
classes: ${object_classes}
|
||||
-
|
||||
type: Collect3D
|
||||
keys:
|
||||
- img
|
||||
- points
|
||||
- radar
|
||||
- gt_bboxes_3d
|
||||
- gt_labels_3d
|
||||
#- gt_masks_bev
|
||||
meta_keys:
|
||||
- camera_intrinsics
|
||||
- camera2ego
|
||||
- lidar2ego
|
||||
- lidar2camera
|
||||
- lidar2image
|
||||
- camera2lidar
|
||||
- img_aug_matrix
|
||||
- lidar_aug_matrix
|
||||
-
|
||||
type: GTDepth
|
||||
keyframe_only: true
|
||||
|
|
@ -2,8 +2,9 @@ from .base_points import BasePoints
|
|||
from .cam_points import CameraPoints
|
||||
from .depth_points import DepthPoints
|
||||
from .lidar_points import LiDARPoints
|
||||
from .radar_points import RadarPoints
|
||||
|
||||
__all__ = ["BasePoints", "CameraPoints", "DepthPoints", "LiDARPoints"]
|
||||
__all__ = ["BasePoints", "CameraPoints", "DepthPoints", "LiDARPoints", "RadarPoints"]
|
||||
|
||||
|
||||
def get_points_type(points_type):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
from .base_points import BasePoints
|
||||
import torch
|
||||
|
||||
class RadarPoints(BasePoints):
|
||||
"""Points of instances in LIDAR coordinates.
|
||||
Args:
|
||||
tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix.
|
||||
points_dim (int): Number of the dimension of a point.
|
||||
Each row is (x, y, z). Default to 3.
|
||||
attribute_dims (dict): Dictionary to indicate the meaning of extra
|
||||
dimension. Default to None.
|
||||
Attributes:
|
||||
tensor (torch.Tensor): Float matrix of N x points_dim.
|
||||
points_dim (int): Integer indicating the dimension of a point.
|
||||
Each row is (x, y, z, ...).
|
||||
attribute_dims (bool): Dictionary to indicate the meaning of extra
|
||||
dimension. Default to None.
|
||||
rotation_axis (int): Default rotation axis for points rotation.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor, points_dim=3, attribute_dims=None):
|
||||
super(RadarPoints, self).__init__(
|
||||
tensor, points_dim=points_dim, attribute_dims=attribute_dims
|
||||
)
|
||||
self.rotation_axis = 2
|
||||
|
||||
def flip(self, bev_direction="horizontal"):
|
||||
"""Flip the boxes in BEV along given BEV direction."""
|
||||
if bev_direction == "horizontal":
|
||||
self.tensor[:, 1] = -self.tensor[:, 1]
|
||||
self.tensor[:, 4] = -self.tensor[:, 4]
|
||||
elif bev_direction == "vertical":
|
||||
self.tensor[:, 0] = -self.tensor[:, 0]
|
||||
self.tensor[:, 3] = -self.tensor[:, 3]
|
||||
|
||||
def jitter(self, amount):
|
||||
jitter_noise = torch.randn(self.tensor.shape[0], 3)
|
||||
jitter_noise *= amount
|
||||
self.tensor[:, :3] += jitter_noise
|
||||
|
||||
def scale(self, scale_factor):
|
||||
"""Scale the points with horizontal and vertical scaling factors.
|
||||
Args:
|
||||
scale_factors (float): Scale factors to scale the points.
|
||||
"""
|
||||
self.tensor[:, :3] *= scale_factor
|
||||
self.tensor[:, 3:5] *= scale_factor
|
||||
|
||||
def rotate(self, rotation, axis=None):
|
||||
"""Rotate points with the given rotation matrix or angle.
|
||||
Args:
|
||||
rotation (float, np.ndarray, torch.Tensor): Rotation matrix
|
||||
or angle.
|
||||
axis (int): Axis to rotate at. Defaults to None.
|
||||
"""
|
||||
if not isinstance(rotation, torch.Tensor):
|
||||
rotation = self.tensor.new_tensor(rotation)
|
||||
assert (
|
||||
rotation.shape == torch.Size([3, 3]) or rotation.numel() == 1
|
||||
), f"invalid rotation shape {rotation.shape}"
|
||||
|
||||
if axis is None:
|
||||
axis = self.rotation_axis
|
||||
|
||||
if rotation.numel() == 1:
|
||||
rot_sin = torch.sin(rotation)
|
||||
rot_cos = torch.cos(rotation)
|
||||
if axis == 1:
|
||||
rot_mat_T = rotation.new_tensor(
|
||||
[[rot_cos, 0, -rot_sin], [0, 1, 0], [rot_sin, 0, rot_cos]]
|
||||
)
|
||||
elif axis == 2 or axis == -1:
|
||||
rot_mat_T = rotation.new_tensor(
|
||||
[[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]
|
||||
)
|
||||
elif axis == 0:
|
||||
rot_mat_T = rotation.new_tensor(
|
||||
[[0, rot_cos, -rot_sin], [0, rot_sin, rot_cos], [1, 0, 0]]
|
||||
)
|
||||
else:
|
||||
raise ValueError("axis should in range")
|
||||
rot_mat_T = rot_mat_T.T
|
||||
elif rotation.numel() == 9:
|
||||
rot_mat_T = rotation
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
|
||||
self.tensor[:, 3:5] = self.tensor[:, 3:5] @ rot_mat_T[:2, :2]
|
||||
|
||||
return rot_mat_T
|
||||
|
||||
def in_range_bev(self, point_range):
|
||||
"""Check whether the points are in the given range.
|
||||
Args:
|
||||
point_range (list | torch.Tensor): The range of point
|
||||
in order of (x_min, y_min, x_max, y_max).
|
||||
Returns:
|
||||
torch.Tensor: Indicating whether each point is inside \
|
||||
the reference range.
|
||||
"""
|
||||
in_range_flags = (
|
||||
(self.tensor[:, 0] > point_range[0])
|
||||
& (self.tensor[:, 1] > point_range[1])
|
||||
& (self.tensor[:, 0] < point_range[2])
|
||||
& (self.tensor[:, 1] < point_range[3])
|
||||
)
|
||||
return in_range_flags
|
||||
|
||||
def convert_to(self, dst, rt_mat=None):
|
||||
"""Convert self to ``dst`` mode.
|
||||
Args:
|
||||
dst (:obj:`CoordMode`): The target Point mode.
|
||||
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
|
||||
matrix between different coordinates. Defaults to None.
|
||||
The conversion from `src` coordinates to `dst` coordinates
|
||||
usually comes along the change of sensors, e.g., from camera
|
||||
to LiDAR. This requires a transformation matrix.
|
||||
Returns:
|
||||
:obj:`BasePoints`: The converted point of the same type \
|
||||
in the `dst` mode.
|
||||
"""
|
||||
from mmdet3d.core.bbox import Coord3DMode
|
||||
|
||||
return Coord3DMode.convert_point(point=self, src=Coord3DMode.LIDAR, dst=dst, rt_mat=rt_mat)
|
||||
|
|
@ -215,9 +215,15 @@ class NuScenesDataset(Custom3DDataset):
|
|||
lidar_path=info["lidar_path"],
|
||||
sweeps=info["sweeps"],
|
||||
timestamp=info["timestamp"],
|
||||
location=info["location"],
|
||||
location=info.get('location', None),
|
||||
radar=info.get('radars', None),
|
||||
)
|
||||
|
||||
if data['location'] is None:
|
||||
data.pop('location')
|
||||
if data['radar'] is None:
|
||||
data.pop('radar')
|
||||
|
||||
# ego to global transform
|
||||
ego2global = np.eye(4).astype(np.float32)
|
||||
ego2global[:3, :3] = Quaternion(info["ego2global_rotation"]).rotation_matrix
|
||||
|
|
@ -253,7 +259,7 @@ class NuScenesDataset(Custom3DDataset):
|
|||
|
||||
# camera intrinsics
|
||||
camera_intrinsics = np.eye(4).astype(np.float32)
|
||||
camera_intrinsics[:3, :3] = camera_info["camera_intrinsics"]
|
||||
camera_intrinsics[:3, :3] = camera_info["cam_intrinsic"]
|
||||
data["camera_intrinsics"].append(camera_intrinsics)
|
||||
|
||||
# lidar to image transform
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ class DefaultFormatBundle3D:
|
|||
assert isinstance(results["points"], BasePoints)
|
||||
results["points"] = DC(results["points"].tensor)
|
||||
|
||||
if "radar" in results:
|
||||
results["radar"] = DC(results["radar"].tensor)
|
||||
|
||||
for key in ["voxels", "coors", "voxel_centers", "num_points"]:
|
||||
if key not in results:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -3,15 +3,18 @@ from typing import Any, Dict, Tuple
|
|||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmdet3d.core.points import RadarPoints
|
||||
from nuscenes.utils.data_classes import RadarPointCloud
|
||||
from nuscenes.map_expansion.map_api import NuScenesMap
|
||||
from nuscenes.map_expansion.map_api import locations as LOCATIONS
|
||||
from PIL import Image
|
||||
|
||||
|
||||
from mmdet3d.core.points import BasePoints, get_points_type
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines import LoadAnnotations
|
||||
|
||||
import torch
|
||||
|
||||
from .loading_utils import load_augmented_point_cloud, reduce_LiDAR_beams
|
||||
|
||||
|
||||
|
|
@ -180,6 +183,7 @@ class LoadPointsFromMultiSweeps:
|
|||
cloud arrays.
|
||||
"""
|
||||
points = results["points"]
|
||||
points = points[:, self.use_dim]
|
||||
points.tensor[:, 4] = 0
|
||||
sweep_points_list = [points]
|
||||
ts = results["timestamp"] / 1e6
|
||||
|
|
@ -216,6 +220,7 @@ class LoadPointsFromMultiSweeps:
|
|||
|
||||
if self.remove_close:
|
||||
points_sweep = self._remove_close(points_sweep)
|
||||
points_sweep = points_sweep[:, self.use_dim]
|
||||
sweep_ts = sweep["timestamp"] / 1e6
|
||||
points_sweep[:, :3] = (
|
||||
points_sweep[:, :3] @ sweep["sensor2lidar_rotation"].T
|
||||
|
|
@ -226,7 +231,7 @@ class LoadPointsFromMultiSweeps:
|
|||
sweep_points_list.append(points_sweep)
|
||||
|
||||
points = points.cat(sweep_points_list)
|
||||
points = points[:, self.use_dim]
|
||||
|
||||
results["points"] = points
|
||||
return results
|
||||
|
||||
|
|
@ -556,3 +561,234 @@ class LoadAnnotations3D(LoadAnnotations):
|
|||
results = self._load_attr_labels(results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class NormalizePointFeatures:
|
||||
def __call__(self, results):
|
||||
points = results["points"]
|
||||
points.tensor[:, 3] = torch.tanh(points.tensor[:, 3])
|
||||
results["points"] = points
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadRadarPointsMultiSweeps(object):
|
||||
"""Load radar points from multiple sweeps.
|
||||
This is usually used for nuScenes dataset to utilize previous sweeps.
|
||||
Args:
|
||||
sweeps_num (int): Number of sweeps. Defaults to 10.
|
||||
load_dim (int): Dimension number of the loaded points. Defaults to 5.
|
||||
use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4].
|
||||
file_client_args (dict): Config dict of file clients, refer to
|
||||
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
|
||||
for more details. Defaults to dict(backend='disk').
|
||||
pad_empty_sweeps (bool): Whether to repeat keyframe when
|
||||
sweeps is empty. Defaults to False.
|
||||
remove_close (bool): Whether to remove close points.
|
||||
Defaults to False.
|
||||
test_mode (bool): If test_model=True used for testing, it will not
|
||||
randomly sample sweeps but select the nearest N frames.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
load_dim=18,
|
||||
use_dim=[0, 1, 2, 3, 4],
|
||||
sweeps_num=3,
|
||||
file_client_args=dict(backend='disk'),
|
||||
max_num=300,
|
||||
pc_range=[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0],
|
||||
compensate_velocity=False,
|
||||
normalize_dims=[(3, 0, 50), (4, -100, 100), (5, -100, 100)],
|
||||
filtering='default',
|
||||
normalize=False,
|
||||
test_mode=False):
|
||||
self.load_dim = load_dim
|
||||
self.use_dim = use_dim
|
||||
self.sweeps_num = sweeps_num
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = None
|
||||
self.max_num = max_num
|
||||
self.test_mode = test_mode
|
||||
self.pc_range = pc_range
|
||||
self.compensate_velocity = compensate_velocity
|
||||
self.normalize_dims = normalize_dims
|
||||
self.filtering = filtering
|
||||
self.normalize = normalize
|
||||
|
||||
self.encoding = [
|
||||
(3, 'one-hot', 8), # dynprop
|
||||
(11, 'one-hot', 5), # ambig_state
|
||||
(14, 'one-hot', 18), # invalid_state
|
||||
(15, 'ordinal', 7), # pdh
|
||||
(0, 'nusc-filter', 1) # binary feature: 1 if nusc would have filtered it out
|
||||
]
|
||||
|
||||
|
||||
def perform_encodings(self, points, encoding):
|
||||
for idx, encoding_type, encoding_dims in self.encoding:
|
||||
|
||||
assert encoding_type in ['one-hot', 'ordinal', 'nusc-filter']
|
||||
|
||||
feat = points[:, idx]
|
||||
|
||||
if encoding_type == 'one-hot':
|
||||
encoding = np.zeros((points.shape[0], encoding_dims))
|
||||
encoding[np.arange(feat.shape[0]), np.rint(feat).astype(int)] = 1
|
||||
if encoding_type == 'ordinal':
|
||||
encoding = np.zeros((points.shape[0], encoding_dims))
|
||||
for i in range(encoding_dims):
|
||||
encoding[:, i] = (np.rint(feat) > i).astype(int)
|
||||
if encoding_type == 'nusc-filter':
|
||||
encoding = np.zeros((points.shape[0], encoding_dims))
|
||||
mask1 = (points[:, 14] == 0)
|
||||
mask2 = (points[:, 3] < 7)
|
||||
mask3 = (points[:, 11] == 3)
|
||||
|
||||
encoding[mask1 & mask2 & mask3, 0] = 1
|
||||
|
||||
|
||||
points = np.concatenate([points, encoding], axis=1)
|
||||
return points
|
||||
|
||||
def _load_points(self, pts_filename):
|
||||
"""Private function to load point clouds data.
|
||||
Args:
|
||||
pts_filename (str): Filename of point clouds data.
|
||||
Returns:
|
||||
np.ndarray: An array containing point clouds data.
|
||||
[N, 18]
|
||||
"""
|
||||
|
||||
invalid_states, dynprop_states, ambig_states = {
|
||||
'default': ([0], range(7), [3]),
|
||||
'none': (range(18), range(8), range(5)),
|
||||
}[self.filtering]
|
||||
|
||||
radar_obj = RadarPointCloud.from_file(
|
||||
pts_filename,
|
||||
invalid_states, dynprop_states, ambig_states
|
||||
)
|
||||
|
||||
#[18, N]
|
||||
points = radar_obj.points
|
||||
|
||||
return points.transpose().astype(np.float32)
|
||||
|
||||
|
||||
def _pad_or_drop(self, points):
|
||||
'''
|
||||
points: [N, 18]
|
||||
'''
|
||||
|
||||
num_points = points.shape[0]
|
||||
|
||||
if num_points == self.max_num:
|
||||
masks = np.ones((num_points, 1),
|
||||
dtype=points.dtype)
|
||||
|
||||
return points, masks
|
||||
|
||||
if num_points > self.max_num:
|
||||
points = np.random.permutation(points)[:self.max_num, :]
|
||||
masks = np.ones((self.max_num, 1),
|
||||
dtype=points.dtype)
|
||||
|
||||
return points, masks
|
||||
|
||||
if num_points < self.max_num:
|
||||
zeros = np.zeros((self.max_num - num_points, points.shape[1]),
|
||||
dtype=points.dtype)
|
||||
masks = np.ones((num_points, 1),
|
||||
dtype=points.dtype)
|
||||
|
||||
points = np.concatenate((points, zeros), axis=0)
|
||||
masks = np.concatenate((masks, zeros.copy()[:, [0]]), axis=0)
|
||||
|
||||
return points, masks
|
||||
|
||||
def normalize_feats(self, points, normalize_dims):
|
||||
for dim, min, max in normalize_dims:
|
||||
points[:, dim] -= min
|
||||
points[:, dim] /= (max-min)
|
||||
return points
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to load multi-sweep point clouds from files.
|
||||
Args:
|
||||
results (dict): Result dict containing multi-sweep point cloud \
|
||||
filenames.
|
||||
Returns:
|
||||
dict: The result dict containing the multi-sweep points data. \
|
||||
Added key and value are described below.
|
||||
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point \
|
||||
cloud arrays.
|
||||
"""
|
||||
radars_dict = results['radar']
|
||||
|
||||
points_sweep_list = []
|
||||
for key, sweeps in radars_dict.items():
|
||||
if len(sweeps) < self.sweeps_num:
|
||||
idxes = list(range(len(sweeps)))
|
||||
else:
|
||||
idxes = list(range(self.sweeps_num))
|
||||
|
||||
ts = sweeps[0]['timestamp'] * 1e-6
|
||||
for idx in idxes:
|
||||
sweep = sweeps[idx]
|
||||
|
||||
points_sweep = self._load_points(sweep['data_path'])
|
||||
points_sweep = np.copy(points_sweep).reshape(-1, self.load_dim)
|
||||
|
||||
timestamp = sweep['timestamp'] * 1e-6
|
||||
time_diff = ts - timestamp
|
||||
time_diff = np.ones((points_sweep.shape[0], 1)) * time_diff
|
||||
|
||||
# velocity compensated by the ego motion in sensor frame
|
||||
velo_comp = points_sweep[:, 8:10]
|
||||
velo_comp = np.concatenate(
|
||||
(velo_comp, np.zeros((velo_comp.shape[0], 1))), 1)
|
||||
velo_comp = velo_comp @ sweep['sensor2lidar_rotation'].T
|
||||
velo_comp = velo_comp[:, :2]
|
||||
|
||||
# velocity in sensor frame
|
||||
velo = points_sweep[:, 6:8]
|
||||
velo = np.concatenate(
|
||||
(velo, np.zeros((velo.shape[0], 1))), 1)
|
||||
velo = velo @ sweep['sensor2lidar_rotation'].T
|
||||
velo = velo[:, :2]
|
||||
|
||||
points_sweep[:, :3] = points_sweep[:, :3] @ sweep[
|
||||
'sensor2lidar_rotation'].T
|
||||
points_sweep[:, :3] += sweep['sensor2lidar_translation']
|
||||
|
||||
if self.compensate_velocity:
|
||||
points_sweep[:, :2] += velo_comp * time_diff
|
||||
|
||||
points_sweep_ = np.concatenate(
|
||||
[points_sweep[:, :6], velo,
|
||||
velo_comp, points_sweep[:, 10:],
|
||||
time_diff], axis=1)
|
||||
|
||||
# current format is x y z dyn_prop id rcs vx vy vx_comp vy_comp is_quality_valid ambig_state x_rms y_rms invalid_state pdh0 vx_rms vy_rms timestamp
|
||||
points_sweep_list.append(points_sweep_)
|
||||
|
||||
points = np.concatenate(points_sweep_list, axis=0)
|
||||
points = self.perform_encodings(points, self.encoding)
|
||||
points = points[:, self.use_dim]
|
||||
|
||||
if self.normalize:
|
||||
points = self.normalize_feats(points, self.normalize_dims)
|
||||
|
||||
points = RadarPoints(
|
||||
points, points_dim=points.shape[-1], attribute_dims=None
|
||||
)
|
||||
|
||||
results['radar'] = points
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
"""str: Return a string that describes the module."""
|
||||
return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})'
|
||||
|
|
@ -22,6 +22,78 @@ from ..builder import OBJECTSAMPLERS
|
|||
from .utils import noise_per_object_v3_
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class GTDepth:
|
||||
def __init__(self, keyframe_only=False):
|
||||
self.keyframe_only = keyframe_only
|
||||
|
||||
def __call__(self, data):
|
||||
sensor2ego = data['camera2ego'].data
|
||||
cam_intrinsic = data['camera_intrinsics'].data
|
||||
img_aug_matrix = data['img_aug_matrix'].data
|
||||
bev_aug_matrix = data['lidar_aug_matrix'].data
|
||||
lidar2ego = data['lidar2ego'].data
|
||||
camera2lidar = data['camera2lidar'].data
|
||||
lidar2image = data['lidar2image'].data
|
||||
|
||||
rots = sensor2ego[..., :3, :3]
|
||||
trans = sensor2ego[..., :3, 3]
|
||||
intrins = cam_intrinsic[..., :3, :3]
|
||||
post_rots = img_aug_matrix[..., :3, :3]
|
||||
post_trans = img_aug_matrix[..., :3, 3]
|
||||
lidar2ego_rots = lidar2ego[..., :3, :3]
|
||||
lidar2ego_trans = lidar2ego[..., :3, 3]
|
||||
camera2lidar_rots = camera2lidar[..., :3, :3]
|
||||
camera2lidar_trans = camera2lidar[..., :3, 3]
|
||||
|
||||
points = data['points'].data
|
||||
img = data['img'].data
|
||||
|
||||
if self.keyframe_only:
|
||||
points = points[points[:, 4] == 0]
|
||||
|
||||
batch_size = len(points)
|
||||
depth = torch.zeros(img.shape[0], *img.shape[-2:]) #.to(points[0].device)
|
||||
|
||||
# for b in range(batch_size):
|
||||
cur_coords = points[:, :3]
|
||||
|
||||
# inverse aug
|
||||
cur_coords -= bev_aug_matrix[:3, 3]
|
||||
cur_coords = torch.inverse(bev_aug_matrix[:3, :3]).matmul(
|
||||
cur_coords.transpose(1, 0)
|
||||
)
|
||||
# lidar2image
|
||||
cur_coords = lidar2image[:, :3, :3].matmul(cur_coords)
|
||||
cur_coords += lidar2image[:, :3, 3].reshape(-1, 3, 1)
|
||||
# get 2d coords
|
||||
dist = cur_coords[:, 2, :]
|
||||
cur_coords[:, 2, :] = torch.clamp(cur_coords[:, 2, :], 1e-5, 1e5)
|
||||
cur_coords[:, :2, :] /= cur_coords[:, 2:3, :]
|
||||
|
||||
# imgaug
|
||||
cur_coords = img_aug_matrix[:, :3, :3].matmul(cur_coords)
|
||||
cur_coords += img_aug_matrix[:, :3, 3].reshape(-1, 3, 1)
|
||||
cur_coords = cur_coords[:, :2, :].transpose(1, 2)
|
||||
|
||||
# normalize coords for grid sample
|
||||
cur_coords = cur_coords[..., [1, 0]]
|
||||
|
||||
on_img = (
|
||||
(cur_coords[..., 0] < img.shape[2])
|
||||
& (cur_coords[..., 0] >= 0)
|
||||
& (cur_coords[..., 1] < img.shape[3])
|
||||
& (cur_coords[..., 1] >= 0)
|
||||
)
|
||||
for c in range(on_img.shape[0]):
|
||||
masked_coords = cur_coords[c, on_img[c]].long()
|
||||
masked_dist = dist[c, on_img[c]]
|
||||
depth[c, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist
|
||||
|
||||
data['depths'] = depth
|
||||
return data
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ImageAug3D:
|
||||
def __init__(
|
||||
|
|
@ -142,6 +214,11 @@ class GlobalRotScaleTrans:
|
|||
data["points"].translate(translation)
|
||||
data["points"].scale(scale)
|
||||
|
||||
if "radar" in data:
|
||||
data["radar"].rotate(-theta)
|
||||
data["radar"].translate(translation)
|
||||
data["radar"].scale(scale)
|
||||
|
||||
gt_boxes = data["gt_bboxes_3d"]
|
||||
rotation = rotation @ gt_boxes.rotate(theta).numpy()
|
||||
gt_boxes.translate(translation)
|
||||
|
|
@ -254,6 +331,8 @@ class RandomFlip3D:
|
|||
rotation = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) @ rotation
|
||||
if "points" in data:
|
||||
data["points"].flip("horizontal")
|
||||
if "radar" in data:
|
||||
data["radar"].flip("horizontal")
|
||||
if "gt_bboxes_3d" in data:
|
||||
data["gt_bboxes_3d"].flip("horizontal")
|
||||
if "gt_masks_bev" in data:
|
||||
|
|
@ -263,6 +342,8 @@ class RandomFlip3D:
|
|||
rotation = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) @ rotation
|
||||
if "points" in data:
|
||||
data["points"].flip("vertical")
|
||||
if "radar" in data:
|
||||
data["radar"].flip("vertical")
|
||||
if "gt_bboxes_3d" in data:
|
||||
data["gt_bboxes_3d"].flip("vertical")
|
||||
if "gt_masks_bev" in data:
|
||||
|
|
@ -522,6 +603,14 @@ class PointsRangeFilter:
|
|||
points_mask = points.in_range_3d(self.pcd_range)
|
||||
clean_points = points[points_mask]
|
||||
data["points"] = clean_points
|
||||
|
||||
if "radar" in data:
|
||||
radar = data["radar"]
|
||||
# radar_mask = radar.in_range_3d(self.pcd_range)
|
||||
radar_mask = radar.in_range_bev([-55.0, -55.0, 55.0, 55.0])
|
||||
clean_radar = radar[radar_mask]
|
||||
data["radar"] = clean_radar
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,4 +5,5 @@ from .second import *
|
|||
from .sparse_encoder import *
|
||||
from .pillar_encoder import *
|
||||
from .vovnet import *
|
||||
from .dla import *
|
||||
from .dla import *
|
||||
from .radar_encoder import *
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
from torch import nn
|
||||
|
||||
from typing import Any, Dict
|
||||
from functools import cached_property
|
||||
|
||||
import torch
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmcv.cnn.resnet import make_res_layer, BasicBlock
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmdet3d.models.builder import build_backbone
|
||||
from mmdet.models import BACKBONES
|
||||
from torchvision.utils import save_image
|
||||
from mmdet3d.ops import feature_decorator
|
||||
from mmcv.cnn.bricks.non_local import NonLocal2d
|
||||
|
||||
from flash_attn.flash_attention import FlashMHA
|
||||
|
||||
|
||||
__all__ = ["RadarFeatureNet", "RadarEncoder"]
|
||||
|
||||
|
||||
def get_paddings_indicator(actual_num, max_num, axis=0):
|
||||
"""Create boolean mask by actually number of a padded tensor.
|
||||
Args:
|
||||
actual_num ([type]): [description]
|
||||
max_num ([type]): [description]
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
|
||||
actual_num = torch.unsqueeze(actual_num, axis + 1)
|
||||
# tiled_actual_num: [N, M, 1]
|
||||
max_num_shape = [1] * len(actual_num.shape)
|
||||
max_num_shape[axis + 1] = -1
|
||||
max_num = torch.arange(max_num, dtype=torch.int, device=actual_num.device).view(
|
||||
max_num_shape
|
||||
)
|
||||
# tiled_actual_num: [[3,3,3,3,3], [4,4,4,4,4], [2,2,2,2,2]]
|
||||
# tiled_max_num: [[0,1,2,3,4], [0,1,2,3,4], [0,1,2,3,4]]
|
||||
paddings_indicator = actual_num.int() > max_num
|
||||
# paddings_indicator shape: [batch_size, max_num]
|
||||
return paddings_indicator
|
||||
|
||||
|
||||
class RFNLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, norm_cfg=None, last_layer=False):
|
||||
"""
|
||||
Pillar Feature Net Layer.
|
||||
The Pillar Feature Net could be composed of a series of these layers, but the PointPillars paper results only
|
||||
used a single PFNLayer. This layer performs a similar role as second.pytorch.voxelnet.VFELayer.
|
||||
:param in_channels: <int>. Number of input channels.
|
||||
:param out_channels: <int>. Number of output channels.
|
||||
:param last_layer: <bool>. If last_layer, there is no concatenation of features.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.name = "RFNLayer"
|
||||
self.last_vfe = last_layer
|
||||
|
||||
self.units = out_channels
|
||||
|
||||
if norm_cfg is None:
|
||||
norm_cfg = dict(type="BN1d", eps=1e-3, momentum=0.01)
|
||||
self.norm_cfg = norm_cfg
|
||||
|
||||
self.linear = nn.Linear(in_channels, self.units, bias=False)
|
||||
self.norm = build_norm_layer(self.norm_cfg, self.units)[1]
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.linear(inputs)
|
||||
torch.backends.cudnn.enabled = False
|
||||
x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
|
||||
torch.backends.cudnn.enabled = True
|
||||
x = F.relu(x)
|
||||
|
||||
if self.last_vfe:
|
||||
x_max = torch.max(x, dim=1, keepdim=True)[0]
|
||||
return x_max
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class RadarFeatureNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
feat_channels=(64,),
|
||||
with_distance=False,
|
||||
voxel_size=(0.2, 0.2, 4),
|
||||
point_cloud_range=(0, -40, -3, 70.4, 40, 1),
|
||||
norm_cfg=None,
|
||||
):
|
||||
"""
|
||||
Pillar Feature Net.
|
||||
The network prepares the pillar features and performs forward pass through PFNLayers. This net performs a
|
||||
similar role to SECOND's second.pytorch.voxelnet.VoxelFeatureExtractor.
|
||||
:param num_input_features: <int>. Number of input features, either x, y, z or x, y, z, r.
|
||||
:param num_filters: (<int>: N). Number of features in each of the N PFNLayers.
|
||||
:param with_distance: <bool>. Whether to include Euclidean distance to points.
|
||||
:param voxel_size: (<float>: 3). Size of voxels, only utilize x and y size.
|
||||
:param pc_range: (<float>: 6). Point cloud range, only utilize x and y min.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.name = "RadarFeatureNet"
|
||||
assert len(feat_channels) > 0
|
||||
|
||||
self.in_channels = in_channels
|
||||
in_channels += 2
|
||||
# in_channels += 5
|
||||
self._with_distance = with_distance
|
||||
self.export_onnx = False
|
||||
|
||||
# Create PillarFeatureNet layers
|
||||
feat_channels = [in_channels] + list(feat_channels)
|
||||
rfn_layers = []
|
||||
for i in range(len(feat_channels) - 1):
|
||||
in_filters = feat_channels[i]
|
||||
out_filters = feat_channels[i + 1]
|
||||
if i < len(feat_channels) - 2:
|
||||
last_layer = False
|
||||
else:
|
||||
last_layer = True
|
||||
rfn_layers.append(
|
||||
RFNLayer(
|
||||
in_filters, out_filters, norm_cfg=norm_cfg, last_layer=last_layer
|
||||
)
|
||||
)
|
||||
self.rfn_layers = nn.ModuleList(rfn_layers)
|
||||
|
||||
# Need pillar (voxel) size and x/y offset in order to calculate pillar offset
|
||||
self.vx = voxel_size[0]
|
||||
self.vy = voxel_size[1]
|
||||
self.x_offset = self.vx / 2 + point_cloud_range[0]
|
||||
self.y_offset = self.vy / 2 + point_cloud_range[1]
|
||||
self.pc_range = point_cloud_range
|
||||
|
||||
def forward(self, features, num_voxels, coors):
|
||||
|
||||
if not self.export_onnx:
|
||||
dtype = features.dtype
|
||||
|
||||
# Find distance of x, y, and z from cluster center
|
||||
points_mean = features[:, :, :3].sum(dim=1, keepdim=True) / num_voxels.type_as(
|
||||
features
|
||||
).view(-1, 1, 1)
|
||||
f_cluster = features[:, :, :3] - points_mean
|
||||
|
||||
f_center = torch.zeros_like(features[:, :, :2])
|
||||
f_center[:, :, 0] = features[:, :, 0] - (
|
||||
coors[:, 1].to(dtype).unsqueeze(1) * self.vx + self.x_offset
|
||||
)
|
||||
f_center[:, :, 1] = features[:, :, 1] - (
|
||||
coors[:, 2].to(dtype).unsqueeze(1) * self.vy + self.y_offset
|
||||
)
|
||||
|
||||
# print(self.pc_range) [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
|
||||
# normalize x,y,z to [0, 1]
|
||||
features[:, :, 0:1] = (features[:, :, 0:1] - self.pc_range[0]) / (self.pc_range[3] - self.pc_range[0])
|
||||
features[:, :, 1:2] = (features[:, :, 1:2] - self.pc_range[1]) / (self.pc_range[4] - self.pc_range[1])
|
||||
features[:, :, 2:3] = (features[:, :, 2:3] - self.pc_range[2]) / (self.pc_range[5] - self.pc_range[2])
|
||||
|
||||
# Combine together feature decorations
|
||||
features_ls = [features, f_center]
|
||||
features = torch.cat(features_ls, dim=-1)
|
||||
|
||||
# The feature decorations were calculated without regard to whether pillar was empty. Need to ensure that
|
||||
# empty pillars remain set to zeros.
|
||||
voxel_count = features.shape[1]
|
||||
mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
|
||||
mask = torch.unsqueeze(mask, -1).type_as(features)
|
||||
features *= mask
|
||||
features = torch.nan_to_num(features)
|
||||
else:
|
||||
features = feature_decorator(features, num_voxels, coors, self.vx, self.vy, self.x_offset, self.y_offset, True, False, True)
|
||||
|
||||
# Forward pass through PFNLayers
|
||||
for rfn in self.rfn_layers:
|
||||
features = rfn(features)
|
||||
|
||||
return features.squeeze()
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class RadarEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
pts_voxel_encoder: Dict[str, Any],
|
||||
pts_middle_encoder: Dict[str, Any],
|
||||
pts_transformer_encoder=None,
|
||||
pts_bev_encoder=None,
|
||||
post_scatter=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.pts_voxel_encoder = build_backbone(pts_voxel_encoder)
|
||||
self.pts_middle_encoder = build_backbone(pts_middle_encoder)
|
||||
self.pts_transformer_encoder = build_backbone(pts_transformer_encoder) if pts_transformer_encoder is not None else None
|
||||
self.pts_bev_encoder = build_backbone(pts_bev_encoder) if pts_bev_encoder is not None else None
|
||||
self.post_scatter = build_backbone(post_scatter) if post_scatter is not None else None
|
||||
|
||||
def forward(self, feats, coords, batch_size, sizes, img_features=None):
|
||||
x = self.pts_voxel_encoder(feats, sizes, coords)
|
||||
|
||||
if self.pts_transformer_encoder is not None:
|
||||
x = self.pts_transformer_encoder(x, sizes, coords, batch_size)
|
||||
|
||||
x = self.pts_middle_encoder(x, coords, batch_size)
|
||||
|
||||
if self.post_scatter is not None:
|
||||
x = self.post_scatter(x, img_features)
|
||||
|
||||
if self.pts_bev_encoder is not None:
|
||||
x = self.pts_bev_encoder(x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
def visualize_pillars(self, feats, coords, sizes):
|
||||
nx, ny = 128, 128
|
||||
canvas = torch.zeros(
|
||||
nx*ny, dtype=sizes.dtype, device=sizes.device
|
||||
)
|
||||
indices = coords[:, 1] * ny + coords[:, 2]
|
||||
indices = indices.type(torch.long)
|
||||
canvas[indices] = sizes
|
||||
torch.save(canvas, 'sample_canvas')
|
||||
|
|
@ -55,6 +55,19 @@ class BEVFusion(Base3DFusionModel):
|
|||
)
|
||||
self.voxelize_reduce = encoders["lidar"].get("voxelize_reduce", True)
|
||||
|
||||
if encoders.get("radar") is not None:
|
||||
if encoders["radar"]["voxelize"].get("max_num_points", -1) > 0:
|
||||
voxelize_module = Voxelization(**encoders["radar"]["voxelize"])
|
||||
else:
|
||||
voxelize_module = DynamicScatter(**encoders["radar"]["voxelize"])
|
||||
self.encoders["radar"] = nn.ModuleDict(
|
||||
{
|
||||
"voxelize": voxelize_module,
|
||||
"backbone": build_backbone(encoders["radar"]["backbone"]),
|
||||
}
|
||||
)
|
||||
self.voxelize_reduce = encoders["radar"].get("voxelize_reduce", True)
|
||||
|
||||
if fuser is not None:
|
||||
self.fuser = build_fuser(fuser)
|
||||
else:
|
||||
|
|
@ -79,6 +92,10 @@ class BEVFusion(Base3DFusionModel):
|
|||
if heads[name] is not None:
|
||||
self.loss_scale[name] = 1.0
|
||||
|
||||
# If the camera's vtransform is a BEVDepth version, then we're using depth loss.
|
||||
self.use_depth_loss = ((encoders.get('camera', {}) or {}).get('vtransform', {}) or {}).get('type', '') in ['BEVDepth', 'AwareBEVDepth', 'DBEVDepth', 'AwareDBEVDepth']
|
||||
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self) -> None:
|
||||
|
|
@ -89,6 +106,7 @@ class BEVFusion(Base3DFusionModel):
|
|||
self,
|
||||
x,
|
||||
points,
|
||||
radar_points,
|
||||
camera2ego,
|
||||
lidar2ego,
|
||||
lidar2camera,
|
||||
|
|
@ -98,6 +116,7 @@ class BEVFusion(Base3DFusionModel):
|
|||
img_aug_matrix,
|
||||
lidar_aug_matrix,
|
||||
img_metas,
|
||||
gt_depths=None,
|
||||
) -> torch.Tensor:
|
||||
B, N, C, H, W = x.size()
|
||||
x = x.view(B * N, C, H, W)
|
||||
|
|
@ -114,6 +133,7 @@ class BEVFusion(Base3DFusionModel):
|
|||
x = self.encoders["camera"]["vtransform"](
|
||||
x,
|
||||
points,
|
||||
radar_points,
|
||||
camera2ego,
|
||||
lidar2ego,
|
||||
lidar2camera,
|
||||
|
|
@ -123,21 +143,35 @@ class BEVFusion(Base3DFusionModel):
|
|||
img_aug_matrix,
|
||||
lidar_aug_matrix,
|
||||
img_metas,
|
||||
depth_loss=self.use_depth_loss,
|
||||
gt_depths=gt_depths,
|
||||
)
|
||||
return x
|
||||
|
||||
def extract_lidar_features(self, x) -> torch.Tensor:
|
||||
feats, coords, sizes = self.voxelize(x)
|
||||
|
||||
def extract_features(self, x, sensor) -> torch.Tensor:
|
||||
feats, coords, sizes = self.voxelize(x, sensor)
|
||||
batch_size = coords[-1, 0] + 1
|
||||
x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes)
|
||||
x = self.encoders[sensor]["backbone"](feats, coords, batch_size, sizes=sizes)
|
||||
return x
|
||||
|
||||
# def extract_lidar_features(self, x) -> torch.Tensor:
|
||||
# feats, coords, sizes = self.voxelize(x)
|
||||
# batch_size = coords[-1, 0] + 1
|
||||
# x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes)
|
||||
# return x
|
||||
|
||||
# def extract_radar_features(self, x) -> torch.Tensor:
|
||||
# feats, coords, sizes = self.radar_voxelize(x)
|
||||
# batch_size = coords[-1, 0] + 1
|
||||
# x = self.encoders["radar"]["backbone"](feats, coords, batch_size, sizes=sizes)
|
||||
# return x
|
||||
|
||||
@torch.no_grad()
|
||||
@force_fp32()
|
||||
def voxelize(self, points):
|
||||
def voxelize(self, points, sensor):
|
||||
feats, coords, sizes = [], [], []
|
||||
for k, res in enumerate(points):
|
||||
ret = self.encoders["lidar"]["voxelize"](res)
|
||||
ret = self.encoders[sensor]["voxelize"](res)
|
||||
if len(ret) == 3:
|
||||
# hard voxelize
|
||||
f, c, n = ret
|
||||
|
|
@ -162,6 +196,36 @@ class BEVFusion(Base3DFusionModel):
|
|||
|
||||
return feats, coords, sizes
|
||||
|
||||
# @torch.no_grad()
|
||||
# @force_fp32()
|
||||
# def radar_voxelize(self, points):
|
||||
# feats, coords, sizes = [], [], []
|
||||
# for k, res in enumerate(points):
|
||||
# ret = self.encoders["radar"]["voxelize"](res)
|
||||
# if len(ret) == 3:
|
||||
# # hard voxelize
|
||||
# f, c, n = ret
|
||||
# else:
|
||||
# assert len(ret) == 2
|
||||
# f, c = ret
|
||||
# n = None
|
||||
# feats.append(f)
|
||||
# coords.append(F.pad(c, (1, 0), mode="constant", value=k))
|
||||
# if n is not None:
|
||||
# sizes.append(n)
|
||||
|
||||
# feats = torch.cat(feats, dim=0)
|
||||
# coords = torch.cat(coords, dim=0)
|
||||
# if len(sizes) > 0:
|
||||
# sizes = torch.cat(sizes, dim=0)
|
||||
# if self.voxelize_reduce:
|
||||
# feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
|
||||
# -1, 1
|
||||
# )
|
||||
# feats = feats.contiguous()
|
||||
|
||||
# return feats, coords, sizes
|
||||
|
||||
@auto_fp16(apply_to=("img", "points"))
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -176,6 +240,8 @@ class BEVFusion(Base3DFusionModel):
|
|||
img_aug_matrix,
|
||||
lidar_aug_matrix,
|
||||
metas,
|
||||
depths,
|
||||
radar=None,
|
||||
gt_masks_bev=None,
|
||||
gt_bboxes_3d=None,
|
||||
gt_labels_3d=None,
|
||||
|
|
@ -196,6 +262,8 @@ class BEVFusion(Base3DFusionModel):
|
|||
img_aug_matrix,
|
||||
lidar_aug_matrix,
|
||||
metas,
|
||||
depths,
|
||||
radar,
|
||||
gt_masks_bev,
|
||||
gt_bboxes_3d,
|
||||
gt_labels_3d,
|
||||
|
|
@ -217,12 +285,15 @@ class BEVFusion(Base3DFusionModel):
|
|||
img_aug_matrix,
|
||||
lidar_aug_matrix,
|
||||
metas,
|
||||
depths=None,
|
||||
radar=None,
|
||||
gt_masks_bev=None,
|
||||
gt_bboxes_3d=None,
|
||||
gt_labels_3d=None,
|
||||
**kwargs,
|
||||
):
|
||||
features = []
|
||||
auxiliary_losses = {}
|
||||
for sensor in (
|
||||
self.encoders if self.training else list(self.encoders.keys())[::-1]
|
||||
):
|
||||
|
|
@ -230,6 +301,7 @@ class BEVFusion(Base3DFusionModel):
|
|||
feature = self.extract_camera_features(
|
||||
img,
|
||||
points,
|
||||
radar,
|
||||
camera2ego,
|
||||
lidar2ego,
|
||||
lidar2camera,
|
||||
|
|
@ -239,11 +311,17 @@ class BEVFusion(Base3DFusionModel):
|
|||
img_aug_matrix,
|
||||
lidar_aug_matrix,
|
||||
metas,
|
||||
gt_depths=depths,
|
||||
)
|
||||
if self.use_depth_loss:
|
||||
feature, auxiliary_losses['depth'] = feature[0], feature[-1]
|
||||
elif sensor == "lidar":
|
||||
feature = self.extract_lidar_features(points)
|
||||
feature = self.extract_features(points, sensor)
|
||||
elif sensor == "radar":
|
||||
feature = self.extract_features(radar, sensor)
|
||||
else:
|
||||
raise ValueError(f"unsupported sensor: {sensor}")
|
||||
|
||||
features.append(feature)
|
||||
|
||||
if not self.training:
|
||||
|
|
@ -276,6 +354,11 @@ class BEVFusion(Base3DFusionModel):
|
|||
outputs[f"loss/{type}/{name}"] = val * self.loss_scale[type]
|
||||
else:
|
||||
outputs[f"stats/{type}/{name}"] = val
|
||||
if self.use_depth_loss:
|
||||
if 'depth' in auxiliary_losses:
|
||||
outputs["loss/depth"] = auxiliary_losses['depth']
|
||||
else:
|
||||
raise ValueError('Use depth loss is true, but depth loss not found')
|
||||
return outputs
|
||||
else:
|
||||
outputs = [{} for _ in range(batch_size)]
|
||||
|
|
@ -303,3 +386,4 @@ class BEVFusion(Base3DFusionModel):
|
|||
else:
|
||||
raise ValueError(f"unsupported head: {type}")
|
||||
return outputs
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from .lss import *
|
||||
from .depth_lss import *
|
||||
from .aware_bevdepth import *
|
||||
|
|
@ -0,0 +1,698 @@
|
|||
from typing import Tuple
|
||||
|
||||
from mmcv.cnn import build_conv_layer
|
||||
from mmcv.runner import force_fp32
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
||||
from mmdet3d.models.builder import VTRANSFORMS
|
||||
from mmdet.models.backbones.resnet import BasicBlock
|
||||
|
||||
from .base import BaseTransform, BaseDepthTransform
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["AwareBEVDepth"]
|
||||
|
||||
|
||||
class DepthRefinement(nn.Module):
|
||||
"""
|
||||
pixel cloud feature extraction
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, mid_channels, out_channels):
|
||||
super(DepthRefinement, self).__init__()
|
||||
|
||||
self.reduce_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(mid_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(mid_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.out_conv = nn.Sequential(
|
||||
nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True),
|
||||
# nn.BatchNorm3d(out_channels),
|
||||
# nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
@autocast(False)
|
||||
def forward(self, x):
|
||||
x = self.reduce_conv(x)
|
||||
x = self.conv(x) + x
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class _ASPPModule(nn.Module):
|
||||
def __init__(self, inplanes, planes, kernel_size, padding, dilation,
|
||||
BatchNorm):
|
||||
super(_ASPPModule, self).__init__()
|
||||
self.atrous_conv = nn.Conv2d(inplanes,
|
||||
planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
self.bn = BatchNorm(planes)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self._init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.atrous_conv(x)
|
||||
x = self.bn(x)
|
||||
|
||||
return self.relu(x)
|
||||
|
||||
def _init_weight(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
torch.nn.init.kaiming_normal_(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
class ASPP(nn.Module):
|
||||
def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d):
|
||||
super(ASPP, self).__init__()
|
||||
|
||||
dilations = [1, 6, 12, 18]
|
||||
|
||||
self.aspp1 = _ASPPModule(inplanes,
|
||||
mid_channels,
|
||||
1,
|
||||
padding=0,
|
||||
dilation=dilations[0],
|
||||
BatchNorm=BatchNorm)
|
||||
self.aspp2 = _ASPPModule(inplanes,
|
||||
mid_channels,
|
||||
3,
|
||||
padding=dilations[1],
|
||||
dilation=dilations[1],
|
||||
BatchNorm=BatchNorm)
|
||||
self.aspp3 = _ASPPModule(inplanes,
|
||||
mid_channels,
|
||||
3,
|
||||
padding=dilations[2],
|
||||
dilation=dilations[2],
|
||||
BatchNorm=BatchNorm)
|
||||
self.aspp4 = _ASPPModule(inplanes,
|
||||
mid_channels,
|
||||
3,
|
||||
padding=dilations[3],
|
||||
dilation=dilations[3],
|
||||
BatchNorm=BatchNorm)
|
||||
|
||||
self.global_avg_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Conv2d(inplanes, mid_channels, 1, stride=1, bias=False),
|
||||
BatchNorm(mid_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.conv1 = nn.Conv2d(int(mid_channels * 5),
|
||||
mid_channels,
|
||||
1,
|
||||
bias=False)
|
||||
self.bn1 = BatchNorm(mid_channels)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self._init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x2 = self.aspp2(x)
|
||||
x3 = self.aspp3(x)
|
||||
x4 = self.aspp4(x)
|
||||
x5 = self.global_avg_pool(x)
|
||||
x5 = F.interpolate(x5,
|
||||
size=x4.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=True)
|
||||
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
def _init_weight(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
torch.nn.init.kaiming_normal_(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.ReLU,
|
||||
drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop)
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop2 = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
|
||||
super().__init__()
|
||||
self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True)
|
||||
self.act1 = act_layer()
|
||||
self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True)
|
||||
self.gate = gate_layer()
|
||||
|
||||
def forward(self, x, x_se):
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
return x * self.gate(x_se)
|
||||
|
||||
class DepthNet(nn.Module):
|
||||
def __init__(self, in_channels, mid_channels, context_channels,
|
||||
depth_channels):
|
||||
super(DepthNet, self).__init__()
|
||||
self.reduce_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
nn.BatchNorm2d(mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.context_conv = nn.Conv2d(mid_channels,
|
||||
context_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.bn = nn.BatchNorm1d(27)
|
||||
self.depth_mlp = Mlp(27, mid_channels, mid_channels)
|
||||
self.depth_se = SELayer(mid_channels) # NOTE: add camera-aware
|
||||
self.context_mlp = Mlp(27, mid_channels, mid_channels)
|
||||
self.context_se = SELayer(mid_channels) # NOTE: add camera-aware
|
||||
self.depth_conv_1 = nn.Sequential(
|
||||
BasicBlock(mid_channels, mid_channels),
|
||||
BasicBlock(mid_channels, mid_channels),
|
||||
BasicBlock(mid_channels, mid_channels),
|
||||
)
|
||||
self.depth_conv_2 = nn.Sequential(
|
||||
ASPP(mid_channels, mid_channels),
|
||||
build_conv_layer(cfg=dict(
|
||||
type='Conv2d',
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)),
|
||||
nn.BatchNorm2d(mid_channels),
|
||||
)
|
||||
self.depth_conv_3 = nn.Sequential(
|
||||
nn.Conv2d(mid_channels,
|
||||
depth_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0),
|
||||
nn.BatchNorm2d(depth_channels),
|
||||
)
|
||||
self.export = False
|
||||
|
||||
def export_mode(self):
|
||||
self.export = True
|
||||
|
||||
@force_fp32()
|
||||
def forward(self, x, mats_dict):
|
||||
intrins = mats_dict['intrin_mats'][:, ..., :3, :3]
|
||||
batch_size = intrins.shape[0]
|
||||
num_cams = intrins.shape[1]
|
||||
ida = mats_dict['ida_mats'][:, ...]
|
||||
sensor2ego = mats_dict['sensor2ego_mats'][:, ..., :3, :]
|
||||
bda = mats_dict['bda_mat'].view(batch_size, 1, 4, 4).repeat(1, num_cams, 1, 1)
|
||||
|
||||
# If exporting, cache the MLP input, since it's based on
|
||||
# intrinsics and data augmentation, which are constant at inference time.
|
||||
if not hasattr(self, 'mlp_input') or not self.export:
|
||||
mlp_input = torch.cat(
|
||||
[
|
||||
torch.stack(
|
||||
[
|
||||
intrins[:, ..., 0, 0],
|
||||
intrins[:, ..., 1, 1],
|
||||
intrins[:, ..., 0, 2],
|
||||
intrins[:, ..., 1, 2],
|
||||
ida[:, ..., 0, 0],
|
||||
ida[:, ..., 0, 1],
|
||||
ida[:, ..., 0, 3],
|
||||
ida[:, ..., 1, 0],
|
||||
ida[:, ..., 1, 1],
|
||||
ida[:, ..., 1, 3],
|
||||
bda[:, ..., 0, 0],
|
||||
bda[:, ..., 0, 1],
|
||||
bda[:, ..., 1, 0],
|
||||
bda[:, ..., 1, 1],
|
||||
bda[:, ..., 2, 2],
|
||||
],
|
||||
dim=-1,
|
||||
),
|
||||
sensor2ego.view(batch_size, num_cams, -1),
|
||||
],
|
||||
-1,
|
||||
)
|
||||
self.mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1]))
|
||||
|
||||
|
||||
x = self.reduce_conv(x)
|
||||
context_se = self.context_mlp(self.mlp_input)[..., None, None]
|
||||
context = self.context_se(x, context_se)
|
||||
context = self.context_conv(context)
|
||||
depth_se = self.depth_mlp(self.mlp_input)[..., None, None]
|
||||
depth = self.depth_se(x, depth_se)
|
||||
depth = self.depth_conv_1(depth)
|
||||
depth = self.depth_conv_2(depth)
|
||||
depth = self.depth_conv_3(depth)
|
||||
|
||||
return torch.cat([depth, context], dim=1)
|
||||
|
||||
|
||||
@VTRANSFORMS.register_module()
|
||||
class AwareBEVDepth(BaseTransform):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
image_size: Tuple[int, int],
|
||||
feature_size: Tuple[int, int],
|
||||
xbound: Tuple[float, float, float],
|
||||
ybound: Tuple[float, float, float],
|
||||
zbound: Tuple[float, float, float],
|
||||
dbound: Tuple[float, float, float],
|
||||
use_points = 'lidar',
|
||||
downsample: int = 1,
|
||||
bevdepth_downsample: int = 16,
|
||||
bevdepth_refine: bool = True,
|
||||
depth_loss_factor: float = 3.0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
image_size=image_size,
|
||||
feature_size=feature_size,
|
||||
xbound=xbound,
|
||||
ybound=ybound,
|
||||
zbound=zbound,
|
||||
dbound=dbound,
|
||||
use_points=use_points,
|
||||
)
|
||||
self.depth_loss_factor = depth_loss_factor
|
||||
self.downsample_factor = bevdepth_downsample
|
||||
self.bevdepth_refine = bevdepth_refine
|
||||
if self.bevdepth_refine:
|
||||
self.refinement = DepthRefinement(self.C, self.C, self.C)
|
||||
|
||||
self.depth_channels = self.frustum.shape[0]
|
||||
|
||||
mid_channels = in_channels
|
||||
self.depthnet = DepthNet(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
self.C,
|
||||
self.D
|
||||
)
|
||||
|
||||
|
||||
if downsample > 1:
|
||||
assert downsample == 2, downsample
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=downsample,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
def export_mode(self):
|
||||
super().export_mode()
|
||||
self.depthnet.export_mode()
|
||||
|
||||
@force_fp32()
|
||||
def get_cam_feats(self, x, mats_dict):
|
||||
B, N, C, fH, fW = x.shape
|
||||
|
||||
x = x.view(B * N, C, fH, fW)
|
||||
|
||||
x = self.depthnet(x, mats_dict)
|
||||
depth = x[:, : self.D].softmax(dim=1)
|
||||
|
||||
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
|
||||
|
||||
if self.bevdepth_refine:
|
||||
x = x.permute(0, 3, 1, 4, 2).contiguous() # [n, c, d, h, w] -> [n, h, c, w, d]
|
||||
n, h, c, w, d = x.shape
|
||||
x = x.view(-1, c, w, d)
|
||||
x = self.refinement(x)
|
||||
x = x.view(n, h, c, w, d).permute(0, 2, 4, 1, 3).contiguous().float()
|
||||
|
||||
x = x.view(B, N, self.C, self.D, fH, fW)
|
||||
x = x.permute(0, 1, 3, 4, 5, 2)
|
||||
return x, depth
|
||||
|
||||
|
||||
def get_depth_loss(self, depth_labels, depth_preds):
|
||||
if len(depth_labels.shape) == 5:
|
||||
# only key-frame will calculate depth loss
|
||||
depth_labels = depth_labels[:, 0, ...]
|
||||
|
||||
depth_labels = self.get_downsampled_gt_depth(depth_labels)
|
||||
depth_preds = depth_preds.permute(0, 2, 3, 1).contiguous().view(
|
||||
-1, self.depth_channels)
|
||||
fg_mask = torch.max(depth_labels, dim=1).values > 0.0
|
||||
|
||||
with autocast(enabled=False):
|
||||
depth_loss = (F.binary_cross_entropy(
|
||||
depth_preds[fg_mask],
|
||||
depth_labels[fg_mask],
|
||||
reduction='none',
|
||||
).sum() / max(1.0, fg_mask.sum()))
|
||||
|
||||
return self.depth_loss_factor * depth_loss
|
||||
|
||||
def get_downsampled_gt_depth(self, gt_depths):
|
||||
"""
|
||||
Input:
|
||||
gt_depths: [B, N, H, W]
|
||||
Output:
|
||||
gt_depths: [B*N*h*w, d]
|
||||
"""
|
||||
B, N, H, W = gt_depths.shape
|
||||
gt_depths = gt_depths.view(
|
||||
B * N,
|
||||
H // self.downsample_factor,
|
||||
self.downsample_factor,
|
||||
W // self.downsample_factor,
|
||||
self.downsample_factor,
|
||||
1,
|
||||
)
|
||||
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||
gt_depths = gt_depths.view(
|
||||
-1, self.downsample_factor * self.downsample_factor)
|
||||
gt_depths_tmp = torch.where(gt_depths == 0.0,
|
||||
1e5 * torch.ones_like(gt_depths),
|
||||
gt_depths)
|
||||
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
|
||||
gt_depths = gt_depths.view(B * N, H // self.downsample_factor,
|
||||
W // self.downsample_factor)
|
||||
|
||||
gt_depths = (gt_depths -
|
||||
(self.dbound[0] - self.dbound[2])) / self.dbound[2]
|
||||
gt_depths = torch.where(
|
||||
(gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0),
|
||||
gt_depths, torch.zeros_like(gt_depths))
|
||||
gt_depths = F.one_hot(gt_depths.long(),
|
||||
num_classes=self.depth_channels + 1).view(
|
||||
-1, self.depth_channels + 1)[:, 1:]
|
||||
|
||||
return gt_depths.float()
|
||||
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = super().forward(*args, **kwargs)
|
||||
x, depth_pred = x[0], x[-1]
|
||||
x = self.downsample(x)
|
||||
if kwargs.get('depth_loss', False):
|
||||
# print(kwargs['gt_depths'])
|
||||
depth_loss = self.get_depth_loss(kwargs['gt_depths'], depth_pred)
|
||||
return x, depth_loss
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
@VTRANSFORMS.register_module()
|
||||
class AwareDBEVDepth(BaseDepthTransform):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
image_size: Tuple[int, int],
|
||||
feature_size: Tuple[int, int],
|
||||
xbound: Tuple[float, float, float],
|
||||
ybound: Tuple[float, float, float],
|
||||
zbound: Tuple[float, float, float],
|
||||
dbound: Tuple[float, float, float],
|
||||
use_points = 'lidar',
|
||||
depth_input = 'scalar',
|
||||
height_expand = False,
|
||||
downsample: int = 1,
|
||||
bevdepth_downsample: int = 16,
|
||||
bevdepth_refine: bool = True,
|
||||
depth_loss_factor: float = 3.0,
|
||||
add_depth_features = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
image_size=image_size,
|
||||
feature_size=feature_size,
|
||||
xbound=xbound,
|
||||
ybound=ybound,
|
||||
zbound=zbound,
|
||||
dbound=dbound,
|
||||
use_points=use_points,
|
||||
depth_input=depth_input,
|
||||
height_expand=height_expand,
|
||||
add_depth_features=add_depth_features,
|
||||
)
|
||||
self.depth_loss_factor = depth_loss_factor
|
||||
self.downsample_factor = bevdepth_downsample
|
||||
self.bevdepth_refine = bevdepth_refine
|
||||
if self.bevdepth_refine:
|
||||
self.refinement = DepthRefinement(self.C, self.C, self.C)
|
||||
|
||||
self.depth_channels = self.frustum.shape[0]
|
||||
|
||||
mid_channels = in_channels
|
||||
self.depthnet = DepthNet(
|
||||
in_channels+64,
|
||||
mid_channels,
|
||||
self.C,
|
||||
self.D
|
||||
)
|
||||
|
||||
dtransform_in_channels = 1 if depth_input=='scalar' else self.D
|
||||
if self.add_depth_features:
|
||||
dtransform_in_channels += 45
|
||||
|
||||
if depth_input == 'scalar':
|
||||
self.dtransform = nn.Sequential(
|
||||
nn.Conv2d(dtransform_in_channels, 8, 1),
|
||||
nn.BatchNorm2d(8),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(8, 32, 5, stride=4, padding=2),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 64, 5, stride=2, padding=2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(64, 64, 5, stride=2, padding=2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
else:
|
||||
self.dtransform = nn.Sequential(
|
||||
nn.Conv2d(dtransform_in_channels, 32, 1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 32, 5, stride=4, padding=2),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 64, 5, stride=2, padding=2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(64, 64, 5, stride=2, padding=2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
|
||||
if downsample > 1:
|
||||
assert downsample == 2, downsample
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=downsample,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
@force_fp32()
|
||||
def get_cam_feats(self, x, d, mats_dict):
|
||||
B, N, C, fH, fW = x.shape
|
||||
|
||||
d = d.view(B * N, *d.shape[2:])
|
||||
x = x.view(B * N, C, fH, fW)
|
||||
|
||||
d = self.dtransform(d)
|
||||
|
||||
x = torch.cat([d, x], dim=1)
|
||||
x = self.depthnet(x, mats_dict)
|
||||
depth = x[:, : self.D].softmax(dim=1)
|
||||
|
||||
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
|
||||
|
||||
if self.bevdepth_refine:
|
||||
x = x.permute(0, 3, 1, 4, 2).contiguous() # [n, c, d, h, w] -> [n, h, c, w, d]
|
||||
n, h, c, w, d = x.shape
|
||||
x = x.view(-1, c, w, d)
|
||||
x = self.refinement(x)
|
||||
x = x.view(n, h, c, w, d).permute(0, 2, 4, 1, 3).contiguous().float()
|
||||
|
||||
# Here, x.shape is [num_cams, num_channels, depth_bins, downsampled_height, downsampled_width]
|
||||
x = x.view(B, N, self.C, self.D, fH, fW)
|
||||
x = x.permute(0, 1, 3, 4, 5, 2)
|
||||
return x, depth
|
||||
|
||||
def export_mode(self):
|
||||
super().export_mode()
|
||||
self.depthnet.export_mode()
|
||||
|
||||
def get_depth_loss(self, depth_labels, depth_preds):
|
||||
# if len(depth_labels.shape) == 5:
|
||||
# # only key-frame will calculate depth loss
|
||||
# depth_labels = depth_labels[:, 0, ...]
|
||||
|
||||
depth_labels = self.get_downsampled_gt_depth(depth_labels)
|
||||
depth_preds = depth_preds.permute(0, 2, 3, 1).contiguous().view(
|
||||
-1, self.depth_channels)
|
||||
fg_mask = torch.max(depth_labels, dim=1).values > 0.0
|
||||
|
||||
with autocast(enabled=False):
|
||||
depth_loss = (F.binary_cross_entropy(
|
||||
depth_preds[fg_mask],
|
||||
depth_labels[fg_mask],
|
||||
reduction='none',
|
||||
).sum() / max(1.0, fg_mask.sum()))
|
||||
|
||||
return self.depth_loss_factor * depth_loss
|
||||
|
||||
def get_downsampled_gt_depth(self, gt_depths):
|
||||
"""
|
||||
Input:
|
||||
gt_depths: [B, N, H, W]
|
||||
Output:
|
||||
gt_depths: [B*N*h*w, d]
|
||||
"""
|
||||
B, N, H, W = gt_depths.shape
|
||||
gt_depths = gt_depths.view(
|
||||
B * N,
|
||||
H // self.downsample_factor,
|
||||
self.downsample_factor,
|
||||
W // self.downsample_factor,
|
||||
self.downsample_factor,
|
||||
1,
|
||||
)
|
||||
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||
gt_depths = gt_depths.view(
|
||||
-1, self.downsample_factor * self.downsample_factor)
|
||||
gt_depths_tmp = torch.where(gt_depths == 0.0,
|
||||
1e5 * torch.ones_like(gt_depths),
|
||||
gt_depths)
|
||||
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
|
||||
gt_depths = gt_depths.view(B * N, H // self.downsample_factor,
|
||||
W // self.downsample_factor)
|
||||
|
||||
gt_depths = (gt_depths -
|
||||
(self.dbound[0] - self.dbound[2])) / self.dbound[2]
|
||||
gt_depths = torch.where(
|
||||
(gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0),
|
||||
gt_depths, torch.zeros_like(gt_depths))
|
||||
gt_depths = F.one_hot(gt_depths.long(),
|
||||
num_classes=self.depth_channels + 1).view(
|
||||
-1, self.depth_channels + 1)[:, 1:]
|
||||
|
||||
return gt_depths.float()
|
||||
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = super().forward(*args, **kwargs)
|
||||
x, depth_pred = x[0], x[-1]
|
||||
x = self.downsample(x)
|
||||
if kwargs.get('depth_loss', False):
|
||||
depth_loss = self.get_depth_loss(kwargs['gt_depths'], depth_pred)
|
||||
return x, depth_loss
|
||||
else:
|
||||
return x
|
||||
|
|
@ -8,6 +8,9 @@ from mmdet3d.ops import bev_pool
|
|||
|
||||
__all__ = ["BaseTransform", "BaseDepthTransform"]
|
||||
|
||||
def boolmask2idx(mask):
|
||||
# A utility function, workaround for ONNX not supporting 'nonzero'
|
||||
return torch.nonzero(mask).squeeze(1).tolist()
|
||||
|
||||
def gen_dx_bx(xbound, ybound, zbound):
|
||||
dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])
|
||||
|
|
@ -29,6 +32,10 @@ class BaseTransform(nn.Module):
|
|||
ybound: Tuple[float, float, float],
|
||||
zbound: Tuple[float, float, float],
|
||||
dbound: Tuple[float, float, float],
|
||||
use_points='lidar',
|
||||
depth_input='scalar',
|
||||
height_expand=True,
|
||||
add_depth_features=True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
|
@ -38,6 +45,12 @@ class BaseTransform(nn.Module):
|
|||
self.ybound = ybound
|
||||
self.zbound = zbound
|
||||
self.dbound = dbound
|
||||
self.use_points = use_points
|
||||
assert use_points in ['radar', 'lidar']
|
||||
self.depth_input=depth_input
|
||||
assert depth_input in ['scalar', 'one-hot']
|
||||
self.height_expand = height_expand
|
||||
self.add_depth_features = add_depth_features
|
||||
|
||||
dx, bx, nx = gen_dx_bx(self.xbound, self.ybound, self.zbound)
|
||||
self.dx = nn.Parameter(dx, requires_grad=False)
|
||||
|
|
@ -167,6 +180,7 @@ class BaseTransform(nn.Module):
|
|||
self,
|
||||
img,
|
||||
points,
|
||||
radar,
|
||||
camera2ego,
|
||||
lidar2ego,
|
||||
lidar2camera,
|
||||
|
|
@ -199,10 +213,26 @@ class BaseTransform(nn.Module):
|
|||
extra_rots=extra_rots,
|
||||
extra_trans=extra_trans,
|
||||
)
|
||||
mats_dict = {
|
||||
'intrin_mats': camera_intrinsics,
|
||||
'ida_mats': img_aug_matrix,
|
||||
'bda_mat': lidar_aug_matrix,
|
||||
'sensor2ego_mats': camera2ego,
|
||||
}
|
||||
x = self.get_cam_feats(img, mats_dict)
|
||||
|
||||
x = self.get_cam_feats(img)
|
||||
use_depth = False
|
||||
if type(x) == tuple:
|
||||
x, depth = x
|
||||
use_depth = True
|
||||
|
||||
x = self.bev_pool(geom, x)
|
||||
return x
|
||||
|
||||
if use_depth:
|
||||
return x, depth
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class BaseDepthTransform(BaseTransform):
|
||||
|
|
@ -211,6 +241,7 @@ class BaseDepthTransform(BaseTransform):
|
|||
self,
|
||||
img,
|
||||
points,
|
||||
radar,
|
||||
sensor2ego,
|
||||
lidar2ego,
|
||||
lidar2camera,
|
||||
|
|
@ -232,12 +263,22 @@ class BaseDepthTransform(BaseTransform):
|
|||
camera2lidar_rots = camera2lidar[..., :3, :3]
|
||||
camera2lidar_trans = camera2lidar[..., :3, 3]
|
||||
|
||||
# print(img.shape, self.image_size, self.feature_size)
|
||||
if self.use_points == 'radar':
|
||||
points = radar
|
||||
|
||||
if self.height_expand:
|
||||
for b in range(len(points)):
|
||||
points_repeated = points[b].repeat_interleave(8, dim=0)
|
||||
points_repeated[:, 2] = torch.arange(0.25, 2.25, 0.25).repeat(points[b].shape[0])
|
||||
points[b] = points_repeated
|
||||
|
||||
batch_size = len(points)
|
||||
depth = torch.zeros(batch_size, img.shape[1], 1, *self.image_size).to(
|
||||
points[0].device
|
||||
)
|
||||
depth_in_channels = 1 if self.depth_input=='scalar' else self.D
|
||||
if self.add_depth_features:
|
||||
depth_in_channels += points[0].shape[1]
|
||||
|
||||
depth = torch.zeros(batch_size, img.shape[1], depth_in_channels, *self.image_size, device=points[0].device)
|
||||
|
||||
|
||||
for b in range(batch_size):
|
||||
cur_coords = points[b][:, :3]
|
||||
|
|
@ -275,7 +316,17 @@ class BaseDepthTransform(BaseTransform):
|
|||
for c in range(on_img.shape[0]):
|
||||
masked_coords = cur_coords[c, on_img[c]].long()
|
||||
masked_dist = dist[c, on_img[c]]
|
||||
depth[b, c, 0, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist
|
||||
|
||||
if self.depth_input == 'scalar':
|
||||
depth[b, c, 0, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist
|
||||
elif self.depth_input == 'one-hot':
|
||||
# Clamp depths that are too big to D
|
||||
# These can arise when the point range filter is different from the dbound.
|
||||
masked_dist = torch.clamp(masked_dist, max=self.D-1)
|
||||
depth[b, c, masked_dist.long(), masked_coords[:, 0], masked_coords[:, 1]] = 1.0
|
||||
|
||||
if self.add_depth_features:
|
||||
depth[b, c, -points[b].shape[-1]:, masked_coords[:, 0], masked_coords[:, 1]] = points[b][boolmask2idx(on_img[c])].transpose(0,1)
|
||||
|
||||
extra_rots = lidar_aug_matrix[..., :3, :3]
|
||||
extra_trans = lidar_aug_matrix[..., :3, 3]
|
||||
|
|
@ -289,6 +340,23 @@ class BaseDepthTransform(BaseTransform):
|
|||
extra_trans=extra_trans,
|
||||
)
|
||||
|
||||
x = self.get_cam_feats(img, depth)
|
||||
mats_dict = {
|
||||
'intrin_mats': intrins,
|
||||
'ida_mats': img_aug_matrix,
|
||||
'bda_mat': lidar_aug_matrix,
|
||||
'sensor2ego_mats': sensor2ego,
|
||||
}
|
||||
x = self.get_cam_feats(img, depth, mats_dict)
|
||||
|
||||
use_depth = False
|
||||
if type(x) == tuple:
|
||||
x, depth = x
|
||||
use_depth = True
|
||||
|
||||
x = self.bev_pool(geom, x)
|
||||
return x
|
||||
|
||||
if use_depth:
|
||||
return x, depth
|
||||
else:
|
||||
return x
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from mmcv.ops import (
|
|||
)
|
||||
|
||||
from .ball_query import ball_query
|
||||
from .feature_decorator import feature_decorator
|
||||
from .furthest_point_sample import (
|
||||
Points_Sampler,
|
||||
furthest_point_sample,
|
||||
|
|
@ -89,4 +90,5 @@ __all__ = [
|
|||
"PAConvCUDASAModule",
|
||||
"PAConvCUDASAModuleMSG",
|
||||
"bev_pool",
|
||||
"feature_decorator",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .feature_decorator import feature_decorator
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
|
||||
import torch
|
||||
|
||||
from mmdet3d.ops.feature_decorator import feature_decorator_ext
|
||||
|
||||
|
||||
__all__ = ["feature_decorator"]
|
||||
|
||||
def feature_decorator(features, num_voxels, coords, vx, vy, x_offset, y_offset, normalize_coords, use_cluster, use_center):
|
||||
result = torch.ops.feature_decorator_ext.feature_decorator_forward(features, coords, num_voxels, vx, vy, x_offset, y_offset, normalize_coords, use_cluster, use_center)
|
||||
return result
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
A = torch.ones((2, 20, 5), dtype=torch.float32).cuda()
|
||||
B = torch.ones(2, dtype=torch.int32).cuda()
|
||||
C = torch.ones((2, 4), dtype=torch.int32).cuda()
|
||||
D = feature_decorator(A, B, C)
|
||||
D = feature_decorator_ext.feature_decorator_forward(A, B, C)
|
||||
print(D.shape)
|
||||
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
#include <torch/torch.h>
|
||||
|
||||
// CUDA function declarations
|
||||
// void feature_decorator(int b, int d, int h, int w, int n, int c, int n_intervals, const float* x,
|
||||
// const int* geom_feats, const int* interval_starts, const int* interval_lengths, float* out);
|
||||
|
||||
void feature_decorator(float* out);
|
||||
|
||||
at::Tensor feature_decorator_forward(
|
||||
const at::Tensor _x,
|
||||
const at::Tensor _y,
|
||||
const at::Tensor _z,
|
||||
const double vx, const double vy, const double x_offset, const double y_offset,
|
||||
int normalize_coords, int use_cluster, int use_center
|
||||
) {
|
||||
int n = _x.size(0);
|
||||
int c = _x.size(1);
|
||||
int a = _x.size(2);
|
||||
auto options = torch::TensorOptions().dtype(_x.dtype()).device(_x.device());
|
||||
int decorate_dims = 0;
|
||||
if (use_cluster > 0) {
|
||||
decorate_dims += 3;
|
||||
}
|
||||
if (use_center > 0) {
|
||||
decorate_dims += 2;
|
||||
}
|
||||
|
||||
at::Tensor _out = torch::zeros({n, c, a+decorate_dims}, options);
|
||||
float* out = _out.data_ptr<float>();
|
||||
const float* x = _x.data_ptr<float>();
|
||||
const int* y = _y.data_ptr<int>();
|
||||
const int* z = _z.data_ptr<int>();
|
||||
feature_decorator(out);
|
||||
return _out;
|
||||
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("feature_decorator_forward", &feature_decorator_forward,
|
||||
"feature_decorator_forward");
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
torch::RegisterOperators("feature_decorator_ext::feature_decorator_forward", &feature_decorator_forward);
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// __global__ void feature_decorator_kernel(int b, int d, int h, int w, int n, int c, int n_intervals,
|
||||
// const float *__restrict__ x,
|
||||
// const int *__restrict__ geom_feats,
|
||||
// const int *__restrict__ interval_starts,
|
||||
// const int *__restrict__ interval_lengths,
|
||||
// float* __restrict__ out) {
|
||||
// int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// int index = idx / c;
|
||||
// int cur_c = idx % c;
|
||||
// if (index >= n_intervals) return;
|
||||
// int interval_start = interval_starts[index];
|
||||
// int interval_length = interval_lengths[index];
|
||||
// const int* cur_geom_feats = geom_feats + interval_start * 4;
|
||||
// const float* cur_x = x + interval_start * c + cur_c;
|
||||
// float* cur_out = out + cur_geom_feats[3] * d * h * w * c +
|
||||
// cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c +
|
||||
// cur_geom_feats[1] * c + cur_c;
|
||||
// float psum = 0;
|
||||
// for(int i = 0; i < interval_length; i++){
|
||||
// psum += cur_x[i * c];
|
||||
// }
|
||||
// *cur_out = psum;
|
||||
// }
|
||||
|
||||
__global__ void feature_decorator_kernel(float* __restrict__ out) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// if (idx == 0) {
|
||||
// out = 5.0;
|
||||
// }
|
||||
out[0] = 5.0;
|
||||
out[15] = 6.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// void feature_decorator(int b, int d, int h, int w, int n, int c, int n_intervals, const float* x,
|
||||
// const int* geom_feats, const int* interval_starts, const int* interval_lengths, float* out) {
|
||||
// feature_decorator_kernel<<<(int)ceil(((double)n_intervals * c / 256)), 256>>>(
|
||||
// b, d, h, w, n, c, n_intervals, x, geom_feats, interval_starts, interval_lengths, out
|
||||
// );
|
||||
// }
|
||||
|
||||
void feature_decorator(float* out) {
|
||||
feature_decorator_kernel<<<1, 1>>>(out);
|
||||
}
|
||||
|
||||
|
|
@ -1,51 +1,34 @@
|
|||
import os
|
||||
from collections import OrderedDict
|
||||
from os import path as osp
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from nuscenes.nuscenes import NuScenes
|
||||
from nuscenes.utils.geometry_utils import view_points
|
||||
from os import path as osp
|
||||
from pyquaternion import Quaternion
|
||||
from shapely.geometry import MultiPoint, box
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from mmdet3d.core.bbox.box_np_ops import points_cam2img
|
||||
from mmdet3d.datasets import NuScenesDataset
|
||||
|
||||
nus_categories = (
|
||||
"car",
|
||||
"truck",
|
||||
"trailer",
|
||||
"bus",
|
||||
"construction_vehicle",
|
||||
"bicycle",
|
||||
"motorcycle",
|
||||
"pedestrian",
|
||||
"traffic_cone",
|
||||
"barrier",
|
||||
)
|
||||
nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
|
||||
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
|
||||
'barrier')
|
||||
|
||||
nus_attributes = (
|
||||
"cycle.with_rider",
|
||||
"cycle.without_rider",
|
||||
"pedestrian.moving",
|
||||
"pedestrian.standing",
|
||||
"pedestrian.sitting_lying_down",
|
||||
"vehicle.moving",
|
||||
"vehicle.parked",
|
||||
"vehicle.stopped",
|
||||
"None",
|
||||
)
|
||||
nus_attributes = ('cycle.with_rider', 'cycle.without_rider',
|
||||
'pedestrian.moving', 'pedestrian.standing',
|
||||
'pedestrian.sitting_lying_down', 'vehicle.moving',
|
||||
'vehicle.parked', 'vehicle.stopped', 'None')
|
||||
|
||||
|
||||
def create_nuscenes_infos(
|
||||
root_path, info_prefix, version="v1.0-trainval", max_sweeps=10
|
||||
):
|
||||
def create_nuscenes_infos(root_path,
|
||||
info_prefix,
|
||||
version='v1.0-trainval',
|
||||
max_sweeps=10,
|
||||
max_radar_sweeps=10):
|
||||
"""Create info file of nuscene dataset.
|
||||
|
||||
Given the raw data, generate its related info file in pkl format.
|
||||
|
||||
Args:
|
||||
root_path (str): Path of the data root.
|
||||
info_prefix (str): Prefix of the info file to be generated.
|
||||
|
|
@ -53,100 +36,96 @@ def create_nuscenes_infos(
|
|||
Default: 'v1.0-trainval'
|
||||
max_sweeps (int): Max number of sweeps.
|
||||
Default: 10
|
||||
max_radar_sweeps (int): Max number of radar sweeps.
|
||||
Default: 10
|
||||
"""
|
||||
from nuscenes.nuscenes import NuScenes
|
||||
|
||||
nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
|
||||
from nuscenes.utils import splits
|
||||
|
||||
available_vers = ["v1.0-trainval", "v1.0-test", "v1.0-mini"]
|
||||
available_vers = ['v1.0-trainval', 'v1.0-test', 'v1.0-mini']
|
||||
assert version in available_vers
|
||||
if version == "v1.0-trainval":
|
||||
if version == 'v1.0-trainval':
|
||||
train_scenes = splits.train
|
||||
val_scenes = splits.val
|
||||
elif version == "v1.0-test":
|
||||
elif version == 'v1.0-test':
|
||||
train_scenes = splits.test
|
||||
val_scenes = []
|
||||
elif version == "v1.0-mini":
|
||||
elif version == 'v1.0-mini':
|
||||
train_scenes = splits.mini_train
|
||||
val_scenes = splits.mini_val
|
||||
else:
|
||||
raise ValueError("unknown")
|
||||
raise ValueError('unknown')
|
||||
|
||||
# filter existing scenes.
|
||||
available_scenes = get_available_scenes(nusc)
|
||||
available_scene_names = [s["name"] for s in available_scenes]
|
||||
train_scenes = list(filter(lambda x: x in available_scene_names, train_scenes))
|
||||
available_scene_names = [s['name'] for s in available_scenes]
|
||||
train_scenes = list(
|
||||
filter(lambda x: x in available_scene_names, train_scenes))
|
||||
val_scenes = list(filter(lambda x: x in available_scene_names, val_scenes))
|
||||
train_scenes = set(
|
||||
[
|
||||
available_scenes[available_scene_names.index(s)]["token"]
|
||||
for s in train_scenes
|
||||
]
|
||||
)
|
||||
val_scenes = set(
|
||||
[available_scenes[available_scene_names.index(s)]["token"] for s in val_scenes]
|
||||
)
|
||||
train_scenes = set([
|
||||
available_scenes[available_scene_names.index(s)]['token']
|
||||
for s in train_scenes
|
||||
])
|
||||
val_scenes = set([
|
||||
available_scenes[available_scene_names.index(s)]['token']
|
||||
for s in val_scenes
|
||||
])
|
||||
|
||||
test = "test" in version
|
||||
test = 'test' in version
|
||||
if test:
|
||||
print("test scene: {}".format(len(train_scenes)))
|
||||
print('test scene: {}'.format(len(train_scenes)))
|
||||
else:
|
||||
print(
|
||||
"train scene: {}, val scene: {}".format(len(train_scenes), len(val_scenes))
|
||||
)
|
||||
print('train scene: {}, val scene: {}'.format(
|
||||
len(train_scenes), len(val_scenes)))
|
||||
train_nusc_infos, val_nusc_infos = _fill_trainval_infos(
|
||||
nusc, train_scenes, val_scenes, test, max_sweeps=max_sweeps
|
||||
)
|
||||
nusc, train_scenes, val_scenes, test, max_sweeps=max_sweeps, max_radar_sweeps=max_radar_sweeps)
|
||||
|
||||
metadata = dict(version=version)
|
||||
if test:
|
||||
print("test sample: {}".format(len(train_nusc_infos)))
|
||||
print('test sample: {}'.format(len(train_nusc_infos)))
|
||||
data = dict(infos=train_nusc_infos, metadata=metadata)
|
||||
info_path = osp.join(root_path, "{}_infos_test.pkl".format(info_prefix))
|
||||
info_path = osp.join(root_path,
|
||||
'{}_infos_test_radar.pkl'.format(info_prefix))
|
||||
mmcv.dump(data, info_path)
|
||||
else:
|
||||
print(
|
||||
"train sample: {}, val sample: {}".format(
|
||||
len(train_nusc_infos), len(val_nusc_infos)
|
||||
)
|
||||
)
|
||||
print(info_prefix)
|
||||
print('train sample: {}, val sample: {}'.format(
|
||||
len(train_nusc_infos), len(val_nusc_infos)))
|
||||
data = dict(infos=train_nusc_infos, metadata=metadata)
|
||||
info_path = osp.join(root_path, "{}_infos_train.pkl".format(info_prefix))
|
||||
info_path = osp.join(info_prefix,
|
||||
'{}_infos_train_radar.pkl'.format(info_prefix))
|
||||
mmcv.dump(data, info_path)
|
||||
data["infos"] = val_nusc_infos
|
||||
info_val_path = osp.join(root_path, "{}_infos_val.pkl".format(info_prefix))
|
||||
data['infos'] = val_nusc_infos
|
||||
info_val_path = osp.join(info_prefix,
|
||||
'{}_infos_val_radar.pkl'.format(info_prefix))
|
||||
mmcv.dump(data, info_val_path)
|
||||
|
||||
|
||||
def get_available_scenes(nusc):
|
||||
"""Get available scenes from the input nuscenes class.
|
||||
|
||||
Given the raw data, get the information of available scenes for
|
||||
further info generation.
|
||||
|
||||
Args:
|
||||
nusc (class): Dataset class in the nuScenes dataset.
|
||||
|
||||
Returns:
|
||||
available_scenes (list[dict]): List of basic information for the
|
||||
available scenes.
|
||||
"""
|
||||
available_scenes = []
|
||||
print("total scene num: {}".format(len(nusc.scene)))
|
||||
print('total scene num: {}'.format(len(nusc.scene)))
|
||||
for scene in nusc.scene:
|
||||
scene_token = scene["token"]
|
||||
scene_rec = nusc.get("scene", scene_token)
|
||||
sample_rec = nusc.get("sample", scene_rec["first_sample_token"])
|
||||
sd_rec = nusc.get("sample_data", sample_rec["data"]["LIDAR_TOP"])
|
||||
scene_token = scene['token']
|
||||
scene_rec = nusc.get('scene', scene_token)
|
||||
sample_rec = nusc.get('sample', scene_rec['first_sample_token'])
|
||||
sd_rec = nusc.get('sample_data', sample_rec['data']['LIDAR_TOP'])
|
||||
has_more_frames = True
|
||||
scene_not_exist = False
|
||||
while has_more_frames:
|
||||
lidar_path, boxes, _ = nusc.get_sample_data(sd_rec["token"])
|
||||
lidar_path, boxes, _ = nusc.get_sample_data(sd_rec['token'])
|
||||
lidar_path = str(lidar_path)
|
||||
if os.getcwd() in lidar_path:
|
||||
# path from lyftdataset is absolute path
|
||||
lidar_path = lidar_path.split(f"{os.getcwd()}/")[-1]
|
||||
lidar_path = lidar_path.split(f'{os.getcwd()}/')[-1]
|
||||
# relative path
|
||||
if not mmcv.is_filepath(lidar_path):
|
||||
scene_not_exist = True
|
||||
|
|
@ -156,13 +135,17 @@ def get_available_scenes(nusc):
|
|||
if scene_not_exist:
|
||||
continue
|
||||
available_scenes.append(scene)
|
||||
print("exist scene num: {}".format(len(available_scenes)))
|
||||
print('exist scene num: {}'.format(len(available_scenes)))
|
||||
return available_scenes
|
||||
|
||||
|
||||
def _fill_trainval_infos(nusc, train_scenes, val_scenes, test=False, max_sweeps=10):
|
||||
def _fill_trainval_infos(nusc,
|
||||
train_scenes,
|
||||
val_scenes,
|
||||
test=False,
|
||||
max_sweeps=10,
|
||||
max_radar_sweeps=10):
|
||||
"""Generate the train/val infos from the raw data.
|
||||
|
||||
Args:
|
||||
nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset.
|
||||
train_scenes (list[str]): Basic information of training scenes.
|
||||
|
|
@ -170,101 +153,126 @@ def _fill_trainval_infos(nusc, train_scenes, val_scenes, test=False, max_sweeps=
|
|||
test (bool): Whether use the test mode. In the test mode, no
|
||||
annotations can be accessed. Default: False.
|
||||
max_sweeps (int): Max number of sweeps. Default: 10.
|
||||
|
||||
max_radar_sweeps (int): Max number of radar sweeps. Default: 10.
|
||||
Returns:
|
||||
tuple[list[dict]]: Information of training set and validation set
|
||||
that will be saved to the info file.
|
||||
"""
|
||||
train_nusc_infos = []
|
||||
val_nusc_infos = []
|
||||
token2idx = {}
|
||||
|
||||
i_ = 0
|
||||
|
||||
for sample in mmcv.track_iter_progress(nusc.sample):
|
||||
lidar_token = sample["data"]["LIDAR_TOP"]
|
||||
sd_rec = nusc.get("sample_data", sample["data"]["LIDAR_TOP"])
|
||||
cs_record = nusc.get("calibrated_sensor", sd_rec["calibrated_sensor_token"])
|
||||
pose_record = nusc.get("ego_pose", sd_rec["ego_pose_token"])
|
||||
location = nusc.get(
|
||||
"log", nusc.get("scene", sample["scene_token"])["log_token"]
|
||||
)["location"]
|
||||
# i_ += 1
|
||||
# if i_ > 6:
|
||||
# break
|
||||
|
||||
lidar_token = sample['data']['LIDAR_TOP']
|
||||
sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
|
||||
cs_record = nusc.get('calibrated_sensor',
|
||||
sd_rec['calibrated_sensor_token'])
|
||||
pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
|
||||
lidar_path, boxes, _ = nusc.get_sample_data(lidar_token)
|
||||
|
||||
mmcv.check_file_exist(lidar_path)
|
||||
|
||||
info = {
|
||||
"lidar_path": lidar_path,
|
||||
"token": sample["token"],
|
||||
"sweeps": [],
|
||||
"cams": dict(),
|
||||
"lidar2ego_translation": cs_record["translation"],
|
||||
"lidar2ego_rotation": cs_record["rotation"],
|
||||
"ego2global_translation": pose_record["translation"],
|
||||
"ego2global_rotation": pose_record["rotation"],
|
||||
"timestamp": sample["timestamp"],
|
||||
"location": location,
|
||||
'lidar_path': lidar_path,
|
||||
'token': sample['token'],
|
||||
'sweeps': [],
|
||||
'cams': dict(),
|
||||
'radars': dict(),
|
||||
'lidar2ego_translation': cs_record['translation'],
|
||||
'lidar2ego_rotation': cs_record['rotation'],
|
||||
'ego2global_translation': pose_record['translation'],
|
||||
'ego2global_rotation': pose_record['rotation'],
|
||||
'timestamp': sample['timestamp'],
|
||||
'prev_token': sample['prev']
|
||||
}
|
||||
|
||||
l2e_r = info["lidar2ego_rotation"]
|
||||
l2e_t = info["lidar2ego_translation"]
|
||||
e2g_r = info["ego2global_rotation"]
|
||||
e2g_t = info["ego2global_translation"]
|
||||
l2e_r = info['lidar2ego_rotation']
|
||||
l2e_t = info['lidar2ego_translation']
|
||||
e2g_r = info['ego2global_rotation']
|
||||
e2g_t = info['ego2global_translation']
|
||||
l2e_r_mat = Quaternion(l2e_r).rotation_matrix
|
||||
e2g_r_mat = Quaternion(e2g_r).rotation_matrix
|
||||
|
||||
# obtain 6 image's information per frame
|
||||
camera_types = [
|
||||
"CAM_FRONT",
|
||||
"CAM_FRONT_RIGHT",
|
||||
"CAM_FRONT_LEFT",
|
||||
"CAM_BACK",
|
||||
"CAM_BACK_LEFT",
|
||||
"CAM_BACK_RIGHT",
|
||||
'CAM_FRONT',
|
||||
'CAM_FRONT_RIGHT',
|
||||
'CAM_FRONT_LEFT',
|
||||
'CAM_BACK',
|
||||
'CAM_BACK_LEFT',
|
||||
'CAM_BACK_RIGHT',
|
||||
]
|
||||
for cam in camera_types:
|
||||
cam_token = sample["data"][cam]
|
||||
cam_path, _, camera_intrinsics = nusc.get_sample_data(cam_token)
|
||||
cam_info = obtain_sensor2top(
|
||||
nusc, cam_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, cam
|
||||
)
|
||||
cam_info.update(camera_intrinsics=camera_intrinsics)
|
||||
info["cams"].update({cam: cam_info})
|
||||
cam_token = sample['data'][cam]
|
||||
cam_path, _, cam_intrinsic = nusc.get_sample_data(cam_token)
|
||||
cam_info = obtain_sensor2top(nusc, cam_token, l2e_t, l2e_r_mat,
|
||||
e2g_t, e2g_r_mat, cam)
|
||||
cam_info.update(cam_intrinsic=cam_intrinsic)
|
||||
info['cams'].update({cam: cam_info})
|
||||
|
||||
radar_names = ['RADAR_FRONT', 'RADAR_FRONT_LEFT', 'RADAR_FRONT_RIGHT', 'RADAR_BACK_LEFT', 'RADAR_BACK_RIGHT']
|
||||
|
||||
for radar_name in radar_names:
|
||||
radar_token = sample['data'][radar_name]
|
||||
radar_rec = nusc.get('sample_data', radar_token)
|
||||
sweeps = []
|
||||
|
||||
while len(sweeps) < max_radar_sweeps:
|
||||
if not radar_rec['prev'] == '':
|
||||
radar_path, _, radar_intrin = nusc.get_sample_data(radar_token)
|
||||
|
||||
radar_info = obtain_sensor2top(nusc, radar_token, l2e_t, l2e_r_mat,
|
||||
e2g_t, e2g_r_mat, radar_name)
|
||||
sweeps.append(radar_info)
|
||||
radar_token = radar_rec['prev']
|
||||
radar_rec = nusc.get('sample_data', radar_token)
|
||||
else:
|
||||
radar_path, _, radar_intrin = nusc.get_sample_data(radar_token)
|
||||
|
||||
radar_info = obtain_sensor2top(nusc, radar_token, l2e_t, l2e_r_mat,
|
||||
e2g_t, e2g_r_mat, radar_name)
|
||||
sweeps.append(radar_info)
|
||||
|
||||
info['radars'].update({radar_name: sweeps})
|
||||
# obtain sweeps for a single key-frame
|
||||
sd_rec = nusc.get("sample_data", sample["data"]["LIDAR_TOP"])
|
||||
sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
|
||||
sweeps = []
|
||||
while len(sweeps) < max_sweeps:
|
||||
if not sd_rec["prev"] == "":
|
||||
sweep = obtain_sensor2top(
|
||||
nusc, sd_rec["prev"], l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, "lidar"
|
||||
)
|
||||
if not sd_rec['prev'] == '':
|
||||
sweep = obtain_sensor2top(nusc, sd_rec['prev'], l2e_t,
|
||||
l2e_r_mat, e2g_t, e2g_r_mat, 'lidar')
|
||||
sweeps.append(sweep)
|
||||
sd_rec = nusc.get("sample_data", sd_rec["prev"])
|
||||
sd_rec = nusc.get('sample_data', sd_rec['prev'])
|
||||
else:
|
||||
break
|
||||
info["sweeps"] = sweeps
|
||||
info['sweeps'] = sweeps
|
||||
# obtain annotation
|
||||
if not test:
|
||||
annotations = [
|
||||
nusc.get("sample_annotation", token) for token in sample["anns"]
|
||||
nusc.get('sample_annotation', token)
|
||||
for token in sample['anns']
|
||||
]
|
||||
locs = np.array([b.center for b in boxes]).reshape(-1, 3)
|
||||
dims = np.array([b.wlh for b in boxes]).reshape(-1, 3)
|
||||
rots = np.array([b.orientation.yaw_pitch_roll[0] for b in boxes]).reshape(
|
||||
-1, 1
|
||||
)
|
||||
rots = np.array([b.orientation.yaw_pitch_roll[0]
|
||||
for b in boxes]).reshape(-1, 1)
|
||||
velocity = np.array(
|
||||
[nusc.box_velocity(token)[:2] for token in sample["anns"]]
|
||||
)
|
||||
[nusc.box_velocity(token)[:2] for token in sample['anns']])
|
||||
valid_flag = np.array(
|
||||
[
|
||||
(anno["num_lidar_pts"] + anno["num_radar_pts"]) > 0
|
||||
for anno in annotations
|
||||
],
|
||||
dtype=bool,
|
||||
).reshape(-1)
|
||||
[(anno['num_lidar_pts'] + anno['num_radar_pts']) > 0
|
||||
for anno in annotations],
|
||||
dtype=bool).reshape(-1)
|
||||
# convert velo from global to lidar
|
||||
for i in range(len(boxes)):
|
||||
velo = np.array([*velocity[i], 0.0])
|
||||
velo = velo @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
|
||||
velo = velo @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(
|
||||
l2e_r_mat).T
|
||||
velocity[i] = velo[:2]
|
||||
|
||||
names = [b.name for b in boxes]
|
||||
|
|
@ -275,28 +283,52 @@ def _fill_trainval_infos(nusc, train_scenes, val_scenes, test=False, max_sweeps=
|
|||
# we need to convert rot to SECOND format.
|
||||
gt_boxes = np.concatenate([locs, dims, -rots - np.pi / 2], axis=1)
|
||||
assert len(gt_boxes) == len(
|
||||
annotations
|
||||
), f"{len(gt_boxes)}, {len(annotations)}"
|
||||
info["gt_boxes"] = gt_boxes
|
||||
info["gt_names"] = names
|
||||
info["gt_velocity"] = velocity.reshape(-1, 2)
|
||||
info["num_lidar_pts"] = np.array([a["num_lidar_pts"] for a in annotations])
|
||||
info["num_radar_pts"] = np.array([a["num_radar_pts"] for a in annotations])
|
||||
info["valid_flag"] = valid_flag
|
||||
annotations), f'{len(gt_boxes)}, {len(annotations)}'
|
||||
info['gt_boxes'] = gt_boxes
|
||||
info['gt_names'] = names
|
||||
info['gt_velocity'] = velocity.reshape(-1, 2)
|
||||
info['num_lidar_pts'] = np.array(
|
||||
[a['num_lidar_pts'] for a in annotations])
|
||||
info['num_radar_pts'] = np.array(
|
||||
[a['num_radar_pts'] for a in annotations])
|
||||
info['valid_flag'] = valid_flag
|
||||
|
||||
if sample["scene_token"] in train_scenes:
|
||||
if sample['scene_token'] in train_scenes:
|
||||
train_nusc_infos.append(info)
|
||||
token2idx[info['token']] = ('train', len(train_nusc_infos) - 1)
|
||||
else:
|
||||
val_nusc_infos.append(info)
|
||||
token2idx[info['token']] = ('val', len(val_nusc_infos) - 1)
|
||||
|
||||
for info in train_nusc_infos:
|
||||
prev_token = info['prev_token']
|
||||
if prev_token == '':
|
||||
info['prev'] = -1
|
||||
else:
|
||||
prev_set, prev_idx = token2idx[prev_token]
|
||||
assert prev_set == 'train'
|
||||
info['prev'] = prev_idx
|
||||
|
||||
for info in val_nusc_infos:
|
||||
prev_token = info['prev_token']
|
||||
if prev_token == '':
|
||||
info['prev'] = -1
|
||||
else:
|
||||
prev_set, prev_idx = token2idx[prev_token]
|
||||
assert prev_set == 'val'
|
||||
info['prev'] = prev_idx
|
||||
|
||||
return train_nusc_infos, val_nusc_infos
|
||||
|
||||
|
||||
def obtain_sensor2top(
|
||||
nusc, sensor_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, sensor_type="lidar"
|
||||
):
|
||||
def obtain_sensor2top(nusc,
|
||||
sensor_token,
|
||||
l2e_t,
|
||||
l2e_r_mat,
|
||||
e2g_t,
|
||||
e2g_r_mat,
|
||||
sensor_type='lidar'):
|
||||
"""Obtain the info with RT matric from general sensor to Top LiDAR.
|
||||
|
||||
Args:
|
||||
nusc (class): Dataset class in the nuScenes dataset.
|
||||
sensor_token (str): Sample data token corresponding to the
|
||||
|
|
@ -308,53 +340,48 @@ def obtain_sensor2top(
|
|||
e2g_r_mat (np.ndarray): Rotation matrix from ego to global
|
||||
in shape (3, 3).
|
||||
sensor_type (str): Sensor to calibrate. Default: 'lidar'.
|
||||
|
||||
Returns:
|
||||
sweep (dict): Sweep information after transformation.
|
||||
"""
|
||||
sd_rec = nusc.get("sample_data", sensor_token)
|
||||
cs_record = nusc.get("calibrated_sensor", sd_rec["calibrated_sensor_token"])
|
||||
pose_record = nusc.get("ego_pose", sd_rec["ego_pose_token"])
|
||||
data_path = str(nusc.get_sample_data_path(sd_rec["token"]))
|
||||
sd_rec = nusc.get('sample_data', sensor_token)
|
||||
cs_record = nusc.get('calibrated_sensor',
|
||||
sd_rec['calibrated_sensor_token'])
|
||||
pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
|
||||
data_path = str(nusc.get_sample_data_path(sd_rec['token']))
|
||||
if os.getcwd() in data_path: # path from lyftdataset is absolute path
|
||||
data_path = data_path.split(f"{os.getcwd()}/")[-1] # relative path
|
||||
data_path = data_path.split(f'{os.getcwd()}/')[-1] # relative path
|
||||
sweep = {
|
||||
"data_path": data_path,
|
||||
"type": sensor_type,
|
||||
"sample_data_token": sd_rec["token"],
|
||||
"sensor2ego_translation": cs_record["translation"],
|
||||
"sensor2ego_rotation": cs_record["rotation"],
|
||||
"ego2global_translation": pose_record["translation"],
|
||||
"ego2global_rotation": pose_record["rotation"],
|
||||
"timestamp": sd_rec["timestamp"],
|
||||
'data_path': data_path,
|
||||
'type': sensor_type,
|
||||
'sample_data_token': sd_rec['token'],
|
||||
'sensor2ego_translation': cs_record['translation'],
|
||||
'sensor2ego_rotation': cs_record['rotation'],
|
||||
'ego2global_translation': pose_record['translation'],
|
||||
'ego2global_rotation': pose_record['rotation'],
|
||||
'timestamp': sd_rec['timestamp']
|
||||
}
|
||||
l2e_r_s = sweep["sensor2ego_rotation"]
|
||||
l2e_t_s = sweep["sensor2ego_translation"]
|
||||
e2g_r_s = sweep["ego2global_rotation"]
|
||||
e2g_t_s = sweep["ego2global_translation"]
|
||||
l2e_r_s = sweep['sensor2ego_rotation']
|
||||
l2e_t_s = sweep['sensor2ego_translation']
|
||||
e2g_r_s = sweep['ego2global_rotation']
|
||||
e2g_t_s = sweep['ego2global_translation']
|
||||
|
||||
# obtain the RT from sensor to Top LiDAR
|
||||
# sweep->ego->global->ego'->lidar
|
||||
l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix
|
||||
e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix
|
||||
R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ (
|
||||
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
|
||||
)
|
||||
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
|
||||
T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ (
|
||||
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
|
||||
)
|
||||
T -= (
|
||||
e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
|
||||
+ l2e_t @ np.linalg.inv(l2e_r_mat).T
|
||||
)
|
||||
sweep["sensor2lidar_rotation"] = R.T # points @ R.T + T
|
||||
sweep["sensor2lidar_translation"] = T
|
||||
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
|
||||
T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
|
||||
) + l2e_t @ np.linalg.inv(l2e_r_mat).T
|
||||
sweep['sensor2lidar_rotation'] = R.T # points @ R.T + T
|
||||
sweep['sensor2lidar_translation'] = T
|
||||
return sweep
|
||||
|
||||
|
||||
def export_2d_annotation(root_path, info_path, version, mono3d=True):
|
||||
"""Export 2d annotation from the info file and raw data.
|
||||
|
||||
Args:
|
||||
root_path (str): Root path of the raw data.
|
||||
info_path (str): Path of the info file.
|
||||
|
|
@ -363,14 +390,14 @@ def export_2d_annotation(root_path, info_path, version, mono3d=True):
|
|||
"""
|
||||
# get bbox annotations for camera
|
||||
camera_types = [
|
||||
"CAM_FRONT",
|
||||
"CAM_FRONT_RIGHT",
|
||||
"CAM_FRONT_LEFT",
|
||||
"CAM_BACK",
|
||||
"CAM_BACK_LEFT",
|
||||
"CAM_BACK_RIGHT",
|
||||
'CAM_FRONT',
|
||||
'CAM_FRONT_RIGHT',
|
||||
'CAM_FRONT_LEFT',
|
||||
'CAM_BACK',
|
||||
'CAM_BACK_LEFT',
|
||||
'CAM_BACK_RIGHT',
|
||||
]
|
||||
nusc_infos = mmcv.load(info_path)["infos"]
|
||||
nusc_infos = mmcv.load(info_path)['infos']
|
||||
nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
|
||||
# info_2d_list = []
|
||||
cat2Ids = [
|
||||
|
|
@ -381,97 +408,100 @@ def export_2d_annotation(root_path, info_path, version, mono3d=True):
|
|||
coco_2d_dict = dict(annotations=[], images=[], categories=cat2Ids)
|
||||
for info in mmcv.track_iter_progress(nusc_infos):
|
||||
for cam in camera_types:
|
||||
cam_info = info["cams"][cam]
|
||||
cam_info = info['cams'][cam]
|
||||
coco_infos = get_2d_boxes(
|
||||
nusc,
|
||||
cam_info["sample_data_token"],
|
||||
visibilities=["", "1", "2", "3", "4"],
|
||||
mono3d=mono3d,
|
||||
)
|
||||
(height, width, _) = mmcv.imread(cam_info["data_path"]).shape
|
||||
coco_2d_dict["images"].append(
|
||||
cam_info['sample_data_token'],
|
||||
visibilities=['', '1', '2', '3', '4'],
|
||||
mono3d=mono3d)
|
||||
(height, width, _) = mmcv.imread(cam_info['data_path']).shape
|
||||
coco_2d_dict['images'].append(
|
||||
dict(
|
||||
file_name=cam_info["data_path"].split("data/nuscenes/")[-1],
|
||||
id=cam_info["sample_data_token"],
|
||||
token=info["token"],
|
||||
cam2ego_rotation=cam_info["sensor2ego_rotation"],
|
||||
cam2ego_translation=cam_info["sensor2ego_translation"],
|
||||
ego2global_rotation=info["ego2global_rotation"],
|
||||
ego2global_translation=info["ego2global_translation"],
|
||||
camera_intrinsics=cam_info["camera_intrinsics"],
|
||||
file_name=cam_info['data_path'].split('data/nuscenes/')
|
||||
[-1],
|
||||
id=cam_info['sample_data_token'],
|
||||
token=info['token'],
|
||||
cam2ego_rotation=cam_info['sensor2ego_rotation'],
|
||||
cam2ego_translation=cam_info['sensor2ego_translation'],
|
||||
ego2global_rotation=info['ego2global_rotation'],
|
||||
ego2global_translation=info['ego2global_translation'],
|
||||
cam_intrinsic=cam_info['cam_intrinsic'],
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
)
|
||||
height=height))
|
||||
for coco_info in coco_infos:
|
||||
if coco_info is None:
|
||||
continue
|
||||
# add an empty key for coco format
|
||||
coco_info["segmentation"] = []
|
||||
coco_info["id"] = coco_ann_id
|
||||
coco_2d_dict["annotations"].append(coco_info)
|
||||
coco_info['segmentation'] = []
|
||||
coco_info['id'] = coco_ann_id
|
||||
coco_2d_dict['annotations'].append(coco_info)
|
||||
coco_ann_id += 1
|
||||
if mono3d:
|
||||
json_prefix = f"{info_path[:-4]}_mono3d"
|
||||
json_prefix = f'{info_path[:-4]}_mono3d'
|
||||
else:
|
||||
json_prefix = f"{info_path[:-4]}"
|
||||
mmcv.dump(coco_2d_dict, f"{json_prefix}.coco.json")
|
||||
json_prefix = f'{info_path[:-4]}'
|
||||
mmcv.dump(coco_2d_dict, f'{json_prefix}.coco.json')
|
||||
|
||||
|
||||
def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str], mono3d=True):
|
||||
def get_2d_boxes(nusc,
|
||||
sample_data_token: str,
|
||||
visibilities: List[str],
|
||||
mono3d=True):
|
||||
"""Get the 2D annotation records for a given `sample_data_token`.
|
||||
|
||||
Args:
|
||||
sample_data_token (str): Sample data token belonging to a camera \
|
||||
keyframe.
|
||||
visibilities (list[str]): Visibility filter.
|
||||
mono3d (bool): Whether to get boxes with mono3d annotation.
|
||||
|
||||
Return:
|
||||
list[dict]: List of 2D annotation record that belongs to the input
|
||||
`sample_data_token`.
|
||||
"""
|
||||
|
||||
# Get the sample data and the sample corresponding to that sample data.
|
||||
sd_rec = nusc.get("sample_data", sample_data_token)
|
||||
sd_rec = nusc.get('sample_data', sample_data_token)
|
||||
|
||||
assert sd_rec["sensor_modality"] == "camera", (
|
||||
"Error: get_2d_boxes only works" " for camera sample_data!"
|
||||
)
|
||||
if not sd_rec["is_key_frame"]:
|
||||
raise ValueError("The 2D re-projections are available only for keyframes.")
|
||||
assert sd_rec[
|
||||
'sensor_modality'] == 'camera', 'Error: get_2d_boxes only works' \
|
||||
' for camera sample_data!'
|
||||
if not sd_rec['is_key_frame']:
|
||||
raise ValueError(
|
||||
'The 2D re-projections are available only for keyframes.')
|
||||
|
||||
s_rec = nusc.get("sample", sd_rec["sample_token"])
|
||||
s_rec = nusc.get('sample', sd_rec['sample_token'])
|
||||
|
||||
# Get the calibrated sensor and ego pose
|
||||
# record to get the transformation matrices.
|
||||
cs_rec = nusc.get("calibrated_sensor", sd_rec["calibrated_sensor_token"])
|
||||
pose_rec = nusc.get("ego_pose", sd_rec["ego_pose_token"])
|
||||
camera_intrinsic = np.array(cs_rec["camera_intrinsic"])
|
||||
cs_rec = nusc.get('calibrated_sensor', sd_rec['calibrated_sensor_token'])
|
||||
pose_rec = nusc.get('ego_pose', sd_rec['ego_pose_token'])
|
||||
camera_intrinsic = np.array(cs_rec['camera_intrinsic'])
|
||||
|
||||
# Get all the annotation with the specified visibilties.
|
||||
ann_recs = [nusc.get("sample_annotation", token) for token in s_rec["anns"]]
|
||||
ann_recs = [
|
||||
ann_rec for ann_rec in ann_recs if (ann_rec["visibility_token"] in visibilities)
|
||||
nusc.get('sample_annotation', token) for token in s_rec['anns']
|
||||
]
|
||||
ann_recs = [
|
||||
ann_rec for ann_rec in ann_recs
|
||||
if (ann_rec['visibility_token'] in visibilities)
|
||||
]
|
||||
|
||||
repro_recs = []
|
||||
|
||||
for ann_rec in ann_recs:
|
||||
# Augment sample_annotation with token information.
|
||||
ann_rec["sample_annotation_token"] = ann_rec["token"]
|
||||
ann_rec["sample_data_token"] = sample_data_token
|
||||
ann_rec['sample_annotation_token'] = ann_rec['token']
|
||||
ann_rec['sample_data_token'] = sample_data_token
|
||||
|
||||
# Get the box in global coordinates.
|
||||
box = nusc.get_box(ann_rec["token"])
|
||||
box = nusc.get_box(ann_rec['token'])
|
||||
|
||||
# Move them to the ego-pose frame.
|
||||
box.translate(-np.array(pose_rec["translation"]))
|
||||
box.rotate(Quaternion(pose_rec["rotation"]).inverse)
|
||||
box.translate(-np.array(pose_rec['translation']))
|
||||
box.rotate(Quaternion(pose_rec['rotation']).inverse)
|
||||
|
||||
# Move them to the calibrated sensor frame.
|
||||
box.translate(-np.array(cs_rec["translation"]))
|
||||
box.rotate(Quaternion(cs_rec["rotation"]).inverse)
|
||||
box.translate(-np.array(cs_rec['translation']))
|
||||
box.rotate(Quaternion(cs_rec['rotation']).inverse)
|
||||
|
||||
# Filter out the corners that are not in front of the calibrated
|
||||
# sensor.
|
||||
|
|
@ -480,9 +510,8 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str], mono3d=T
|
|||
corners_3d = corners_3d[:, in_front]
|
||||
|
||||
# Project 3d box to 2d.
|
||||
corner_coords = (
|
||||
view_points(corners_3d, camera_intrinsic, True).T[:, :2].tolist()
|
||||
)
|
||||
corner_coords = view_points(corners_3d, camera_intrinsic,
|
||||
True).T[:, :2].tolist()
|
||||
|
||||
# Keep only corners that fall within the image.
|
||||
final_coords = post_process_coords(corner_coords)
|
||||
|
|
@ -495,49 +524,44 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str], mono3d=T
|
|||
min_x, min_y, max_x, max_y = final_coords
|
||||
|
||||
# Generate dictionary record to be included in the .json file.
|
||||
repro_rec = generate_record(
|
||||
ann_rec, min_x, min_y, max_x, max_y, sample_data_token, sd_rec["filename"]
|
||||
)
|
||||
repro_rec = generate_record(ann_rec, min_x, min_y, max_x, max_y,
|
||||
sample_data_token, sd_rec['filename'])
|
||||
|
||||
# If mono3d=True, add 3D annotations in camera coordinates
|
||||
if mono3d and (repro_rec is not None):
|
||||
loc = box.center.tolist()
|
||||
|
||||
dim = box.wlh
|
||||
dim[[0, 1, 2]] = dim[[1, 2, 0]] # convert wlh to our lhw
|
||||
dim = dim.tolist()
|
||||
|
||||
rot = box.orientation.yaw_pitch_roll[0]
|
||||
rot = [-rot] # convert the rot to our cam coordinate
|
||||
dim = box.wlh.tolist()
|
||||
rot = [box.orientation.yaw_pitch_roll[0]]
|
||||
|
||||
global_velo2d = nusc.box_velocity(box.token)[:2]
|
||||
global_velo3d = np.array([*global_velo2d, 0.0])
|
||||
e2g_r_mat = Quaternion(pose_rec["rotation"]).rotation_matrix
|
||||
c2e_r_mat = Quaternion(cs_rec["rotation"]).rotation_matrix
|
||||
cam_velo3d = (
|
||||
global_velo3d @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(c2e_r_mat).T
|
||||
)
|
||||
e2g_r_mat = Quaternion(pose_rec['rotation']).rotation_matrix
|
||||
c2e_r_mat = Quaternion(cs_rec['rotation']).rotation_matrix
|
||||
cam_velo3d = global_velo3d @ np.linalg.inv(
|
||||
e2g_r_mat).T @ np.linalg.inv(c2e_r_mat).T
|
||||
velo = cam_velo3d[0::2].tolist()
|
||||
|
||||
repro_rec["bbox_cam3d"] = loc + dim + rot
|
||||
repro_rec["velo_cam3d"] = velo
|
||||
repro_rec['bbox_cam3d'] = loc + dim + rot
|
||||
repro_rec['velo_cam3d'] = velo
|
||||
|
||||
center3d = np.array(loc).reshape([1, 3])
|
||||
center2d = points_cam2img(center3d, camera_intrinsic, with_depth=True)
|
||||
repro_rec["center2d"] = center2d.squeeze().tolist()
|
||||
center2d = points_cam2img(
|
||||
center3d, camera_intrinsic, with_depth=True)
|
||||
repro_rec['center2d'] = center2d.squeeze().tolist()
|
||||
# normalized center2D + depth
|
||||
# if samples with depth < 0 will be removed
|
||||
if repro_rec["center2d"][2] <= 0:
|
||||
if repro_rec['center2d'][2] <= 0:
|
||||
continue
|
||||
|
||||
ann_token = nusc.get("sample_annotation", box.token)["attribute_tokens"]
|
||||
ann_token = nusc.get('sample_annotation',
|
||||
box.token)['attribute_tokens']
|
||||
if len(ann_token) == 0:
|
||||
attr_name = "None"
|
||||
attr_name = 'None'
|
||||
else:
|
||||
attr_name = nusc.get("attribute", ann_token[0])["name"]
|
||||
attr_name = nusc.get('attribute', ann_token[0])['name']
|
||||
attr_id = nus_attributes.index(attr_name)
|
||||
repro_rec["attribute_name"] = attr_name
|
||||
repro_rec["attribute_id"] = attr_id
|
||||
repro_rec['attribute_name'] = attr_name
|
||||
repro_rec['attribute_id'] = attr_id
|
||||
|
||||
repro_recs.append(repro_rec)
|
||||
|
||||
|
|
@ -549,12 +573,10 @@ def post_process_coords(
|
|||
) -> Union[Tuple[float, float, float, float], None]:
|
||||
"""Get the intersection of the convex hull of the reprojected bbox corners
|
||||
and the image canvas, return None if no intersection.
|
||||
|
||||
Args:
|
||||
corner_coords (list[int]): Corner coordinates of reprojected
|
||||
bounding box.
|
||||
imsize (tuple[int]): Size of the image canvas.
|
||||
|
||||
Return:
|
||||
tuple [float]: Intersection of the convex hull of the 2D box
|
||||
corners and the image canvas.
|
||||
|
|
@ -565,8 +587,7 @@ def post_process_coords(
|
|||
if polygon_from_2d_box.intersects(img_canvas):
|
||||
img_intersection = polygon_from_2d_box.intersection(img_canvas)
|
||||
intersection_coords = np.array(
|
||||
[coord for coord in img_intersection.exterior.coords]
|
||||
)
|
||||
[coord for coord in img_intersection.exterior.coords])
|
||||
|
||||
min_x = min(intersection_coords[:, 0])
|
||||
min_y = min(intersection_coords[:, 1])
|
||||
|
|
@ -578,18 +599,10 @@ def post_process_coords(
|
|||
return None
|
||||
|
||||
|
||||
def generate_record(
|
||||
ann_rec: dict,
|
||||
x1: float,
|
||||
y1: float,
|
||||
x2: float,
|
||||
y2: float,
|
||||
sample_data_token: str,
|
||||
filename: str,
|
||||
) -> OrderedDict:
|
||||
def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
|
||||
sample_data_token: str, filename: str) -> OrderedDict:
|
||||
"""Generate one 2D annotation record given various informations on top of
|
||||
the 2D bounding box coordinates.
|
||||
|
||||
Args:
|
||||
ann_rec (dict): Original 3d annotation record.
|
||||
x1 (float): Minimum value of the x coordinate.
|
||||
|
|
@ -599,7 +612,6 @@ def generate_record(
|
|||
sample_data_token (str): Sample data token.
|
||||
filename (str):The corresponding image file where the annotation
|
||||
is present.
|
||||
|
||||
Returns:
|
||||
dict: A sample 2D annotation record.
|
||||
- file_name (str): flie name
|
||||
|
|
@ -611,39 +623,43 @@ def generate_record(
|
|||
- iscrowd (int): whether the area is crowd
|
||||
"""
|
||||
repro_rec = OrderedDict()
|
||||
repro_rec["sample_data_token"] = sample_data_token
|
||||
repro_rec['sample_data_token'] = sample_data_token
|
||||
coco_rec = dict()
|
||||
|
||||
relevant_keys = [
|
||||
"attribute_tokens",
|
||||
"category_name",
|
||||
"instance_token",
|
||||
"next",
|
||||
"num_lidar_pts",
|
||||
"num_radar_pts",
|
||||
"prev",
|
||||
"sample_annotation_token",
|
||||
"sample_data_token",
|
||||
"visibility_token",
|
||||
'attribute_tokens',
|
||||
'category_name',
|
||||
'instance_token',
|
||||
'next',
|
||||
'num_lidar_pts',
|
||||
'num_radar_pts',
|
||||
'prev',
|
||||
'sample_annotation_token',
|
||||
'sample_data_token',
|
||||
'visibility_token',
|
||||
]
|
||||
|
||||
for key, value in ann_rec.items():
|
||||
if key in relevant_keys:
|
||||
repro_rec[key] = value
|
||||
|
||||
repro_rec["bbox_corners"] = [x1, y1, x2, y2]
|
||||
repro_rec["filename"] = filename
|
||||
repro_rec['bbox_corners'] = [x1, y1, x2, y2]
|
||||
repro_rec['filename'] = filename
|
||||
|
||||
coco_rec["file_name"] = filename
|
||||
coco_rec["image_id"] = sample_data_token
|
||||
coco_rec["area"] = (y2 - y1) * (x2 - x1)
|
||||
coco_rec['file_name'] = filename
|
||||
coco_rec['image_id'] = sample_data_token
|
||||
coco_rec['area'] = (y2 - y1) * (x2 - x1)
|
||||
|
||||
if repro_rec["category_name"] not in NuScenesDataset.NameMapping:
|
||||
if repro_rec['category_name'] not in NuScenesDataset.NameMapping:
|
||||
return None
|
||||
cat_name = NuScenesDataset.NameMapping[repro_rec["category_name"]]
|
||||
coco_rec["category_name"] = cat_name
|
||||
coco_rec["category_id"] = nus_categories.index(cat_name)
|
||||
coco_rec["bbox"] = [x1, y1, x2 - x1, y2 - y1]
|
||||
coco_rec["iscrowd"] = 0
|
||||
cat_name = NuScenesDataset.NameMapping[repro_rec['category_name']]
|
||||
coco_rec['category_name'] = cat_name
|
||||
coco_rec['category_id'] = nus_categories.index(cat_name)
|
||||
coco_rec['bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
||||
coco_rec['iscrowd'] = 0
|
||||
|
||||
return coco_rec
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_nuscenes_infos('data/nuscenes/', 'radar_nuscenes_5sweeps')
|
||||
Loading…
Reference in New Issue