You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/03 01:24:41 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export support for RNN (#19958)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 98cb0b3  [v1.x] ONNX export support for RNN (#19958)
98cb0b3 is described below

commit 98cb0b31c7a3a85c2fefe7cad4d737b67f702cd7
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Tue Mar 2 17:23:22 2021 -0800

    [v1.x] ONNX export support for RNN (#19958)
    
    * convert RNN
    
    * use split
    
    * fix sanity
    
    * fix param
    
    * fix sanity
    
    * fix space
    
    * add note
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 87 ++++++++++++++++++++++
 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py   |  6 +-
 tests/python-pytest/onnx/test_operators.py         | 14 ++++
 3 files changed, 106 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 2521cf5..7999576 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -4105,3 +4105,90 @@ def convert_sequence_reverse(node, **kwargs):
         ]
 
     return nodes
+
+
+@mx_op.register("RNN")
+def convert_RNN(node, **kwargs):
+    """Map MXNet's RNN operator attributes to onnx's operators
+    and return the created node.
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    mode = str(attrs.get('mode'))
+    if mode != 'lstm':
+        raise NotImplementedError('Currently RNN onnx export only supports lstm mode')
+
+    bidirectional = str(attrs.get('bidirectional', 'False'))
+    if bidirectional != 'False':
+        raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')
+
+    num_layers = int(attrs.get('num_layers', '1'))
+    if num_layers != 1:
+        raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1')
+
+    p = float(attrs.get('p', '0'))
+    if p != 0:
+        raise NotImplementedError('Currently RNN onnx export only supports p equals to 0')
+
+    use_sequence_length = str(attrs.get('use_sequence_length', 'False'))
+    if use_sequence_length != 'False':
+        raise NotImplementedError('Currently RNN onnx export only supports use_sequence_length equals to False')
+
+    projection_size = str(attrs.get('projection_size', 'None'))
+    if projection_size != 'None':
+        raise NotImplementedError('Currently RNN onnx export only supports projection_size equals to None')
+
+    state_outputs = str(attrs.get('state_outputs', 'False'))
+    if state_outputs != 'True':
+        raise NotImplementedError('Currently RNN onnx export only supports state_outputs equals to True')
+
+    state_size = int(attrs.get('state_size'))
+    data = input_nodes[0]
+    param = input_nodes[1]
+    initial_h = input_nodes[2]
+    initial_c = input_nodes[3]
+
+    create_tensor([0], name+'_0', kwargs['initializer'])
+    create_tensor([1], name+'_1', kwargs['initializer'])
+    create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
+    create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+    create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+    create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
+    create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+
+    nodes = [
+        make_node('Shape', [data], [name+'_data_shape']),
+        make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+        # get W
+        make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
+        make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
+        make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
+        make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
+        make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
+        make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
+        # get R
+        make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
+        make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
+        make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
+        make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
+        make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
+        # get B
+        make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
+        make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
+        make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
+                                            name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
+        make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
+                             name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
+        make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
+        # get seq_len
+        make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+        make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+        # compute LSTM
+        make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
+                  [name+'0_', name+'1', name+'2'], hidden_size=state_size),
+        make_node('Squeeze', [name+'0_'], [name], axes=[1]),
+    ]
+    return nodes
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index af6af8b..898a8df 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -141,7 +141,11 @@ class MXNetGraph(object):
 
         out_names = list()
         for name in sym.list_outputs():
-            if name.endswith('_output'):
+            if name.endswith('_state_output'): # handel special cases for RNN operator
+                out_names.append(name[:-len('_state_output')]+'1')
+            elif name.endswith('_statecell_output'): # handel special cases for RNN operator
+                out_names.append(name[:-len('_statecell_output')]+'2')
+            elif name.endswith('_output'):
                 out_names.append(name[:-len('_output')])
             elif re.search('.*_output[0-9]$', name):
                 out_names.append(name[:-len('_output0')]+name[-1])
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index f73672b..8ebfbab 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -1204,3 +1204,17 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params):
     seq_len = mx.nd.array(params[1])
     M1 = def_model('SequenceReverse', use_sequence_length=True)
     op_export_test('SequenceReverse1', M1, [x, seq_len], tmp_path)
+
+
+# onnx LSTM from opset 11 does not support float64
+@pytest.mark.parametrize('dtype', ['float32'])
+@pytest.mark.parametrize('state_size', [128, 256, 512])
+def test_onnx_export_RNN(tmp_path, dtype, state_size):
+    # the current implementation fails assertion checks for large parm/state_size. 
+    M = def_model('RNN', mode='lstm', state_size=state_size, state_outputs=True,  num_layers=1, p=0)
+    x = mx.nd.random.normal(0, 10, (38, 1, 300), dtype=dtype)
+    batch_size = np.shape(x)[1]
+    input_size = np.shape(x)[2]
+    param = mx.nd.random.normal(0, 1, [4*state_size*input_size + 4*state_size*state_size + 8*state_size], dtype=dtype)
+    state = mx.nd.random.uniform(-1, 1, [1, batch_size, state_size], dtype=dtype)
+    cell = mx.nd.random.uniform(-1, 1, [1, batch_size, state_size], dtype=dtype)