You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/09/21 16:12:35 UTC

[incubator-mxnet] branch master updated: np compatible vstack (#15850)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 11f73ed  np compatible vstack (#15850)
11f73ed is described below

commit 11f73ed59eb481b76a8ffca683293eec2cf41889
Author: Haozheng Fan <fh...@gmail.com>
AuthorDate: Sun Sep 22 00:11:28 2019 +0800

    np compatible vstack (#15850)
---
 python/mxnet/ndarray/numpy/_op.py      |  53 +++++++++++-
 python/mxnet/numpy/multiarray.py       |  49 ++++++++++-
 python/mxnet/symbol/numpy/_symbol.py   |  35 +++++++-
 src/operator/numpy/np_matrix_op-inl.h  |  80 ++++++++++++++++++
 src/operator/numpy/np_matrix_op.cc     | 143 +++++++++++++++++++++++++++++++++
 src/operator/numpy/np_matrix_op.cu     |   7 ++
 tests/python/unittest/test_numpy_op.py |  56 +++++++++++++
 7 files changed, 419 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 163d908..bb63e94 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -32,7 +32,7 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
            'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
            'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
            'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
-           'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
+           'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
            'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel']
 
@@ -1991,6 +1991,57 @@ def stack(arrays, axis=0, out=None):
 
 
 @set_module('mxnet.ndarray.numpy')
+def vstack(arrays, out=None):
+    r"""Stack arrays in sequence vertically (row wise).
+
+    This is equivalent to concatenation along the first axis after 1-D arrays
+    of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
+    `vsplit`.
+
+    This function makes most sense for arrays with up to 3 dimensions. For
+    instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions `concatenate` and `stack`
+    provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : sequence of ndarrays
+        The arrays must have the same shape along all but the first axis.
+        1-D arrays must have the same length.
+
+    Returns
+    -------
+    stacked : ndarray
+        The array formed by stacking the given arrays, will be at least 2-D.
+
+    Examples
+    --------
+    >>> a = np.array([1, 2, 3])
+    >>> b = np.array([2, 3, 4])
+    >>> np.vstack((a, b))
+    array([[1., 2., 3.],
+            [2., 3., 4.]])
+
+    >>> a = np.array([[1], [2], [3]])
+    >>> b = np.array([[2], [3], [4]])
+    >>> np.vstack((a, b))
+    array([[1.],
+            [2.],
+            [3.],
+            [2.],
+            [3.],
+            [4.]])
+    """
+    def get_list(arrays):
+        if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
+            raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
+        return [arr for arr in arrays]
+
+    arrays = get_list(arrays)
+    return _npi.vstack(*arrays)
+
+
+@set_module('mxnet.ndarray.numpy')
 def maximum(x1, x2, out=None):
     """Returns element-wise maximum of the input arrays with broadcasting.
 
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 52dc9fb..f738d63 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -52,8 +52,8 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
            'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
            'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
            'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
-           'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
-           'ravel']
+           'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
+           'copysign', 'ravel']
 
 # Return code for dispatching indexing function call
 _NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -3561,6 +3561,51 @@ def stack(arrays, axis=0, out=None):
 
 
 @set_module('mxnet.numpy')
+def vstack(arrays, out=None):
+    r"""Stack arrays in sequence vertically (row wise).
+
+    This is equivalent to concatenation along the first axis after 1-D arrays
+    of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
+    `vsplit`.
+
+    This function makes most sense for arrays with up to 3 dimensions. For
+    instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions `concatenate` and `stack`
+    provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : sequence of ndarrays
+        The arrays must have the same shape along all but the first axis.
+        1-D arrays must have the same length.
+
+    Returns
+    -------
+    stacked : ndarray
+        The array formed by stacking the given arrays, will be at least 2-D.
+
+    Examples
+    --------
+    >>> a = np.array([1, 2, 3])
+    >>> b = np.array([2, 3, 4])
+    >>> np.vstack((a, b))
+    array([[1., 2., 3.],
+           [2., 3., 4.]])
+
+    >>> a = np.array([[1], [2], [3]])
+    >>> b = np.array([[2], [3], [4]])
+    >>> np.vstack((a, b))
+    array([[1.],
+           [2.],
+           [3.],
+           [2.],
+           [3.],
+           [4.]])
+    """
+    return _mx_nd_np.vstack(arrays)
+
+
+@set_module('mxnet.numpy')
 def maximum(x1, x2, out=None):
     """Returns element-wise maximum of the input arrays with broadcasting.
 
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 962fee2..03d3d0d 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -34,7 +34,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
            'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
            'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
            'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
-           'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
+           'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
            'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel']
 
