You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/03 17:16:58 UTC

[incubator-mxnet] branch master updated: add reverse option to ndarray inplace reshape (#10767)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 66b2944  add reverse option to ndarray inplace reshape (#10767)
66b2944 is described below

commit 66b294434aeffa9ed3f1cf01416345549139bf23
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Thu May 3 10:16:48 2018 -0700

    add reverse option to ndarray inplace reshape (#10767)
    
    * add reverse option to ndarray inplace reshape
    
    * update check
---
 include/mxnet/c_api.h                 |  1 +
 python/mxnet/ndarray/ndarray.py       | 26 +++++++++++++++++------
 src/c_api/c_api.cc                    |  3 ++-
 tests/python/unittest/test_ndarray.py | 39 ++++++++++++++---------------------
 4 files changed, 39 insertions(+), 30 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 3f04051..9ac90d6 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -663,6 +663,7 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
 MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
                                  int ndim,
                                  dim_t *dims,
+                                 bool reverse,
                                  NDArrayHandle *out);
 /*!
  * \brief get the shape of the array
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 6b2ff23..2411932 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -989,6 +989,19 @@ fixed-size items.
               - input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4)
               - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)
 
+            - If the argument `reverse` is set to 1, then the special values are inferred from right
+              to left.
+
+              Example::
+
+              - without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be
+                (40,5).
+              - with reverse=1, output shape will be (50,4).
+
+        reverse : bool, default False
+            If true then the special values are inferred from right to left. Only supported as
+            keyword argument.
+
 
         Returns
         -------
@@ -1029,18 +1042,19 @@ fixed-size items.
         elif not shape:
             shape = kwargs.get('shape')
             assert shape, "Shape must be provided."
-            if len(kwargs) != 1:
-                raise TypeError("Only 'shape' is supported as keyword argument. Got: {}."
-                                .format(', '.join(kwargs.keys())))
-        else:
-            assert not kwargs,\
-                "Specifying both positional and keyword arguments is not allowed in reshape."
+        if not all(k in ['shape', 'reverse'] for k in kwargs):
+            raise TypeError(
+                "Got unknown keywords in reshape: {}. " \
+                "Accepted keyword arguments are 'shape' and 'reverse'.".format(
+                    ', '.join([k for k in kwargs if k not in ['shape', 'reverse']])))
+        reverse = kwargs.get('reverse', False)
         handle = NDArrayHandle()
 
         # Actual reshape
         check_call(_LIB.MXNDArrayReshape64(self.handle,
                                            len(shape),
                                            c_array(ctypes.c_int64, shape),
+                                           reverse,
                                            ctypes.byref(handle)))
         return NDArray(handle=handle, writable=self.writable)
 
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 34b4fd2..b3dcd6a 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -431,12 +431,13 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
 MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
                                  int ndim,
                                  dim_t *dims,
+                                 bool reverse,
                                  NDArrayHandle *out) {
   NDArray *ptr = new NDArray();
   API_BEGIN();
   NDArray *arr = static_cast<NDArray*>(handle);
   nnvm::Tuple<dim_t> shape(dims, dims+ndim);
-  TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), false);
+  TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
   *ptr = arr->ReshapeWithRecord(new_shape);
   *out = ptr;
   API_END_HANDLE_ERROR(delete ptr);
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 030816e..9ff2f1a 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -154,30 +154,23 @@ def test_ndarray_negate():
 
 @with_seed()
 def test_ndarray_reshape():
-    tensor  = mx.nd.array([[[1, 2], [3, 4]],
-                           [[5, 6], [7, 8]]])
-    true_res = mx.nd.arange(8) + 1
-    assert same(tensor.reshape((-1, )).asnumpy(), true_res.asnumpy())
-    true_res  = mx.nd.array([[1, 2, 3, 4],
-                             [5, 6, 7, 8]])
-    assert same(tensor.reshape((2, -1)).asnumpy(), true_res.asnumpy())
-    assert same(tensor.reshape((0, -1)).asnumpy(), true_res.asnumpy())
-    true_res  = mx.nd.array([[1, 2],
-                             [3, 4],
-                             [5, 6],
-                             [7, 8]])
-    assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.asnumpy())
-    assert same(tensor.reshape(4, 2).asnumpy(), true_res.asnumpy())
-    assert same(tensor.reshape(-1, 2).asnumpy(), true_res.asnumpy())
-    true_res = mx.nd.arange(8) + 1
+    tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
+    true_res = mx.nd.arange(30) + 1
+    assert same(tensor.reshape((-1,)).asnumpy(), true_res.asnumpy())
+    assert same(tensor.reshape((2, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
+    assert same(tensor.reshape((0, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
+    assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.reshape(15, 2).asnumpy())
+    assert same(tensor.reshape(6, 5).asnumpy(), true_res.reshape(6, 5).asnumpy())
+    assert same(tensor.reshape(-1, 2).asnumpy(), true_res.reshape(15, 2).asnumpy())
     assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
-    assert same(tensor.reshape(8).asnumpy(), true_res.asnumpy())
-
-    assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2, 4).asnumpy())
-    assert same(tensor.reshape(-1, 4).asnumpy(), true_res.reshape(2, 4).asnumpy())
-    assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 2, 2).asnumpy())
-    assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(4, 2).asnumpy())
-    assert same(tensor.reshape(-1, 4).reshape(0, -4, 2, -1).asnumpy(), true_res.reshape(2, 2, 2).asnumpy())
+    assert same(tensor.reshape(30).asnumpy(), true_res.asnumpy())
+    assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2, 15).asnumpy())
+    assert same(tensor.reshape(-1, 6).asnumpy(), true_res.reshape(5, 6).asnumpy())
+    assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
+    assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(6, 5).asnumpy())
+    assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
+    assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
+    assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())
 
 
 @with_seed()

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.