You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2021/05/03 11:57:34 UTC
[incubator-mxnet] branch master updated: [BUGFIX] fix numpy op
fallback bug when ndarray in kwargs (#20233)
This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 00de7dd [BUGFIX] fix numpy op fallback bug when ndarray in kwargs (#20233)
00de7dd is described below
commit 00de7dd6c62dfe48cf930cc2a1c1bc21817d6d6a
Author: JackieWu <wk...@live.cn>
AuthorDate: Mon May 3 19:55:43 2021 +0800
[BUGFIX] fix numpy op fallback bug when ndarray in kwargs (#20233)
Fixes #20232
---
python/mxnet/numpy/multiarray.py | 39 ++++++++++++++++++++--------------
tests/python/unittest/test_numpy_op.py | 14 ++++++++++++
2 files changed, 37 insertions(+), 16 deletions(-)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index dd6504e..b64c170 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -198,24 +198,32 @@ def _as_mx_np_array(object, ctx=None, zero_copy=False):
raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object))))
-def _as_onp_array(object):
- """Convert object to mxnet.numpy.ndarray."""
- cur_ctx = None
+def _as_onp_array(object, cur_ctx=None):
+ """Convert object to numpy.ndarray."""
+ def _update_ctx(cur_ctx, tmp_ctx):
+ if cur_ctx is None:
+ cur_ctx = tmp_ctx
+ elif tmp_ctx is not None and cur_ctx != tmp_ctx:
+ raise ValueError('Ambiguous to set the context for the output ndarray since' # pylint: disable=too-few-format-args
+ ' input ndarrays are allocated on different devices: {} and {}'
+ .format(str(cur_ctx, tmp_ctx)))
+ return cur_ctx
+
if isinstance(object, ndarray):
return object.asnumpy(), object.ctx
elif isinstance(object, (list, tuple)):
tmp = []
for arr in object:
- arr, tmp_ctx = _as_onp_array(arr)
- # if isinstance(arr, (list, tuple)):
- # raise TypeError('type {} not supported'.format(str(type(arr))))
+ arr, tmp_ctx = _as_onp_array(arr, cur_ctx)
tmp.append(arr)
- if cur_ctx is None:
- cur_ctx = tmp_ctx
- elif tmp_ctx is not None and cur_ctx != tmp_ctx:
- raise ValueError('Ambiguous to set the context for the output ndarray since' # pylint: disable=too-few-format-args
- ' input ndarrays are allocated on different devices: {} and {}'
- .format(str(cur_ctx, tmp_ctx)))
+ cur_ctx = _update_ctx(cur_ctx, tmp_ctx)
+ return object.__class__(tmp), cur_ctx
+ elif isinstance(object, dict):
+ tmp = dict()
+ for key, value in object.items():
+ value, tmp_ctx = _as_onp_array(value, cur_ctx)
+ tmp[key] = value
+ cur_ctx = _update_ctx(cur_ctx, tmp_ctx)
return object.__class__(tmp), cur_ctx
else:
return object, cur_ctx
@@ -377,13 +385,12 @@ class ndarray(NDArray): # pylint: disable=invalid-name
raise ValueError("Falling back to NumPy operator {} with autograd active is not supported."
"Please consider moving the operator to the outside of the autograd scope.")\
.format(func)
- new_args, cur_ctx = _as_onp_array(args)
+ cur_ctx = None
+ new_args, cur_ctx = _as_onp_array(args, cur_ctx)
+ new_kwargs, cur_ctx = _as_onp_array(kwargs, cur_ctx)
if cur_ctx is None:
raise ValueError('Unknown context for the input ndarrays. It is probably a bug. Please'
' create an issue on GitHub.')
- new_kwargs = {}
- for k, v in kwargs.items():
- new_kwargs[k] = v.asnumpy() if isinstance(v, ndarray) else v
if func not in _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD:
import logging
logging.warning("np.%s is a fallback operator, "
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 1e253f2..ba8e327 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -10349,3 +10349,17 @@ def test_broadcast_like_different_types():
z = mx.npx.broadcast_like(x, y, 1, 1)
assert_almost_equal(z.asnumpy(), np.array([[0,0],[0,0]]))
assert x.dtype == z.dtype
+
+
+@use_np
+def test_np_apply_along_axis_fallback():
+ data = np.random.randint(-100, 100, (2, 3))
+ axis = 1
+ func1d = lambda x: x.mean()
+ np_y = _np.apply_along_axis(func1d, 1, data.asnumpy())
+ y1 = np.apply_along_axis(func1d, 1, data)
+ y2 = np.apply_along_axis(func1d, 1, arr=data)
+ assert_almost_equal(y1.asnumpy(), np_y)
+ assert y1.asnumpy().dtype == np_y.dtype
+ assert_almost_equal(y2.asnumpy(), np_y)
+ assert y2.asnumpy().dtype == np_y.dtype