@@ -2397,6 +2397,39 @@ def stack(arrays, axis=0, out=None):
 
 
 @set_module('mxnet.symbol.numpy')
+def vstack(arrays, out=None):
+    r"""Stack arrays in sequence vertically (row wise).
+
+    This is equivalent to concatenation along the first axis after 1-D arrays
+    of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
+    `vsplit`.
+
+    This function makes most sense for arrays with up to 3 dimensions. For
+    instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions `concatenate` and `stack`
+    provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : sequence of _Symbol
+        The arrays must have the same shape along all but the first axis.
+        1-D arrays must have the same length.
+
+    Returns
+    -------
+    stacked : _Symbol
+        The array formed by stacking the given arrays, will be at least 2-D.
+    """
+    def get_list(arrays):
+        if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
+            raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
+        return [arr for arr in arrays]
+
+    arrays = get_list(arrays)
+    return _npi.vstack(*arrays)
+
+
+@set_module('mxnet.symbol.numpy')
 def maximum(x1, x2, out=None):
     return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
 
diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h
index 6d3d9ea..fe2fadd 100644
--- a/src/operator/numpy/np_matrix_op-inl.h
+++ b/src/operator/numpy/np_matrix_op-inl.h
@@ -41,6 +41,14 @@ struct NumpyTransposeParam : public dmlc::Parameter<NumpyTransposeParam> {
   }
 };
 
+struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
+  int num_args;
+  DMLC_DECLARE_PARAMETER(NumpyVstackParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
+    .describe("Number of inputs to be vstacked.");
+  }
+};
+
 template<typename xpu>
 void NumpyTranspose(const nnvm::NodeAttrs& attrs,
                     const OpContext& ctx,
@@ -60,6 +68,78 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu>
+void NumpyVstackForward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<TBlob>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow_op;
+
+  const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), param.num_args);
+  CHECK_EQ(outputs.size(), 1);
+  CHECK_EQ(req.size(), 1);
+
+  // reshape if necessary
+  std::vector<TBlob> data(param.num_args);
+  for (int i = 0; i < param.num_args; i++) {
+    if (inputs[i].shape_.ndim() == 0 || inputs[i].shape_.ndim() == 1) {
+      TShape shape = Shape2(1, inputs[i].shape_.Size());
+      data[i] = inputs[i].reshape(shape);
+    } else {
+      data[i] = inputs[i];
+    }
+  }
+
+  // initialize ConcatOp
+  ConcatParam cparam;
+  cparam.num_args = param.num_args;
+  cparam.dim = 0;
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    ConcatOp<xpu, DType> op;
+    op.Init(cparam);
+    op.Forward(ctx, data, req, outputs);
+  });
+}
+
+template<typename xpu>
+void NumpyVstackBackward(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<TBlob>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow_op;
+
+  const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 1);
+  CHECK_EQ(outputs.size(), param.num_args);
+  CHECK_EQ(req.size(), param.num_args);
+
+  // reshape if necessary
+  std::vector<TBlob> data(param.num_args);
+  for (int i = 0; i < param.num_args; i++) {
+    if (outputs[i].shape_.ndim() == 0 || outputs[i].shape_.ndim() == 1) {
+      TShape shape = Shape2(1, outputs[i].shape_.Size());
+      data[i] = outputs[i].reshape(shape);
+    } else {
+      data[i] = outputs[i];
+    }
+  }
+
+  // initialize ConcatOp
+  ConcatParam cparam;
+  cparam.num_args = param.num_args;
+  cparam.dim = 0;
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    ConcatOp<xpu, DType> op;
+    op.Init(cparam);
+    op.Backward(ctx, inputs[0], req, data);
+  });
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index f88dd56..5509f34 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -25,6 +25,7 @@
 
 #include <vector>
 #include "./np_matrix_op-inl.h"
