You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/07/20 23:34:46 UTC

[incubator-mxnet] branch master updated: Improve test seeding in test_numpy_interoperablity.py (#18762)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6bb3d72  Improve test seeding in test_numpy_interoperablity.py (#18762)
6bb3d72 is described below

commit 6bb3d724189e1d5727fc7b2b41cab3863294ff99
Author: Dick Carter <dc...@nvidia.com>
AuthorDate: Mon Jul 20 16:33:38 2020 -0700

    Improve test seeding in test_numpy_interoperablity.py (#18762)
---
 .../python/unittest/test_numpy_interoperability.py | 29 ++++++++++++++--------
 1 file changed, 19 insertions(+), 10 deletions(-)

diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index 3b97864..d6b5595 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -29,7 +29,7 @@ from mxnet import np
 from mxnet.test_utils import assert_almost_equal
 from mxnet.test_utils import use_np
 from mxnet.test_utils import is_op_runnable
-from common import assertRaises, with_seed, random_seed
+from common import assertRaises, with_seed, random_seed, setup_module, teardown_module
 from mxnet.numpy_dispatch_protocol import with_array_function_protocol, with_array_ufunc_protocol
 from mxnet.numpy_dispatch_protocol import _NUMPY_ARRAY_FUNCTION_LIST, _NUMPY_ARRAY_UFUNC_LIST
 
@@ -62,8 +62,15 @@ class OpArgMngr(object):
 
     @staticmethod
     def get_workloads(name):
+        if OpArgMngr._args == {}:
+            _prepare_workloads()
         return OpArgMngr._args.get(name, None)
 
+    @staticmethod
+    def randomize_workloads():
+        # Force a new _prepare_workloads(), which will be based on new random numbers
+        OpArgMngr._args = {}
+
 
 def _add_workload_all():
     # check bad element in all positions
@@ -516,8 +523,8 @@ def _add_workload_linalg_cholesky():
     shapes = [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)]
     dtypes = (np.float32, np.float64)
 
-    for shape, dtype in itertools.product(shapes, dtypes):
-        with random_seed(1):
+    with random_seed(1):
+        for shape, dtype in itertools.product(shapes, dtypes):
             a = _np.random.randn(*shape)
 
         t = list(range(len(shape)))
@@ -3183,9 +3190,6 @@ def _prepare_workloads():
     _add_workload_vander()
 
 
-_prepare_workloads()
-
-
 def _get_numpy_op_output(onp_op, *args, **kwargs):
     onp_args = [arg.asnumpy() if isinstance(arg, np.ndarray) else arg for arg in args]
     onp_kwargs = {k: v.asnumpy() if isinstance(v, np.ndarray) else v for k, v in kwargs.items()}
@@ -3197,7 +3201,7 @@ def _get_numpy_op_output(onp_op, *args, **kwargs):
     return onp_op(*onp_args, **onp_kwargs)
 
 
-def _check_interoperability_helper(op_name, *args, **kwargs):
+def _check_interoperability_helper(op_name, rel_tol, abs_tol, *args, **kwargs):
     strs = op_name.split('.')
     if len(strs) == 1:
         onp_op = getattr(_np, op_name)
@@ -3213,11 +3217,11 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
         assert type(out) == type(expected_out)
         for arr, expected_arr in zip(out, expected_out):
             if isinstance(arr, np.ndarray):
-                assert_almost_equal(arr.asnumpy(), expected_arr, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True)
+                assert_almost_equal(arr.asnumpy(), expected_arr, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True)
             else:
                 _np.testing.assert_equal(arr, expected_arr)
     elif isinstance(out, np.ndarray):
-        assert_almost_equal(out.asnumpy(), expected_out, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True)
+        assert_almost_equal(out.asnumpy(), expected_out, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True)
     elif isinstance(out, _np.dtype):
         _np.testing.assert_equal(out, expected_out)
     else:
@@ -3229,6 +3233,7 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
 
 
 def check_interoperability(op_list):
+    OpArgMngr.randomize_workloads()
     for name in op_list:
         if name in _TVM_OPS and not is_op_runnable():
             continue
@@ -3240,13 +3245,17 @@ def check_interoperability(op_list):
         if name in ['full_like', 'zeros_like', 'ones_like'] and \
                 StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
             continue
+        default_tols = (1e-3, 1e-4)
+        tols = {'linalg.tensorinv': (1e-2, 5e-3),
+                'linalg.solve':     (1e-3, 5e-2)}
+        (rel_tol, abs_tol) = tols.get(name, default_tols)
         print('Dispatch test:', name)
         workloads = OpArgMngr.get_workloads(name)
         assert workloads is not None, 'Workloads for operator `{}` has not been ' \
                                       'added for checking interoperability with ' \
                                       'the official NumPy.'.format(name)
         for workload in workloads:
-            _check_interoperability_helper(name, *workload['args'], **workload['kwargs'])
+            _check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs'])
 
 
 @with_seed()