You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/30 03:24:31 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export support for GRU (#20060)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 0a3719a  [v1.x] ONNX export support for GRU (#20060)
0a3719a is described below

commit 0a3719a706bdf6a93316ad1e6e7326c86a73138e
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Mon Mar 29 20:22:31 2021 -0700

    [v1.x] ONNX export support for GRU (#20060)
    
    * export gru
    
    * fix sanity
    
    * reduce change state_size
    
    Co-authored-by: Wei Chu <we...@amazon.com>
    Co-authored-by: Zhaoqi Zhu <zh...@gmail.com>
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 350 ++++++++++++++-------
 tests/python-pytest/onnx/test_operators.py         |  31 +-
 2 files changed, 251 insertions(+), 130 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 22e6282..eb91a3c 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -4240,10 +4240,6 @@ def convert_RNN(node, **kwargs):
 
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    mode = str(attrs.get('mode'))
-    if mode != 'lstm':
-        raise NotImplementedError('Currently RNN onnx export only supports lstm mode')
-
     bidirectional = str(attrs.get('bidirectional', 'False'))
     if bidirectional != 'False':
         raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')
@@ -4270,128 +4266,246 @@ def convert_RNN(node, **kwargs):
     data = input_nodes[0]
     param = input_nodes[1]
     initial_h = input_nodes[2]
-    initial_c = input_nodes[3]
 
     nodes = []
 
-    if num_layers == 2:
-        create_tensor([0], name+'_0', kwargs['initializer'])
-        create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
-        create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+    mode = str(attrs.get('mode'))
+    if mode == 'lstm':
+        initial_c = input_nodes[3]
+        if num_layers == 2:
+            create_tensor([0], name+'_0', kwargs['initializer'])
+            create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+            create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+            create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
+            create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+            create_tensor([4*4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
 
-        create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
-        create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+            nodes += [
+                make_node('Shape', [data], [name+'_data_shape']),
+                make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+
+                # Layer 0
+                # get W
+                make_node('Slice', [param, name+'_0', name+'_4*state_size^2'], [name+'_W0_1d']),
+                make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02', name+'_W03']),
+                make_node('Concat', [name+'_W00', name+'_W03', name+'_W01', name+'_W02'], [name+'_W0_'], axis=0),
+                make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
+                # get R
+                make_node('Add', [name+'_4*state_size^2', name+'_4*state_size^2'], [name+'_R0_offset']),
+                make_node('Slice', [param, name+'_4*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
+                make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02', name+'_R03']),
+                make_node('Concat', [name+'_R00', name+'_R03', name+'_R01', name+'_R02'], [name+'_R0_'], axis=0),
+                make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
+                # get B
+                make_node('Add', [name+'_WR_offset', name+'_8*state_size'], [name+'_B0_offset']),
+                make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
+                make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', name+'_B03',
+                                                     name+'_B04', name+'_B05', name+'_B06', name+'_B07']),
+                make_node('Concat', [name+'_B00', name+'_B03', name+'_B01', name+'_B02',
+                                     name+'_B04', name+'_B07', name+'_B05', name+'_B06'], [name+'_B0_'], axis=0),
+                make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
+                # get initial states
+                make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
+                make_node('Split', [initial_c], [name+'_initial_c0', name+'_initial_c1'], axis=0),
+                # get seq_len
+                make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+                make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+                # Layer 0 LSTM
+                make_node('LSTM', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
+                                   name+'_initial_h0', name+'_initial_c0'],
+                          [name+'_lstm0_out_', name+'_lstm0_h', name+'_lstm0_c'], hidden_size=state_size),
+                make_node('Squeeze', [name+'_lstm0_out_'], [name+'_lstm0_out'], axes=[1]),
+
+                # Layer 1
+                # get W
+                make_node('Add', [name+'_R0_offset', name+'_4*state_size^2'], [name+'_W1_offset']),
+                make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
+                make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12', name+'_W13']),
+                make_node('Concat', [name+'_W10', name+'_W13', name+'_W11', name+'_W12'], [name+'_W1_'], axis=0),
+                make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
+                # get R
+                make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
+                make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12', name+'_R13']),
+                make_node('Concat', [name+'_R10', name+'_R13', name+'_R11', name+'_R12'], [name+'_R1_'], axis=0),
+                make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
+                # get B
+                make_node('Add', [name+'_B0_offset', name+'_8*state_size'], [name+'_B1_offset']),
+                make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
+                make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', name+'_B13',
+                                                     name+'_B14', name+'_B15', name+'_B16', name+'_B17']),
+                make_node('Concat', [name+'_B10', name+'_B13', name+'_B11', name+'_B12',
+                                     name+'_B14', name+'_B17', name+'_B15', name+'_B16'], [name+'_B1_'], axis=0),
+                make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
+                # Layer 1 LSTM
+                make_node('LSTM', [name+'_lstm0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
+                                   name+'_initial_h1', name+'_initial_c1'],
+                          [name+'_lstm1_out_', name+'_lstm1_h', name+'_lstm1_c'], hidden_size=state_size),
+                make_node('Squeeze', [name+'_lstm1_out_'], [name], axes=[1]),
+                make_node('Concat', [name+'_lstm0_h', name+'_lstm1_h'], [name+'1'], axis=0),
+                make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
+            ]
+        elif num_layers == 1:
+            create_tensor([0], name+'_0', kwargs['initializer'])
+            create_tensor([1], name+'_1', kwargs['initializer'])
+            create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
+            create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+            create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+            create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
+            create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
 
-        create_tensor([4*4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
+            nodes += [
+                make_node('Shape', [data], [name+'_data_shape']),
+                make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+                # get W
+                make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
+                make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
+                make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
+                make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
+                make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
+                make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
+                # get R
+                make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
+                make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
+                make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
+                make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
+                make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
+                # get B
+                make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
+                make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
+                make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
+                                                    name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
+                make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
+                                     name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
+                make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
+                # get seq_len
+                make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+                make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+                # compute LSTM
+                make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
+                          [name+'0_', name+'1', name+'2'], hidden_size=state_size),
+                make_node('Squeeze', [name+'0_'], [name], axes=[1]),
+            ]
+        else:
+            raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
 
-        nodes += [
-            make_node('Shape', [data], [name+'_data_shape']),
-            make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
-
-            # Layer 0
-            # get W
-            make_node('Slice', [param, name+'_0', name+'_4*state_size^2'], [name+'_W0_1d']),
-            make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02', name+'_W03']),
-            make_node('Concat', [name+'_W00', name+'_W03', name+'_W01', name+'_W02'], [name+'_W0_'], axis=0),
-            make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
-            # get R
-            make_node('Add', [name+'_4*state_size^2', name+'_4*state_size^2'], [name+'_R0_offset']),
-            make_node('Slice', [param, name+'_4*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
-            make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02', name+'_R03']),
-            make_node('Concat', [name+'_R00', name+'_R03', name+'_R01', name+'_R02'], [name+'_R0_'], axis=0),
-            make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
-            # get B
-            make_node('Add', [name+'_WR_offset', name+'_8*state_size'], [name+'_B0_offset']),
-            make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
-            make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', name+'_B03',
-                                                 name+'_B04', name+'_B05', name+'_B06', name+'_B07']),
-            make_node('Concat', [name+'_B00', name+'_B03', name+'_B01', name+'_B02',
-                                 name+'_B04', name+'_B07', name+'_B05', name+'_B06'], [name+'_B0_'], axis=0),
-            make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
-            # get initial states
-            make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
-            make_node('Split', [initial_c], [name+'_initial_c0', name+'_initial_c1'], axis=0),
-            # get seq_len
-            make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
-            make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
-            # Layer 0 LSTM
-            make_node('LSTM', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
-                               name+'_initial_h0', name+'_initial_c0'],
-                      [name+'_lstm0_out_', name+'_lstm0_h', name+'_lstm0_c'], hidden_size=state_size),
-            make_node('Squeeze', [name+'_lstm0_out_'], [name+'_lstm0_out'], axes=[1]),
-
-            # Layer 1
-            # get W
-            make_node('Add', [name+'_R0_offset', name+'_4*state_size^2'], [name+'_W1_offset']),
-            make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
-            make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12', name+'_W13']),
-            make_node('Concat', [name+'_W10', name+'_W13', name+'_W11', name+'_W12'], [name+'_W1_'], axis=0),
-            make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
-            # get R
-            make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
-            make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12', name+'_R13']),
-            make_node('Concat', [name+'_R10', name+'_R13', name+'_R11', name+'_R12'], [name+'_R1_'], axis=0),
-            make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
-            # get B
-            make_node('Add', [name+'_B0_offset', name+'_8*state_size'], [name+'_B1_offset']),
-            make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
-            make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', name+'_B13',
-                                                 name+'_B14', name+'_B15', name+'_B16', name+'_B17']),
-            make_node('Concat', [name+'_B10', name+'_B13', name+'_B11', name+'_B12',
-                                 name+'_B14', name+'_B17', name+'_B15', name+'_B16'], [name+'_B1_'], axis=0),
-            make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
-            # Layer 1 LSTM
-            make_node('LSTM', [name+'_lstm0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
-                               name+'_initial_h1', name+'_initial_c1'],
-                      [name+'_lstm1_out_', name+'_lstm1_h', name+'_lstm1_c'], hidden_size=state_size),
-            make_node('Squeeze', [name+'_lstm1_out_'], [name], axes=[1]),
-            make_node('Concat', [name+'_lstm0_h', name+'_lstm1_h'], [name+'1'], axis=0),
-            make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
-        ]
-    elif num_layers == 1:
-        create_tensor([0], name+'_0', kwargs['initializer'])
-        create_tensor([1], name+'_1', kwargs['initializer'])
-        create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
-        create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
-        create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
-        create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
-        create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+    elif mode == 'gru':
+        if num_layers == 2:
+            create_tensor([0], name+'_0', kwargs['initializer'])
+            create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
+            create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
+            create_tensor([1, 3*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
+            create_tensor([1, 6*state_size], name+'_B_shape', kwargs['initializer'])
+            create_tensor([4*3*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
 
-        nodes += [
-            make_node('Shape', [data], [name+'_data_shape']),
-            make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
-            # get W
-            make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
-            make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
-            make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
-            make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
-            make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
-            make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
-            # get R
-            make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
-            make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
-            make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
-            make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
-            make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
-            # get B
-            make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
-            make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
-            make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
-                                                name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
-            make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
-                                 name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
-            make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
-            # get seq_len
-            make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
-            make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
-            # compute LSTM
-            make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
-                      [name+'0_', name+'1', name+'2'], hidden_size=state_size),
-            make_node('Squeeze', [name+'0_'], [name], axes=[1]),
-        ]
-    else:
-        raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
+            nodes += [
+                make_node('Shape', [data], [name+'_data_shape']),
+                make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+
+                # Layer 0
+                # get W
+                make_node('Slice', [param, name+'_0', name+'_3*state_size^2'], [name+'_W0_1d']),
+                make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02']),
+                make_node('Concat', [name+'_W01', name+'_W00', name+'_W02'], [name+'_W0_'], axis=0),
+                make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
+                # get R
+                make_node('Add', [name+'_3*state_size^2', name+'_3*state_size^2'], [name+'_R0_offset']),
+                make_node('Slice', [param, name+'_3*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
+                make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02']),
+                make_node('Concat', [name+'_R01', name+'_R00', name+'_R02'], [name+'_R0_'], axis=0),
+                make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
+                # get B
+                make_node('Add', [name+'_WR_offset', name+'_6*state_size'], [name+'_B0_offset']),
+                make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
+                make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02',
+                                                     name+'_B03', name+'_B04', name+'_B05']),
+                make_node('Concat', [name+'_B01', name+'_B00', name+'_B02',
+                                     name+'_B04', name+'_B03', name+'_B05'], [name+'_B0_'], axis=0),
+                make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
+                # get initial states
+                make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
+                # get seq_len
+                make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+                make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+                # Layer 0 GRU
+                make_node('GRU', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
+                                  name+'_initial_h0'],
+                          [name+'_gru0_out_', name+'_gru0_h'], hidden_size=state_size, linear_before_reset=1),
+                make_node('Squeeze', [name+'_gru0_out_'], [name+'_gru0_out'], axes=[1]),
+
+                # Layer 1
+                # get W
+                make_node('Add', [name+'_R0_offset', name+'_3*state_size^2'], [name+'_W1_offset']),
+                make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
+                make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12']),
+                make_node('Concat', [name+'_W11', name+'_W10', name+'_W12'], [name+'_W1_'], axis=0),
+                make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
+                # get R
+                make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
+                make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12']),
+                make_node('Concat', [name+'_R11', name+'_R10', name+'_R12'], [name+'_R1_'], axis=0),
+                make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
+                # get B
+                make_node('Add', [name+'_B0_offset', name+'_6*state_size'], [name+'_B1_offset']),
+                make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
+                make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12',
+                                                     name+'_B13', name+'_B14', name+'_B15']),
+                make_node('Concat', [name+'_B11', name+'_B10', name+'_B12',
+                                     name+'_B14', name+'_B13', name+'_B15'], [name+'_B1_'], axis=0),
+                make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
+                # Layer 1 GRU
+                make_node('GRU', [name+'_gru0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
+                                  name+'_initial_h1'],
+                          [name+'_gru1_out_', name+'_gru1_h'], hidden_size=state_size, linear_before_reset=1),
+                make_node('Squeeze', [name+'_gru1_out_'], [name], axes=[1]),
+                make_node('Concat', [name+'_gru0_h', name+'_gru1_h'], [name+'1'], axis=0)
+            ]
 
