You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2021/05/03 11:57:34 UTC

[incubator-mxnet] branch master updated: [BUGFIX] fix numpy op fallback bug when ndarray in kwargs (#20233)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 00de7dd  [BUGFIX] fix numpy op fallback bug when ndarray in kwargs (#20233)
00de7dd is described below

commit 00de7dd6c62dfe48cf930cc2a1c1bc21817d6d6a
Author: JackieWu <wk...@live.cn>
AuthorDate: Mon May 3 19:55:43 2021 +0800

    [BUGFIX] fix numpy op fallback bug when ndarray in kwargs (#20233)
    
    Fixes #20232
---
 python/mxnet/numpy/multiarray.py       | 39 ++++++++++++++++++++--------------
 tests/python/unittest/test_numpy_op.py | 14 ++++++++++++
 2 files changed, 37 insertions(+), 16 deletions(-)

diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index dd6504e..b64c170 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -198,24 +198,32 @@ def _as_mx_np_array(object, ctx=None, zero_copy=False):
         raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object))))
 
 
-def _as_onp_array(object):
-    """Convert object to mxnet.numpy.ndarray."""
-    cur_ctx = None
+def _as_onp_array(object, cur_ctx=None):
+    """Convert object to numpy.ndarray."""
+    def _update_ctx(cur_ctx, tmp_ctx):
+        if cur_ctx is None:
+            cur_ctx = tmp_ctx
+        elif tmp_ctx is not None and cur_ctx != tmp_ctx:
+            raise ValueError('Ambiguous to set the context for the output ndarray since'  # pylint: disable=too-few-format-args
+                             ' input ndarrays are allocated on different devices: {} and {}'
+                             .format(str(cur_ctx, tmp_ctx)))
+        return cur_ctx
+
     if isinstance(object, ndarray):
         return object.asnumpy(), object.ctx
     elif isinstance(object, (list, tuple)):
         tmp = []
         for arr in object:
-            arr, tmp_ctx = _as_onp_array(arr)
-            # if isinstance(arr, (list, tuple)):
-            #     raise TypeError('type {} not supported'.format(str(type(arr))))
+            arr, tmp_ctx = _as_onp_array(arr, cur_ctx)
             tmp.append(arr)
-            if cur_ctx is None:
-                cur_ctx = tmp_ctx
-            elif tmp_ctx is not None and cur_ctx != tmp_ctx:
-                raise ValueError('Ambiguous to set the context for the output ndarray since'  # pylint: disable=too-few-format-args
-                                 ' input ndarrays are allocated on different devices: {} and {}'
-                                 .format(str(cur_ctx, tmp_ctx)))
+            cur_ctx = _update_ctx(cur_ctx, tmp_ctx)
+        return object.__class__(tmp), cur_ctx
+    elif isinstance(object, dict):
+        tmp = dict()
+        for key, value in object.items():
+            value, tmp_ctx = _as_onp_array(value, cur_ctx)
+            tmp[key] = value
+            cur_ctx = _update_ctx(cur_ctx, tmp_ctx)
         return object.__class__(tmp), cur_ctx
     else:
         return object, cur_ctx
@@ -377,13 +385,12 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
                 raise ValueError("Falling back to NumPy operator {} with autograd active is not supported."
                                  "Please consider moving the operator to the outside of the autograd scope.")\
                                  .format(func)
-            new_args, cur_ctx = _as_onp_array(args)
+            cur_ctx = None
+            new_args, cur_ctx = _as_onp_array(args, cur_ctx)
+            new_kwargs, cur_ctx = _as_onp_array(kwargs, cur_ctx)
             if cur_ctx is None:
                 raise ValueError('Unknown context for the input ndarrays. It is probably a bug. Please'
                                  ' create an issue on GitHub.')
-            new_kwargs = {}
-            for k, v in kwargs.items():
-                new_kwargs[k] = v.asnumpy() if isinstance(v, ndarray) else v
             if func not in _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD:
                 import logging
                 logging.warning("np.%s is a fallback operator, "
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 1e253f2..ba8e327 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -10349,3 +10349,17 @@ def test_broadcast_like_different_types():
     z = mx.npx.broadcast_like(x, y, 1, 1)
     assert_almost_equal(z.asnumpy(), np.array([[0,0],[0,0]]))
     assert x.dtype == z.dtype
+
+
+@use_np
+def test_np_apply_along_axis_fallback():
+    data = np.random.randint(-100, 100, (2, 3))
+    axis = 1
+    func1d = lambda x: x.mean()
+    np_y = _np.apply_along_axis(func1d, 1, data.asnumpy())
+    y1 = np.apply_along_axis(func1d, 1, data)
+    y2 = np.apply_along_axis(func1d, 1, arr=data)
+    assert_almost_equal(y1.asnumpy(), np_y)
+    assert y1.asnumpy().dtype == np_y.dtype
+    assert_almost_equal(y2.asnumpy(), np_y)
+    assert y2.asnumpy().dtype == np_y.dtype