You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/09/25 05:29:51 UTC
[incubator-mxnet] branch numpy_prs updated: Numpy det and slogdet
operator (#15861)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch numpy_prs
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/numpy_prs by this push:
new a2a1b68 Numpy det and slogdet operator (#15861)
a2a1b68 is described below
commit a2a1b68c0ae2cdd0d8d3af4bb56a85ba173390e4
Author: ckt624 <ck...@gmail.com>
AuthorDate: Wed Sep 25 01:28:58 2019 -0400
Numpy det and slogdet operator (#15861)
* Add alias.
Add tests.
Add slogdet tests.
Add docs
Change shapes
Change tests.
Change slogdet tests
Change style.
* Fix
* Fix
---
python/mxnet/_numpy_op_doc.py | 122 +++++++++++++
src/operator/tensor/la_op.cc | 2 +
src/operator/tensor/la_op.cu | 2 +
src/operator/tensor/la_op.h | 7 +-
tests/python/unittest/test_numpy_op.py | 305 +++++++++++++++++++++------------
5 files changed, 332 insertions(+), 106 deletions(-)
diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index 6d2776e..f168d56 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -20,6 +20,128 @@
"""Doc placeholder for numpy ops with prefix _np."""
+def _np__linalg_det(a):
+ """
+ det(a)
+
+ Compute the determinant of an array.
+
+ Parameters
+ ----------
+ a : (..., M, M) ndarray
+ Input array to compute determinants for.
+
+ Returns
+ -------
+ det : (...) ndarray
+ Determinant of `a`.
+
+ See Also
+ --------
+ slogdet : Another way to represent the determinant, more suitable
+ for large matrices where underflow/overflow may occur.
+
+ Notes
+ -----
+
+ Broadcasting rules apply, see the `numpy.linalg` documentation for
+ details.
+
+ The determinant is computed via LU factorization using the LAPACK
+ routine z/dgetrf.
+
+ Examples
+ --------
+ The determinant of a 2-D array [[a, b], [c, d]] is ad - bc:
+
+ >>> a = np.array([[1, 2], [3, 4]])
+ >>> np.linalg.det(a)
+ -2.0
+
+ Computing determinants for a stack of matrices:
+
+ >>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ])
+ >>> a.shape
+ (3, 2, 2)
+ >>> np.linalg.det(a)
+ array([-2., -3., -8.])
+ """
+ pass
+
+
+def _np__linalg_slogdet(a):
+ """
+ slogdet(a)
+
+ Compute the sign and (natural) logarithm of the determinant of an array.
+
+ If an array has a very small or very large determinant, then a call to
+ `det` may overflow or underflow. This routine is more robust against such
+ issues, because it computes the logarithm of the determinant rather than
+ the determinant itself.
+
+ Parameters
+ ----------
+ a : (..., M, M) ndarray
+ Input array, has to be a square 2-D array.
+
+ Returns
+ -------
+ sign : (...) ndarray
+ A number representing the sign of the determinant. For a real matrix,
+ this is 1, 0, or -1.
+ logdet : (...) array_like
+ The natural log of the absolute value of the determinant.
+
+ If the determinant is zero, then `sign` will be 0 and `logdet` will be
+ -Inf. In all cases, the determinant is equal to ``sign * np.exp(logdet)``.
+
+ See Also
+ --------
+ det
+
+ Notes
+ -----
+
+ Broadcasting rules apply, see the `numpy.linalg` documentation for
+ details.
+
+ The determinant is computed via LU factorization using the LAPACK
+ routine z/dgetrf.
+
+
+ Examples
+ --------
+ The determinant of a 2-D array ``[[a, b], [c, d]]`` is ``ad - bc``:
+
+ >>> a = np.array([[1, 2], [3, 4]])
+ >>> (sign, logdet) = np.linalg.slogdet(a)
+ >>> (sign, logdet)
+ (-1., 0.69314718055994529)
+ >>> sign * np.exp(logdet)
+ -2.0
+
+ Computing log-determinants for a stack of matrices:
+
+ >>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ])
+ >>> a.shape
+ (3, 2, 2)
+ >>> sign, logdet = np.linalg.slogdet(a)
+ >>> (sign, logdet)
+ (array([-1., -1., -1.]), array([ 0.69314718, 1.09861229, 2.07944154]))
+ >>> sign * np.exp(logdet)
+ array([-2., -3., -8.])
+
+ This routine succeeds where ordinary `det` does not:
+
+ >>> np.linalg.det(np.eye(500) * 0.1)
+ 0.0
+ >>> np.linalg.slogdet(np.eye(500) * 0.1)
+ (1., -1151.2925464970228)
+ """
+ pass
+
+
def _np_ones_like(a):
"""
ones_like(a)
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index ce7d1d5..24968ec 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -941,6 +941,7 @@ NNVM_REGISTER_OP(_backward_linalg_inverse)
NNVM_REGISTER_OP(_linalg_det)
.add_alias("linalg_det")
+.add_alias("_np__linalg_det")
.describe(R"code(Compute the determinant of a matrix.
Input is a tensor *A* of dimension *n >= 2*.
@@ -991,6 +992,7 @@ NNVM_REGISTER_OP(_backward_linalg_det)
.set_attr<FCompute>("FCompute<cpu>", LaOpDetBackward<cpu, 1, det_backward>);
NNVM_REGISTER_OP(_linalg_slogdet)
+.add_alias("_np__linalg_slogdet")
.add_alias("linalg_slogdet")
.describe(R"code(Compute the sign and log of the determinant of a matrix.
Input is a tensor *A* of dimension *n >= 2*.
diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu
index 68c3318..ec622f0 100644
--- a/src/operator/tensor/la_op.cu
+++ b/src/operator/tensor/la_op.cu
@@ -100,12 +100,14 @@ NNVM_REGISTER_OP(_backward_linalg_inverse)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1, inverse_backward>);
NNVM_REGISTER_OP(_linalg_det)
+.add_alias("_np__linalg_det")
.set_attr<FCompute>("FCompute<gpu>", LaOpDetForward<gpu, 1, det>);
NNVM_REGISTER_OP(_backward_linalg_det)
.set_attr<FCompute>("FCompute<gpu>", LaOpDetBackward<gpu, 1, det_backward>);
NNVM_REGISTER_OP(_linalg_slogdet)
+.add_alias("_np__linalg_slogdet")
.set_attr<FCompute>("FCompute<gpu>", LaOpDetForward<gpu, 2, slogdet>);
NNVM_REGISTER_OP(_backward_linalg_slogdet)
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index e024693..bb56dc5 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -26,6 +26,7 @@
#define MXNET_OPERATOR_TENSOR_LA_OP_H_
#include <mxnet/operator_util.h>
+#include <mxnet/imperative.h>
#include <vector>
#include <algorithm>
#include "../mshadow_op.h"
@@ -428,7 +429,11 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal";
mxnet::TShape out;
if (ndim == 2) {
- out = mxnet::TShape(1, 1);
+ if (Imperative::Get()->is_np_shape()) {
+ out = mxnet::TShape(0, 1);
+ } else {
+ out = mxnet::TShape(1, 1);
+ }
} else {
out = mxnet::TShape(in.begin(), in.end() - 2);
}
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 264f7c0..9515f09 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -27,7 +27,7 @@ from mxnet.test_utils import check_numeric_gradient, use_np, collapse_sum_like
from common import assertRaises, with_seed
import random
import scipy.stats as ss
-from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry
+from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, collapse_sum_like
from mxnet.runtime import Features
import platform
@@ -140,8 +140,8 @@ def test_np_tensordot():
test_tensordot = TestTensordot(axes)
if hybridize:
test_tensordot.hybridize()
- a = rand_ndarray(shape = a_shape, dtype = dtype).as_np_ndarray()
- b = rand_ndarray(shape = b_shape, dtype = dtype).as_np_ndarray()
+ a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
+ b = rand_ndarray(shape=b_shape, dtype=dtype).as_np_ndarray()
a.attach_grad()
b.attach_grad()
@@ -166,7 +166,7 @@ def test_np_tensordot():
b_sym = mx.sym.Variable("b").as_np_ndarray()
mx_sym = mx.sym.np.tensordot(a_sym, b_sym, axes).as_nd_ndarray()
check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
- rtol=1e-1, atol=1e-1, dtype = dtype)
+ rtol=1e-1, atol=1e-1, dtype=dtype)
@with_seed()
@@ -220,6 +220,96 @@ def test_np_dot():
@with_seed()
@use_np
+def test_np_linalg_det():
+ class TestDet(HybridBlock):
+ def __init__(self):
+ super(TestDet, self).__init__()
+
+ def hybrid_forward(self, F, a):
+ return F.np.linalg.det(a)
+
+ # test non zero size input
+ tensor_shapes = [
+ (5, 5),
+ (3, 3, 3),
+ (2, 2, 2, 2, 2),
+ (1, 1)
+ ]
+
+ for hybridize in [True, False]:
+ for shape in tensor_shapes:
+ for dtype in [_np.float32, _np.float64]:
+ a_shape = (1,) + shape
+ test_det = TestDet()
+ if hybridize:
+ test_det.hybridize()
+ a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
+ a.attach_grad()
+
+ np_out = _np.linalg.det(a.asnumpy())
+ with mx.autograd.record():
+ mx_out = test_det(a)
+ assert mx_out.shape == np_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)
+ mx_out.backward()
+
+ # Test imperative once again
+ mx_out = np.linalg.det(a)
+ np_out = _np.linalg.det(a.asnumpy())
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1, use_broadcast=False)
+
+ # test numeric gradient
+ a_sym = mx.sym.Variable("a").as_np_ndarray()
+ mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray()
+ check_numeric_gradient(mx_sym, [a.as_nd_ndarray()],
+ rtol=1e-1, atol=1e-1, dtype=dtype)
+
+
+@with_seed()
+@use_np
+def test_np_linalg_slogdet():
+ class TestSlogdet(HybridBlock):
+ def __init__(self):
+ super(TestSlogdet, self).__init__()
+
+ def hybrid_forward(self, F, a):
+ return F.np.linalg.slogdet(a)
+
+ # test non zero size input
+ tensor_shapes = [
+ (5, 5),
+ (3, 3, 3),
+ (2, 2, 2, 2, 2),
+ (1, 1)
+ ]
+
+ for hybridize in [True, False]:
+ for a_shape in tensor_shapes:
+ for dtype in [_np.float32, _np.float64]:
+ test_slogdet = TestSlogdet()
+ if hybridize:
+ test_slogdet.hybridize()
+ a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
+ a.attach_grad()
+
+ np_out = _np.linalg.slogdet(a.asnumpy())
+ with mx.autograd.record():
+ mx_out = test_slogdet(a)
+ assert mx_out[0].shape == np_out[0].shape
+ assert mx_out[1].shape == np_out[1].shape
+ assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1, use_broadcast=False)
+ assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1, use_broadcast=False)
+ mx_out[1].backward()
+
+ # Test imperative once again
+ mx_out = np.linalg.slogdet(a)
+ np_out = _np.linalg.slogdet(a.asnumpy())
+ assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1, use_broadcast=False)
+ assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1, use_broadcast=False)
+
+
+@with_seed()
+@use_np
def test_np_ldexp():
class TestLdexp(HybridBlock):
def __init__(self):
@@ -231,25 +321,24 @@ def test_np_ldexp():
def _np_ldexp(x1, x2):
return x1 * _np.power(2.0, x2)
- def dldx(x1, x2):
- grad_a = _np.power(2.0, x2)
- grad_b = _np_ldexp(x1, x2) * _np.log(2.0)
- if len(x1) == 1:
- grad_a = _np.sum(grad_a)
- if len(x2) == 1:
- grad_b = _np.sum(grad_b)
+ def dldx(x1, x2):
+ out = _np_ldexp(x1, x2)
+ grad_a = _np.broadcast_to(_np.power(2.0, x2), out.shape)
+ grad_b = _np.broadcast_to(out * _np.log(2.0), out.shape)
+ grad_a = collapse_sum_like(grad_a, x1.shape)
+ grad_b = collapse_sum_like(grad_b, x2.shape)
return [grad_a, grad_b]
- shapes = [
+ shapes = [
((3, 1), (3, 1)),
((3, 1, 2), (3, 1, 2)),
- ((1, ),(1, )),
+ ((1, ), (1, )),
((1, ), (2, )),
((3, ), (1, )),
((3, 0), (3, 0)), # zero-size shape
((0, 1), (0, 1)), # zero-size shape
((2, 0, 2), (2, 0, 2)), # zero-size shape
- ]
+ ]
for hybridize in [True, False]:
for shape1, shape2 in shapes:
@@ -257,7 +346,7 @@ def test_np_ldexp():
test_ldexp = TestLdexp()
if hybridize:
test_ldexp.hybridize()
- x1 = rand_ndarray(shape=shape1, dtype=dtype).as_np_ndarray()
+ x1 = rand_ndarray(shape=shape1, dtype=dtype).as_np_ndarray()
x1.attach_grad()
x2 = rand_ndarray(shape=shape2, dtype=dtype).as_np_ndarray()
x2.attach_grad()
@@ -266,17 +355,17 @@ def test_np_ldexp():
with mx.autograd.record():
mx_out = test_ldexp(x1, x2)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1, use_broadcast=False)
mx_out.backward()
np_backward = dldx(x1.asnumpy(), x2.asnumpy())
- assert_almost_equal(x1.grad.asnumpy(), np_backward[0], atol=1e-1, rtol=1e-1)
- assert_almost_equal(x2.grad.asnumpy(), np_backward[1], atol=1e-1, rtol=1e-1)
+ assert_almost_equal(x1.grad.asnumpy(), np_backward[0], atol=1e-1, rtol=1e-1, use_broadcast=False)
+ assert_almost_equal(x2.grad.asnumpy(), np_backward[1], atol=1e-1, rtol=1e-1, use_broadcast=False)
# Test imperative once again
mx_out = np.ldexp(x1, x2)
np_out = _np_ldexp(x1.asnumpy(), x2.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1, use_broadcast=False)
@with_seed()
@@ -310,16 +399,16 @@ def test_np_vdot():
with mx.autograd.record():
mx_out = test_vdot(a, b)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
np_backward = vdot_backward(a.asnumpy(), b.asnumpy())
- assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-2, atol=1e-2)
- assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-2, atol=1e-2)
+ assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol=1e-2, atol=1e-2, use_broadcast=False)
+ assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol=1e-2, atol=1e-2, use_broadcast=False)
# Test imperative once again
mx_out = np.vdot(a, b)
np_out = _np.vdot(a.asnumpy(), b.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
# test numeric gradient
if len(shape) > 0 and _np.prod(shape) > 0:
@@ -407,16 +496,16 @@ def test_np_inner():
with mx.autograd.record():
mx_out = test_inner(a, b)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
np_backward = inner_backward(a.asnumpy(), b.asnumpy())
- assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-2, atol=1e-2)
- assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-2, atol=1e-2)
+ assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol=1e-2, atol=1e-2, use_broadcast=False)
+ assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol=1e-2, atol=1e-2, use_broadcast=False)
# Test imperative once again
mx_out = np.inner(a, b)
np_out = _np.inner(a.asnumpy(), b.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
# test numeric gradient
a_sym = mx.sym.Variable("a").as_np_ndarray()
@@ -458,13 +547,13 @@ def test_np_outer():
with mx.autograd.record():
mx_out = test_outer(a, b)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
# Test imperative once again
mx_out = np.outer(a, b)
np_out = _np.outer(a.asnumpy(), b.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
# test numeric gradient
a_sym = mx.sym.Variable("a").as_np_ndarray()
@@ -617,7 +706,7 @@ def test_np_max_min():
y = test_gluon(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if itype == 'float16' else 1e-3,
- atol=1e-5 if itype == 'float16' else 1e-5)
+ atol=1e-5 if itype == 'float16' else 1e-5, use_broadcast=False)
y.backward()
# only check the gradient with hardcoded input
if is_int(itype):
@@ -631,7 +720,7 @@ def test_np_max_min():
else:
mx_out = np.min(x, axis=axis, keepdims=keepdims)
np_out = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
# test zero and zero dim
shapes = [(), (0), (2, 0), (0, 2, 1)]
@@ -690,7 +779,7 @@ def test_np_mean():
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
- atol=1e-5 if dtype == 'float16' else 1e-5)
+ atol=1e-5 if dtype == 'float16' else 1e-5, use_broadcast=False)
y.backward()
N = x.size / y.size
@@ -706,7 +795,7 @@ def test_np_mean():
# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@with_seed()
@@ -798,16 +887,17 @@ def test_np_linspace():
mx_ret = np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
np_ret = _np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
if retstep:
- assert_almost_equal(mx_ret[0].asnumpy(), np_ret[0], atol=1e-3, rtol=1e-5)
+ assert_almost_equal(mx_ret[0].asnumpy(), np_ret[0], atol=1e-3, rtol=1e-5, use_broadcast=False)
same(mx_ret[1], np_ret[1])
else:
- assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5)
+ assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5, use_broadcast=False)
# check for exception input
for config in exception_configs:
assertRaises(MXNetError, np.linspace, *config)
# check linspace equivalent to arange
for test_index in range(1000):
- assert_almost_equal(mx.np.linspace(0, test_index, test_index + 1).asnumpy(), _np.arange(test_index + 1))
+ assert_almost_equal(mx.np.linspace(0, test_index, test_index + 1).asnumpy(), _np.arange(test_index + 1),
+ use_broadcast=False)
class TestLinspace(HybridBlock):
def __init__(self, start, stop, num=50, endpoint=None, retstep=False, dtype=None, axis=0):
@@ -840,7 +930,7 @@ def test_np_linspace():
if hybridize:
net.hybridize()
mx_out = net(x)
- assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5, use_broadcast=False)
@with_seed()
@@ -1234,11 +1324,12 @@ def test_np_unary_funcs():
with mx.autograd.record():
y = mx_func(mx_test_data)
assert y.shape == np_out.shape
- assert_almost_equal(y.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(y.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
if ref_grad:
y.backward()
- assert_almost_equal(mx_test_data.grad.asnumpy(), ref_grad(np_test_data), rtol=1e-1, atol=1e-2, equal_nan=True)
+ assert_almost_equal(mx_test_data.grad.asnumpy(), ref_grad(np_test_data), rtol=1e-1, atol=1e-2,
+ equal_nan=True, use_broadcast=False)
funcs = {
'absolute' : (lambda x: -1. * (x < 0) + (x > 0), -1.0, 1.0),
@@ -1313,14 +1404,14 @@ def test_npx_relu():
with mx.autograd.record():
mx_out = test_relu(x)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
np_backward = np_relu_grad(x.asnumpy())
- assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out = npx.relu(x)
np_out = np_relu(x.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@with_seed()
@@ -1350,14 +1441,14 @@ def test_npx_sigmoid():
with mx.autograd.record():
mx_out = test_sigmoid(x)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
np_backward = np_sigmoid_grad(np_out)
- assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out = npx.sigmoid(x)
np_out = np_sigmoid(x.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@with_seed()
@@ -1463,17 +1554,17 @@ def test_np_split():
y = test_split(a)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx.autograd.backward(y)
- assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
# test imperative
mx_outs = np.split(a, indices_or_sections=indices_or_sections, axis=axis)
np_outs = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
for mx_out, np_out in zip(mx_outs, np_outs):
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@with_seed()
@@ -1513,19 +1604,19 @@ def test_np_concat():
y = test_concat(a, b, c, d)
assert y.shape == expected_ret.shape
- assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5, use_broadcast=False)
y.backward()
- assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
- assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
- assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
- assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
+ assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
+ assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
+ assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
# test imperative
mx_out = np.concatenate([a, b, c, d], axis=axis)
np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@with_seed()
@@ -1568,10 +1659,10 @@ def test_np_stack():
y.backward()
- assert_almost_equal(mx_a.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
- assert_almost_equal(mx_b.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
- assert_almost_equal(mx_c.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
- assert_almost_equal(mx_d.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_a.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
+ assert_almost_equal(mx_b.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
+ assert_almost_equal(mx_c.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
+ assert_almost_equal(mx_d.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5, use_broadcast=False)
np_out = _np.stack([np_a, np_b, np_c, np_d], axis=axis)
mx_out = np.stack([mx_a, mx_b, mx_c, mx_d], axis=axis)
@@ -1601,14 +1692,14 @@ def test_np_ravel():
with mx.autograd.record():
mx_out = test_ravel(x)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
np_backward = _np.ones(shape)
- assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out = np.ravel(x)
np_out = _np.ravel(x.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@with_seed()
@@ -1919,15 +2010,15 @@ def test_np_cumsum():
with mx.autograd.record():
mx_out = test_cumsum(x)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
np_backward = np_cumsum_backward(_np.ones(np_out.shape, dtype=otype),
axis=axis, dtype=otype).reshape(x.shape)
- assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out = np.cumsum(x, axis=axis, dtype=otype)
np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@with_seed()
@@ -1958,7 +2049,7 @@ def test_np_choice():
expected_density = (weight.asnumpy() if weight is not None else
_np.array([1 / num_classes] * num_classes))
# test almost equal
- assert_almost_equal(generated_density, expected_density, rtol=1e-1, atol=1e-1)
+ assert_almost_equal(generated_density, expected_density, rtol=1e-1, atol=1e-1, use_broadcast=False)
# test shape
assert (samples.shape == shape)
@@ -1975,7 +2066,7 @@ def test_np_choice():
out = sampler(num_classes, 1, replace=False, p=weight).item()
bins[out] += 1
bins /= num_trials
- assert_almost_equal(bins, expected_freq, rtol=1e-1, atol=1e-1)
+ assert_almost_equal(bins, expected_freq, rtol=1e-1, atol=1e-1, use_broadcast=False)
def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None):
a = np.arange(set_size)
@@ -2136,7 +2227,7 @@ def test_np_linalg_norm():
net.hybridize()
mx_ret = net(a)
np_ret = _np.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims)
- assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4)
+ assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4, use_broadcast=False)
@with_seed()
@@ -2195,18 +2286,18 @@ def test_np_copysign():
with mx.autograd.record():
mx_out = test_copysign(a1, a2)
assert mx_out.shape == expected_np.shape
- assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol, use_broadcast=False)
# Test gradient
mx_out.backward()
a1_grad, a2_grad = get_grad(a1_np, a2_np)
- assert_almost_equal(a1.grad.asnumpy(), a1_grad, rtol=rtol, atol=atol)
- assert_almost_equal(a2.grad.asnumpy(), a2_grad, rtol=rtol, atol=atol)
+ assert_almost_equal(a1.grad.asnumpy(), a1_grad, rtol=rtol, atol=atol, use_broadcast=False)
+ assert_almost_equal(a2.grad.asnumpy(), a2_grad, rtol=rtol, atol=atol, use_broadcast=False)
# Test imperative once again
mx_out = np.copysign(a1, a2)
expected_np = _np.copysign(a1_np, a2_np)
- assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol, use_broadcast=False)
types = ['float16', 'float32', 'float64']
for x_shape in shapes:
@@ -2220,12 +2311,12 @@ def test_np_copysign():
with mx.autograd.record():
mx_out = np.copysign(x, scalar)
assert mx_out.shape == expected_np.shape
- assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol, use_broadcast=False)
# Test gradient
mx_out.backward()
x_grad = get_grad_left(x_np, scalar)
- assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)
+ assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol, use_broadcast=False)
# Test right
x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
@@ -2236,12 +2327,12 @@ def test_np_copysign():
with mx.autograd.record():
mx_out = np.copysign(scalar, x)
assert mx_out.shape == expected_np.shape
- assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol, use_broadcast=False)
# Test gradient
mx_out.backward()
x_grad = get_grad_right(scalar, x_np)
- assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)
+ assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol, use_broadcast=False)
@with_seed()
@@ -2310,22 +2401,22 @@ def test_np_svd():
# check UT @ L @ V == A
t = _np.matmul(UT * L[..., None, :], V)
assert t.shape == data_np.shape
- assert_almost_equal(t, data_np, rtol=rtol, atol=atol)
+ assert_almost_equal(t, data_np, rtol=rtol, atol=atol, use_broadcast=False)
# check UT @ U == I
I = _np.matmul(UT, _np.swapaxes(UT, -2, -1))
I_np = _np.ones_like(UT) * _np.eye(shape[-2])
assert I.shape == I_np.shape
- assert_almost_equal(I, I_np, rtol=rtol, atol=atol)
+ assert_almost_equal(I, I_np, rtol=rtol, atol=atol, use_broadcast=False)
# check U @ UT == I
I = _np.matmul(_np.swapaxes(UT, -2, -1), UT)
I_np = _np.ones_like(UT) * _np.eye(shape[-2])
assert I.shape == I_np.shape
- assert_almost_equal(I, I_np, rtol=rtol, atol=atol)
+ assert_almost_equal(I, I_np, rtol=rtol, atol=atol, use_broadcast=False)
# check V @ VT == I
I = _np.matmul(V, _np.swapaxes(V, -2, -1))
I_np = _np.ones_like(UT) * _np.eye(shape[-2])
assert I.shape == I_np.shape
- assert_almost_equal(I, I_np, rtol=rtol, atol=atol)
+ assert_almost_equal(I, I_np, rtol=rtol, atol=atol, use_broadcast=False)
# check descending singular values
s = [L[..., i] - L[..., i + 1] for i in range(L.shape[-1] - 1)]
s = _np.array(s)
@@ -2336,7 +2427,8 @@ def test_np_svd():
mx.autograd.backward(ret)
if ((s > 1e-5).all() and (L.size == 0 or (L > 1e-5).all())):
backward_expected = get_grad(ret[0].asnumpy(), ret[1].asnumpy(), ret[2].asnumpy())
- assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
+ assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol,
+ use_broadcast=False)
@with_seed()
@@ -2381,18 +2473,18 @@ def test_np_vstack():
with mx.autograd.record():
mx_out = test_vstack(*v)
assert mx_out.shape == expected_np.shape
- assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol, use_broadcast=False)
# Test gradient
mx_out.backward()
for i in range(3):
expected_grad = g(v_np[i])
- assert_almost_equal(v[i].grad.asnumpy(), expected_grad, rtol=rtol, atol=atol)
+ assert_almost_equal(v[i].grad.asnumpy(), expected_grad, rtol=rtol, atol=atol, use_broadcast=False)
# Test imperative once again
mx_out = np.vstack(v)
expected_np = _np.vstack(v_np)
- assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol, use_broadcast=False)
@with_seed()
@@ -2508,15 +2600,18 @@ def test_np_trace():
with mx.autograd.record():
out_mx = test_trace(data.as_np_ndarray())
assert out_mx.shape == expected_np.shape
- assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol)
+ assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol,
+ use_broadcast=False)
out_mx.backward()
backward_expected = g(data_np, axis1=axis1, axis2=axis2, offset=offset)
- assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
+ assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol,
+ use_broadcast=False)
# Test imperative once again
data = mx.nd.array(data_np, dtype=dtype)
out_mx = np.trace(data.as_np_ndarray(), axis1=axis1, axis2=axis2, offset=offset)
- assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol)
+ assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol,
+ use_broadcast=False)
# bad params
params = [
@@ -2564,11 +2659,11 @@ def test_np_windows():
if hybridize:
mx_func.hybridize()
mx_out = mx_func(x)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
# test imperative
mx_out = getattr(np, func)(M=config, dtype=dtype)
np_out = np_func(M=config).astype(dtype)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@with_seed()
@@ -2600,15 +2695,15 @@ def test_np_flip():
with mx.autograd.record():
mx_out = test_flip(x)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False)
mx_out.backward()
np_backward = _np.ones(np_out.shape)
- assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol)
+ assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol, use_broadcast=False)
# Test imperative once again
mx_out = np.flip(x, axis)
np_out = _np.flip(x.asnumpy(), axis)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False)
@with_seed()
@@ -2636,11 +2731,11 @@ def test_np_around():
np_out = _np.around(x.asnumpy(), d)
mx_out = test_around(x)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False)
mx_out = np.around(x, d)
np_out = _np.around(x.asnumpy(), d)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False)
@with_seed()
@@ -2696,18 +2791,18 @@ def test_np_arctan2():
with mx.autograd.record():
mx_out = test_arctan2(x1, x2)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False)
mx_out.backward()
np_backward_1 = x21 / (x11 * x11 + x21 * x21)
np_backward_2 = -1 * x11 / (x11 * x11 + x21 * x21)
np_backward_1 = dimReduce(np_backward_1, x11)
np_backward_2 = dimReduce(np_backward_2, x21)
- assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol)
- assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol)
+ assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol, use_broadcast=False)
+ assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol, use_broadcast=False)
mx_out = np.arctan2(x1, x2)
np_out = _np.arctan2(x1.asnumpy(), x2.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False)
@with_seed()
@@ -2733,13 +2828,13 @@ def test_np_nonzero():
np_out = _np.transpose(np_out)
mx_out = test_nonzero(x)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol, use_broadcast=False)
# Test imperative once again
mx_out = npx.nonzero(x)
np_out = _np.nonzero(x.asnumpy())
np_out = _np.transpose(np_out)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol, use_broadcast=False)
@with_seed()
@@ -2794,18 +2889,18 @@ def test_np_hypot():
with mx.autograd.record():
mx_out = test_hypot(x1, x2)
assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False)
mx_out.backward()
np_backward_1 = x11 / np_out
np_backward_2 = x21 / np_out
np_backward_1 = dimReduce(np_backward_1, x11)
np_backward_2 = dimReduce(np_backward_2, x21)
- assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol)
- assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol)
+ assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol, use_broadcast=False)
+ assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol, use_broadcast=False)
mx_out = np.hypot(x1, x2)
np_out = _np.hypot(x1.asnumpy(), x2.asnumpy())
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False)
@with_seed()
@@ -2847,13 +2942,13 @@ def test_np_unique():
mx_out = test_unique(x)
assert mx_out[0].shape == np_out[0].shape
for i in range(4):
- assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5, use_broadcast=False)
# Test imperative once again
mx_out = np.unique(x, *config[1:])
np_out = _np.unique(x.asnumpy(), *config[1:])
for i in range(4):
- assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5)
+ assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5, use_broadcast=False)
if __name__ == '__main__':