You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2020/08/19 21:12:29 UTC

[incubator-mxnet] branch master updated: Numpy Ops Large Tensor Tests (#18932)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 69ec338  Numpy Ops Large Tensor Tests (#18932)
69ec338 is described below

commit 69ec3389f301eff693c124c09a96a77782eba6de
Author: Zhaoqi Zhu <zh...@usc.edu>
AuthorDate: Wed Aug 19 14:10:55 2020 -0700

    Numpy Ops Large Tensor Tests (#18932)
    
    * move tests into nightly
    
    * Update test_np_large_array.py
    
    * Update test_np_large_array.py
    
    * Update test_np_large_array.py
    
    * Update test_np_large_array.py
    
    * add more tests
    
    * Add more tests
    
    * add more ops
    
    * add more tests
    
    * add npx tests, add backward test for some np tests
    
    * add backward tests to exting numpy tests and add npx tests
    
    * add more npx tests
    
    * revisit tests
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-10-124.us-west-2.compute.internal>
---
 tests/nightly/test_np_large_array.py | 662 ++++++++++++++++++++++++++++++++++-
 1 file changed, 659 insertions(+), 3 deletions(-)

diff --git a/tests/nightly/test_np_large_array.py b/tests/nightly/test_np_large_array.py
index 7f13135..28d8aeb 100644
--- a/tests/nightly/test_np_large_array.py
+++ b/tests/nightly/test_np_large_array.py
@@ -37,7 +37,8 @@ LARGE_X = 100000000
 SMALL_X = 100
 SMALL_Y = 50
 INT_OVERFLOW = 2**31
-
+HALF_INT_OVERFLOW = 2**30
+DOUBLE_INT_OVERFLOW = 2**32
 
 @use_np
 def test_gluon_embedding():
@@ -78,7 +79,390 @@ def test_softmax():
         output = npx.softmax(input_data, axis=axis)
         assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)
 
-#@pytest.mark.skip(reason="CI hasn't switch to ILP64 OpenBLAS yet")
+'''
+  _ _ _  _ _ __  _ __ _  _
+ | ' \ || | '  \| '_ \ || |
+ |_||_\_,_|_|_|_| .__/\_, |
+                |_|   |__/
+'''
+
+@use_np
+def test_ones():
+    A = np.ones((INT_OVERFLOW, 2))
+    assert A.shape == (INT_OVERFLOW, 2)
+    assert A[0][0] == 1
+
+@use_np
+def test_zeros():
+    A = np.zeros((INT_OVERFLOW, 2))
+    assert A.shape == (INT_OVERFLOW, 2)
+    assert A[0][0] == 0
+
+@use_np
+def test_abs():
+    A = np.ones((INT_OVERFLOW, 2))
+    A[0][0] = -1
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.abs(A)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == -1
+
+@use_np
+def test_absolute():
+    A = np.ones((INT_OVERFLOW, 2))
+    A[0][0] = -1
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.absolute(A)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == -1
+
+@use_np
+@pytest.mark.skip(reason='backward errors out on (2^30,2), gives wrong result \
+    on (2^31, 2)')
+def test_add():
+    INT_OVERFLOW = 2**30
+    A = np.ones((INT_OVERFLOW, 2))
+    B = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        C = np.add(A, B)
+    assert C.shape == (INT_OVERFLOW, 2)
+    assert C[0][0] == 2
+    C.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 1
+
+# this will fail; broadcast needs to be fixed
+# TODO add backward test after forward is fixed
+@use_np
+@pytest.mark.skip(reason='Does not support large tensor; to be fixed')
+def test_add_broadcast():
+    A = np.ones((INT_OVERFLOW, 2))
+    B = np.ones((INT_OVERFLOW, 1))
+    C = np.add(A, B)
+    assert C.shape == (INT_OVERFLOW, 2)
+    assert C[0][0] == 2
+
+@use_np
+def test_all():
+    A = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.all(A)
+    assert B.asnumpy() == True
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 0 
+
+@use_np
+def test_amin():
+    A = np.ones((INT_OVERFLOW, 2))
+    A[100][1] = -1
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.amin(A)
+    assert B.asnumpy() == -1.0
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 0
+
+@use_np
+def test_amax():
+    A = np.zeros((INT_OVERFLOW, 2))
+    A[100][1] = 1
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.amax(A)
+    print(B)
+    assert B.asnumpy() == 1.0
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 0
+
+@use_np
+def test_argmin():
+    A = np.ones((INT_OVERFLOW, 2))
+    A[10][1] = -1
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.argmin(A)
+    print(B)
+    assert B.asnumpy() == 21
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 0
+
+@use_np
+def test_argmax():
+    A = np.zeros((INT_OVERFLOW, 2))
+    A[10][1] = 1
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.argmax(A)
+    print(B)
+    assert B.asnumpy() == 21
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 0
+
+@use_np
+def test_trigonometric_family():
+    def batch_check(x, funcs):
+        for f in funcs:
+            one = np.ones((1))
+            x.attach_grad()
+            one.attach_grad()
+            with mx.autograd.record():
+                y = f(x)
+                _ = f(one)
+            assert y.shape == (INT_OVERFLOW, 2)
+            assert y[0][0] == _
+            y.backward()
+            _.backward()
+            assert x.grad.shape == (INT_OVERFLOW, 2)
+            assert x.grad[0][0] == one.grad
+    A = np.ones((INT_OVERFLOW, 2))
+    batch_check(A, [np.arccos, np.arccosh, np.arcsin, \
+        np.arcsin, np.arctan, np.arctanh, np.sin, np.cos, \
+        np.tan, np.sinh, np.cosh, np.tanh])
+
+@use_np
+def test_any():
+    A = np.zeros((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.any(A)
+    assert B.asnumpy() == False
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 0
+
+@use_np
+def test_append():
+    A = np.ones((1, INT_OVERFLOW))
+    B = np.ones((2, INT_OVERFLOW))
+    A.attach_grad() 
+    with mx.autograd.record():
+        C = np.append(A, B, axis=0)
+    assert C.shape == (3, INT_OVERFLOW)
+    assert C[2][0] == 1
+    C.backward()
+    assert A.grad.shape == (1, INT_OVERFLOW)
+    assert A[0][0] == 1
+
+@use_np
+def test_arange():
+    A = np.arange(INT_OVERFLOW, dtype='int32')
+    assert A.shape == (INT_OVERFLOW, )
+    assert A[100] == 100
+
+@use_np
+def test_argsort():
+    A = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.argsort(A)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 0
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A[0][0] == 1
+
+# broken
+# TODO add backward test after foward is fixed
+@use_np
+@pytest.mark.skip(reason='Does not support large tensor; to be fixed')
+def test_round():
+    A = np.ones((INT_OVERFLOW, 2))
+    B = np.round(A)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 1
+
+# broken
+# TODO add backward test after forward is fixed
+@use_np
+@pytest.mark.skip(reason='Does not support large tensor; to be fixed')
+def test_array_split():
+    A = np.zeros((INT_OVERFLOW, 2))
+    B = np.array_split(A, 2)
+    print(B)
+    assert B[0].shape ==(HALF_INT_OVERFLOW, 2)
+    assert B[1].shape ==(HALF_INT_OVERFLOW, 2)
+    assert B[0][0][0] == 0
+
+@use_np
+def test_atleast_xd_family():
+    def batch_check(x, funcs, shapes):
+        for f, s in zip(funcs, shapes):
+            x.attach_grad()
+            with mx.autograd.record():
+                y = f(x)
+            assert y.shape == s
+            y.backward()
+            assert x.grad.shape == (INT_OVERFLOW, )
+            assert x.grad[0] == 0
+    A = np.zeros((INT_OVERFLOW))
+    batch_check(A, [np.atleast_1d, np.atleast_2d, np.atleast_3d], \
+            [(INT_OVERFLOW, ), (1, INT_OVERFLOW), (1, INT_OVERFLOW, 1)])
+
+@use_np
+def test_average():
+    A = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.average(A)
+    assert B.asnumpy() == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert_almost_equal(A.grad[0][0], np.array([1.0 / DOUBLE_INT_OVERFLOW]), \
+            rtol=1e-3, atol=1e-5)
+
+@use_np
+def test_bincount():
+    A = np.ones((INT_OVERFLOW), dtype='int32')
+    A[0] = 0
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.bincount(A)
+    assert B.shape == (2,)
+    assert B[-1] == INT_OVERFLOW - 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, )
+    assert A.grad[0] == 0
+
+# broken
+# TODO add backward test after forward is fixed
+@use_np
+@pytest.mark.skip(reason='Does not support large tensor; to be fixed')
+def test_bitwise_family():
+    def batch_check(x1, x2, funcs):
+        for f in funcs:
+            y = f(x1, x2)
+            one = np.ones((1), dtype='int32')
+            assert y.shape == (INT_OVERFLOW, 2)
+            assert y[0][0] == f(one, one)
+    # test on broadcast input
+    A = np.ones((INT_OVERFLOW, 1), dtype='int32')
+    B = np.ones((INT_OVERFLOW, 2), dtype='int32')
+    batch_check(A, B, [np.bitwise_and, np.bitwise_or, np.bitwise_xor])
+    C = np.bitwise_not(A)
+    assert C.shape == (INT_OVERFLOW, 1)
+    assert C[0] == np.bitwise_not(np.ones((1), dtype='int32')) 
+
+@use_np
+def test_blackman():
+    A = np.blackman((INT_OVERFLOW))
+    assert A.shape == (INT_OVERFLOW, )
+
+@use_np
+def test_broadcast_to():
+    A = np.ones((2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.broadcast_to(A, (INT_OVERFLOW, 2))
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (2, )
+    with mx.autograd.record():
+        B = np.broadcast_to(A.reshape(2, 1), (2, INT_OVERFLOW))
+    assert B.shape == (2, INT_OVERFLOW)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (2, )
+
+@use_np
+def test_root_family():
+    def batch_check(x, funcs, grads):
+        for f, g in zip(funcs, grads):
+            x.attach_grad()
+            with mx.autograd.record():
+                y = f(x)
+            assert y.shape == (INT_OVERFLOW, 2)
+            assert y[0][0] == 1
+            y.backward()
+            assert x.grad.shape == (INT_OVERFLOW, 2)
+            assert_almost_equal(A.grad[0][0], np.array(g), \
+                rtol=1e-3, atol=1e-5)
+    A = np.ones((INT_OVERFLOW, 2))
+    batch_check(A, [np.sqrt, np.cbrt], [0.5, 1.0 / 3])
+
+@use_np
+def test_ceil_floor():
+    def batch_check(x, funcs):
+        for f in funcs:
+            x.attach_grad()
+            with mx.autograd.record():
+                y = f(x)
+            assert y.shape == (INT_OVERFLOW, 2)
+            assert y[0][0] == 1
+            y.backward()
+            assert x.grad.shape == (INT_OVERFLOW, 2)
+            assert x.grad[0][0] == 0
+    A = np.ones((INT_OVERFLOW, 2))
+    batch_check(A, [np.ceil, np.floor])
+
+@use_np
+def test_clip():
+    A = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.clip(A, 1, 1)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 1
+
+@use_np
+def test_column_stack():
+    A = np.ones(INT_OVERFLOW)
+    A.attach_grad()
+    with mx.autograd.record():
+        B = np.column_stack((A, A))
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, )
+    assert A.grad[0] == 2
+
+@use_np
+def test_concatenate():
+    def batch_check(x1, x2, axises, shapes):
+        for a, s in zip(axises, shapes):
+            x1.attach_grad()
+            with mx.autograd.record():
+                y = np.concatenate((x1, x2), axis=a)
+            assert y.shape == s
+            y.backward()
+            assert x1.grad.shape == (2, INT_OVERFLOW)
+            assert x1.grad[0][0] == 1
+    A = np.ones((2, INT_OVERFLOW))
+    B = np.ones((1, INT_OVERFLOW))
+    batch_check(A, B, [0, None], \
+            [(3, INT_OVERFLOW), (int(INT_OVERFLOW * 3), )])
+
+@use_np
+# backward not working https://github.com/apache/incubator-mxnet/issues/18952
+def test_copysign():
+    A = np.ones((INT_OVERFLOW, 2))
+    #A.attach_grad()
+    #with mx.autograd.record():
+    B = np.copysign(A, -1)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == -1
+    #B.backward()
+    #assert A.grad.shape == (INT_OVERFLOW, 2)
+    
+@pytest.mark.skip(reason="CI hasn't switch to ILP64 OpenBLAS yet")
 @use_np
 def test_dot():
     A = np.ones((1, INT_OVERFLOW), dtype='float32')
@@ -86,6 +470,278 @@ def test_dot():
     A.attach_grad()
     with mx.autograd.record():
         C = np.dot(A, B)
-    assert_almost_equal(C.asnumpy(), [INT_OVERFLOW], rtol=1e-5, atol=1e-5)
+    assert_almost_equal(C, [INT_OVERFLOW], rtol=1e-5, atol=1e-5)
     C.backward()
     assert A.grad.shape == (1, INT_OVERFLOW)
+    assert A.grad[0][0] == 1
+
+'''
+                                     _               _
+  _ _ _  _ _ __  _ __ _  _   _____ _| |_ ___ _ _  __(_)___ _ _
+ | ' \ || | '  \| '_ \ || | / -_) \ /  _/ -_) ' \(_-< / _ \ ' \
+ |_||_\_,_|_|_|_| .__/\_, | \___/_\_\\__\___|_||_/__/_\___/_||_|
+                |_|   |__/
+'''
+
+@use_np
+def test_activation():
+    A = np.zeros((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.activation(A, act_type='sigmoid')
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 0.5
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert_almost_equal(A.grad[0][0], np.array([0.25]), \
+                rtol=1e-3, atol=1e-5)
+@use_np
+def test_arange_like():
+    A = np.zeros((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.arange_like(A)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[100][0] == 200
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 0
+
+# TODO implement this test after dot is fixed for large tensors and we have
+# migrated to using ILP64 BLAS/LAPACK
+@use_np
+@pytest.mark.skip(reason='dot is known to not work on large tensors. PR to fix: https://github.com/apache/incubator-mxnet/pull/18925')
+def test_batch_dot():
+    assert False 
+
+@use_np
+def test_cast():
+    A = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.cast(A, dtype='float16')
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 1
+
+@use_np
+def test_broadcast_like():
+    A = np.ones((1, 2))
+    B = np.zeros((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        C = npx.broadcast_like(A, B)
+    assert C.shape == (INT_OVERFLOW, 2)
+    assert C[0][0] == 1
+    C.backward()
+    assert A.grad.shape == (1, 2)
+    with mx.autograd.record():
+        C = npx.broadcast_like(A.reshape(2, 1), B.T)
+    assert C.shape == (2, INT_OVERFLOW)
+    assert C[0][0] == 1
+    C.backward()
+    assert A.grad.shape == (1, 2)
+    assert_almost_equal(A.grad[0][0], np.array([INT_OVERFLOW]), \
+                            rtol=1e-3, atol=1e-5)
+
+@use_np
+def test_constraint_check():
+    A = np.ones((2, INT_OVERFLOW))
+    constraint = (A > 0)
+    B = npx.constraint_check(constraint)
+    assert B.asnumpy() == True
+
+# broken
+@use_np
+@pytest.mark.skip(reason='Does not support large tensor; to be fixed')
+def test_batch_flatten():
+    A = np.ones((2, 1, INT_OVERFLOW))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.batch_flatten(A)
+    assert B.shape == (2, INT_OVERFLOW)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (2, 1, INT_OVERFLOW)
+    assert A.grad[0][0][0] == 1
+
+# broken
+@use_np
+@pytest.mark.skip(reason='Does not support large tensor; to be fixed')
+def test_batch_norm():
+    A = np.ones((2, INT_OVERFLOW))
+    gamma = np.ones((2))
+    beta = np.zeros((2))
+    mov_mean = np.ones((2))
+    mov_var = np.ones((2))
+    A.attach_grad() 
+    with mx.autograd.record():
+        B = npx.batch_norm(A, gamma, beta, mov_mean, mov_var)
+    assert B.shape == (2, INT_OVERFLOW)
+    assert B[0][0] == 0
+    B.backward()
+    assert A.grad.shape == (2, INT_OVERFLOW)
+    assert A.grad[0][0] == 0
+
+@use_np
+@pytest.mark.skip(reason='segfault on (2, large)')
+def test_nonzero():
+    A = np.zeros((2, INT_OVERFLOW))
+    A[0][0] = 1
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.nonzero(A)
+    assert B.shape == (1, 2)
+    assert B[0][0] == 0
+    B.backward()
+    assert A.grad.shape == (2, INT_OVERFLOW)
+    assert A.grad[0][0] == 0
+
+@use_np
+def test_one_hot():
+    A = np.zeros((INT_OVERFLOW))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.one_hot(A, 2)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, )
+    assert A.grad[0] == 0
+
+@use_np
+@pytest.mark.skip(reason='backward value broken on large tensor')
+def test_pick():
+    A = np.zeros((INT_OVERFLOW, 2))
+    B = np.zeros((INT_OVERFLOW))
+    A.attach_grad()
+    B.attach_grad()
+    with mx.autograd.record():
+        C = npx.pick(A, B)
+    assert C.shape == (INT_OVERFLOW, )
+    assert C[0] == 0
+    C.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert B.grad.shape == (INT_OVERFLOW, )
+    assert A.grad[0][0] == 1
+
+@use_np
+def test_scalar_poisson():
+    A = npx.scalar_poisson(lam=4, shape=(2, INT_OVERFLOW))
+    assert A.shape == (2, INT_OVERFLOW)
+
+@use_np
+def test_tensor_poisson():
+    lam = np.array([2.0, 4.0])
+    A = npx.tensor_poisson(lam, shape=(INT_OVERFLOW))
+    assert A.shape == (2, INT_OVERFLOW)
+
+@use_np
+def test_reshape():
+    A = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad() 
+    with mx.autograd.record():
+        B = npx.reshape(A, (-5))
+    assert B.shape == (DOUBLE_INT_OVERFLOW, )
+    assert B[0] == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 1
+
+@use_np
+def test_reshape_like():
+    A = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.reshape_like(A, np.zeros((2, INT_OVERFLOW)))
+    assert B.shape == (2, INT_OVERFLOW)
+    assert B[0][0] == 1
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 1
+
+@use_np
+def test_sigmoid():
+    A = np.zeros((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.sigmoid(A)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 0.5
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert_almost_equal(A.grad[0][0], np.array([0.25]), \
+                rtol=1e-3, atol=1e-5)
+
+@use_np
+@pytest.mark.skip(reason='Does not support large tensor; to be fixed')
+def test_shape_array():
+    A = np.zeros((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.shape_array(A)
+    assert B[0] == INT_OVERFLOW and B[1] == 2
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, 2)
+    assert A.grad[0][0] == 0
+
+@use_np
+def test_stop_gradient():
+    A = np.ones((INT_OVERFLOW, 2))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.stop_gradient(A * 3)
+    assert B.shape == (INT_OVERFLOW, 2)
+    assert B[0][0] == 3
+    B.backward()
+    # should be 3 if not for stop_gradient()
+    assert A.grad[0][0] == 0
+    
+@use_np
+def test_sequence_mask():
+    A = np.ones((2, 2, INT_OVERFLOW))
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.sequence_mask(A, sequence_length=np.array([1,1]), \
+                use_sequence_length=True)
+    assert B.shape == (2, 2, INT_OVERFLOW)
+    assert B[0][0][0] == 1
+    assert B[1][0][0] == 0
+    B.backward()
+    assert A.grad.shape == (2, 2, INT_OVERFLOW)
+    assert A.grad[0][0][0] == 1
+
+@use_np
+def test_topk():
+    A = np.ones((2, INT_OVERFLOW))
+    A[0][100] = 2
+    A[1][200] = 2
+    A.attach_grad()
+    with mx.autograd.record():
+        B = npx.topk(A, k=2)
+    assert B.shape == (2, 2)
+    assert B[0][0] == 100 and B[1][0] == 200
+    B.backward()
+    assert A.grad.shape == (2, INT_OVERFLOW)
+    assert A.grad[0][0] == 0
+
+@use_np
+def test_slice():
+    A = np.ones((INT_OVERFLOW, 3))
+    A[100][1] = 2
+    B = npx.slice(A, begin=(100,1), end=(200,3))
+    assert B.shape == (100, 2)
+    assert B[0][0] == 2
+
+def test_smooth_l1():
+    A = np.arange((INT_OVERFLOW))
+    A.attach_grad() 
+    with mx.autograd.record():
+        B = npx.smooth_l1(A)
+    assert B.shape == (INT_OVERFLOW, )
+    assert B[1] == 0.5
+    B.backward()
+    assert A.grad.shape == (INT_OVERFLOW, )
+    assert A.grad[0][0] == 0