216 lines
6.4 KiB
Python
216 lines
6.4 KiB
Python
# Copyright 2019 Yan Yan
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
|
|
from . import sparse_conv_ext
|
|
|
|
|
|
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
|
|
ndim = len(input_size)
|
|
output_size = []
|
|
for i in range(ndim):
|
|
size = (input_size[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[
|
|
i
|
|
] + 1
|
|
if kernel_size[i] == -1:
|
|
output_size.append(1)
|
|
else:
|
|
output_size.append(size)
|
|
return output_size
|
|
|
|
|
|
def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation, output_padding):
|
|
ndim = len(input_size)
|
|
output_size = []
|
|
for i in range(ndim):
|
|
if kernel_size[i] == -1:
|
|
raise ValueError("deconv don't support kernel_size < 0")
|
|
size = (input_size[i] - 1) * stride[i] - 2 * padding[i] + kernel_size[i] + output_padding[i]
|
|
output_size.append(size)
|
|
return output_size
|
|
|
|
|
|
def get_indice_pairs(
|
|
indices,
|
|
batch_size,
|
|
spatial_shape,
|
|
ksize=3,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
out_padding=0,
|
|
subm=False,
|
|
transpose=False,
|
|
grid=None,
|
|
):
|
|
ndim = indices.shape[1] - 1
|
|
if not isinstance(ksize, (list, tuple)):
|
|
ksize = [ksize] * ndim
|
|
if not isinstance(stride, (list, tuple)):
|
|
stride = [stride] * ndim
|
|
if not isinstance(padding, (list, tuple)):
|
|
padding = [padding] * ndim
|
|
if not isinstance(dilation, (list, tuple)):
|
|
dilation = [dilation] * ndim
|
|
if not isinstance(out_padding, (list, tuple)):
|
|
out_padding = [out_padding] * ndim
|
|
|
|
for d, s in zip(dilation, stride):
|
|
assert any([s == 1, d == 1]), "don't support this."
|
|
|
|
if not subm:
|
|
if transpose:
|
|
out_shape = get_deconv_output_size(
|
|
spatial_shape, ksize, stride, padding, dilation, out_padding
|
|
)
|
|
else:
|
|
out_shape = get_conv_output_size(spatial_shape, ksize, stride, padding, dilation)
|
|
|
|
else:
|
|
out_shape = spatial_shape
|
|
if grid is None:
|
|
if ndim == 2:
|
|
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_2d
|
|
elif ndim == 3:
|
|
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_3d
|
|
elif ndim == 4:
|
|
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_4d
|
|
else:
|
|
raise NotImplementedError
|
|
return get_indice_pairs_func(
|
|
indices,
|
|
batch_size,
|
|
out_shape,
|
|
spatial_shape,
|
|
ksize,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
out_padding,
|
|
int(subm),
|
|
int(transpose),
|
|
)
|
|
else:
|
|
if ndim == 2:
|
|
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_grid_2d
|
|
elif ndim == 3:
|
|
get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_grid_3d
|
|
else:
|
|
raise NotImplementedError
|
|
return get_indice_pairs_func(
|
|
indices,
|
|
grid,
|
|
batch_size,
|
|
out_shape,
|
|
spatial_shape,
|
|
ksize,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
out_padding,
|
|
int(subm),
|
|
int(transpose),
|
|
)
|
|
|
|
|
|
def indice_conv(
|
|
features, filters, indice_pairs, indice_pair_num, num_activate_out, inverse=False, subm=False
|
|
):
|
|
if filters.dtype == torch.float32:
|
|
return sparse_conv_ext.indice_conv_fp32(
|
|
features,
|
|
filters,
|
|
indice_pairs,
|
|
indice_pair_num,
|
|
num_activate_out,
|
|
int(inverse),
|
|
int(subm),
|
|
)
|
|
elif filters.dtype == torch.half:
|
|
return sparse_conv_ext.indice_conv_half(
|
|
features,
|
|
filters,
|
|
indice_pairs,
|
|
indice_pair_num,
|
|
num_activate_out,
|
|
int(inverse),
|
|
int(subm),
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def fused_indice_conv(
|
|
features, filters, bias, indice_pairs, indice_pair_num, num_activate_out, inverse, subm
|
|
):
|
|
if features.dtype == torch.half:
|
|
func = sparse_conv_ext.fused_indice_conv_half
|
|
elif filters.dtype == torch.float32:
|
|
func = sparse_conv_ext.fused_indice_conv_fp32
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return func(
|
|
features,
|
|
filters,
|
|
bias,
|
|
indice_pairs,
|
|
indice_pair_num,
|
|
num_activate_out,
|
|
int(inverse),
|
|
int(subm),
|
|
)
|
|
|
|
|
|
def indice_conv_backward(
|
|
features, filters, out_bp, indice_pairs, indice_pair_num, inverse=False, subm=False
|
|
):
|
|
if filters.dtype == torch.float32:
|
|
return sparse_conv_ext.indice_conv_backward_fp32(
|
|
features, filters, out_bp, indice_pairs, indice_pair_num, int(inverse), int(subm)
|
|
)
|
|
elif filters.dtype == torch.half:
|
|
return sparse_conv_ext.indice_conv_backward_half(
|
|
features, filters, out_bp, indice_pairs, indice_pair_num, int(inverse), int(subm)
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out):
|
|
if features.dtype == torch.float32:
|
|
return sparse_conv_ext.indice_maxpool_fp32(
|
|
features, indice_pairs, indice_pair_num, num_activate_out
|
|
)
|
|
elif features.dtype == torch.half:
|
|
return sparse_conv_ext.indice_maxpool_half(
|
|
features, indice_pairs, indice_pair_num, num_activate_out
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, indice_pair_num):
|
|
if features.dtype == torch.float32:
|
|
return sparse_conv_ext.indice_maxpool_backward_fp32(
|
|
features, out_features, out_bp, indice_pairs, indice_pair_num
|
|
)
|
|
elif features.dtype == torch.half:
|
|
return sparse_conv_ext.indice_maxpool_backward_half(
|
|
features, out_features, out_bp, indice_pairs, indice_pair_num
|
|
)
|
|
else:
|
|
raise NotImplementedError
|