You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/10/13 04:04:32 UTC
[incubator-mxnet] 01/03: [Numpy] Numpy compatible dstack (#15871)
This is an automated email from the ASF dual-hosted git repository.
haoj pushed a commit to branch numpy_pr_merge
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 6913fe29b6bb6eba6b07a8a6843f355fe0f91c14
Author: Mike <ma...@connect.hku.hk>
AuthorDate: Fri Oct 11 17:51:00 2019 -0400
[Numpy] Numpy compatible dstack (#15871)
* Add dstack that pass CPU test
Rgister dstack on GPU
Minor comment fix
Minor syntax fix
Syntax fix according to comments
header fix
* Fix sanity
---
python/mxnet/ndarray/numpy/_op.py | 46 ++++++++++++++-
python/mxnet/numpy/multiarray.py | 48 +++++++++++++++-
python/mxnet/symbol/numpy/_symbol.py | 33 ++++++++++-
src/operator/nn/concat-inl.h | 62 ++++++++++++++++++++
src/operator/numpy/np_matrix_op.cc | 100 ++++++++++++++++++++++++++++++++-
src/operator/numpy/np_matrix_op.cu | 7 +++
tests/python/unittest/test_numpy_op.py | 61 ++++++++++++++++++++
7 files changed, 350 insertions(+), 7 deletions(-)
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index aea8b19..2846d2b 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -32,8 +32,8 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
- 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
- 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
+ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
+ 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
@@ -2468,6 +2468,48 @@ def vstack(arrays, out=None):
@set_module('mxnet.ndarray.numpy')
+def dstack(arrays):
+ """
+ Stack arrays in sequence depth wise (along third axis).
+ This is equivalent to concatenation along the third axis after 2-D arrays
+ of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
+ `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
+ `dsplit`.
+ This function makes most sense for arrays with up to 3 dimensions. For
+ instance, for pixel-data with a height (first axis), width (second axis),
+ and r/g/b channels (third axis). The functions `concatenate`, `stack` and
+ `block` provide more general stacking and concatenation operations.
+
+ Parameters
+ ----------
+ tup : sequence of arrays
+ The arrays must have the same shape along all but the third axis.
+ 1-D or 2-D arrays must have the same shape.
+
+ Returns
+ -------
+ stacked : ndarray
+ The array formed by stacking the given arrays, will be at least 3-D.
+
+ Examples
+ --------
+ >>> a = np.array((1,2,3))
+ >>> b = np.array((2,3,4))
+ >>> np.dstack((a,b))
+ array([[[1, 2],
+ [2, 3],
+ [3, 4]]])
+ >>> a = np.array([[1],[2],[3]])
+ >>> b = np.array([[2],[3],[4]])
+ >>> np.dstack((a,b))
+ array([[[1, 2]],
+ [[2, 3]],
+ [[3, 4]]])
+ """
+ return _npi.dstack(*arrays)
+
+
+@set_module('mxnet.ndarray.numpy')
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index d3ae4d1..00a7709 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -51,8 +51,8 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
- 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
- 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
+ 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
+ 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
@@ -4011,6 +4011,50 @@ def vstack(arrays, out=None):
@set_module('mxnet.numpy')
+def dstack(arrays):
+ """
+ Stack arrays in sequence depth wise (along third axis).
+
+ This is equivalent to concatenation along the third axis after 2-D arrays
+ of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
+ `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
+ `dsplit`.
+
+ This function makes most sense for arrays with up to 3 dimensions. For
+ instance, for pixel-data with a height (first axis), width (second axis),
+ and r/g/b channels (third axis). The functions `concatenate`, `stack` and
+ `block` provide more general stacking and concatenation operations.
+
+ Parameters
+ ----------
+ tup : sequence of arrays
+ The arrays must have the same shape along all but the third axis.
+ 1-D or 2-D arrays must have the same shape.
+
+ Returns
+ -------
+ stacked : ndarray
+ The array formed by stacking the given arrays, will be at least 3-D.
+
+ Examples
+ --------
+ >>> a = np.array((1,2,3))
+ >>> b = np.array((2,3,4))
+ >>> np.dstack((a,b))
+ array([[[1, 2],
+ [2, 3],
+ [3, 4]]])
+ >>> a = np.array([[1],[2],[3]])
+ >>> b = np.array([[2],[3],[4]])
+ >>> np.dstack((a,b))
+ array([[[1, 2]],
+ [[2, 3]],
+ [[3, 4]]])
+ """
+ return _npi.dstack(*arrays)
+
+
+@set_module('mxnet.numpy')
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 9a90942..de11cfb 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -34,8 +34,8 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
- 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
- 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
+ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
+ 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
@@ -2662,6 +2662,35 @@ def vstack(arrays, out=None):
@set_module('mxnet.symbol.numpy')
+def dstack(arrays):
+ """
+ Stack arrays in sequence depth wise (along third axis).
+
+ This is equivalent to concatenation along the third axis after 2-D arrays
+ of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
+ `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
+ `dsplit`.
+
+ This function makes most sense for arrays with up to 3 dimensions. For
+ instance, for pixel-data with a height (first axis), width (second axis),
+ and r/g/b channels (third axis). The functions `concatenate`, `stack` and
+ `block` provide more general stacking and concatenation operations.
+
+ Parameters
+ ----------
+ tup : sequence of _Symbol
+ The arrays must have the same shape along all but the first axis.
+ 1-D arrays must have the same length.
+
+ Returns
+ -------
+ stacked : _Symbol
+ The array formed by stacking the given arrays, will be at least 2-D.
+ """
+ return _npi.dstack(*arrays)
+
+
+@set_module('mxnet.symbol.numpy')
def maximum(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h
index 7a58ae6..1fb20ac 100644
--- a/src/operator/nn/concat-inl.h
+++ b/src/operator/nn/concat-inl.h
@@ -142,6 +142,37 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
}
template<typename xpu>
+void DStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
+ param.dim = 2;
+ std::vector<TBlob> modified_inputs(inputs.size());
+ for (int i = 0; i < param.num_args; ++i) {
+ if (inputs[i].shape_.ndim() == 0) {
+ modified_inputs[i] = inputs[i].reshape(TShape(3, 1));
+ } else if (inputs[i].shape_.ndim() == 1) {
+ TShape t = TShape(3, 1);
+ t[1] = inputs[i].shape_[0];
+ modified_inputs[i] = inputs[i].reshape(t);
+ } else if (inputs[i].shape_.ndim() == 2) {
+ TShape t = TShape(3, 1);
+ t[0] = inputs[i].shape_[0];
+ t[1] = inputs[i].shape_[1];
+ modified_inputs[i] = inputs[i].reshape(t);
+ } else {
+ modified_inputs[i] = inputs[i];
+ }
+ }
+ MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
+ ConcatOp<xpu, DType> op;
+ op.Init(param);
+ op.Forward(ctx, modified_inputs, req, outputs);
+ });
+}
+
+template<typename xpu>
void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
@@ -154,6 +185,37 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
});
}
+template<typename xpu>
+void DStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
+ param.dim = 2;
+ std::vector<TBlob> modified_outputs(outputs.size());
+ for (int i = 0; i < param.num_args; ++i) {
+ if (outputs[i].shape_.ndim() == 0) {
+ modified_outputs[i] = outputs[i].reshape(TShape(3, 1));
+ } else if (outputs[i].shape_.ndim() == 1) {
+ TShape t = TShape(3, 1);
+ t[1] = outputs[i].shape_[0];
+ modified_outputs[i] = outputs[i].reshape(t);
+ } else if (outputs[i].shape_.ndim() == 2) {
+ TShape t = TShape(3, 1);
+ t[0] = outputs[i].shape_[0];
+ t[1] = outputs[i].shape_[1];
+ modified_outputs[i] = outputs[i].reshape(t);
+ } else {
+ modified_outputs[i] = outputs[i];
+ }
+ }
+ MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
+ ConcatOp<xpu, DType> op;
+ op.Init(param);
+ op.Backward(ctx, inputs[concat_enum::kOut], req, modified_outputs);
+ });
+}
+
/*!
* \brief concat CSRNDArray on the first dimension.
*/
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index f54f325..64d4a03 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -255,6 +255,67 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape);
+bool DStackShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_shape,
+ mxnet::ShapeVector *out_shape) {
+ using namespace mshadow;
+ ConcatParam param_ = nnvm::get<ConcatParam>(attrs.parsed);
+ CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
+ mxnet::TShape dshape;
+ dim_t size = 0;
+ bool has_unknown_dim_size = false;
+ int axis = 2;
+ param_.dim = axis;
+ for (int i = 0; i < param_.num_args; ++i) {
+ if ((*in_shape)[i].ndim() == 0) {
+ (*in_shape)[i] = mxnet::TShape(3, 1);
+ } else if ((*in_shape)[i].ndim() == 1) {
+ mxnet::TShape t = mxnet::TShape(3, 1);
+ t[1] = (*in_shape)[i][0];
+ (*in_shape)[i] = t;
+ } else if ((*in_shape)[i].ndim() == 2) {
+ mxnet::TShape t = mxnet::TShape(3, 1);
+ t[0] = (*in_shape)[i][0];
+ t[1] = (*in_shape)[i][1];
+ (*in_shape)[i] = t;
+ }
+ mxnet::TShape &tmp = (*in_shape)[i];
+ if (tmp.ndim() > 0) {
+ CheckAxis(axis, tmp.ndim());
+ if (!mxnet::dim_size_is_known(tmp, axis)) {
+ has_unknown_dim_size = true;
+ } else {
+ size += tmp[axis];
+ }
+ tmp[axis] = -1;
+ shape_assign(&dshape, tmp);
+ }
+ }
+
+ mxnet::TShape tmp = (*out_shape)[0];
+ if (tmp.ndim() > 0) {
+ axis = CheckAxis(param_.dim, tmp.ndim());
+ tmp[axis] = -1;
+ shape_assign(&dshape, tmp);
+ }
+
+ if (dshape.ndim() == -1) return false;
+ CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";
+
+ for (int i = 0; i < param_.num_args; ++i) {
+ CHECK(shape_assign(&(*in_shape)[i], dshape))
+ << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
+ }
+
+ if (!has_unknown_dim_size) {
+ dshape[axis] = size;
+ }
+ CHECK(shape_assign(&(*out_shape)[0], dshape))
+ << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
+
+ return shape_is_known(dshape);
+}
+
bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type);
@@ -269,7 +330,6 @@ struct NumpyConcatGrad {
}
};
-
NNVM_REGISTER_OP(_npi_concatenate)
.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
@@ -490,6 +550,44 @@ NNVM_REGISTER_OP(_backward_np_vstack)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyVstackBackward<cpu>);
+NNVM_REGISTER_OP(_npi_dstack)
+.describe(R"code(Stack tensors in sequence depthwise (in third dimension))code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+ return params.num_args;
+})
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+ std::vector<std::string> ret;
+ for (int i = 0; i < params.num_args; ++i) {
+ ret.push_back(std::string("data") + std::to_string(i));
+ }
+ return ret;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"out"};
+})
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<nnvm::FInferType>("FInferType", ConcatType)
+.set_attr<mxnet::FInferShape>("FInferShape", DStackShape)
+.set_attr<FCompute>("FCompute<cpu>", DStackCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_dstack"})
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+.add_arguments(ConcatParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_np_dstack)
+.set_num_outputs([](const NodeAttrs& attrs) {
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+ return params.num_args;
+})
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", DStackGradCompute<cpu>);
+
inline bool NumpyRollShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index caab410..125cd91 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -53,6 +53,12 @@ NNVM_REGISTER_OP(_npi_vstack)
NNVM_REGISTER_OP(_backward_np_vstack)
.set_attr<FCompute>("FCompute<gpu>", NumpyVstackBackward<gpu>);
+NNVM_REGISTER_OP(_npi_dstack)
+.set_attr<FCompute>("FCompute<gpu>", DStackCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_np_dstack)
+.set_attr<FCompute>("FCompute<gpu>", DStackGradCompute<gpu>);
+
NNVM_REGISTER_OP(_np_roll)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollCompute<gpu>);
@@ -90,5 +96,6 @@ NNVM_REGISTER_OP(_npi_flip)
NNVM_REGISTER_OP(_backward_npi_flip)
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
+
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 89fe576..978d5d3 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1687,6 +1687,67 @@ def test_np_stack():
@with_seed()
@use_np
+def test_np_dstack():
+ class TestDStack(HybridBlock):
+ def __init__(self):
+ super(TestDStack, self).__init__()
+
+ def hybrid_forward(self, F, a, *args):
+ return F.np.dstack([a] + list(args))
+
+ def get_new_shape(shape):
+ if len(shape) < 3:
+ return shape
+ axis = 2
+ shape_lst = list(shape)
+ shape_lst[axis] = random.randint(0, 5)
+ return tuple(shape_lst)
+
+ shapes = [
+ (),
+ (1,),
+ (2,1),
+ (2,2,4),
+ (2,0,0),
+ (0,1,3),
+ (2,0,3),
+ (2,3,4,5)
+ ]
+ for hybridize in [True, False]:
+ for shape in shapes:
+ test_dstack = TestDStack()
+ if hybridize:
+ test_dstack.hybridize()
+ # test symbolic forward
+ a = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
+ a.attach_grad()
+ b = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
+ b.attach_grad()
+ c = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
+ c.attach_grad()
+ d = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
+ d.attach_grad()
+ with mx.autograd.record():
+ mx_out = test_dstack(a, b, c, d)
+ np_out = _np.dstack((a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()))
+ assert mx_out.shape == np_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+ # test symbolic backward
+ mx_out.backward()
+ assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)
+
+ # test imperative
+ mx_out = np.dstack((a, b, c, d))
+ np_out = _np.dstack((a.asnumpy(),b.asnumpy(), c.asnumpy(), d.asnumpy()))
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+
+@with_seed()
+@use_np
def test_np_ravel():
class TestRavel(HybridBlock):
def __init__(self):