+#include "../nn/concat-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -346,5 +347,147 @@ Examples::
 .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
 .add_arguments(StackParam::__FIELDS__());
 
+bool NumpyVstackType(const nnvm::NodeAttrs& attrs,
+                     std::vector<int> *in_type,
+                     std::vector<int> *out_type) {
+  const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
+  CHECK_EQ(in_type->size(), param.num_args);
+  CHECK_EQ(out_type->size(), 1);
+  int dtype = -1;
+  for (int i = 0; i < param.num_args; i++) {
+    if (dtype == -1) {
+      dtype = in_type->at(i);
+    }
+  }
+  if (dtype == -1) {
+    dtype = out_type->at(0);
+  }
+  for (int i = 0; i < param.num_args; i++) {
+    TYPE_ASSIGN_CHECK(*in_type, i, dtype);
+  }
+  TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
+  return dtype != -1;
+}
+
+bool NumpyVstackShape(const nnvm::NodeAttrs& attrs,
+                      mxnet::ShapeVector* in_attrs,
+                      mxnet::ShapeVector* out_attrs) {
+  CHECK_EQ(out_attrs->size(), 1U);
+  const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), param.num_args);
+  std::vector<mxnet::TShape> in_attrs_tmp(param.num_args);
+  TShape dshape;
+  for (int i = 0; i < param.num_args; i++) {
+    if ((*in_attrs)[i].ndim() == 0) {
+      in_attrs_tmp[i] = TShape(2, 1);
+    } else if ((*in_attrs)[i].ndim() == 1) {
+      in_attrs_tmp[i] = TShape(2, 1);
+      in_attrs_tmp[i][1] = (*in_attrs)[i][0];
+    } else {
+      in_attrs_tmp[i] = (*in_attrs)[i];
+    }
+    TShape tmp(in_attrs_tmp[i].ndim(), -1);
+    shape_assign(&dshape, tmp);
+  }
+  TShape tmp((*out_attrs)[0].ndim(), -1);
+  shape_assign(&dshape, tmp);
+  for (int i = 0; i < param.num_args; i++) {
+    SHAPE_ASSIGN_CHECK(in_attrs_tmp, i, dshape)
+  }
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape)
+  if (dshape.ndim() == -1) {
+    return false;
+  }
+  int cnt = 0, sum = 0, pos = -1;
+  for (int i = 0; i < param.num_args; i++) {
+    TShape tmp = in_attrs_tmp[i];
+    if (!dim_size_is_known(tmp, 0)) {
+      cnt++;
+      pos = i;
+    } else {
+      sum += tmp[0];
+    }
+    tmp[0] = -1;
+    shape_assign(&dshape, tmp);
+  }
+  tmp = out_attrs->at(0);
+  if (!dim_size_is_known(tmp, 0)) {
+    cnt++;
+    pos = -1;
+  } else {
+    sum += tmp[0];
+  }
+  tmp[0] = -1;
+  shape_assign(&dshape, tmp);
+  for (int i = 0; i < param.num_args; i++) {
+    SHAPE_ASSIGN_CHECK(in_attrs_tmp, i, dshape)
+  }
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape)\
+  dshape[0] = 0;
+  if (!shape_is_known(dshape)) {
+    return false;
+  }
+
+  dshape[0] = sum;
+  if (cnt == 0) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  } else if (cnt == 1) {
+    if (pos >= 0) {
+      in_attrs_tmp[pos][0] = out_attrs->at(0)[0] - sum;
+    } else {
+      out_attrs->at(0)[0] = sum;
+    }
+  } else {
+    return false;
+  }
+
+  for (int i = 0; i < param.num_args; i++) {
+    if (in_attrs->at(i).ndim() == 1) {
+      in_attrs->at(i)[0] = in_attrs_tmp[i][1];
+    } else if (in_attrs->at(i).ndim() >= 2) {
+      in_attrs->at(i) = in_attrs_tmp[i];
+    }
+  }
+
+  return true;
+}
+
+DMLC_REGISTER_PARAMETER(NumpyVstackParam);
+
+NNVM_REGISTER_OP(_npi_vstack)
+.describe(R"code()code" ADD_FILELINE)
+.set_attr_parser(ParamParser<NumpyVstackParam>)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+  const NumpyVstackParam& param = dmlc::get<NumpyVstackParam>(attrs.parsed);
+  return static_cast<uint32_t>(param.num_args);
+})
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const nnvm::NodeAttrs& attrs) {
+    int num_args = dmlc::get<NumpyVstackParam>(attrs.parsed).num_args;
+    std::vector<std::string> ret;
+    for (int i = 0; i < num_args; i++) {
+      ret.push_back(std::string("arg") + std::to_string(i));
+    }
+    return ret;
+  })
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyVstackShape)
+.set_attr<nnvm::FInferType>("FInferType", NumpyVstackType)
+.set_attr<FCompute>("FCompute<cpu>", NumpyVstackForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_vstack"})
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to vstack")
+.add_arguments(NumpyVstackParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_np_vstack)
+.set_attr_parser(ParamParser<NumpyVstackParam>)
+.set_num_inputs(1)
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+  const NumpyVstackParam& param = dmlc::get<NumpyVstackParam>(attrs.parsed);
+  return static_cast<uint32_t>(param.num_args);
+})
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", NumpyVstackBackward<cpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 4ba527d..b017ad6 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -24,6 +24,7 @@
  */
 
 #include "./np_matrix_op-inl.h"
