You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/08 06:40:10 UTC
[incubator-mxnet] branch master updated: Fix ndarray assignment
issue with basic indexing (#10022)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 39c0fd8 Fix ndarray assignment issue with basic indexing (#10022)
39c0fd8 is described below
commit 39c0fd82312e138ef6b7f6531adb1f2fe423cb07
Author: reminisce <wu...@gmail.com>
AuthorDate: Wed Mar 7 22:40:04 2018 -0800
Fix ndarray assignment issue with basic indexing (#10022)
* Fix ndarray assignment issue with basic index
* Uncomment useful code
---
python/mxnet/ndarray/ndarray.py | 2 ++
tests/python/unittest/test_ndarray.py | 5 +++++
2 files changed, 7 insertions(+)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 5ac2796..5367845 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -695,6 +695,8 @@ fixed-size items.
# may need to broadcast first
if isinstance(value, NDArray):
if value.handle is not self.handle:
+ if value.shape != shape:
+ value = value.broadcast_to(shape)
value.copyto(self)
elif isinstance(value, numeric_types):
_internal._full(shape=shape, ctx=self.context,
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index e96fb2f..16f08b0 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -992,6 +992,8 @@ def test_ndarray_indexing():
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
if np_value is not None:
np_array[np_index] = np_value
+ elif isinstance(mx_value, mx.nd.NDArray):
+ np_array[np_index] = mx_value.asnumpy()
else:
np_array[np_index] = mx_value
mx_array[mx_index] = mx_value
@@ -1024,6 +1026,9 @@ def test_ndarray_indexing():
# test value is an numeric_type
assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0))
if len(indexed_array_shape) > 1:
+ # test NDArray with broadcast
+ assert_same(np_array, np_index, mx_array, index,
+ mx.nd.random.uniform(low=-10000, high=0, shape=(indexed_array_shape[-1],)))
# test numpy array with broadcast
assert_same(np_array, np_index, mx_array, index,
np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)))
--
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.