You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/03 17:16:58 UTC
[incubator-mxnet] branch master updated: add reverse option to
ndarray inplace reshape (#10767)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 66b2944 add reverse option to ndarray inplace reshape (#10767)
66b2944 is described below
commit 66b294434aeffa9ed3f1cf01416345549139bf23
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Thu May 3 10:16:48 2018 -0700
add reverse option to ndarray inplace reshape (#10767)
* add reverse option to ndarray inplace reshape
* update check
---
include/mxnet/c_api.h | 1 +
python/mxnet/ndarray/ndarray.py | 26 +++++++++++++++++------
src/c_api/c_api.cc | 3 ++-
tests/python/unittest/test_ndarray.py | 39 ++++++++++++++---------------------
4 files changed, 39 insertions(+), 30 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 3f04051..9ac90d6 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -663,6 +663,7 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
int ndim,
dim_t *dims,
+ bool reverse,
NDArrayHandle *out);
/*!
* \brief get the shape of the array
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 6b2ff23..2411932 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -989,6 +989,19 @@ fixed-size items.
- input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4)
- input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)
+ - If the argument `reverse` is set to 1, then the special values are inferred from right
+ to left.
+
+ Example::
+
+ - without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be
+ (40,5).
+ - with reverse=1, output shape will be (50,4).
+
+ reverse : bool, default False
+ If true then the special values are inferred from right to left. Only supported as
+ keyword argument.
+
Returns
-------
@@ -1029,18 +1042,19 @@ fixed-size items.
elif not shape:
shape = kwargs.get('shape')
assert shape, "Shape must be provided."
- if len(kwargs) != 1:
- raise TypeError("Only 'shape' is supported as keyword argument. Got: {}."
- .format(', '.join(kwargs.keys())))
- else:
- assert not kwargs,\
- "Specifying both positional and keyword arguments is not allowed in reshape."
+ if not all(k in ['shape', 'reverse'] for k in kwargs):
+ raise TypeError(
+ "Got unknown keywords in reshape: {}. " \
+ "Accepted keyword arguments are 'shape' and 'reverse'.".format(
+ ', '.join([k for k in kwargs if k not in ['shape', 'reverse']])))
+ reverse = kwargs.get('reverse', False)
handle = NDArrayHandle()
# Actual reshape
check_call(_LIB.MXNDArrayReshape64(self.handle,
len(shape),
c_array(ctypes.c_int64, shape),
+ reverse,
ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 34b4fd2..b3dcd6a 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -431,12 +431,13 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
int ndim,
dim_t *dims,
+ bool reverse,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
- TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), false);
+ TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
*ptr = arr->ReshapeWithRecord(new_shape);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 030816e..9ff2f1a 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -154,30 +154,23 @@ def test_ndarray_negate():
@with_seed()
def test_ndarray_reshape():
- tensor = mx.nd.array([[[1, 2], [3, 4]],
- [[5, 6], [7, 8]]])
- true_res = mx.nd.arange(8) + 1
- assert same(tensor.reshape((-1, )).asnumpy(), true_res.asnumpy())
- true_res = mx.nd.array([[1, 2, 3, 4],
- [5, 6, 7, 8]])
- assert same(tensor.reshape((2, -1)).asnumpy(), true_res.asnumpy())
- assert same(tensor.reshape((0, -1)).asnumpy(), true_res.asnumpy())
- true_res = mx.nd.array([[1, 2],
- [3, 4],
- [5, 6],
- [7, 8]])
- assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.asnumpy())
- assert same(tensor.reshape(4, 2).asnumpy(), true_res.asnumpy())
- assert same(tensor.reshape(-1, 2).asnumpy(), true_res.asnumpy())
- true_res = mx.nd.arange(8) + 1
+ tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
+ true_res = mx.nd.arange(30) + 1
+ assert same(tensor.reshape((-1,)).asnumpy(), true_res.asnumpy())
+ assert same(tensor.reshape((2, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
+ assert same(tensor.reshape((0, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
+ assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.reshape(15, 2).asnumpy())
+ assert same(tensor.reshape(6, 5).asnumpy(), true_res.reshape(6, 5).asnumpy())
+ assert same(tensor.reshape(-1, 2).asnumpy(), true_res.reshape(15, 2).asnumpy())
assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
- assert same(tensor.reshape(8).asnumpy(), true_res.asnumpy())
-
- assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2, 4).asnumpy())
- assert same(tensor.reshape(-1, 4).asnumpy(), true_res.reshape(2, 4).asnumpy())
- assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 2, 2).asnumpy())
- assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(4, 2).asnumpy())
- assert same(tensor.reshape(-1, 4).reshape(0, -4, 2, -1).asnumpy(), true_res.reshape(2, 2, 2).asnumpy())
+ assert same(tensor.reshape(30).asnumpy(), true_res.asnumpy())
+ assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2, 15).asnumpy())
+ assert same(tensor.reshape(-1, 6).asnumpy(), true_res.reshape(5, 6).asnumpy())
+ assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
+ assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(6, 5).asnumpy())
+ assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
+ assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
+ assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())
@with_seed()
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.