You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2020/09/06 12:31:47 UTC
[incubator-mxnet] branch master updated: Assure NDArray.reshape
does not change the array size (#19078)
This is an automated email from the ASF dual-hosted git repository.
marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 62b7f03 Assure NDArray.reshape does not change the array size (#19078)
62b7f03 is described below
commit 62b7f030b7d32aa021afb092036dec6175b090ae
Author: r3stl355 <ul...@hotmail.com>
AuthorDate: Sun Sep 6 13:30:57 2020 +0100
Assure NDArray.reshape does not change the array size (#19078)
Co-authored-by: r3stl355 <ul...@amazon.com>
---
python/mxnet/ndarray/ndarray.py | 7 ++++++-
tests/python/unittest/test_ndarray.py | 3 ++-
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 0f638a1..a6eae05 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1546,7 +1546,12 @@ fixed-size items.
c_array(ctypes.c_int64, shape),
reverse,
ctypes.byref(handle)))
- return self.__class__(handle=handle, writable=self.writable)
+ res = self.__class__(handle=handle, writable=self.writable)
+
+ # Array size should not change
+ if np.prod(res.shape) != np.prod(self.shape):
+ raise ValueError('Cannot reshape array of size {} into shape {}'.format(np.prod(self.shape), shape))
+ return res
def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index a01746e..9e80d48 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -237,7 +237,8 @@ def test_ndarray_reshape():
assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())
-
+ # https://github.com/apache/incubator-mxnet/issues/18886
+ assertRaises(ValueError, tensor.reshape, (2, 3))
@with_seed()
def test_ndarray_flatten():