You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:47 UTC

[incubator-mxnet] 11/42: Numpy-compatible stack (#15027)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 257f9f4464350248bf9e2760521d3aaae322c23b
Author: Hao Jin <hj...@gmail.com>
AuthorDate: Fri May 31 14:49:34 2019 -0700

    Numpy-compatible stack (#15027)
    
    * numpy stack
    
    * migrate to use_np_shape
---
 python/mxnet/ndarray/numpy/_op.py      | 32 ++++++++++++++++++++-
 python/mxnet/numpy/multiarray.py       | 26 ++++++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py   | 32 ++++++++++++++++++++-
 src/imperative/imperative.cc           |  4 ++-
 src/operator/numpy/np_matrix_op.cc     | 40 ++++++++++++++++++++++++++
 src/operator/numpy/np_matrix_op.cu     |  3 ++
 tests/python/unittest/test_numpy_op.py | 51 ++++++++++++++++++++++++++++++++++
 7 files changed, 184 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 72b890d..76825f1 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -24,7 +24,7 @@ from ...util import _sanity_check_params, set_module
 from ...context import current_context
 from . import _internal as _npi
 
-__all__ = ['zeros', 'ones', 'maximum', 'minimum']
+__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -171,3 +171,33 @@ def minimum(x1, x2, out=None):
     out : mxnet.numpy.ndarray or scalar
         The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
     return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
+
+
+@set_module('mxnet.ndarray.numpy')
+def stack(arrays, axis=0, out=None):
+    """Join a sequence of arrays along a new axis.
+
+        The axis parameter specifies the index of the new axis in the dimensions of the result.
+        For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension.
+
+    Parameters
+    ----------
+    arrays : sequence of array_like
+        Each array must have the same shape.
+    axis : int, optional
+        The axis in the result array along which the input arrays are stacked.
+    out : ndarray, optional
+        If provided, the destination to place the result. The shape must be correct,
+        matching that of what stack would have returned if no out argument were specified.
+
+    Returns
+    -------
+    stacked : ndarray
+        The stacked array has one more dimension than the input arrays."""
+    def get_list(arrays):
+        if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
+            raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
+        return [arr for arr in arrays]
+
+    arrays = get_list(arrays)
+    return _npi.stack(*arrays, axis=axis, out=out)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index e9afd23..da7e61e 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -35,7 +35,7 @@ from ..context import current_context
 from ..ndarray import numpy as _mx_nd_np
 from ..ndarray.numpy import _internal as _npi
 
-__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum']
+__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack']
 
 
 # This function is copied from ndarray.py since pylint
@@ -1305,3 +1305,27 @@ def minimum(x1, x2, out=None):
     out : mxnet.numpy.ndarray or scalar
         The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
     return _mx_nd_np.minimum(x1, x2, out=out)
+
+
+@set_module('mxnet.numpy')
+def stack(arrays, axis=0, out=None):
+    """Join a sequence of arrays along a new axis.
+
+        The axis parameter specifies the index of the new axis in the dimensions of the result.
+        For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension.
+
+    Parameters
+    ----------
+    arrays : sequence of array_like
+        Each array must have the same shape.
+    axis : int, optional
+        The axis in the result array along which the input arrays are stacked.
+    out : ndarray, optional
+        If provided, the destination to place the result. The shape must be correct,
+        matching that of what stack would have returned if no out argument were specified.
+
+    Returns
+    -------
+    stacked : ndarray
+        The stacked array has one more dimension than the input arrays."""
+    return _mx_nd_np.stack(arrays, axis=axis, out=out)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 6a03cdb..d55a878 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -29,7 +29,7 @@ from ..symbol import Symbol
 from .._internal import _set_np_symbol_class
 from . import _internal as _npi
 
-__all__ = ['zeros', 'ones', 'maximum', 'minimum']
+__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack']
 
 
 @set_module('mxnet.symbol.numpy')
@@ -1000,4 +1000,34 @@ def minimum(x1, x2, out=None):
     return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
 
 
+@set_module('mxnet.symbol.numpy')
+def stack(arrays, axis=0, out=None):
+    """Join a sequence of arrays along a new axis.
+
+        The axis parameter specifies the index of the new axis in the dimensions of the result.
+        For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension.
+
+    Parameters
+    ----------
+    arrays : sequence of array_like
+        Each array must have the same shape.
+    axis : int, optional
+        The axis in the result array along which the input arrays are stacked.
+    out : ndarray, optional
+        If provided, the destination to place the result. The shape must be correct,
+        matching that of what stack would have returned if no out argument were specified.
+
+    Returns
+    -------
+    stacked : ndarray
+        The stacked array has one more dimension than the input arrays."""
+    def get_list(arrays):
+        if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
+            raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
+        return [arr for arr in arrays]
+
+    arrays = get_list(arrays)
+    return _npi.stack(*arrays, axis=axis, out=out)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index e2c0c9d..c00021c 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -313,7 +313,9 @@ std::vector<NDArray*> Imperative::Backward(
     } else {
       info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(),
                                 true, outputs[i]->dtype());
-      info.outputs.back() = static_cast<real_t>(1.0);
+      if (info.outputs.back().shape().Size() != 0) {
+        info.outputs.back() = static_cast<real_t>(1.0);
+      }
     }
   }
 
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 6e93442..db479a0 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -212,5 +212,45 @@ NNVM_REGISTER_OP(_np_reshape)
 .add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.")
 .add_arguments(NumpyReshapeParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_npi_stack)