+        elif num_layers == 1:
+            create_tensor([0], name+'_0', kwargs['initializer'])
+            create_tensor([1], name+'_1', kwargs['initializer'])
+            create_tensor([3*state_size], name+'_3*state_size', kwargs['initializer'])
+            create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
+            create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
+            create_tensor([1, 3*state_size, state_size], name+'_R_shape', kwargs['initializer'])
+            create_tensor([1, 6*state_size], name+'_B_shape', kwargs['initializer'])
+
+            nodes += [
+                make_node('Shape', [data], [name+'_data_shape']),
+                make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+                # get W
+                make_node('Mul', [name+'_3*state_size', name+'_input_size'], [name+'_mul0']),
+                make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
+                make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2']),
+                make_node('Concat', [name+'_W1', name+'_W0', name+'_W2'], [name+'_W_'], axis=0),
+                make_node('Concat', [name+'_1', name+'_3*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
+                make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
+                # get R
+                make_node('Add', [name+'_mul0', name+'_3*state_size^2'], [name+'_add0']),
+                make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
+                make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2']),
+                make_node('Concat', [name+'_R1', name+'_R0', name+'_R2'], [name+'_R_'], axis=0),
+                make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
+                # get B
+                make_node('Add', [name+'_add0', name+'_6*state_size'], [name+'_add1']),
+                make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
+                make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2',
+                                                    name+'_B3', name+'_B4', name+'_B5']),
+                make_node('Concat', [name+'_B1', name+'_B0', name+'_B2',
+                                     name+'_B4', name+'_B3', name+'_B5'], [name+'_B_'], axis=0),
+                make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
+                # get seq_len
+                make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+                make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+                # compute LSTM
+                make_node('GRU', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
+                          [name+'0_', name+'1'], hidden_size=state_size, linear_before_reset=1),
+                make_node('Squeeze', [name+'0_'], [name], axes=[1]),
+            ]
+        else:
+            raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1')
+
+    else:
+        raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
     return nodes
 
 @mx_op.register('_rnn_param_concat')
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 520ac40..51170e6 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -1221,27 +1221,34 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params):
 
 
 # onnx LSTM from opset 11 does not support float64
