You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/12/16 06:16:30 UTC

[incubator-mxnet] branch master updated: large tensor tests batch 6 (#19617)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ccc53d  large tensor tests batch 6 (#19617)
9ccc53d is described below

commit 9ccc53dd27f213e52e4d088319b194520ebceb79
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Tue Dec 15 22:15:33 2020 -0800

    large tensor tests batch 6 (#19617)
    
    * last batch of large tensor tests
    
    * Update test_np_large_array.py
    
    * Update test_np_large_array.py
---
 src/operator/numpy/np_matmul_op-inl.h |   2 +-
 tests/nightly/test_np_large_array.py  | 228 +++++++++++++++++++++++++++++++++-
 2 files changed, 224 insertions(+), 6 deletions(-)

diff --git a/src/operator/numpy/np_matmul_op-inl.h b/src/operator/numpy/np_matmul_op-inl.h
index 8f1b4f9..935de8b 100644
--- a/src/operator/numpy/np_matmul_op-inl.h
+++ b/src/operator/numpy/np_matmul_op-inl.h
@@ -108,7 +108,7 @@ struct SumByShape {
    * \note in_size >= out_size
    */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* output, DType* input,
+  MSHADOW_XINLINE static void Map(index_t i, DType* output, DType* input,
                                   size_t in_size, size_t out_size,
                                   const int req){
     // i is the global position in flattened output
diff --git a/tests/nightly/test_np_large_array.py b/tests/nightly/test_np_large_array.py
index 99de1b7..a423253 100644
--- a/tests/nightly/test_np_large_array.py
+++ b/tests/nightly/test_np_large_array.py
@@ -1691,13 +1691,23 @@ def test_arange_like():
     assert A.grad.shape == (INT_OVERFLOW, 2)
     assert A.grad[0][0] == 0
 
-# TODO implement this test after dot is fixed for large tensors and we have
-# migrated to using ILP64 BLAS/LAPACK
 @use_np
-@pytest.mark.skip(reason='dot is known to not work on large tensors. PR to fix: https://github.com/apache/incubator-mxnet/pull/18925')
 def test_batch_dot():
-    assert False 
-
+    inp1 = np.zeros((2, 1, INT_OVERFLOW))
+    inp2 = np.zeros((2, INT_OVERFLOW, 1))
+    inp1[-1, -1, -1] = 2
+    inp2[-1, -1, -1] = 3
+    inp1.attach_grad()
+    inp2.attach_grad()
+    with mx.autograd.record():
+        out = npx.batch_dot(inp1, inp2)
+        out.backward()
+    assert out.shape == (2, 1, 1)
+    assert out[1, 0, 0] == 6
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[-1, -1, -1] == 3
+    assert inp2.grad.shape == inp2.shape
+    assert inp2.grad[-1, -1, -1] == 2
 
 @use_np
 def test_cast():
@@ -2723,6 +2733,212 @@ def test_insert():
     assertRaises(MXNetError, np.insert, arr=inp3, obj=np.array([2, 2], dtype=np.int64), values=np.array([5, 6]))
 
 
+@use_np
+def test_moveaxis():
+    inp = np.zeros((2, 1, INT_OVERFLOW))
+    inp[0, 0, -1], inp[1, 0, -1] = 1, 2
+    inp.attach_grad()
+    with mx.autograd.record():
+        out = np.moveaxis(inp, 2, 0)
+        out.backward()
+    assert out.shape == (INT_OVERFLOW, 2, 1)
+    assert out[-1, 0, 0] == 1 and out[-1, 1, 0] == 2
+    assert inp.grad.shape == inp.shape
+    assert inp.grad[-1, -1, -1] == 1
+
+
+@use_np
+def test_newaxis():
+    inp = np.zeros((2, INT_OVERFLOW))
+    inp[-1, -1] = 1
+    out1 = inp[np.newaxis, :, :]
+    assert out1.shape == (1, 2, INT_OVERFLOW)
+    assert out1[0, -1, -1] == 1
+    out1 = out1[:, :, :, np.newaxis]
+    assert out1.shape == (1, 2, INT_OVERFLOW, 1)
+    assert out1[0, -1, -1, 0] == 1
+
+
+@use_np
+def test_triu_indices():
+    N = 2**16
+    data = np.triu_indices(N, 1)
+    assert data[0].shape == (((1 + (N-1)) * (N-1) / 2), )
+    assert data[0][-1] == N - 2 and data[1][-1] == N - 1
+
+
+@use_np
+def test_triu_indices_from():
+    N = 2**16
+    arr = np.zeros((N, N))
+    data = np.triu_indices_from(arr, 1)
+    assert data[0].shape == (((1 + (N-1)) * (N-1) / 2), )
+    assert data[0][-1] == N - 2 and data[1][-1] == N - 1
+
+
+@use_np
+def test_empty():
+    data = np.empty((2, INT_OVERFLOW), dtype='float64')
+    data = data + 1
+    assert data.shape == (2, INT_OVERFLOW)
+    assert data[-1, -1] == 1
+
+
+@use_np
+def test_shape_reshape():
+    inp = np.zeros((2, INT_OVERFLOW))
+    inp[0, -1] = 1
+    assert np.shape(inp) == (2, INT_OVERFLOW)
+    out = np.reshape(inp, (INT_OVERFLOW, 2))
+    assert np.shape(inp) == (2, INT_OVERFLOW)
+    assert np.shape(out) == (INT_OVERFLOW, 2)
+    assert out[HALF_INT_OVERFLOW-1, 1] == 1
+
+
+@use_np
+def test_copy():
+    inp = np.zeros((2, INT_OVERFLOW))
+    inp[1, -1] = 2
+    out = np.copy(inp)
+    out[0, -1] = 3
+    assert out.shape == inp.shape
+    assert inp[0, -1] == 0 and inp[1, -1] == 2
+    assert out[0, -1] == 3 and inp[1, -1] == 2
+
+
+@use_np
+def test_broadcast_arrays():
+    inp1 = np.ones((INT_OVERFLOW))
+    inp1[-1] = 2
+    inp2 = np.array([[3], [4]])
+    inp1.attach_grad()
+    inp2.attach_grad()
+    with mx.autograd.record():
+        out = np.broadcast_arrays(inp1, inp2)
+        out[0].backward()
+        out[1].backward()
+    assert out[0].shape == (2, INT_OVERFLOW)
+    assert out[0][-1, -1] == 2
+    assert out[1].shape == (2, INT_OVERFLOW)
+    assert out[1][0, -1] == 3 and out[1][1, -1] == 4
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[-1] == 2
+    assert inp2.grad.shape == inp2.shape
+    assert inp2.grad[-1, -1] == INT_OVERFLOW
+
+
+@use_np
+def test_inner():
+    inp1 = np.ones((INT_OVERFLOW))
+    inp2 = np.zeros((INT_OVERFLOW))
+    inp2[-1] = 3
+    inp1.attach_grad()
+    inp2.attach_grad()
+    with mx.autograd.record():
+        out = np.inner(inp1, inp2)
+        out.backward()
+    assert out.shape == ()
+    assert out == 3
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[-1] == 3
+    assert inp2.grad.shape == inp2.shape
+    assert inp2.grad[-1] == 1
+
+
+@use_np
+def test_matmul():
+    inp1 = np.ones((1, 2, INT_OVERFLOW), dtype='float64')
+    inp2 = np.ones((INT_OVERFLOW, 1), dtype='float64')
+    inp1.attach_grad()
+    inp2.attach_grad()
+    with mx.autograd.record():
+        out = np.matmul(inp1, inp2)
+        out.backward()
+    assert out.shape == (1, 2, 1)
+    assert out[0, 0, 0] == INT_OVERFLOW
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[-1, -1, -1] == 1
+    assert inp2.grad.shape == inp2.shape
+    assert inp2.grad[-1, -1] == 2
+
+
+@use_np
+def test_outer():
+    inp1 = np.ones((INT_OVERFLOW), dtype='float64')
+    inp1[-1] = 2
+    inp2 = np.ones((2), dtype='float64')
+    inp2[-1] = 3
+    inp1.attach_grad()
+    inp2.attach_grad()
+    with mx.autograd.record():
+        out = np.outer(inp1, inp2)
+        out.backward()
+    assert out.shape == (INT_OVERFLOW, 2)
+    assert out[-1, 0] == 2 and out[0, -1] == 3 and out[-1, -1] == 6
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[0] == 2 + 3 - 1
+    assert inp2.grad.shape == inp2.shape
+    assert inp2.grad[0] == INT_OVERFLOW + 2 - 1
+
+
+@use_np
+def test_tensordot():
+    inp1 = np.ones((1, INT_OVERFLOW, 2), dtype='float64')
+    inp1[0, -1, 1] = 2
+    inp2 = np.ones((INT_OVERFLOW, 1, 1), dtype='float64')
+    inp2[-1, 0, 0] = 3
+    inp1.attach_grad()
+    inp2.attach_grad()
+    with mx.autograd.record():
+        out = np.tensordot(inp1, inp2, axes=[[0, 1], [1, 0]])
+        out.backward()
+    assert out.shape == (2, 1)
+    assert out[0] == INT_OVERFLOW + 3 - 1 and out[1] == INT_OVERFLOW + 6 - 1
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[-1, -1, -1] == 3 and inp1.grad[0, 0, 0] == 1
+    assert inp2.grad.shape == inp2.shape
+    assert inp2.grad[-1, -1, -1] == 3 and inp2.grad[0, 0, 0] == 2
+
+
+@use_np
+def test_vdot():
+    inp1 = np.ones((2, INT_OVERFLOW))
+    inp2 = np.zeros((INT_OVERFLOW, 2))
+    inp1[0, -1] = 2
+    inp2[HALF_INT_OVERFLOW-1, 1] = 3
+    inp1.attach_grad()
+    inp2.attach_grad()
+    with mx.autograd.record():
+        out = np.vdot(inp1, inp2)
+        out.backward()
+    assert out.shape == ()
+    assert out == 6
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[0, -1] == 3 and inp1.grad[-1, -1] == 0
+    assert inp2.grad.shape == inp2.shape
+    assert inp2.grad[HALF_INT_OVERFLOW-1, 1] == 2 and inp2.grad[-1, -1] == 1
+
+
+@use_np
+def test_dot():
+    inp1 = np.zeros((1, INT_OVERFLOW))
+    inp2 = np.zeros((INT_OVERFLOW, 1))
+    inp1[-1, -1] = 2
+    inp2[-1, -1] = 3
+    inp1.attach_grad()
+    inp2.attach_grad()
+    with mx.autograd.record():
+        out = np.dot(inp1, inp2)
+        out.backward()
+    assert out.shape == (1, 1)
+    assert out[0, 0] == 6
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[-1, -1] == 3
+    assert inp2.grad.shape == inp2.shape
+    assert inp2.grad[-1, -1] == 2
+
+
+@use_np
 def test_convolution():
     dim = 2
     batch_size = 1
@@ -2768,6 +2984,7 @@ def test_deconvolution():
     assert inp.grad[0][0][0][0] == 0
 
 
+@use_np
 def test_dropout():
     shape = (LARGE_X, SMALL_Y)
     inp = mx.np.ones(shape=shape)
@@ -2780,6 +2997,7 @@ def test_dropout():
     assert inp.grad.shape == shape
     assert _np.count_nonzero(inp.grad[0] == 2) != 0
 
+
 @use_np
 def test_log_softmax():
     LOG_SOFTMAX_VAL = -18.420681