You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2020/08/27 14:28:31 UTC
[incubator-mxnet] branch master updated: Numpy Gather ND Large
Tensor fix (#18981)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 998c6ce Numpy Gather ND Large Tensor fix (#18981)
998c6ce is described below
commit 998c6cea73fd788ebfa3e30b7182a8397e939d0e
Author: Zhaoqi Zhu <zh...@usc.edu>
AuthorDate: Thu Aug 27 07:21:20 2020 -0700
Numpy Gather ND Large Tensor fix (#18981)
* Update test_np_large_array.py
* fix indexing
* update tests
Co-authored-by: Zhu <zh...@3c22fbbb4e1a.ant.amazon.com>
Co-authored-by: Ubuntu <ub...@ip-172-31-38-169.us-west-2.compute.internal>
---
src/operator/tensor/indexing_op.cc | 10 +++++-----
tests/nightly/test_np_large_array.py | 16 ++++++++++++++++
2 files changed, 21 insertions(+), 5 deletions(-)
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index e256a02..b3a527e 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -455,7 +455,7 @@ void GatherNDCheckBoundCPU(mshadow::Stream<cpu> *s, const DType* idx_ptr, index_
using namespace mxnet_op;
Kernel<set_zero, cpu>::Launch(s, M, is_valid_dim_ptr);
Kernel<is_valid_check_gather_nd, cpu>::Launch(s, M, is_valid_dim_ptr, idx_ptr, N, mshape);
- for (int m = 0; m < M; m++) {
+ for (index_t m = 0; m < M; m++) {
if (is_valid_dim_ptr[m] > mshape[m] - 1 || is_valid_dim_ptr[m] < - mshape[m]) {
LOG(FATAL)<< "IndexError: index " << is_valid_dim_ptr[m] << " is out of bounds for axis "
<< m << " with size " << mshape[m];
@@ -476,12 +476,12 @@ void GatherNDForwardCPU(const nnvm::NodeAttrs& attrs,
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
const mxnet::TShape& dshape = inputs[0].shape_;
const mxnet::TShape& ishape = inputs[1].shape_;
- int M = ishape[0];
- int N = ishape.Size() / M;
- int K = dshape.ProdShape(M, dshape.ndim());
+ index_t M = ishape[0];
+ index_t N = ishape.Size() / M;
+ index_t K = dshape.ProdShape(M, dshape.ndim());
mshadow::Shape<10> strides;
mshadow::Shape<10> mshape;
- for (int i = M-1, stride = K; i >= 0; stride *= dshape[i], --i) {
+ for (index_t i = M-1, stride = K; i >= 0; stride *= dshape[i], --i) {
strides[i] = stride;
mshape[i] = dshape[i];
}
diff --git a/tests/nightly/test_np_large_array.py b/tests/nightly/test_np_large_array.py
index 6a473fd..692b29e 100644
--- a/tests/nightly/test_np_large_array.py
+++ b/tests/nightly/test_np_large_array.py
@@ -1044,6 +1044,22 @@ def test_save_load():
assert B[0][0][100] == 100
@use_np
+def test_gather_nd():
+ A = np.ones((1, 2, INT_OVERFLOW))
+ A [0, 1, 100] = 100
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.gather_nd(data=A, \
+ indices=np.array([[0, 0] , [0, 1], [INT_OVERFLOW-1, 100]], \
+ dtype='int64'))
+ assert B.shape == (2, )
+ assert B[0] == 1 and B[1] == 100
+ B.backward()
+ assert A.grad.shape == (1, 2, INT_OVERFLOW)
+ assert A.grad[0, 0, 0] == 0
+ assert A.grad[0, 0, INT_OVERFLOW-1] == 1
+
+@use_np
def test_random_bernoulli():
prob = np.zeros((INT_OVERFLOW))
prob[0] = 1