You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/21 06:11:44 UTC

[incubator-mxnet] branch master updated: fix type inference in index_copy. (#12890)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d3d343c  fix type inference in index_copy. (#12890)
d3d343c is described below

commit d3d343c782fbb9a90009df4cf0e63d9eadab4c3f
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Sat Oct 20 23:11:30 2018 -0700

    fix type inference in index_copy. (#12890)
---
 src/operator/contrib/index_copy.cc     | 12 +++++++++++-
 tests/python/unittest/test_operator.py |  2 +-
 2 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/src/operator/contrib/index_copy.cc b/src/operator/contrib/index_copy.cc
index 07067a3..316c8a7 100644
--- a/src/operator/contrib/index_copy.cc
+++ b/src/operator/contrib/index_copy.cc
@@ -26,6 +26,16 @@
 namespace mxnet {
 namespace op {
 
+static bool IndexCopyType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int> *in_attrs,
+                          std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+  return out_attrs->at(0) != -1;
+}
+
 NNVM_REGISTER_OP(_contrib_index_copy)
 .describe(R"code(Copies the elements of a `new_tensor` into the `old_tensor` by 
 selecting the indices in the order given in `index`. The output will be a new tensor 
@@ -56,7 +66,7 @@ mx.nd.contrib.index_copy(x, index, t)
 .set_num_inputs(3)
 .set_num_outputs(1)
 .set_attr<nnvm::FInferShape>("FInferShape", IndexCopyShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", IndexCopyType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_contrib_backward_index_copy"})
 .set_attr<FCompute>("FCompute<cpu>", IndexCopyForward<cpu>)
 .add_argument("old_tensor", "NDArray-or-Symbol", "Old tensor")
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 5df7d97..3292541 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4766,7 +4766,7 @@ def test_quantization_op():
 def test_index_copy():
     x = mx.nd.zeros((5,3))
     t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]])
-    index = mx.nd.array([0,4,2])
+    index = mx.nd.array([0,4,2], dtype=np.int64)
 
     x.attach_grad()
     t.attach_grad()