You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:01:16 UTC

[incubator-mxnet] 40/42: [Numpy] Numpy hstack (#15302)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 431e2f8e3cecb3c1c9aec98df234449e3fb7b087
Author: Mike <ma...@connect.hku.hk>
AuthorDate: Sat Jul 13 01:08:11 2019 +0800

    [Numpy] Numpy hstack (#15302)
    
    * Add numpy compatible hstack that currently pass CPU tests
    
    Regsiter hstack in GPU
    
    Rgister backward function to GPU
    
    Add some comments
    
    Fix the issues pointed out in code review
    
    Minor syntax fix according to comments
    
    Add docs
    
    * Minor syntax fix
---
 python/mxnet/ndarray/numpy/_op.py      | 42 +++++++++++++++-
 python/mxnet/numpy/multiarray.py       | 42 +++++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py   | 42 +++++++++++++++-
 src/operator/nn/concat-inl.h           | 44 ++++++++++++++++
 src/operator/numpy/np_matrix_op.cc     | 91 ++++++++++++++++++++++++++++++++++
 src/operator/numpy/np_matrix_op.cu     |  6 +++
 tests/python/unittest/test_numpy_op.py | 65 ++++++++++++++++++++++++
 7 files changed, 329 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index ff0e8c8..76ed88c 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -33,7 +33,7 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
            'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'eye',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
            'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
-           'argsort']
+           'argsort', 'hstack']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -511,6 +511,46 @@ def concatenate(seq, axis=0, out=None):
 
 
 @set_module('mxnet.ndarray.numpy')
+def hstack(arrays):
+    """
+    Stack arrays in sequence horizontally (column wise).
+    This is equivalent to concatenation along the second axis,
+    except for 1-D arrays where it concatenates along the first axis.
+    Rebuilds arrays divided by hsplit.
+    This function makes most sense for arrays with up to 3 dimensions.
+    For instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions concatenate,
+    stack and block provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : sequence of ndarrays
+        The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length.
+
+    Returns
+    -------
+    stacked : ndarray
+        The array formed by stacking the given arrays.
+
+    Examples
+    --------
+    >>> from mxnet import np,npx
+    >>> a = np.array((1,2,3))
+    >>> b = np.array((2,3,4))
+    >>> np.hstack((a,b))
+    array([1., 2., 3., 2., 3., 4.])
+
+    >>> a = np.array([[1],[2],[3]])
+    >>> b = np.array([[2],[3],[4]])
+    >>> np.hstack((a,b))
+    array([[1., 2.],
+           [2., 3.],
+           [3., 4.]])
+    """
+    return _npi.hstack(*arrays)
+
+
+@set_module('mxnet.ndarray.numpy')
 def add(x1, x2, out=None):
     """Add arguments element-wise.
 
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 83fcfc1..d20db96 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -48,7 +48,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', '
            'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
            'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
-           'argsort']
+           'argsort', 'hstack']
 
 
 # This function is copied from ndarray.py since pylint
@@ -1566,6 +1566,46 @@ def stack(arrays, axis=0, out=None):
 
 
 @set_module('mxnet.numpy')
+def hstack(arrays):
+    """
+    Stack arrays in sequence horizontally (column wise).
+    This is equivalent to concatenation along the second axis,
+    except for 1-D arrays where it concatenates along the first axis.
+    Rebuilds arrays divided by hsplit.
+    This function makes most sense for arrays with up to 3 dimensions.
+    For instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions concatenate,
+    stack and block provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : sequence of ndarrays
+        The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length.
+
+    Returns
+    -------
+    stacked : ndarray
+        The array formed by stacking the given arrays.
+
+    Examples
+    --------
+    >>> from mxnet import np,npx
+    >>> a = np.array((1,2,3))
+    >>> b = np.array((2,3,4))
+    >>> np.hstack((a,b))
+    array([1., 2., 3., 2., 3., 4.])
+
+    >>> a = np.array([[1],[2],[3]])
+    >>> b = np.array([[2],[3],[4]])
+    >>> np.hstack((a,b))
+    array([[1., 2.],
+           [2., 3.],
+           [3., 4.]])
+    """
+    return _npi.hstack(*arrays)
+
+
+@set_module('mxnet.numpy')
 def arange(start, stop=None, step=1, dtype=None, ctx=None):
     """Return evenly spaced values within a given interval.
 
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 92e0563..987ed61 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -33,7 +33,7 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arang
            'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
            'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt',
            'abs', 'exp', 'arctan', 'sign', 'log', 'degrees', 'log2', 'rint', 'radians', 'mean',
