You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/11/15 20:56:55 UTC
[incubator-mxnet] branch master updated: Fix a bug in index_copy
(#13218)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e7f9770 Fix a bug in index_copy (#13218)
e7f9770 is described below
commit e7f97701a097353ced06b0074a7f41def2b4c40b
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Fri Nov 16 04:56:40 2018 +0800
Fix a bug in index_copy (#13218)
* fix.
* add test.
* retrigger
---
src/operator/contrib/index_copy-inl.h | 14 ++++++++------
tests/python/unittest/test_operator.py | 19 ++++++++++---------
2 files changed, 18 insertions(+), 15 deletions(-)
diff --git a/src/operator/contrib/index_copy-inl.h b/src/operator/contrib/index_copy-inl.h
index b97138a..923fb0f 100644
--- a/src/operator/contrib/index_copy-inl.h
+++ b/src/operator/contrib/index_copy-inl.h
@@ -32,6 +32,7 @@
#include "../elemwise_op_common.h"
#include "../mshadow_op.h"
#include "../mxnet_op.h"
+#include "../tensor/init_op.h"
namespace mxnet {
namespace op {
@@ -83,12 +84,12 @@ void IndexCopyForward(const nnvm::NodeAttrs& attrs,
});
}
-template<int req>
struct index_copy_backward {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i,
int dim,
int index_size,
+ int req1, int req2,
DType* out_grad,
IType* index,
DType* in_grad_1,
@@ -98,12 +99,12 @@ struct index_copy_backward {
int idx = static_cast<int>(index[p]);
if (i >= idx*dim && i < (idx+1)*dim) {
int offset = i - idx*dim;
- KERNEL_ASSIGN(in_grad_2[p*dim+offset], req, out_grad[i]);
+ KERNEL_ASSIGN(in_grad_2[p*dim+offset], req2, out_grad[i]);
return;
}
}
// Copy to in_grad_1
- KERNEL_ASSIGN(in_grad_1[i], req, out_grad[i]);
+ KERNEL_ASSIGN(in_grad_1[i], req1, out_grad[i]);
}
};
@@ -122,18 +123,19 @@ void IndexCopyBackward(const nnvm::NodeAttrs& attrs,
const TBlob& in_grad_2 = outputs[2];
int dim = inputs[3].Size() / inputs[2].Size();
int index_size = inputs[2].Size();
+ Fill<false>(s, outputs[0], req[0], 0);
+ Fill<false>(s, outputs[2], req[2], 0);
// index_copy_backward
MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(index.type_flag_, IType, {
- MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
- mxnet_op::Kernel<index_copy_backward<req_type>, xpu>::Launch(s,
+ mxnet_op::Kernel<index_copy_backward, xpu>::Launch(s,
out_grad.Size(),
dim, index_size,
+ req[0], req[2],
out_grad.dptr<DType>(),
index.dptr<IType>(),
in_grad_1.dptr<DType>(),
in_grad_2.dptr<DType>());
- });
});
});
}
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 5fe9e3e..283a282 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4773,24 +4773,25 @@ def test_index_copy():
x = mx.nd.zeros((5,3))
t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]])
index = mx.nd.array([0,4,2], dtype=np.int64)
+ tensor = mx.nd.array([[1,2,3],[0,0,0],[7,8,9],[0,0,0],[4,5,6]])
+ x_grad = mx.nd.array([[0,0,0],[1,1,1],[0,0,0],[1,1,1],[0,0,0]])
+ t_grad = mx.nd.array([[1,1,1],[1,1,1],[1,1,1]])
- x.attach_grad()
t.attach_grad()
- index.attach_grad()
-
with mx.autograd.record():
out = mx.nd.contrib.index_copy(x, index, t)
out.backward()
+ assert same(out.asnumpy(), tensor.asnumpy())
+ assert same(t.grad.asnumpy(), t_grad.asnumpy())
- tensor = mx.nd.array([[1,2,3],[0,0,0],[7,8,9],[0,0,0],[4,5,6]])
- x_grad = mx.nd.array([[0,0,0],[1,1,1],[0,0,0],[1,1,1],[0,0,0]])
- t_grad = mx.nd.array([[1,1,1],[1,1,1],[1,1,1]])
- index_grad = mx.nd.array([0,0,0])
-
+ x.attach_grad()
+ t.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.contrib.index_copy(x, index, t)
+ out.backward()
assert same(out.asnumpy(), tensor.asnumpy())
assert same(x.grad.asnumpy(), x_grad.asnumpy())
assert same(t.grad.asnumpy(), t_grad.asnumpy())
- assert same(index.grad.asnumpy(), index_grad.asnumpy())
@with_seed()
def test_div_sqrt_dim():