185 lines
6.9 KiB
Python
185 lines
6.9 KiB
Python
# Copyright 2019 Yan Yan
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import numpy as np
|
|
import unittest
|
|
|
|
|
|
class TestCase(unittest.TestCase):
|
|
def _GetNdArray(self, a):
|
|
if not isinstance(a, np.ndarray):
|
|
a = np.array(a)
|
|
return a
|
|
|
|
def assertAllEqual(self, a, b):
|
|
"""Asserts that two numpy arrays have the same values.
|
|
|
|
Args:
|
|
a: the expected numpy ndarray or anything can be converted to one.
|
|
b: the actual numpy ndarray or anything can be converted to one.
|
|
"""
|
|
a = self._GetNdArray(a)
|
|
b = self._GetNdArray(b)
|
|
self.assertEqual(
|
|
a.shape, b.shape, "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)
|
|
)
|
|
same = a == b
|
|
|
|
if a.dtype == np.float32 or a.dtype == np.float64:
|
|
same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
|
|
if not np.all(same):
|
|
# Prints more details than np.testing.assert_array_equal.
|
|
diff = np.logical_not(same)
|
|
if a.ndim:
|
|
x = a[np.where(diff)]
|
|
y = b[np.where(diff)]
|
|
print("not equal where = ", np.where(diff))
|
|
else:
|
|
# np.where is broken for scalars
|
|
x, y = a, b
|
|
print("not equal lhs = ", x)
|
|
print("not equal rhs = ", y)
|
|
np.testing.assert_array_equal(a, b)
|
|
|
|
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
|
|
"""Asserts that two numpy arrays, or dicts of same, have near values.
|
|
|
|
This does not support nested dicts.
|
|
Args:
|
|
a: The expected numpy ndarray (or anything can be converted to one), or
|
|
dict of same. Must be a dict iff `b` is a dict.
|
|
b: The actual numpy ndarray (or anything can be converted to one), or
|
|
dict of same. Must be a dict iff `a` is a dict.
|
|
rtol: relative tolerance.
|
|
atol: absolute tolerance.
|
|
Raises:
|
|
ValueError: if only one of `a` and `b` is a dict.
|
|
"""
|
|
is_a_dict = isinstance(a, dict)
|
|
if is_a_dict != isinstance(b, dict):
|
|
raise ValueError("Can't compare dict to non-dict, %s vs %s." % (a, b))
|
|
if is_a_dict:
|
|
self.assertCountEqual(
|
|
a.keys(),
|
|
b.keys(),
|
|
msg="mismatched keys, expected %s, got %s" % (a.keys(), b.keys()),
|
|
)
|
|
for k in a:
|
|
self._assertArrayLikeAllClose(
|
|
a[k], b[k], rtol=rtol, atol=atol, msg="%s: expected %s, got %s." % (k, a, b)
|
|
)
|
|
else:
|
|
self._assertArrayLikeAllClose(a, b, rtol=rtol, atol=atol)
|
|
|
|
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
|
|
a = self._GetNdArray(a)
|
|
b = self._GetNdArray(b)
|
|
self.assertEqual(
|
|
a.shape, b.shape, "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)
|
|
)
|
|
if not np.allclose(a, b, rtol=rtol, atol=atol):
|
|
# Prints more details than np.testing.assert_allclose.
|
|
#
|
|
# NOTE: numpy.allclose (and numpy.testing.assert_allclose)
|
|
# checks whether two arrays are element-wise equal within a
|
|
# tolerance. The relative difference (rtol * abs(b)) and the
|
|
# absolute difference atol are added together to compare against
|
|
# the absolute difference between a and b. Here, we want to
|
|
# print out which elements violate such conditions.
|
|
cond = np.logical_or(
|
|
np.abs(a - b) > atol + rtol * np.abs(b), np.isnan(a) != np.isnan(b)
|
|
)
|
|
if a.ndim:
|
|
x = a[np.where(cond)]
|
|
y = b[np.where(cond)]
|
|
print("not close where = ", np.where(cond))
|
|
else:
|
|
# np.where is broken for scalars
|
|
x, y = a, b
|
|
print("not close lhs = ", x)
|
|
print("not close rhs = ", y)
|
|
print("not close dif = ", np.abs(x - y))
|
|
print("not close tol = ", atol + rtol * np.abs(y))
|
|
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
|
|
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, err_msg=msg)
|
|
|
|
|
|
def params_grid(*params):
|
|
size = len(params)
|
|
length = 1
|
|
for p in params:
|
|
length *= len(p)
|
|
sizes = [len(p) for p in params]
|
|
counter = [0] * size
|
|
total = []
|
|
for i in range(length):
|
|
total.append([0] * size)
|
|
for i in range(length):
|
|
for j in range(size):
|
|
total[i][j] = params[j][counter[j]]
|
|
counter[size - 1] += 1
|
|
for c in range(size - 1, -1, -1):
|
|
if counter[c] == sizes[c] and c > 0:
|
|
counter[c - 1] += 1
|
|
counter[c] = 0
|
|
return total
|
|
|
|
|
|
def generate_sparse_data(
|
|
shape,
|
|
num_points,
|
|
num_channels,
|
|
integer=False,
|
|
data_range=(-1, 1),
|
|
with_dense=True,
|
|
dtype=np.float32,
|
|
):
|
|
dense_shape = shape
|
|
ndim = len(dense_shape)
|
|
# num_points = np.random.randint(10, 100, size=[batch_size, ndim])
|
|
num_points = np.array(num_points)
|
|
# num_points = np.array([3, 2])
|
|
batch_size = len(num_points)
|
|
batch_indices = []
|
|
coors_total = np.stack(np.meshgrid(*[np.arange(0, s) for s in shape]), axis=-1)
|
|
coors_total = coors_total.reshape(-1, ndim)
|
|
for i in range(batch_size):
|
|
np.random.shuffle(coors_total)
|
|
inds_total = coors_total[: num_points[i]]
|
|
inds_total = np.pad(inds_total, ((0, 0), (0, 1)), mode="constant", constant_values=i)
|
|
batch_indices.append(inds_total)
|
|
if integer:
|
|
sparse_data = np.random.randint(
|
|
data_range[0], data_range[1], size=[num_points.sum(), num_channels]
|
|
).astype(dtype)
|
|
else:
|
|
sparse_data = np.random.uniform(
|
|
data_range[0], data_range[1], size=[num_points.sum(), num_channels]
|
|
).astype(dtype)
|
|
|
|
res = {
|
|
"features": sparse_data.astype(dtype),
|
|
}
|
|
if with_dense:
|
|
dense_data = np.zeros([batch_size, num_channels, *dense_shape], dtype=sparse_data.dtype)
|
|
start = 0
|
|
for i, inds in enumerate(batch_indices):
|
|
for j, ind in enumerate(inds):
|
|
dense_slice = (i, slice(None), *ind[:-1])
|
|
dense_data[dense_slice] = sparse_data[start + j]
|
|
start += len(inds)
|
|
res["features_dense"] = dense_data.astype(dtype)
|
|
batch_indices = np.concatenate(batch_indices, axis=0)
|
|
res["indices"] = batch_indices.astype(np.int32)
|
|
return res
|