+#include "../nn/concat-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -46,5 +47,11 @@ NNVM_REGISTER_OP(_backward_np_concat)
 NNVM_REGISTER_OP(_npi_stack)
 .set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);
 
+NNVM_REGISTER_OP(_npi_vstack)
+.set_attr<FCompute>("FCompute<gpu>", NumpyVstackForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_np_vstack)
+.set_attr<FCompute>("FCompute<gpu>", NumpyVstackBackward<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 2ce0fa4..87c7ab2 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2075,6 +2075,62 @@ def test_np_svd():
                     assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
 
 
+@with_seed()
+@use_np
+def test_np_vstack():
+    class TestVstack(HybridBlock):
+        def __init__(self):
+            super(TestVstack, self).__init__()
+        
+        def hybrid_forward(self, F, a, *args):
+            return F.np.vstack([a] + list(args))
+    
+    def g(data):
+        return _np.ones_like(data)
+
+    configs = [
+        ((), (), ()),
+        ((2), (2), (2)),
+        ((0), (0), (0)),
+        ((2, 2), (3, 2), (0, 2)),
+        ((2, 3), (1, 3), (4, 3)),
+        ((2, 2, 2), (3, 2, 2), (1, 2, 2)),
+        ((0, 1, 1), (4, 1, 1), (5, 1, 1)),
+        ((2), (0, 2), (2, 2))
+    ]
+    types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']
+    for config in configs:
+        for hybridize in [True, False]:
+            for dtype in types:
+                test_vstack = TestVstack()
+                if hybridize:
+                    test_vstack.hybridize()
+                rtol = 1e-3
+                atol = 1e-5
+                v = []
+                v_np = []
+                for i in range(3):
+                    v_np.append(_np.array(_np.random.uniform(-10.0, 10.0, config[i]), dtype=dtype))
+                    v.append(mx.nd.array(v_np[i]).as_np_ndarray())
+                    v[i].attach_grad()
+                expected_np = _np.vstack(v_np)
+                with mx.autograd.record():
+                    mx_out = test_vstack(*v)
+                assert mx_out.shape == expected_np.shape
+                assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+
+                # Test gradient
+                mx_out.backward()
+                for i in range(3):
+                    expected_grad = g(v_np[i])
+                    assert_almost_equal(v[i].grad.asnumpy(), expected_grad, rtol=rtol, atol=atol)
+
+                # Test imperative once again
+                mx_out = np.vstack(v)
+                expected_np = _np.vstack(v_np)
+                assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()