You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/09/21 16:12:35 UTC
[incubator-mxnet] branch master updated: np compatible vstack
(#15850)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 11f73ed np compatible vstack (#15850)
11f73ed is described below
commit 11f73ed59eb481b76a8ffca683293eec2cf41889
Author: Haozheng Fan <fh...@gmail.com>
AuthorDate: Sun Sep 22 00:11:28 2019 +0800
np compatible vstack (#15850)
---
python/mxnet/ndarray/numpy/_op.py | 53 +++++++++++-
python/mxnet/numpy/multiarray.py | 49 ++++++++++-
python/mxnet/symbol/numpy/_symbol.py | 35 +++++++-
src/operator/numpy/np_matrix_op-inl.h | 80 ++++++++++++++++++
src/operator/numpy/np_matrix_op.cc | 143 +++++++++++++++++++++++++++++++++
src/operator/numpy/np_matrix_op.cu | 7 ++
tests/python/unittest/test_numpy_op.py | 56 +++++++++++++
7 files changed, 419 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 163d908..bb63e94 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -32,7 +32,7 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
- 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
+ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel']
@@ -1991,6 +1991,57 @@ def stack(arrays, axis=0, out=None):
@set_module('mxnet.ndarray.numpy')
+def vstack(arrays, out=None):
+ r"""Stack arrays in sequence vertically (row wise).
+
+ This is equivalent to concatenation along the first axis after 1-D arrays
+ of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
+ `vsplit`.
+
+ This function makes most sense for arrays with up to 3 dimensions. For
+ instance, for pixel-data with a height (first axis), width (second axis),
+ and r/g/b channels (third axis). The functions `concatenate` and `stack`
+ provide more general stacking and concatenation operations.
+
+ Parameters
+ ----------
+ tup : sequence of ndarrays
+ The arrays must have the same shape along all but the first axis.
+ 1-D arrays must have the same length.
+
+ Returns
+ -------
+ stacked : ndarray
+ The array formed by stacking the given arrays, will be at least 2-D.
+
+ Examples
+ --------
+ >>> a = np.array([1, 2, 3])
+ >>> b = np.array([2, 3, 4])
+ >>> np.vstack((a, b))
+ array([[1., 2., 3.],
+ [2., 3., 4.]])
+
+ >>> a = np.array([[1], [2], [3]])
+ >>> b = np.array([[2], [3], [4]])
+ >>> np.vstack((a, b))
+ array([[1.],
+ [2.],
+ [3.],
+ [2.],
+ [3.],
+ [4.]])
+ """
+ def get_list(arrays):
+ if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
+ raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
+ return [arr for arr in arrays]
+
+ arrays = get_list(arrays)
+ return _npi.vstack(*arrays)
+
+
+@set_module('mxnet.ndarray.numpy')
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 52dc9fb..f738d63 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -52,8 +52,8 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
- 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
- 'ravel']
+ 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
+ 'copysign', 'ravel']
# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -3561,6 +3561,51 @@ def stack(arrays, axis=0, out=None):
@set_module('mxnet.numpy')
+def vstack(arrays, out=None):
+ r"""Stack arrays in sequence vertically (row wise).
+
+ This is equivalent to concatenation along the first axis after 1-D arrays
+ of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
+ `vsplit`.
+
+ This function makes most sense for arrays with up to 3 dimensions. For
+ instance, for pixel-data with a height (first axis), width (second axis),
+ and r/g/b channels (third axis). The functions `concatenate` and `stack`
+ provide more general stacking and concatenation operations.
+
+ Parameters
+ ----------
+ tup : sequence of ndarrays
+ The arrays must have the same shape along all but the first axis.
+ 1-D arrays must have the same length.
+
+ Returns
+ -------
+ stacked : ndarray
+ The array formed by stacking the given arrays, will be at least 2-D.
+
+ Examples
+ --------
+ >>> a = np.array([1, 2, 3])
+ >>> b = np.array([2, 3, 4])
+ >>> np.vstack((a, b))
+ array([[1., 2., 3.],
+ [2., 3., 4.]])
+
+ >>> a = np.array([[1], [2], [3]])
+ >>> b = np.array([[2], [3], [4]])
+ >>> np.vstack((a, b))
+ array([[1.],
+ [2.],
+ [3.],
+ [2.],
+ [3.],
+ [4.]])
+ """
+ return _mx_nd_np.vstack(arrays)
+
+
+@set_module('mxnet.numpy')
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 962fee2..03d3d0d 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -34,7 +34,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
- 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
+ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel']
@@ -2397,6 +2397,39 @@ def stack(arrays, axis=0, out=None):
@set_module('mxnet.symbol.numpy')
+def vstack(arrays, out=None):
+ r"""Stack arrays in sequence vertically (row wise).
+
+ This is equivalent to concatenation along the first axis after 1-D arrays
+ of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
+ `vsplit`.
+
+ This function makes most sense for arrays with up to 3 dimensions. For
+ instance, for pixel-data with a height (first axis), width (second axis),
+ and r/g/b channels (third axis). The functions `concatenate` and `stack`
+ provide more general stacking and concatenation operations.
+
+ Parameters
+ ----------
+ tup : sequence of _Symbol
+ The arrays must have the same shape along all but the first axis.
+ 1-D arrays must have the same length.
+
+ Returns
+ -------
+ stacked : _Symbol
+ The array formed by stacking the given arrays, will be at least 2-D.
+ """
+ def get_list(arrays):
+ if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
+ raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
+ return [arr for arr in arrays]
+
+ arrays = get_list(arrays)
+ return _npi.vstack(*arrays)
+
+
+@set_module('mxnet.symbol.numpy')
def maximum(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h
index 6d3d9ea..fe2fadd 100644
--- a/src/operator/numpy/np_matrix_op-inl.h
+++ b/src/operator/numpy/np_matrix_op-inl.h
@@ -41,6 +41,14 @@ struct NumpyTransposeParam : public dmlc::Parameter<NumpyTransposeParam> {
}
};
+struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
+ int num_args;
+ DMLC_DECLARE_PARAMETER(NumpyVstackParam) {
+ DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
+ .describe("Number of inputs to be vstacked.");
+ }
+};
+
template<typename xpu>
void NumpyTranspose(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -60,6 +68,78 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
}
}
+template<typename xpu>
+void NumpyVstackForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ using namespace mshadow_op;
+
+ const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
+ CHECK_EQ(inputs.size(), param.num_args);
+ CHECK_EQ(outputs.size(), 1);
+ CHECK_EQ(req.size(), 1);
+
+ // reshape if necessary
+ std::vector<TBlob> data(param.num_args);
+ for (int i = 0; i < param.num_args; i++) {
+ if (inputs[i].shape_.ndim() == 0 || inputs[i].shape_.ndim() == 1) {
+ TShape shape = Shape2(1, inputs[i].shape_.Size());
+ data[i] = inputs[i].reshape(shape);
+ } else {
+ data[i] = inputs[i];
+ }
+ }
+
+ // initialize ConcatOp
+ ConcatParam cparam;
+ cparam.num_args = param.num_args;
+ cparam.dim = 0;
+ MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ ConcatOp<xpu, DType> op;
+ op.Init(cparam);
+ op.Forward(ctx, data, req, outputs);
+ });
+}
+
+template<typename xpu>
+void NumpyVstackBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ using namespace mshadow_op;
+
+ const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
+ CHECK_EQ(inputs.size(), 1);
+ CHECK_EQ(outputs.size(), param.num_args);
+ CHECK_EQ(req.size(), param.num_args);
+
+ // reshape if necessary
+ std::vector<TBlob> data(param.num_args);
+ for (int i = 0; i < param.num_args; i++) {
+ if (outputs[i].shape_.ndim() == 0 || outputs[i].shape_.ndim() == 1) {
+ TShape shape = Shape2(1, outputs[i].shape_.Size());
+ data[i] = outputs[i].reshape(shape);
+ } else {
+ data[i] = outputs[i];
+ }
+ }
+
+ // initialize ConcatOp
+ ConcatParam cparam;
+ cparam.num_args = param.num_args;
+ cparam.dim = 0;
+ MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ ConcatOp<xpu, DType> op;
+ op.Init(cparam);
+ op.Backward(ctx, inputs[0], req, data);
+ });
+}
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index f88dd56..5509f34 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -25,6 +25,7 @@
#include <vector>
#include "./np_matrix_op-inl.h"
+#include "../nn/concat-inl.h"
namespace mxnet {
namespace op {
@@ -346,5 +347,147 @@ Examples::
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
.add_arguments(StackParam::__FIELDS__());
+bool NumpyVstackType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_type,
+ std::vector<int> *out_type) {
+ const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
+ CHECK_EQ(in_type->size(), param.num_args);
+ CHECK_EQ(out_type->size(), 1);
+ int dtype = -1;
+ for (int i = 0; i < param.num_args; i++) {
+ if (dtype == -1) {
+ dtype = in_type->at(i);
+ }
+ }
+ if (dtype == -1) {
+ dtype = out_type->at(0);
+ }
+ for (int i = 0; i < param.num_args; i++) {
+ TYPE_ASSIGN_CHECK(*in_type, i, dtype);
+ }
+ TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
+ return dtype != -1;
+}
+
+bool NumpyVstackShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector* in_attrs,
+ mxnet::ShapeVector* out_attrs) {
+ CHECK_EQ(out_attrs->size(), 1U);
+ const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), param.num_args);
+ std::vector<mxnet::TShape> in_attrs_tmp(param.num_args);
+ TShape dshape;
+ for (int i = 0; i < param.num_args; i++) {
+ if ((*in_attrs)[i].ndim() == 0) {
+ in_attrs_tmp[i] = TShape(2, 1);
+ } else if ((*in_attrs)[i].ndim() == 1) {
+ in_attrs_tmp[i] = TShape(2, 1);
+ in_attrs_tmp[i][1] = (*in_attrs)[i][0];
+ } else {
+ in_attrs_tmp[i] = (*in_attrs)[i];
+ }
+ TShape tmp(in_attrs_tmp[i].ndim(), -1);
+ shape_assign(&dshape, tmp);
+ }
+ TShape tmp((*out_attrs)[0].ndim(), -1);
+ shape_assign(&dshape, tmp);
+ for (int i = 0; i < param.num_args; i++) {
+ SHAPE_ASSIGN_CHECK(in_attrs_tmp, i, dshape)
+ }
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape)
+ if (dshape.ndim() == -1) {
+ return false;
+ }
+ int cnt = 0, sum = 0, pos = -1;
+ for (int i = 0; i < param.num_args; i++) {
+ TShape tmp = in_attrs_tmp[i];
+ if (!dim_size_is_known(tmp, 0)) {
+ cnt++;
+ pos = i;
+ } else {
+ sum += tmp[0];
+ }
+ tmp[0] = -1;
+ shape_assign(&dshape, tmp);
+ }
+ tmp = out_attrs->at(0);
+ if (!dim_size_is_known(tmp, 0)) {
+ cnt++;
+ pos = -1;
+ } else {
+ sum += tmp[0];
+ }
+ tmp[0] = -1;
+ shape_assign(&dshape, tmp);
+ for (int i = 0; i < param.num_args; i++) {
+ SHAPE_ASSIGN_CHECK(in_attrs_tmp, i, dshape)
+ }
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape)\
+ dshape[0] = 0;
+ if (!shape_is_known(dshape)) {
+ return false;
+ }
+
+ dshape[0] = sum;
+ if (cnt == 0) {
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+ } else if (cnt == 1) {
+ if (pos >= 0) {
+ in_attrs_tmp[pos][0] = out_attrs->at(0)[0] - sum;
+ } else {
+ out_attrs->at(0)[0] = sum;
+ }
+ } else {
+ return false;
+ }
+
+ for (int i = 0; i < param.num_args; i++) {
+ if (in_attrs->at(i).ndim() == 1) {
+ in_attrs->at(i)[0] = in_attrs_tmp[i][1];
+ } else if (in_attrs->at(i).ndim() >= 2) {
+ in_attrs->at(i) = in_attrs_tmp[i];
+ }
+ }
+
+ return true;
+}
+
+DMLC_REGISTER_PARAMETER(NumpyVstackParam);
+
+NNVM_REGISTER_OP(_npi_vstack)
+.describe(R"code()code" ADD_FILELINE)
+.set_attr_parser(ParamParser<NumpyVstackParam>)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+ const NumpyVstackParam& param = dmlc::get<NumpyVstackParam>(attrs.parsed);
+ return static_cast<uint32_t>(param.num_args);
+})
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const nnvm::NodeAttrs& attrs) {
+ int num_args = dmlc::get<NumpyVstackParam>(attrs.parsed).num_args;
+ std::vector<std::string> ret;
+ for (int i = 0; i < num_args; i++) {
+ ret.push_back(std::string("arg") + std::to_string(i));
+ }
+ return ret;
+ })
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyVstackShape)
+.set_attr<nnvm::FInferType>("FInferType", NumpyVstackType)
+.set_attr<FCompute>("FCompute<cpu>", NumpyVstackForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_vstack"})
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to vstack")
+.add_arguments(NumpyVstackParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_np_vstack)
+.set_attr_parser(ParamParser<NumpyVstackParam>)
+.set_num_inputs(1)
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+ const NumpyVstackParam& param = dmlc::get<NumpyVstackParam>(attrs.parsed);
+ return static_cast<uint32_t>(param.num_args);
+})
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", NumpyVstackBackward<cpu>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 4ba527d..b017ad6 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -24,6 +24,7 @@
*/
#include "./np_matrix_op-inl.h"
+#include "../nn/concat-inl.h"
namespace mxnet {
namespace op {
@@ -46,5 +47,11 @@ NNVM_REGISTER_OP(_backward_np_concat)
NNVM_REGISTER_OP(_npi_stack)
.set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);
+NNVM_REGISTER_OP(_npi_vstack)
+.set_attr<FCompute>("FCompute<gpu>", NumpyVstackForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_np_vstack)
+.set_attr<FCompute>("FCompute<gpu>", NumpyVstackBackward<gpu>);
+
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 2ce0fa4..87c7ab2 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2075,6 +2075,62 @@ def test_np_svd():
assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
+@with_seed()
+@use_np
+def test_np_vstack():
+ class TestVstack(HybridBlock):
+ def __init__(self):
+ super(TestVstack, self).__init__()
+
+ def hybrid_forward(self, F, a, *args):
+ return F.np.vstack([a] + list(args))
+
+ def g(data):
+ return _np.ones_like(data)
+
+ configs = [
+ ((), (), ()),
+ ((2), (2), (2)),
+ ((0), (0), (0)),
+ ((2, 2), (3, 2), (0, 2)),
+ ((2, 3), (1, 3), (4, 3)),
+ ((2, 2, 2), (3, 2, 2), (1, 2, 2)),
+ ((0, 1, 1), (4, 1, 1), (5, 1, 1)),
+ ((2), (0, 2), (2, 2))
+ ]
+ types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']
+ for config in configs:
+ for hybridize in [True, False]:
+ for dtype in types:
+ test_vstack = TestVstack()
+ if hybridize:
+ test_vstack.hybridize()
+ rtol = 1e-3
+ atol = 1e-5
+ v = []
+ v_np = []
+ for i in range(3):
+ v_np.append(_np.array(_np.random.uniform(-10.0, 10.0, config[i]), dtype=dtype))
+ v.append(mx.nd.array(v_np[i]).as_np_ndarray())
+ v[i].attach_grad()
+ expected_np = _np.vstack(v_np)
+ with mx.autograd.record():
+ mx_out = test_vstack(*v)
+ assert mx_out.shape == expected_np.shape
+ assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+
+ # Test gradient
+ mx_out.backward()
+ for i in range(3):
+ expected_grad = g(v_np[i])
+ assert_almost_equal(v[i].grad.asnumpy(), expected_grad, rtol=rtol, atol=atol)
+
+ # Test imperative once again
+ mx_out = np.vstack(v)
+ expected_np = _np.vstack(v_np)
+ assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+
+
if __name__ == '__main__':
import nose
nose.runmodule()