You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2020/09/06 12:31:47 UTC

[incubator-mxnet] branch master updated: Assure NDArray.reshape does not change the array size (#19078)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 62b7f03  Assure NDArray.reshape does not change the array size (#19078)
62b7f03 is described below

commit 62b7f030b7d32aa021afb092036dec6175b090ae
Author: r3stl355 <ul...@hotmail.com>
AuthorDate: Sun Sep 6 13:30:57 2020 +0100

    Assure NDArray.reshape does not change the array size (#19078)
    
    Co-authored-by: r3stl355 <ul...@amazon.com>
---
 python/mxnet/ndarray/ndarray.py       | 7 ++++++-
 tests/python/unittest/test_ndarray.py | 3 ++-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 0f638a1..a6eae05 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1546,7 +1546,12 @@ fixed-size items.
                                            c_array(ctypes.c_int64, shape),
                                            reverse,
                                            ctypes.byref(handle)))
-        return self.__class__(handle=handle, writable=self.writable)
+        res = self.__class__(handle=handle, writable=self.writable)
+
+        # Array size should not change
+        if np.prod(res.shape) != np.prod(self.shape):
+            raise ValueError('Cannot reshape array of size {} into shape {}'.format(np.prod(self.shape), shape))
+        return res
 
     def reshape_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`reshape_like`.
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index a01746e..9e80d48 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -237,7 +237,8 @@ def test_ndarray_reshape():
     assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
     assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
     assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())
-
+    # https://github.com/apache/incubator-mxnet/issues/18886
+    assertRaises(ValueError, tensor.reshape, (2, 3))
 
 @with_seed()
 def test_ndarray_flatten():