+@pytest.mark.parametrize('mode', ['lstm', 'gru'])
 @pytest.mark.parametrize('dtype', ['float32'])
-@pytest.mark.parametrize('state_size', [32, 40])
-@pytest.mark.parametrize('input_size', [32, 40, 64])
+@pytest.mark.parametrize('state_size', [16, 32])
+@pytest.mark.parametrize('input_size', [16, 32, 64])
 @pytest.mark.parametrize('num_layers', [1, 2])
-@pytest.mark.parametrize('batch_size', [1, 3, 5])
+@pytest.mark.parametrize('batch_size', [1, 2, 4])
 @pytest.mark.parametrize('seq_length', [16, 32])
-def test_onnx_export_RNN(tmp_path, dtype, state_size, input_size, num_layers, batch_size, seq_length):
+def test_onnx_export_RNN(tmp_path, mode, dtype, state_size, input_size, num_layers, batch_size, seq_length):
     # TODO: The current implementation fails assertion checks for large parm/state_size. 
-    
+
     # for num_layers >= 2, input_size must equal to state_size
     if num_layers >= 2 and input_size != state_size:
         return
-    
-    M = def_model('RNN', mode='lstm', state_size=state_size, state_outputs=True,  num_layers=num_layers, p=0)
+    factor = 3
+    if mode == 'lstm':
+        factor = 4
+
+    M = def_model('RNN', mode=mode, state_size=state_size, state_outputs=True,  num_layers=num_layers, p=0)
     x = mx.nd.random.normal(0, 10, (seq_length, batch_size, input_size), dtype=dtype)
-    param = mx.nd.random.normal(0, 1, [num_layers*4*state_size*input_size +
-                                       num_layers*4*state_size*state_size +
-                                       num_layers*8*state_size], dtype=dtype)
+    param = mx.nd.random.normal(0, 1, [num_layers*factor*state_size*input_size +
+                                       num_layers*factor*state_size*state_size +
+                                       num_layers*2*factor*state_size], dtype=dtype)
     state = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
-    cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
-    op_export_test('rnn', M, [x, param, state, cell], tmp_path)
+    if mode == 'lstm':
+        cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
+        op_export_test('rnn', M, [x, param, state, cell], tmp_path)
+    else:
+        op_export_test('rnn', M, [x, param, state], tmp_path)
 
 
 @pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])