64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
def scatter_nd(indices, updates, shape):
|
||
|
|
"""pytorch edition of tensorflow scatter_nd.
|
||
|
|
|
||
|
|
this function don't contain except handle code. so use this carefully when
|
||
|
|
indice repeats, don't support repeat add which is supported in tensorflow.
|
||
|
|
"""
|
||
|
|
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
|
||
|
|
ndim = indices.shape[-1]
|
||
|
|
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1] :]
|
||
|
|
flatted_indices = indices.view(-1, ndim)
|
||
|
|
slices = [flatted_indices[:, i] for i in range(ndim)]
|
||
|
|
slices += [Ellipsis]
|
||
|
|
ret[slices] = updates.view(*output_shape)
|
||
|
|
return ret
|
||
|
|
|
||
|
|
|
||
|
|
class SparseConvTensor:
|
||
|
|
def __init__(self, features, indices, spatial_shape, batch_size, grid=None):
|
||
|
|
"""
|
||
|
|
Args:
|
||
|
|
grid: pre-allocated grid tensor.
|
||
|
|
should be used when the volume of spatial shape
|
||
|
|
is very large.
|
||
|
|
"""
|
||
|
|
self.features = features
|
||
|
|
self.indices = indices
|
||
|
|
if self.indices.dtype != torch.int32:
|
||
|
|
self.indices.int()
|
||
|
|
self.spatial_shape = spatial_shape
|
||
|
|
self.batch_size = batch_size
|
||
|
|
self.indice_dict = {}
|
||
|
|
self.grid = grid
|
||
|
|
|
||
|
|
@property
|
||
|
|
def spatial_size(self):
|
||
|
|
return np.prod(self.spatial_shape)
|
||
|
|
|
||
|
|
def find_indice_pair(self, key):
|
||
|
|
if key is None:
|
||
|
|
return None
|
||
|
|
if key in self.indice_dict:
|
||
|
|
return self.indice_dict[key]
|
||
|
|
return None
|
||
|
|
|
||
|
|
def dense(self, channels_first=True):
|
||
|
|
output_shape = (
|
||
|
|
[self.batch_size] + list(self.spatial_shape) + [self.features.shape[1]]
|
||
|
|
)
|
||
|
|
res = scatter_nd(self.indices.long(), self.features, output_shape)
|
||
|
|
if not channels_first:
|
||
|
|
return res
|
||
|
|
ndim = len(self.spatial_shape)
|
||
|
|
trans_params = list(range(0, ndim + 1))
|
||
|
|
trans_params.insert(1, ndim + 1)
|
||
|
|
return res.permute(*trans_params).contiguous()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def sparity(self):
|
||
|
|
return self.indices.shape[0] / np.prod(self.spatial_shape) / self.batch_size
|