You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2020/07/27 17:09:23 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] add large matrix
tests for linalg ops: det, inverse, trsm, trmm (#18744)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 566d9d3 [v1.x] add large matrix tests for linalg ops: det, inverse, trsm, trmm (#18744)
566d9d3 is described below
commit 566d9d348adeb88a2e7c159ae3a6a8ace7d91740
Author: Manu Seth <22...@users.noreply.github.com>
AuthorDate: Mon Jul 27 10:04:30 2020 -0700
[v1.x] add large matrix tests for linalg ops: det, inverse, trsm, trmm (#18744)
* add linalg large matrix tests
* add batch inputs linalg tests
* reducing bsize to 1 to save time
* move matrix generator to utils
* passing mat size as arg
* import util fn
* fix sanity
* add mx
* call backward
* merge fn
* update grad value
* refactor tests
* add mx
* add shape check
Co-authored-by: Ubuntu <ub...@ip-172-31-41-26.us-west-2.compute.internal>
---
python/mxnet/test_utils.py | 13 +++++
tests/nightly/test_large_array.py | 120 +++++++++++++++++++++++++++++++++++++-
2 files changed, 132 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 9a70f6e..f5d2979 100755
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -297,6 +297,19 @@ def create_vector(size, dtype=np.int64):
a = mx.nd.arange(0, size, dtype=dtype)
return a
+# For testing Large Square Matrix with total size > 2^32 elements
+def get_identity_mat(size):
+ A = mx.nd.zeros((size, size))
+ for i in range(size):
+ A[i, i] = 1
+ return A
+
+# For testing Batch of Large Square Matrix with total size > 2^32 elements
+def get_identity_mat_batch(size):
+ A = get_identity_mat(size)
+ A_np = A.asnumpy()
+ return mx.nd.array([A_np, A_np])
+
def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
shuffle_csr_indices=False, ctx=None):
diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py
index f2128ba..020a707 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -25,7 +25,7 @@ import mxnet as mx
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../python/unittest/'))
-from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor
+from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, get_identity_mat, get_identity_mat_batch
from mxnet import gluon, nd
from common import with_seed, with_post_test_cleanup
from nose.tools import with_setup
@@ -1207,9 +1207,127 @@ def test_linalg():
assert A.grad[0,0,0] == 4
assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, atol=1e-5)
+ def check_det():
+ def run_det(inp):
+ inp.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.linalg.det(inp)
+ return inp.grad, out
+
+ A = get_identity_mat(LARGE_SQ_X)
+ grad, out = run_det(A)
+ assert(out.shape == (1,))
+ assert(out[0] == 1)
+ out.backward()
+ assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
+ assert(grad[0, 0] == 1)
+
+ def check_inverse():
+ def run_inverse(inp):
+ inp.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.linalg.inverse(inp)
+ return inp.grad, out
+
+ A = get_identity_mat(LARGE_SQ_X)
+ grad, out = run_inverse(A)
+ assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
+ assert(out[0, 0] == 1)
+ out.backward()
+ assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
+ assert(grad[0, 0] == -1)
+
+ def check_trmm():
+ def run_trmm(inp):
+ inp.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.linalg.trmm(inp, inp)
+ return inp.grad, out
+
+ A = get_identity_mat(LARGE_SQ_X)
+ grad, out = run_trmm(A)
+ assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
+ assert(out[0, 0] == 1)
+ out.backward()
+ assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
+ assert(grad[0, 0] == 2)
+
+ def check_trsm():
+ def run_trsm(inp):
+ inp.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.linalg.trsm(inp, inp)
+ return inp.grad, out
+
+ A = get_identity_mat(LARGE_SQ_X)
+ grad, out = run_trsm(A)
+ assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
+ assert(out[0, 0] == 1)
+ out.backward()
+ assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
+ assert(grad[0, 0] == 0)
+
+ def check_batch_inverse():
+ def run_inverse(inp):
+ inp.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.linalg.inverse(inp)
+ return inp.grad, out
+
+ B = get_identity_mat_batch(LARGE_SQ_X)
+ grad, out = run_inverse(B)
+ assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
+ assert(out[0, 0, 0] == 1)
+ assert(out[1, 0, 0] == 1)
+ out.backward()
+ assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
+ assert(grad[0, 0, 0] == -1)
+ assert(grad[1, 0, 0] == -1)
+
+ def check_batch_trmm():
+ def run_trmm(inp):
+ inp.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.linalg.trmm(inp, inp)
+ return inp.grad, out
+
+ B = get_identity_mat_batch(LARGE_SQ_X)
+ grad, out = run_trmm(B)
+ assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
+ assert(out[0, 0, 0] == 1)
+ assert(out[1, 0, 0] == 1)
+ out.backward()
+ assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
+ assert(grad[0, 0, 0] == 2)
+ assert(grad[1, 0, 0] == 2)
+
+ def check_batch_trsm():
+ def run_trsm(inp):
+ inp.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.linalg.trsm(inp, inp)
+ return inp.grad, out
+
+ B = get_identity_mat_batch(LARGE_SQ_X)
+ grad, out = run_trsm(B)
+ assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
+ assert(out[0, 0, 0] == 1)
+ assert(out[1, 0, 0] == 1)
+ out.backward()
+ assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
+ assert(grad[0, 0, 0] == 0)
+ assert(grad[1, 0, 0] == 0)
+
check_potrf()
check_potri()
check_syrk_batch()
+ check_det()
+ check_inverse()
+ check_trmm()
+ check_trsm()
+ check_batch_inverse()
+ check_batch_trmm()
+ check_batch_trsm()
def test_basic():