You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/21 06:11:44 UTC
[incubator-mxnet] branch master updated: fix type inference in
index_copy. (#12890)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d3d343c fix type inference in index_copy. (#12890)
d3d343c is described below
commit d3d343c782fbb9a90009df4cf0e63d9eadab4c3f
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Sat Oct 20 23:11:30 2018 -0700
fix type inference in index_copy. (#12890)
---
src/operator/contrib/index_copy.cc | 12 +++++++++++-
tests/python/unittest/test_operator.py | 2 +-
2 files changed, 12 insertions(+), 2 deletions(-)
diff --git a/src/operator/contrib/index_copy.cc b/src/operator/contrib/index_copy.cc
index 07067a3..316c8a7 100644
--- a/src/operator/contrib/index_copy.cc
+++ b/src/operator/contrib/index_copy.cc
@@ -26,6 +26,16 @@
namespace mxnet {
namespace op {
+static bool IndexCopyType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 3U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+ TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+ return out_attrs->at(0) != -1;
+}
+
NNVM_REGISTER_OP(_contrib_index_copy)
.describe(R"code(Copies the elements of a `new_tensor` into the `old_tensor` by
selecting the indices in the order given in `index`. The output will be a new tensor
@@ -56,7 +66,7 @@ mx.nd.contrib.index_copy(x, index, t)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", IndexCopyShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", IndexCopyType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_contrib_backward_index_copy"})
.set_attr<FCompute>("FCompute<cpu>", IndexCopyForward<cpu>)
.add_argument("old_tensor", "NDArray-or-Symbol", "Old tensor")
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 5df7d97..3292541 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4766,7 +4766,7 @@ def test_quantization_op():
def test_index_copy():
x = mx.nd.zeros((5,3))
t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]])
- index = mx.nd.array([0,4,2])
+ index = mx.nd.array([0,4,2], dtype=np.int64)
x.attach_grad()
t.attach_grad()