You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/02/14 19:15:32 UTC
[incubator-mxnet] branch master updated: Relaxing type requirements
for slice_like op (#14097)
This is an automated email from the ASF dual-hosted git repository.
zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 73a9f1c Relaxing type requirements for slice_like op (#14097)
73a9f1c is described below
commit 73a9f1ceb3d96b819472050514a59a5eae4baa92
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Thu Feb 14 11:15:10 2019 -0800
Relaxing type requirements for slice_like op (#14097)
* Relaxing types for slice_like op
* Added test
* Fix typo in test
* Fix lint
---
src/operator/tensor/matrix_op.cc | 11 ++++++++++-
tests/python/unittest/test_operator.py | 14 ++++++++++++++
2 files changed, 24 insertions(+), 1 deletion(-)
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index e5d354b..3a244ac 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -661,7 +661,16 @@ Example::
return std::vector<std::string>{"data", "shape_like"};
})
.set_attr<nnvm::FInferShape>("FInferShape", SliceLikeShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name;
+ std::vector<int> checked_in_attrs = { (*in_attrs)[0] };
+ bool ret = !type_is_none((*in_attrs)[1]) &&
+ ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs);
+ (*in_attrs)[0] = checked_in_attrs[0];
+ return ret;
+ })
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice_like"})
.set_attr<FCompute>("FCompute<cpu>", SliceLikeForward<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Source input")
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 1f42215..fc003b2 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2516,6 +2516,20 @@ def test_slice_like():
assert_allclose(xgrad1.asnumpy(), mx.nd.zeros_like(xgrad1).asnumpy())
@with_seed()
+def test_slice_like_different_types():
+ x = [[ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.]]
+
+ y = [[ 0., 0., 0.],
+ [ 0., 0., 0.]]
+
+ x = mx.nd.array(x)
+ y = mx.nd.array(y).astype('int32')
+ z = mx.nd.slice_like(x, y)
+ assert_allclose(z.asnumpy(), [[1,2,3],[5,6,7]])
+
+@with_seed()
def test_flip():
for ndim in range(1, 6):
for t in range(5):