+.describe(R"code(Join a sequence of arrays along a new axis.
+
+The axis parameter specifies the index of the new axis in the dimensions of the
+result. For example, if axis=0 it will be the first dimension and if axis=-1 it
+will be the last dimension.
+
+Examples::
+
+  x = [1, 2]
+  y = [3, 4]
+
+  stack(x, y) = [[1, 2],
+                 [3, 4]]
+  stack(x, y, axis=1) = [[1, 3],
+                         [2, 4]]
+)code")
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
+    return static_cast<uint32_t>(param.num_args);
+  })
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<StackParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    uint32_t num_args = dmlc::get<StackParam>(attrs.parsed).num_args;
+    std::vector<std::string> ret;
+    for (uint32_t i = 0; i < num_args; ++i) {
+      ret.push_back(std::string("arg") + std::to_string(i));
+    }
+    return ret;
+  })
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<mxnet::FInferShape>("FInferShape", StackOpShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", StackOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_stack"})
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
+.add_arguments(StackParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 5bf36e5..615dd26 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -33,5 +33,8 @@ NNVM_REGISTER_OP(_np_transpose)
 NNVM_REGISTER_OP(_np_reshape)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
+NNVM_REGISTER_OP(_npi_stack)
+.set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index e43b91f..853cb50 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -313,6 +313,57 @@ def test_np_minimum():
     check_minimum(np.zeros(()), np.ones((5, 1, 4)))
 
 
+@with_seed()
+@mx.use_np_shape
+def test_np_stack():
+    class TestStack(HybridBlock):
+        def __init__(self, axis=None):
+            super(TestStack, self).__init__()
+            self._axis = axis
+
+        def hybrid_forward(self, F, a, *args):
+            return F.np.stack([a] + list(args), axis=self._axis)
+
+    a, b, c, d = mx.sym.Variable("a"), mx.sym.Variable("b"), mx.sym.Variable("c"), mx.sym.Variable("d")
+    ret = mx.sym.np.stack([a.as_np_ndarray(), b.as_np_ndarray(), c.as_np_ndarray(), d.as_np_ndarray()])
+    assert type(ret) == mx.sym.np._Symbol
+
+    for shape in [(0, 0), (2, 3)]:
+        for hybridize in [True, False]:
+            for axis in range(2):
+                test_stack = TestStack(axis=axis)
+                if hybridize:
+                    test_stack.hybridize()
+                np_a = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32)
+                np_b = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32)
+                np_c = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32)
+                np_d = _np.random.uniform(-1.0, 1.0, shape).astype(_np.float32)
+
+                mx_a = np.array(np_a)
+                mx_a.attach_grad()
+                mx_b = np.array(np_b)
+                mx_b.attach_grad()
+                mx_c = np.array(np_c)
+                mx_c.attach_grad()
+                mx_d = np.array(np_d)
+                mx_d.attach_grad()
+                expected_ret = _np.stack([np_a, np_b, np_c, np_d], axis=axis)
+                with mx.autograd.record():
+                    y = test_stack(mx_a, mx_b, mx_c, mx_d)
+                assert y.shape == expected_ret.shape
+                assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)
+
+                y.backward()
+
+                assert_almost_equal(mx_a.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
+                assert_almost_equal(mx_b.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
+                assert_almost_equal(mx_c.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
+                assert_almost_equal(mx_d.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
+
+                np_out = _np.stack([np_a, np_b, np_c, np_d], axis=axis)
+                mx_out = np.stack([mx_a, mx_b, mx_c, mx_d], axis=axis)
+                assert same(mx_out.asnumpy(), np_out)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()