You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/08/19 10:11:59 UTC

[incubator-mxnet] branch master updated: Test_take, add additional axis (#20532)

This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a4dcf4  Test_take, add additional axis (#20532)
9a4dcf4 is described below

commit 9a4dcf47535fd99ef1bd3e13edefc413bded8335
Author: mozga <ma...@intel.com>
AuthorDate: Thu Aug 19 12:10:25 2021 +0200

    Test_take, add additional axis (#20532)
---
 tests/python/unittest/test_operator.py | 72 +++++++++++++++++-----------------
 1 file changed, 36 insertions(+), 36 deletions(-)

diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 0e07c37..cbae11e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4188,47 +4188,47 @@ def test_take(mode, out_of_range, data_ndim, idx_ndim):
         for _ in range(idx_ndim):
             idx_shape += (np.random.randint(low=1, high=5), )
 
-    data = mx.sym.Variable('a')
-    idx = mx.sym.Variable('indices')
-    idx = mx.sym.BlockGrad(idx)
-    result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode)
-    exe = result._simple_bind(default_context(), a=data_shape,
-                             indices=idx_shape)
-    data_real = np.random.normal(size=data_shape).astype('float32')
-    if out_of_range:
-        idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape)
-        if mode == 'raise':
-            idx_real[idx_real == 0] = 1
-            idx_real *= data_shape[axis]
-    else:
-        idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
-    if axis < 0:
-        axis += len(data_shape)
+        data = mx.sym.Variable('a')
+        idx = mx.sym.Variable('indices')
+        idx = mx.sym.BlockGrad(idx)
+        result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode)
+        exe = result._simple_bind(default_context(), a=data_shape,
+                                indices=idx_shape)
+        data_real = np.random.normal(size=data_shape).astype('float32')
+        if out_of_range:
+            idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape)
+            if mode == 'raise':
+                idx_real[idx_real == 0] = 1
+                idx_real *= data_shape[axis]
+        else:
+            idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
+        if axis < 0:
+            axis += len(data_shape)
 
-    grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32')
-    grad_in = np.zeros(data_shape, dtype='float32')
+        grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32')
+        grad_in = np.zeros(data_shape, dtype='float32')
 
-    exe.arg_dict['a'][:] = mx.nd.array(data_real)
-    exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
-    exe.forward(is_train=True)
-    if out_of_range and mode == 'raise':
-        try:
-            mx_out = exe.outputs[0].asnumpy()
-        except MXNetError as e:
-            return
-        else:
-            # Did not raise exception
-            assert False, "did not raise %s" % MXNetError.__name__
+        exe.arg_dict['a'][:] = mx.nd.array(data_real)
+        exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
+        exe.forward(is_train=True)
+        if out_of_range and mode == 'raise':
+            try:
+                mx_out = exe.outputs[0].asnumpy()
+            except MXNetError as e:
+                return
+            else:
+                # Did not raise exception
+                assert False, "did not raise %s" % MXNetError.__name__
 
-    assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode))
+        assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode))
 
-    for i in np.nditer(idx_real):
-        if mode == 'clip':
-            i = np.clip(i, 0, data_shape[axis])
-        grad_helper(grad_in, axis, i)
+        for i in np.nditer(idx_real):
+            if mode == 'clip':
+                i = np.clip(i, 0, data_shape[axis])
+            grad_helper(grad_in, axis, i)
 
-    exe.backward([mx.nd.array(grad_out)])
-    assert_almost_equal(exe.grad_dict['a'], grad_in)
+        exe.backward([mx.nd.array(grad_out)])
+        assert_almost_equal(exe.grad_dict['a'], grad_in)
 
 
 def test_grid_generator():