You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/07/28 03:30:14 UTC
[incubator-mxnet] branch v1.x updated: [1.x][LT] Add forward,
backward test for linalg.gemm2 (#18784)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new d009345 [1.x][LT] Add forward, backward test for linalg.gemm2 (#18784)
d009345 is described below
commit d0093458e3be5e76d78750043c4e5a3f01a7d056
Author: Chaitanya Prakash Bapat <ch...@gmail.com>
AuthorDate: Mon Jul 27 20:28:43 2020 -0700
[1.x][LT] Add forward, backward test for linalg.gemm2 (#18784)
* added forward, backward test for gemm2
* add backward check
* correct gradient assert
* move test inside linalg_ops
* add shape checks
---
tests/nightly/test_large_array.py | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py
index 020a707..306c827 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -1207,6 +1207,25 @@ def test_linalg():
assert A.grad[0,0,0] == 4
assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, atol=1e-5)
+ def check_gemm2():
+ def run_gemm2(inp1,inp2):
+ inp1.attach_grad()
+ inp2.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.linalg.gemm2(inp1,inp2)
+ return inp1.grad, inp2.grad, out
+
+ inp1=mx.nd.ones(shape=(SMALL_Y, LARGE_X))
+ inp1[0][0]=0.1
+ inp2=mx.nd.ones(shape=(LARGE_X, SMALL_Y))
+ inp1_grad, inp2_grad, out= run_gemm2(inp1,inp2)
+ assert out.asnumpy()[0][0] == LARGE_X
+ assert out.shape == (SMALL_Y, SMALL_Y)
+ out.backward()
+ assert inp1_grad.shape == (SMALL_Y, LARGE_X)
+ assert inp2_grad.shape == (LARGE_X, SMALL_Y)
+ assert_almost_equal(inp2_grad.asnumpy()[0][0],49.1)
+
def check_det():
def run_det(inp):
inp.attach_grad()
@@ -1321,6 +1340,7 @@ def test_linalg():
check_potrf()
check_potri()
check_syrk_batch()
+ check_gemm2()
check_det()
check_inverse()
check_trmm()