You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:50 UTC
[incubator-mxnet] 14/42: numpy concatenate (#15104)
This is an automated email from the ASF dual-hosted git repository.
haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 80cc4babdf6ceacabde0ab186de1729f6325d91c
Author: Hao Jin <hj...@gmail.com>
AuthorDate: Tue Jun 4 15:55:27 2019 -0700
numpy concatenate (#15104)
---
python/mxnet/ndarray/numpy/_op.py | 27 ++++++++++++-
python/mxnet/numpy/multiarray.py | 29 +++++++++++++-
python/mxnet/symbol/numpy/_symbol.py | 27 ++++++++++++-
src/operator/nn/concat.cc | 12 +++---
src/operator/numpy/np_matrix_op.cc | 58 +++++++++++++++++++++++++++
src/operator/numpy/np_matrix_op.cu | 4 ++
src/operator/quantization/quantized_concat.cc | 12 +++---
tests/python/unittest/test_numpy_op.py | 51 +++++++++++++++++++++++
8 files changed, 204 insertions(+), 16 deletions(-)
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 34218e3..6c83e1f 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -24,7 +24,7 @@ from ...util import _sanity_check_params, set_module
from ...context import current_context
from . import _internal as _npi
-__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax']
+__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax']
@set_module('mxnet.ndarray.numpy')
@@ -277,3 +277,28 @@ def argmax(a, axis=None, out=None):
with the dimension along `axis` removed.
"""
return _npi.argmax(a, axis=axis, keepdims=False, out=out)
+
+
+@set_module('mxnet.ndarray.numpy')
+def concatenate(seq, axis=0, out=None):
+ """Join a sequence of arrays along an existing axis.
+
+ Parameters
+ ----------
+ a1, a2, ... : sequence of array_like
+ The arrays must have the same shape, except in the dimension
+ corresponding to `axis` (the first, by default).
+ axis : int, optional
+ The axis along which the arrays will be joined. If axis is None,
+ arrays are flattened before use. Default is 0.
+ out : ndarray, optional
+ If provided, the destination to place the result. The shape must be
+ correct, matching that of what concatenate would have returned if no
+ out argument were specified.
+
+ Returns
+ -------
+ res : ndarray
+ The concatenated array.
+ """
+ return _npi.concatenate(*seq, dim=axis, out=out)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 212dfe3..6b3dcde 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -37,8 +37,8 @@ from ..context import current_context
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
-__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
- 'argmax']
+__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack',
+ 'concatenate', 'arange', 'argmax']
# This function is copied from ndarray.py since pylint
@@ -1486,3 +1486,28 @@ def argmax(a, axis=None, out=None):
with the dimension along `axis` removed.
"""
return _mx_nd_np.argmax(a, axis, out)
+
+
+@set_module('mxnet.numpy')
+def concatenate(seq, axis=0, out=None):
+ """Join a sequence of arrays along an existing axis.
+
+ Parameters
+ ----------
+ a1, a2, ... : sequence of array_like
+ The arrays must have the same shape, except in the dimension
+ corresponding to `axis` (the first, by default).
+ axis : int, optional
+ The axis along which the arrays will be joined. If axis is None,
+ arrays are flattened before use. Default is 0.
+ out : ndarray, optional
+ If provided, the destination to place the result. The shape must be
+ correct, matching that of what concatenate would have returned if no
+ out argument were specified.
+
+ Returns
+ -------
+ res : ndarray
+ The concatenated array.
+ """
+ return _mx_nd_np.concatenate(seq, axis=axis, out=out)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index b2d8a5b..7a55547 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -29,7 +29,7 @@ from ..symbol import Symbol
from .._internal import _set_np_symbol_class
from . import _internal as _npi
-__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax']
+__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax']
@set_module('mxnet.symbol.numpy')
@@ -1061,6 +1061,31 @@ def stack(arrays, axis=0, out=None):
@set_module('mxnet.symbol.numpy')
+def concatenate(seq, axis=0, out=None):
+ """Join a sequence of arrays along an existing axis.
+
+ Parameters
+ ----------
+ a1, a2, ... : sequence of array_like
+ The arrays must have the same shape, except in the dimension
+ corresponding to `axis` (the first, by default).
+ axis : int, optional
+ The axis along which the arrays will be joined. If axis is None,
+ arrays are flattened before use. Default is 0.
+ out : ndarray, optional
+ If provided, the destination to place the result. The shape must be
+ correct, matching that of what concatenate would have returned if no
+ out argument were specified.
+
+ Returns
+ -------
+ res : ndarray
+ The concatenated array.
+ """
+ return _npi.concatenate(*seq, dim=axis, out=out)
+
+
+@set_module('mxnet.symbol.numpy')
def arange(start, stop=None, step=1, dtype=None, ctx=None):
"""Return evenly spaced values within a given interval.
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 8fb2298..cda9c9a 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -32,9 +32,9 @@
namespace mxnet {
namespace op {
-static bool ConcatShape(const nnvm::NodeAttrs& attrs,
- mxnet::ShapeVector *in_shape,
- mxnet::ShapeVector *out_shape) {
+bool ConcatShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_shape,
+ mxnet::ShapeVector *out_shape) {
using namespace mshadow;
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
@@ -138,9 +138,9 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(dshape);
}
-static bool ConcatType(const nnvm::NodeAttrs& attrs,
- std::vector<int> *in_type,
- std::vector<int> *out_type) {
+bool ConcatType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_type,
+ std::vector<int> *out_type) {
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
int dtype = -1;
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index db479a0..80d70e5 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -24,6 +24,7 @@
*/
#include "./np_matrix_op-inl.h"
+#include "../nn/concat-inl.h"
namespace mxnet {
namespace op {
@@ -252,5 +253,62 @@ Examples::
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
.add_arguments(StackParam::__FIELDS__());
+bool ConcatShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_shape,
+ mxnet::ShapeVector *out_shape);
+
+bool ConcatType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_type,
+ std::vector<int> *out_type);
+
+struct NumpyConcatGrad {
+ const char *op_name;
+ std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
+ const std::vector<nnvm::NodeEntry>& ograds) const {
+ CHECK_EQ(ograds.size(), 1);
+ std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
+ return MakeGradNode(op_name, n, heads, n->attrs.dict);
+ }
+};
+
+
+NNVM_REGISTER_OP(_npi_concatenate)
+.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+ return params.num_args;
+})
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+ std::vector<std::string> ret;
+ for (int i = 0; i < params.num_args; ++i) {
+ ret.push_back(std::string("data") + std::to_string(i));
+ }
+ return ret;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"out"};
+})
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<nnvm::FInferType>("FInferType", ConcatType)
+.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
+.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_concat"})
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+.add_arguments(ConcatParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_np_concat)
+.set_num_outputs([](const NodeAttrs& attrs) {
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+ return params.num_args;
+})
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 615dd26..5980e81 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -23,6 +23,7 @@
* \brief GPU Implementation of numpy matrix operations
*/
#include "./np_matrix_op-inl.h"
+#include "../nn/concat-inl.h"
namespace mxnet {
namespace op {
@@ -36,5 +37,8 @@ NNVM_REGISTER_OP(_np_reshape)
NNVM_REGISTER_OP(_npi_stack)
.set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);
+NNVM_REGISTER_OP(_npi_concatenate)
+.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc
index f7a810b..5835701 100644
--- a/src/operator/quantization/quantized_concat.cc
+++ b/src/operator/quantization/quantized_concat.cc
@@ -28,8 +28,8 @@
namespace mxnet {
namespace op {
-static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shape,
- mxnet::ShapeVector* out_shape) {
+static bool QuantizedConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shape,
+ mxnet::ShapeVector* out_shape) {
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args * 3));
CHECK_EQ(out_shape->size(), 3U);
@@ -74,8 +74,8 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha
return shape_is_known(dshape);
}
-static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector<int>* in_type,
- std::vector<int>* out_type) {
+static bool QuantizedConcatType(const nnvm::NodeAttrs& attrs, std::vector<int>* in_type,
+ std::vector<int>* out_type) {
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_type->size(), static_cast<size_t>(param_.num_args * 3));
CHECK_EQ(out_type->size(), 3U);
@@ -130,8 +130,8 @@ If any input holds int8, then the output will be int8. Otherwise output will be
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
-.set_attr<nnvm::FInferType>("FInferType", ConcatType)
-.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
+.set_attr<nnvm::FInferType>("FInferType", QuantizedConcatType)
+.set_attr<mxnet::FInferShape>("FInferShape", QuantizedConcatShape)
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 9804aea..d00573e 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -633,6 +633,57 @@ def test_np_linalg_norm():
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4)
+@with_seed()
+@npx.use_np_shape
+def test_np_concat():
+ class TestConcat(HybridBlock):
+ def __init__(self, axis=None):
+ super(TestConcat, self).__init__()
+ self._axis = axis
+
+ def hybrid_forward(self, F, a, *args):
+ return F.np.concatenate([a] + list(args), axis=self._axis)
+
+ def get_new_shape(shape, axis):
+ shape_lst = list(shape)
+ shape_lst[axis] = random.randint(0, 3)
+ return tuple(shape_lst)
+
+ for shape in [(0, 0), (2, 3)]:
+ for hybridize in [True, False]:
+ for axis in range(2):
+ # test gluon
+ test_concat = TestConcat(axis=axis)
+ if hybridize:
+ test_concat.hybridize()
+
+ a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
+ a.attach_grad()
+ b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
+ b.attach_grad()
+ c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
+ c.attach_grad()
+ d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
+ d.attach_grad()
+ expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
+ with mx.autograd.record():
+ y = test_concat(a, b, c, d)
+ assert y.shape == expected_ret.shape
+ assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)
+
+ y.backward()
+
+ assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)
+
+ # test imperative
+ mx_out = np.concatenate([a, b, c, d], axis=axis)
+ np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+
if __name__ == '__main__':
import nose
nose.runmodule()