You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/03/05 23:26:52 UTC

[incubator-mxnet] branch master updated: Relaxing type requirements for reshape_like op (#14325)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d754da3  Relaxing type requirements for reshape_like op (#14325)
d754da3 is described below

commit d754da3aafd783952cbd41faccde20f85701ce74
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Tue Mar 5 15:26:32 2019 -0800

    Relaxing type requirements for reshape_like op (#14325)
    
    * Relax type requirements in reshape_like
    
    * Add test
    
    * Fix lint
    
    * Retrigger CI
---
 src/operator/tensor/elemwise_unary_op_basic.cc | 11 ++++++++++-
 tests/python/unittest/test_operator.py         | 10 ++++++++++
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 4aaf4df..19a9ac8 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -481,7 +481,16 @@ Negative indices are supported, and `None` can be used for either `lhs_end` or `
     [](const NodeAttrs& attrs) { return std::vector<uint32_t>(1, 1); })
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
 .set_attr<mxnet::FInferShape>("FInferShape", ReshapeLikeShapeCompute)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
+                                             std::vector<int> *in_attrs,
+                                             std::vector<int> *out_attrs) {
+    CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name;
+    std::vector<int> checked_in_attrs = { (*in_attrs)[0] };
+    bool ret = !type_is_none((*in_attrs)[1]) &&
+               ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs);
+    (*in_attrs)[0] = checked_in_attrs[0];
+    return ret;
+  })
 .set_attr<nnvm::FGradient>(
     "FGradient",  [](const nnvm::NodePtr& n,
                      const std::vector<nnvm::NodeEntry>& ograds) {
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 500d2f9..0ac530c 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2530,6 +2530,16 @@ def test_slice_like_different_types():
     assert_allclose(z.asnumpy(), [[1,2,3],[5,6,7]])
 
 @with_seed()
+def test_reshape_like_different_types():
+    x = mx.nd.zeros((2, 3))
+
+    y = mx.nd.array([[1, 2], [3, 4], [5, 6]])
+
+    y = mx.nd.array(y).astype('int32')
+    z = mx.nd.reshape_like(x, y)
+    assert_allclose(z.asnumpy(), [[0,0],[0,0],[0,0]])
+
+@with_seed()
 def test_flip():
     for ndim in range(1, 6):
         for t in range(5):