You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/07/20 23:34:46 UTC
[incubator-mxnet] branch master updated: Improve test seeding in
test_numpy_interoperablity.py (#18762)
This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6bb3d72 Improve test seeding in test_numpy_interoperablity.py (#18762)
6bb3d72 is described below
commit 6bb3d724189e1d5727fc7b2b41cab3863294ff99
Author: Dick Carter <dc...@nvidia.com>
AuthorDate: Mon Jul 20 16:33:38 2020 -0700
Improve test seeding in test_numpy_interoperablity.py (#18762)
---
.../python/unittest/test_numpy_interoperability.py | 29 ++++++++++++++--------
1 file changed, 19 insertions(+), 10 deletions(-)
diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index 3b97864..d6b5595 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -29,7 +29,7 @@ from mxnet import np
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import use_np
from mxnet.test_utils import is_op_runnable
-from common import assertRaises, with_seed, random_seed
+from common import assertRaises, with_seed, random_seed, setup_module, teardown_module
from mxnet.numpy_dispatch_protocol import with_array_function_protocol, with_array_ufunc_protocol
from mxnet.numpy_dispatch_protocol import _NUMPY_ARRAY_FUNCTION_LIST, _NUMPY_ARRAY_UFUNC_LIST
@@ -62,8 +62,15 @@ class OpArgMngr(object):
@staticmethod
def get_workloads(name):
+ if OpArgMngr._args == {}:
+ _prepare_workloads()
return OpArgMngr._args.get(name, None)
+ @staticmethod
+ def randomize_workloads():
+ # Force a new _prepare_workloads(), which will be based on new random numbers
+ OpArgMngr._args = {}
+
def _add_workload_all():
# check bad element in all positions
@@ -516,8 +523,8 @@ def _add_workload_linalg_cholesky():
shapes = [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)]
dtypes = (np.float32, np.float64)
- for shape, dtype in itertools.product(shapes, dtypes):
- with random_seed(1):
+ with random_seed(1):
+ for shape, dtype in itertools.product(shapes, dtypes):
a = _np.random.randn(*shape)
t = list(range(len(shape)))
@@ -3183,9 +3190,6 @@ def _prepare_workloads():
_add_workload_vander()
-_prepare_workloads()
-
-
def _get_numpy_op_output(onp_op, *args, **kwargs):
onp_args = [arg.asnumpy() if isinstance(arg, np.ndarray) else arg for arg in args]
onp_kwargs = {k: v.asnumpy() if isinstance(v, np.ndarray) else v for k, v in kwargs.items()}
@@ -3197,7 +3201,7 @@ def _get_numpy_op_output(onp_op, *args, **kwargs):
return onp_op(*onp_args, **onp_kwargs)
-def _check_interoperability_helper(op_name, *args, **kwargs):
+def _check_interoperability_helper(op_name, rel_tol, abs_tol, *args, **kwargs):
strs = op_name.split('.')
if len(strs) == 1:
onp_op = getattr(_np, op_name)
@@ -3213,11 +3217,11 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
assert type(out) == type(expected_out)
for arr, expected_arr in zip(out, expected_out):
if isinstance(arr, np.ndarray):
- assert_almost_equal(arr.asnumpy(), expected_arr, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True)
+ assert_almost_equal(arr.asnumpy(), expected_arr, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True)
else:
_np.testing.assert_equal(arr, expected_arr)
elif isinstance(out, np.ndarray):
- assert_almost_equal(out.asnumpy(), expected_out, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True)
+ assert_almost_equal(out.asnumpy(), expected_out, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True)
elif isinstance(out, _np.dtype):
_np.testing.assert_equal(out, expected_out)
else:
@@ -3229,6 +3233,7 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
def check_interoperability(op_list):
+ OpArgMngr.randomize_workloads()
for name in op_list:
if name in _TVM_OPS and not is_op_runnable():
continue
@@ -3240,13 +3245,17 @@ def check_interoperability(op_list):
if name in ['full_like', 'zeros_like', 'ones_like'] and \
StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
continue
+ default_tols = (1e-3, 1e-4)
+ tols = {'linalg.tensorinv': (1e-2, 5e-3),
+ 'linalg.solve': (1e-3, 5e-2)}
+ (rel_tol, abs_tol) = tols.get(name, default_tols)
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
'added for checking interoperability with ' \
'the official NumPy.'.format(name)
for workload in workloads:
- _check_interoperability_helper(name, *workload['args'], **workload['kwargs'])
+ _check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs'])
@with_seed()