bev-project/mmdet3d/ops/furthest_point_sample/points_sampler.py

158 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from mmcv.runner import force_fp32
from torch import nn as nn
from typing import List
from .furthest_point_sample import furthest_point_sample, furthest_point_sample_with_dist
from .utils import calc_square_dist
def get_sampler_type(sampler_type):
"""Get the type and mode of points sampler.
Args:
sampler_type (str): The type of points sampler.
The valid value are "D-FPS", "F-FPS", or "FS".
Returns:
class: Points sampler type.
"""
if sampler_type == "D-FPS":
sampler = DFPS_Sampler
elif sampler_type == "F-FPS":
sampler = FFPS_Sampler
elif sampler_type == "FS":
sampler = FS_Sampler
else:
raise ValueError(
'Only "sampler_type" of "D-FPS", "F-FPS", or "FS"' f" are supported, got {sampler_type}"
)
return sampler
class Points_Sampler(nn.Module):
"""Points sampling.
Args:
num_point (list[int]): Number of sample points.
fps_mod_list (list[str]: Type of FPS method, valid mod
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
F-FPS: using feature distances for FPS.
D-FPS: using Euclidean distances of points for FPS.
FS: using F-FPS and D-FPS simultaneously.
fps_sample_range_list (list[int]): Range of points to apply FPS.
Default: [-1].
"""
def __init__(
self,
num_point: List[int],
fps_mod_list: List[str] = ["D-FPS"],
fps_sample_range_list: List[int] = [-1],
):
super(Points_Sampler, self).__init__()
# FPS would be applied to different fps_mod in the list,
# so the length of the num_point should be equal to
# fps_mod_list and fps_sample_range_list.
assert len(num_point) == len(fps_mod_list) == len(fps_sample_range_list)
self.num_point = num_point
self.fps_sample_range_list = fps_sample_range_list
self.samplers = nn.ModuleList()
for fps_mod in fps_mod_list:
self.samplers.append(get_sampler_type(fps_mod)())
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz, features):
"""forward.
Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor): (B, C, N) Descriptors of the features.
Return
Tensor: (B, npoint, sample_num) Indices of sampled points.
"""
indices = []
last_fps_end_index = 0
for fps_sample_range, sampler, npoint in zip(
self.fps_sample_range_list, self.samplers, self.num_point
):
assert fps_sample_range < points_xyz.shape[1]
if fps_sample_range == -1:
sample_points_xyz = points_xyz[:, last_fps_end_index:]
sample_features = (
features[:, :, last_fps_end_index:] if features is not None else None
)
else:
sample_points_xyz = points_xyz[:, last_fps_end_index:fps_sample_range]
sample_features = (
features[:, :, last_fps_end_index:fps_sample_range]
if features is not None
else None
)
fps_idx = sampler(sample_points_xyz.contiguous(), sample_features, npoint)
indices.append(fps_idx + last_fps_end_index)
last_fps_end_index += fps_sample_range
indices = torch.cat(indices, dim=1)
return indices
class DFPS_Sampler(nn.Module):
"""DFPS_Sampling.
Using Euclidean distances of points for FPS.
"""
def __init__(self):
super(DFPS_Sampler, self).__init__()
def forward(self, points, features, npoint):
"""Sampling points with D-FPS."""
fps_idx = furthest_point_sample(points.contiguous(), npoint)
return fps_idx
class FFPS_Sampler(nn.Module):
"""FFPS_Sampler.
Using feature distances for FPS.
"""
def __init__(self):
super(FFPS_Sampler, self).__init__()
def forward(self, points, features, npoint):
"""Sampling points with F-FPS."""
assert features is not None, "feature input to FFPS_Sampler should not be None"
features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
features_dist = calc_square_dist(features_for_fps, features_for_fps, norm=False)
fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
return fps_idx
class FS_Sampler(nn.Module):
"""FS_Sampling.
Using F-FPS and D-FPS simultaneously.
"""
def __init__(self):
super(FS_Sampler, self).__init__()
def forward(self, points, features, npoint):
"""Sampling points with FS_Sampling."""
assert features is not None, "feature input to FS_Sampler should not be None"
features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
features_dist = calc_square_dist(features_for_fps, features_for_fps, norm=False)
fps_idx_ffps = furthest_point_sample_with_dist(features_dist, npoint)
fps_idx_dfps = furthest_point_sample(points, npoint)
fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
return fps_idx