bev-project/mmdet3d/ops/spconv/ops.py

216 lines
6.4 KiB
Python
Raw Normal View History

2022-06-03 12:21:18 +08:00
# Copyright 2019 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from . import sparse_conv_ext
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size)
output_size = []
for i in range(ndim):
size = (input_size[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[
i
] + 1
if kernel_size[i] == -1:
output_size.append(1)
else:
output_size.append(size)
return output_size
def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation, output_padding):
ndim = len(input_size)
output_size = []
for i in range(ndim):
if kernel_size[i] == -1:
raise ValueError("deconv don't support kernel_size < 0")
size = (input_size[i] - 1) * stride[i] - 2 * padding[i] + kernel_size[i] + output_padding[i]
output_size.append(size)
return output_size
def get_indice_pairs(
indices,
batch_size,
spatial_shape,
ksize=3,
stride=1,
padding=0,
dilation=1,
out_padding=0,
subm=False,
transpose=False,
grid=None,
):
ndim = indices.shape[1] - 1
if not isinstance(ksize, (list, tuple)):
ksize = [ksize] * ndim
if not isinstance(stride, (list, tuple)):
stride = [stride] * ndim
if not isinstance(padding, (list, tuple)):
padding = [padding] * ndim
if not isinstance(dilation, (list, tuple)):
dilation = [dilation] * ndim
if not isinstance(out_padding, (list, tuple)):
out_padding = [out_padding] * ndim
for d, s in zip(dilation, stride):
assert any([s == 1, d == 1]), "don't support this."
if not subm:
if transpose:
out_shape = get_deconv_output_size(
spatial_shape, ksize, stride, padding, dilation, out_padding
)
else:
out_shape = get_conv_output_size(spatial_shape, ksize, stride, padding, dilation)
else:
out_shape = spatial_shape
if grid is None:
if ndim == 2:
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_2d
elif ndim == 3:
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_3d
elif ndim == 4:
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_4d
else:
raise NotImplementedError
return get_indice_pairs_func(
indices,
batch_size,
out_shape,
spatial_shape,
ksize,
stride,
padding,
dilation,
out_padding,
int(subm),
int(transpose),
)
else:
if ndim == 2:
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_grid_2d
elif ndim == 3:
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_grid_3d
else:
raise NotImplementedError
return get_indice_pairs_func(
indices,
grid,
batch_size,
out_shape,
spatial_shape,
ksize,
stride,
padding,
dilation,
out_padding,
int(subm),
int(transpose),
)
def indice_conv(
features, filters, indice_pairs, indice_pair_num, num_activate_out, inverse=False, subm=False
):
if filters.dtype == torch.float32:
return sparse_conv_ext.indice_conv_fp32(
features,
filters,
indice_pairs,
indice_pair_num,
num_activate_out,
int(inverse),
int(subm),
)
elif filters.dtype == torch.half:
return sparse_conv_ext.indice_conv_half(
features,
filters,
indice_pairs,
indice_pair_num,
num_activate_out,
int(inverse),
int(subm),
)
else:
raise NotImplementedError
def fused_indice_conv(
features, filters, bias, indice_pairs, indice_pair_num, num_activate_out, inverse, subm
):
if features.dtype == torch.half:
func = sparse_conv_ext.fused_indice_conv_half
elif filters.dtype == torch.float32:
func = sparse_conv_ext.fused_indice_conv_fp32
else:
raise NotImplementedError
return func(
features,
filters,
bias,
indice_pairs,
indice_pair_num,
num_activate_out,
int(inverse),
int(subm),
)
def indice_conv_backward(
features, filters, out_bp, indice_pairs, indice_pair_num, inverse=False, subm=False
):
if filters.dtype == torch.float32:
return sparse_conv_ext.indice_conv_backward_fp32(
features, filters, out_bp, indice_pairs, indice_pair_num, int(inverse), int(subm)
)
elif filters.dtype == torch.half:
return sparse_conv_ext.indice_conv_backward_half(
features, filters, out_bp, indice_pairs, indice_pair_num, int(inverse), int(subm)
)
else:
raise NotImplementedError
def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out):
if features.dtype == torch.float32:
return sparse_conv_ext.indice_maxpool_fp32(
features, indice_pairs, indice_pair_num, num_activate_out
)
elif features.dtype == torch.half:
return sparse_conv_ext.indice_maxpool_half(
features, indice_pairs, indice_pair_num, num_activate_out
)
else:
raise NotImplementedError
def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, indice_pair_num):
if features.dtype == torch.float32:
return sparse_conv_ext.indice_maxpool_backward_fp32(
features, out_features, out_bp, indice_pairs, indice_pair_num
)
elif features.dtype == torch.half:
return sparse_conv_ext.indice_maxpool_backward_half(
features, out_features, out_bp, indice_pairs, indice_pair_num
)
else:
raise NotImplementedError