bev-project/mmdet3d/ops/spconv/pool.py

82 lines
2.9 KiB
Python
Raw Permalink Normal View History

2022-06-03 12:21:18 +08:00
# Copyright 2019 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import functional as Fsp
from . import ops
from .modules import SparseModule
from .structure import SparseConvTensor
class SparseMaxPool(SparseModule):
def __init__(self, ndim, kernel_size, stride=1, padding=0, dilation=1, subm=False):
super(SparseMaxPool, self).__init__()
if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim
if not isinstance(stride, (list, tuple)):
stride = [stride] * ndim
if not isinstance(padding, (list, tuple)):
padding = [padding] * ndim
if not isinstance(dilation, (list, tuple)):
dilation = [dilation] * ndim
self.ndim = ndim
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.subm = subm
self.dilation = dilation
def forward(self, input):
assert isinstance(input, SparseConvTensor)
features = input.features
device = features.device
indices = input.indices
spatial_shape = input.spatial_shape
batch_size = input.batch_size
if not self.subm:
out_spatial_shape = ops.get_conv_output_size(
spatial_shape, self.kernel_size, self.stride, self.padding, self.dilation
)
else:
out_spatial_shape = spatial_shape
outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs(
indices,
batch_size,
spatial_shape,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
0,
self.subm,
)
out_features = Fsp.indice_maxpool(
features, indice_pairs.to(device), indice_pairs_num.to(device), outids.shape[0]
)
out_tensor = SparseConvTensor(out_features, outids, out_spatial_shape, batch_size)
out_tensor.indice_dict = input.indice_dict
out_tensor.grid = input.grid
return out_tensor
class SparseMaxPool2d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding, dilation)
class SparseMaxPool3d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding, dilation)