You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/11/06 10:29:49 UTC

[incubator-mxnet] branch master updated: Add test case for oneDNN RNN (#19465)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f3c69c1  Add test case for oneDNN RNN (#19465)
f3c69c1 is described below

commit f3c69c1adfbd866681976248b08823fa4c1e26ed
Author: bgawrych <ba...@intel.com>
AuthorDate: Fri Nov 6 11:27:53 2020 +0100

    Add test case for oneDNN RNN (#19465)
---
 tests/python/mkl/test_mkldnn.py | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index e94001b..41a0c5f 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -29,6 +29,7 @@ from mxnet.gluon import nn
 from mxnet.test_utils import *
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.append(os.path.join(curr_path, '../unittest/'))
+import itertools
 
 @pytest.mark.seed(1234)
 def test_mkldnn_ndarray_slice():
@@ -649,3 +650,30 @@ def test_elemwise_add():
     for stype in stypes:
         check_elemwise_add_training(stype)
 
+def test_rnn():
+    SEQ_LENGTH = [2**10, 2**5]
+    STATE_SIZE = [1, 2]
+    BATCH_SIZE = [4]
+    INPUT_SIZE = [4]
+    def batch_check(seq_length, state_size, batch_size, input_size):
+        modes_params = [('rnn_relu', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size),)),
+                        ('rnn_tanh', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size),)),
+                        ('gru', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size*3),))
+                        ]
+        for m, p in modes_params:
+            data = mx.np.random.normal(0, 1, (seq_length, batch_size, input_size))
+            state = mx.np.random.normal(0, 1, (1, batch_size, state_size))
+            data.attach_grad()
+            state.attach_grad()
+
+            with mx.autograd.record():
+                y = mx.npx.rnn(data=data, parameters=p, mode=m, \
+                               state=state, state_size=state_size, num_layers=1)
+            assert y.shape == (seq_length, batch_size, state_size)
+            assert type(y[0]).__name__ == 'ndarray'
+            y.backward()
+            assert state.shape == (1, batch_size, state_size)
+            assert type(state[0]).__name__ == 'ndarray'
+
+    for sl, ss, bs, in_s in itertools.product(SEQ_LENGTH, STATE_SIZE, BATCH_SIZE, INPUT_SIZE): 
+        batch_check(sl, ss, bs, in_s)
\ No newline at end of file