You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/08/19 10:11:59 UTC
[incubator-mxnet] branch master updated: Test_take,
add additional axis (#20532)
This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9a4dcf4 Test_take, add additional axis (#20532)
9a4dcf4 is described below
commit 9a4dcf47535fd99ef1bd3e13edefc413bded8335
Author: mozga <ma...@intel.com>
AuthorDate: Thu Aug 19 12:10:25 2021 +0200
Test_take, add additional axis (#20532)
---
tests/python/unittest/test_operator.py | 72 +++++++++++++++++-----------------
1 file changed, 36 insertions(+), 36 deletions(-)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 0e07c37..cbae11e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4188,47 +4188,47 @@ def test_take(mode, out_of_range, data_ndim, idx_ndim):
for _ in range(idx_ndim):
idx_shape += (np.random.randint(low=1, high=5), )
- data = mx.sym.Variable('a')
- idx = mx.sym.Variable('indices')
- idx = mx.sym.BlockGrad(idx)
- result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode)
- exe = result._simple_bind(default_context(), a=data_shape,
- indices=idx_shape)
- data_real = np.random.normal(size=data_shape).astype('float32')
- if out_of_range:
- idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape)
- if mode == 'raise':
- idx_real[idx_real == 0] = 1
- idx_real *= data_shape[axis]
- else:
- idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
- if axis < 0:
- axis += len(data_shape)
+ data = mx.sym.Variable('a')
+ idx = mx.sym.Variable('indices')
+ idx = mx.sym.BlockGrad(idx)
+ result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode)
+ exe = result._simple_bind(default_context(), a=data_shape,
+ indices=idx_shape)
+ data_real = np.random.normal(size=data_shape).astype('float32')
+ if out_of_range:
+ idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape)
+ if mode == 'raise':
+ idx_real[idx_real == 0] = 1
+ idx_real *= data_shape[axis]
+ else:
+ idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
+ if axis < 0:
+ axis += len(data_shape)
- grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32')
- grad_in = np.zeros(data_shape, dtype='float32')
+ grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32')
+ grad_in = np.zeros(data_shape, dtype='float32')
- exe.arg_dict['a'][:] = mx.nd.array(data_real)
- exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
- exe.forward(is_train=True)
- if out_of_range and mode == 'raise':
- try:
- mx_out = exe.outputs[0].asnumpy()
- except MXNetError as e:
- return
- else:
- # Did not raise exception
- assert False, "did not raise %s" % MXNetError.__name__
+ exe.arg_dict['a'][:] = mx.nd.array(data_real)
+ exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
+ exe.forward(is_train=True)
+ if out_of_range and mode == 'raise':
+ try:
+ mx_out = exe.outputs[0].asnumpy()
+ except MXNetError as e:
+ return
+ else:
+ # Did not raise exception
+ assert False, "did not raise %s" % MXNetError.__name__
- assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode))
+ assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode))
- for i in np.nditer(idx_real):
- if mode == 'clip':
- i = np.clip(i, 0, data_shape[axis])
- grad_helper(grad_in, axis, i)
+ for i in np.nditer(idx_real):
+ if mode == 'clip':
+ i = np.clip(i, 0, data_shape[axis])
+ grad_helper(grad_in, axis, i)
- exe.backward([mx.nd.array(grad_out)])
- assert_almost_equal(exe.grad_dict['a'], grad_in)
+ exe.backward([mx.nd.array(grad_out)])
+ assert_almost_equal(exe.grad_dict['a'], grad_in)
def test_grid_generator():