You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/08 06:40:10 UTC

[incubator-mxnet] branch master updated: Fix ndarray assignment issue with basic indexing (#10022)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 39c0fd8  Fix ndarray assignment issue with basic indexing (#10022)
39c0fd8 is described below

commit 39c0fd82312e138ef6b7f6531adb1f2fe423cb07
Author: reminisce <wu...@gmail.com>
AuthorDate: Wed Mar 7 22:40:04 2018 -0800

    Fix ndarray assignment issue with basic indexing (#10022)
    
    * Fix ndarray assignment issue with basic index
    
    * Uncomment useful code
---
 python/mxnet/ndarray/ndarray.py       | 2 ++
 tests/python/unittest/test_ndarray.py | 5 +++++
 2 files changed, 7 insertions(+)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 5ac2796..5367845 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -695,6 +695,8 @@ fixed-size items.
                 # may need to broadcast first
                 if isinstance(value, NDArray):
                     if value.handle is not self.handle:
+                        if value.shape != shape:
+                            value = value.broadcast_to(shape)
                         value.copyto(self)
                 elif isinstance(value, numeric_types):
                     _internal._full(shape=shape, ctx=self.context,
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index e96fb2f..16f08b0 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -992,6 +992,8 @@ def test_ndarray_indexing():
         def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
             if np_value is not None:
                 np_array[np_index] = np_value
+            elif isinstance(mx_value, mx.nd.NDArray):
+                np_array[np_index] = mx_value.asnumpy()
             else:
                 np_array[np_index] = mx_value
             mx_array[mx_index] = mx_value
@@ -1024,6 +1026,9 @@ def test_ndarray_indexing():
             # test value is an numeric_type
             assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0))
             if len(indexed_array_shape) > 1:
+                # test NDArray with broadcast
+                assert_same(np_array, np_index, mx_array, index,
+                            mx.nd.random.uniform(low=-10000, high=0, shape=(indexed_array_shape[-1],)))
                 # test numpy array with broadcast
                 assert_same(np_array, np_index, mx_array, index,
                             np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)))

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.