-           'reciprocal', 'square', 'arcsin', 'argsort']
+           'reciprocal', 'square', 'arcsin', 'argsort', 'hstack']
 
 
 def _num_outputs(sym):
@@ -1191,6 +1191,46 @@ def stack(arrays, axis=0, out=None):
 
 
 @set_module('mxnet.symbol.numpy')
+def hstack(arrays):
+    """
+    Stack arrays in sequence horizontally (column wise).
+    This is equivalent to concatenation along the second axis,
+    except for 1-D arrays where it concatenates along the first axis.
+    Rebuilds arrays divided by hsplit.
+    This function makes most sense for arrays with up to 3 dimensions.
+    For instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions concatenate,
+    stack and block provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : _Symbol
+        The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length.
+
+    Returns
+    -------
+    stacked : _Symbol
+        The array formed by stacking the given arrays.
+
+    Examples
+    --------
+    >>> from mxnet import np,npx
+    >>> a = np.array((1,2,3))
+    >>> b = np.array((2,3,4))
+    >>> np.hstack((a,b))
+    array([1., 2., 3., 2., 3., 4.])
+
+    >>> a = np.array([[1],[2],[3]])
+    >>> b = np.array([[2],[3],[4]])
+    >>> np.hstack((a,b))
+    array([[1., 2.],
+           [2., 3.],
+           [3., 4.]])
+    """
+    return _npi.hstack(*arrays)
+
+
+@set_module('mxnet.symbol.numpy')
 def concatenate(seq, axis=0, out=None):
     """Join a sequence of arrays along an existing axis.
 
diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h
index 7a58ae6..7548f6b 100644
--- a/src/operator/nn/concat-inl.h
+++ b/src/operator/nn/concat-inl.h
@@ -142,6 +142,28 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
 }
 
 template<typename xpu>
+void HStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
+  param.dim = inputs[0].shape_.ndim() > 1 ? 1 : 0;
+  std::vector<TBlob> modified_inputs(inputs.size());
+  for (int i = 0; i < param.num_args; ++i) {
+    if (inputs[i].shape_.ndim() == 0) {
+      modified_inputs[i] = inputs[i].reshape(TShape(1, 1));
+    } else {
+      modified_inputs[i] = inputs[i];
+    }
+  }
+  MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
+    ConcatOp<xpu, DType> op;
+    op.Init(param);
+    op.Forward(ctx, modified_inputs, req, outputs);
+  });
+}
+
+template<typename xpu>
 void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
                        const std::vector<TBlob>& inputs,
                        const std::vector<OpReqType>& req,
@@ -154,6 +176,28 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
   });
 }
 
+template<typename xpu>
+void HStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
+  param.dim = inputs[0].shape_.ndim() > 1 ? 1 : 0;
+  std::vector<TBlob> modified_outputs(outputs.size());
+  for (int i = 0; i < param.num_args; ++i) {
+    if (outputs[i].shape_.ndim() == 0) {
+      modified_outputs[i] = outputs[i].reshape(TShape(1, 1));
+    } else {
+      modified_outputs[i] = outputs[i];
+    }
+  }
+  MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
+    ConcatOp<xpu, DType> op;
+    op.Init(param);
+    op.Backward(ctx, inputs[concat_enum::kOut], req, modified_outputs);
+  });
+}
+
 /*!
  * \brief concat CSRNDArray on the first dimension.
  */
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 1323447..be0c67b 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -55,6 +55,59 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
   return shape_is_known(ret);
 }
 
+bool HStackShape(const nnvm::NodeAttrs& attrs,
+                 mxnet::ShapeVector *in_shape,
+                 mxnet::ShapeVector *out_shape) {
+  using namespace mshadow;
+  ConcatParam param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
+  mxnet::TShape dshape;
+  dim_t size = 0;
+  bool has_unknown_dim_size = false;
+  int axis = (*in_shape)[0].ndim() > 1 ? 1 : 0;
+  param_.dim = axis;
+  for (int i = 0; i < param_.num_args; ++i) {
+    // scalor tensor is treated as one dimensional vector
+    if ((*in_shape)[i].ndim() == 0) {
+      (*in_shape)[i] = mxnet::TShape(1, 1);
+    }
+    mxnet::TShape &tmp = (*in_shape)[i];
+    if (tmp.ndim() > 0) {
+      CheckAxis(axis, tmp.ndim());
+      if (!mxnet::dim_size_is_known(tmp, axis)) {
+        has_unknown_dim_size = true;
+      } else {
+        size += tmp[axis];
+      }
+      tmp[axis] = -1;
+      shape_assign(&dshape, tmp);
+    }
+  }
+
+  mxnet::TShape tmp = (*out_shape)[0];
+  if (tmp.ndim() > 0) {
+    axis = CheckAxis(param_.dim, tmp.ndim());
+    tmp[axis] = -1;
+    shape_assign(&dshape, tmp);
+  }
+
+  if (dshape.ndim() == -1) return false;
+  CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";
+
+  for (int i = 0; i < param_.num_args; ++i) {
+    CHECK(shape_assign(&(*in_shape)[i], dshape))
+        << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
+  }
+
+  if (!has_unknown_dim_size) {
+    dshape[axis] = size;
+  }
+  CHECK(shape_assign(&(*out_shape)[0], dshape))
+      << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
+
+  return shape_is_known(dshape);
+}
+
 NNVM_REGISTER_OP(_np_transpose)
 .describe(R"code(Permute the dimensions of an array.
 
@@ -310,6 +363,44 @@ NNVM_REGISTER_OP(_backward_np_concat)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
 
+NNVM_REGISTER_OP(_npi_hstack)
+.describe(R"code(Stack tensors horizontally (in second dimension))code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+  return params.num_args;
+})
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+    std::vector<std::string> ret;
+    for (int i = 0; i < params.num_args; ++i) {
+      ret.push_back(std::string("data") + std::to_string(i));
+    }
+    return ret;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"out"};
+})
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<nnvm::FInferType>("FInferType", ConcatType)
+.set_attr<mxnet::FInferShape>("FInferShape", HStackShape)
+.set_attr<FCompute>("FCompute<cpu>", HStackCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_hstack"})
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+.add_arguments(ConcatParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_np_hstack)
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+  return params.num_args;
+})
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", HStackGradCompute<cpu>);
+
 bool NumpySqueezeShape(const nnvm::NodeAttrs& attrs,
                        mxnet::ShapeVector *in_attrs,
                        mxnet::ShapeVector *out_attrs) {
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 5354820..7f0c866 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -43,6 +43,12 @@ NNVM_REGISTER_OP(_npi_concatenate)
 NNVM_REGISTER_OP(_backward_np_concat)
 .set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);
 
+NNVM_REGISTER_OP(_npi_hstack)
+.set_attr<FCompute>("FCompute<gpu>", HStackCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_np_hstack)
+.set_attr<FCompute>("FCompute<gpu>", HStackGradCompute<gpu>);
+
 NNVM_REGISTER_OP(_np_squeeze)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 06f0994..6e3ca16 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -972,6 +972,71 @@ def test_np_concat():
 
 @with_seed()
 @npx.use_np_shape
+def test_np_hstack():
+    class TestHStack(HybridBlock):
+        def __init__(self):
+            super(TestHStack, self).__init__()
+        
+        def hybrid_forward(self, F, a, *args):
+            return F.np.hstack([a] + list(args))
+
+    def get_new_shape(shape):
+        if len(shape) == 0:
+            l = random.randint(0,3)
+            if l == 0:
+                return shape 
+            else:
+                return (l,)
+        shape_lst = list(shape)
+        axis = 1 if len(shape) > 1 else 0
+        shape_lst[axis] = random.randint(0, 5)
+        return tuple(shape_lst)
+
+    shapes = [
+        (),
+        (1,),
+        (2,1),
+        (2,2,4),
+        (2,0,0),
+        (0,1,3),
+        (2,0,3),
+        (2,3,4,5)
+    ]
+    for hybridize in [True, False]:
+        for shape in shapes:
+            test_hstack = TestHStack()
+            if hybridize:
+                test_hstack.hybridize()
+            # test symbolic forward
+            a = np.random.uniform(size=get_new_shape(shape))
+            a.attach_grad()
+            b = np.random.uniform(size=get_new_shape(shape))
+            b.attach_grad()
+            c = np.random.uniform(size=get_new_shape(shape))
+            c.attach_grad()
+            d = np.random.uniform(size=get_new_shape(shape))
+            d.attach_grad()
+            with mx.autograd.record():
+                mx_out = test_hstack(a, b, c, d)
+            np_out = _np.hstack((a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()))
+            assert mx_out.shape == np_out.shape
+            assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+            # test symbolic backward
+            mx_out.backward()
+            assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
+            assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
+            assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
+            assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)
+
+            # test imperative
+            mx_out = np.hstack((a, b, c, d))
+            np_out = _np.hstack((a.asnumpy(),b.asnumpy(), c.asnumpy(), d.asnumpy()))
+            assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+
+@with_seed()
+@npx.use_np_shape
 def test_np_swapaxes():
     config = [((0, 1, 2), 0, 1),
               ((0, 1, 2), -1, -2),