bev-project/mmdet3d/ops/spconv/test_utils.py

185 lines
6.9 KiB
Python
Raw Permalink Normal View History

2022-06-03 12:21:18 +08:00
# Copyright 2019 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
class TestCase(unittest.TestCase):
def _GetNdArray(self, a):
if not isinstance(a, np.ndarray):
a = np.array(a)
return a
def assertAllEqual(self, a, b):
"""Asserts that two numpy arrays have the same values.
Args:
a: the expected numpy ndarray or anything can be converted to one.
b: the actual numpy ndarray or anything can be converted to one.
"""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
self.assertEqual(
a.shape, b.shape, "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)
)
same = a == b
if a.dtype == np.float32 or a.dtype == np.float64:
same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
if not np.all(same):
# Prints more details than np.testing.assert_array_equal.
diff = np.logical_not(same)
if a.ndim:
x = a[np.where(diff)]
y = b[np.where(diff)]
print("not equal where = ", np.where(diff))
else:
# np.where is broken for scalars
x, y = a, b
print("not equal lhs = ", x)
print("not equal rhs = ", y)
np.testing.assert_array_equal(a, b)
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
"""Asserts that two numpy arrays, or dicts of same, have near values.
This does not support nested dicts.
Args:
a: The expected numpy ndarray (or anything can be converted to one), or
dict of same. Must be a dict iff `b` is a dict.
b: The actual numpy ndarray (or anything can be converted to one), or
dict of same. Must be a dict iff `a` is a dict.
rtol: relative tolerance.
atol: absolute tolerance.
Raises:
ValueError: if only one of `a` and `b` is a dict.
"""
is_a_dict = isinstance(a, dict)
if is_a_dict != isinstance(b, dict):
raise ValueError("Can't compare dict to non-dict, %s vs %s." % (a, b))
if is_a_dict:
self.assertCountEqual(
a.keys(),
b.keys(),
msg="mismatched keys, expected %s, got %s" % (a.keys(), b.keys()),
)
for k in a:
self._assertArrayLikeAllClose(
a[k], b[k], rtol=rtol, atol=atol, msg="%s: expected %s, got %s." % (k, a, b)
)
else:
self._assertArrayLikeAllClose(a, b, rtol=rtol, atol=atol)
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
self.assertEqual(
a.shape, b.shape, "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)
)
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Prints more details than np.testing.assert_allclose.
#
# NOTE: numpy.allclose (and numpy.testing.assert_allclose)
# checks whether two arrays are element-wise equal within a
# tolerance. The relative difference (rtol * abs(b)) and the
# absolute difference atol are added together to compare against
# the absolute difference between a and b. Here, we want to
# print out which elements violate such conditions.
cond = np.logical_or(
np.abs(a - b) > atol + rtol * np.abs(b), np.isnan(a) != np.isnan(b)
)
if a.ndim:
x = a[np.where(cond)]
y = b[np.where(cond)]
print("not close where = ", np.where(cond))
else:
# np.where is broken for scalars
x, y = a, b
print("not close lhs = ", x)
print("not close rhs = ", y)
print("not close dif = ", np.abs(x - y))
print("not close tol = ", atol + rtol * np.abs(y))
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, err_msg=msg)
def params_grid(*params):
size = len(params)
length = 1
for p in params:
length *= len(p)
sizes = [len(p) for p in params]
counter = [0] * size
total = []
for i in range(length):
total.append([0] * size)
for i in range(length):
for j in range(size):
total[i][j] = params[j][counter[j]]
counter[size - 1] += 1
for c in range(size - 1, -1, -1):
if counter[c] == sizes[c] and c > 0:
counter[c - 1] += 1
counter[c] = 0
return total
def generate_sparse_data(
shape,
num_points,
num_channels,
integer=False,
data_range=(-1, 1),
with_dense=True,
dtype=np.float32,
):
dense_shape = shape
ndim = len(dense_shape)
# num_points = np.random.randint(10, 100, size=[batch_size, ndim])
num_points = np.array(num_points)
# num_points = np.array([3, 2])
batch_size = len(num_points)
batch_indices = []
coors_total = np.stack(np.meshgrid(*[np.arange(0, s) for s in shape]), axis=-1)
coors_total = coors_total.reshape(-1, ndim)
for i in range(batch_size):
np.random.shuffle(coors_total)
inds_total = coors_total[: num_points[i]]
inds_total = np.pad(inds_total, ((0, 0), (0, 1)), mode="constant", constant_values=i)
batch_indices.append(inds_total)
if integer:
sparse_data = np.random.randint(
data_range[0], data_range[1], size=[num_points.sum(), num_channels]
).astype(dtype)
else:
sparse_data = np.random.uniform(
data_range[0], data_range[1], size=[num_points.sum(), num_channels]
).astype(dtype)
res = {
"features": sparse_data.astype(dtype),
}
if with_dense:
dense_data = np.zeros([batch_size, num_channels, *dense_shape], dtype=sparse_data.dtype)
start = 0
for i, inds in enumerate(batch_indices):
for j, ind in enumerate(inds):
dense_slice = (i, slice(None), *ind[:-1])
dense_data[dense_slice] = sparse_data[start + j]
start += len(inds)
res["features_dense"] = dense_data.astype(dtype)
batch_indices = np.concatenate(batch_indices, axis=0)
res["indices"] = batch_indices.astype(np.int32)
return res