You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/15 18:00:05 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX Support for pretrained StandardRNN models (#20017)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new c25da59  [v1.x] ONNX Support for pretrained StandardRNN models (#20017)
c25da59 is described below

commit c25da598eb78efcfaa6466935459c33ab67ff4e1
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Mon Mar 15 10:56:52 2021 -0700

    [v1.x] ONNX Support for pretrained StandardRNN models (#20017)
    
    * add one  o;
    
    * add support for 2-layer lstm
    
    * add standard rnn pretrained lstm  model test
    
    * fix seed
    
    * fix sanity
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 171 ++++++++++++++++-----
 tests/python-pytest/onnx/test_onnxruntime.py       |  53 +++++++
 tests/python-pytest/onnx/test_operators.py         |  30 ++--
 3 files changed, 205 insertions(+), 49 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 6e612d9..462564a 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -4154,8 +4154,6 @@ def convert_RNN(node, **kwargs):
         raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')
 
     num_layers = int(attrs.get('num_layers', '1'))
-    if num_layers != 1:
-        raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1')
 
     p = float(attrs.get('p', '0'))
     if p != 0:
@@ -4179,44 +4177,139 @@ def convert_RNN(node, **kwargs):
     initial_h = input_nodes[2]
     initial_c = input_nodes[3]
 
-    create_tensor([0], name+'_0', kwargs['initializer'])
-    create_tensor([1], name+'_1', kwargs['initializer'])
-    create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
-    create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
-    create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
-    create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
-    create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+    nodes = []
+
+    if num_layers == 2:
+        create_tensor([0], name+'_0', kwargs['initializer'])
+        create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+        create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+
+        create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
+        create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+
+        create_tensor([4*4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
+
+        nodes += [
+            make_node('Shape', [data], [name+'_data_shape']),
+            make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+
+            # Layer 0
+            # get W
+            make_node('Slice', [param, name+'_0', name+'_4*state_size^2'], [name+'_W0_1d']),
+            make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02', name+'_W03']),
+            make_node('Concat', [name+'_W00', name+'_W03', name+'_W01', name+'_W02'], [name+'_W0_'], axis=0),
+            make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
+            # get R
+            make_node('Add', [name+'_4*state_size^2', name+'_4*state_size^2'], [name+'_R0_offset']),
+            make_node('Slice', [param, name+'_4*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
+            make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02', name+'_R03']),
+            make_node('Concat', [name+'_R00', name+'_R03', name+'_R01', name+'_R02'], [name+'_R0_'], axis=0),
+            make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
+            # get B
+            make_node('Add', [name+'_WR_offset', name+'_8*state_size'], [name+'_B0_offset']),
+            make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
+            make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', name+'_B03',
+                                                 name+'_B04', name+'_B05', name+'_B06', name+'_B07']),
+            make_node('Concat', [name+'_B00', name+'_B03', name+'_B01', name+'_B02',
+                                 name+'_B04', name+'_B07', name+'_B05', name+'_B06'], [name+'_B0_'], axis=0),
+            make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
+            # get initial states
+            make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
+            make_node('Split', [initial_c], [name+'_initial_c0', name+'_initial_c1'], axis=0),
+            # get seq_len
+            make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+            make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+            # Layer 0 LSTM
+            make_node('LSTM', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
+                               name+'_initial_h0', name+'_initial_c0'],
+                      [name+'_lstm0_out_', name+'_lstm0_h', name+'_lstm0_c'], hidden_size=state_size),
+            make_node('Squeeze', [name+'_lstm0_out_'], [name+'_lstm0_out'], axes=[1]),
+
+            # Layer 1
+            # get W
+            make_node('Add', [name+'_R0_offset', name+'_4*state_size^2'], [name+'_W1_offset']),
+            make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
+            make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12', name+'_W13']),
+            make_node('Concat', [name+'_W10', name+'_W13', name+'_W11', name+'_W12'], [name+'_W1_'], axis=0),
+            make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
+            # get R
+            make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
+            make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12', name+'_R13']),
+            make_node('Concat', [name+'_R10', name+'_R13', name+'_R11', name+'_R12'], [name+'_R1_'], axis=0),
+            make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
+            # get B
+            make_node('Add', [name+'_B0_offset', name+'_8*state_size'], [name+'_B1_offset']),
+            make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
+            make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', name+'_B13',
+                                                 name+'_B14', name+'_B15', name+'_B16', name+'_B17']),
+            make_node('Concat', [name+'_B10', name+'_B13', name+'_B11', name+'_B12',
+                                 name+'_B14', name+'_B17', name+'_B15', name+'_B16'], [name+'_B1_'], axis=0),
+            make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
+            # Layer 1 LSTM
+            make_node('LSTM', [name+'_lstm0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
+                               name+'_initial_h1', name+'_initial_c1'],
+                      [name+'_lstm1_out_', name+'_lstm1_h', name+'_lstm1_c'], hidden_size=state_size),
+            make_node('Squeeze', [name+'_lstm1_out_'], [name], axes=[1]),
+            make_node('Concat', [name+'_lstm0_h', name+'_lstm1_h'], [name+'1'], axis=0),
+            make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
+        ]
+    elif num_layers == 1:
+        create_tensor([0], name+'_0', kwargs['initializer'])
+        create_tensor([1], name+'_1', kwargs['initializer'])
+        create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
+        create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+        create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+        create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
+        create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+
+        nodes += [
+            make_node('Shape', [data], [name+'_data_shape']),
+            make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+            # get W
+            make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
+            make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
+            make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
+            make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
+            make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
+            make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
+            # get R
+            make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
+            make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
+            make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
+            make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
+            make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
+            # get B
+            make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
+            make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
+            make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
+                                                name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
+            make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
+                                 name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
+            make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
+            # get seq_len
+            make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+            make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+            # compute LSTM
+            make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
+                      [name+'0_', name+'1', name+'2'], hidden_size=state_size),
+            make_node('Squeeze', [name+'0_'], [name], axes=[1]),
+        ]
+    else:
+        raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
+
+    return nodes
+
+@mx_op.register('_rnn_param_concat')
+def convert_rnn_param_concat(node, **kwargs):
+    """Map MXNet’s _rnn_param_concat operator
+    """
+    from onnx.helper import make_node
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    axis = int(attrs.get('dim', 1))
 
     nodes = [
-        make_node('Shape', [data], [name+'_data_shape']),
-        make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
-        # get W
-        make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
-        make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
-        make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
-        make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
-        make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
-        make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
-        # get R
-        make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
-        make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
-        make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
-        make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
-        make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
-        # get B
-        make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
-        make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
-        make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
-                                            name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
-        make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
-                             name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
-        make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
-        # get seq_len
-        make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
-        make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
-        # compute LSTM
-        make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
-                  [name+'0_', name+'1', name+'2'], hidden_size=state_size),
-        make_node('Squeeze', [name+'0_'], [name], axes=[1]),
+        make_node('Concat', input_nodes, [name], axis=axis)
     ]
+
     return nodes
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index 2fba092..bf32259 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -721,6 +721,58 @@ def test_distilbert_inference_onnxruntime(tmp_path, model_name):
 
 
 @with_seed()
+@pytest.mark.parametrize('model_name', [('standard_lstm_lm_200', 200), ('standard_lstm_lm_650', 650),
+                                        ('standard_lstm_lm_1500', 1500)])
+@pytest.mark.parametrize('seq_length', [16, 32])
+def test_standard_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length):
+    try:
+        import gluonnlp as nlp
+        ctx = mx.cpu()
+        dataset= 'wikitext-2'
+        model, _ = nlp.model.get_model(
+            name=model_name[0],
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset,
+            dropout=0)
+        model.hybridize()
+
+        batch = 2
+        num_hidden = model_name[1]
+        num_layers = 2
+        inputs = mx.nd.random.randint(0, 33278, shape=(seq_length, batch),
+                                      ctx=ctx).astype('float32')
+        begin_state = model.begin_state(func=mx.nd.random.uniform, low=0, high=1,
+                                        batch_size=batch, dtype='float32', ctx=ctx)
+        out, out_state= model(inputs, begin_state)
+
+        prefix = "%s/standard_rnn_lstm" % tmp_path
+        model.export(prefix)
+        sym_file = "%s-symbol.json" % prefix
+        params_file = "%s-0000.params" % prefix
+        onnx_file = "%s.onnx" % prefix
+
+        input_shapes = [(seq_length, batch), np.shape(begin_state[0]), np.shape(begin_state[1])]
+        converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes,
+                                                            [np.float32, np.float32, np.float32],
+                                                            onnx_file, verbose=True)
+        sess_options = onnxruntime.SessionOptions()
+        sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        sess = onnxruntime.InferenceSession(onnx_file, sess_options)
+
+        in_tensors = [inputs, begin_state[0], begin_state[1]]
+        input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
+        pred = sess.run(None, input_dict)
+
+        assert_almost_equal(out, pred[2])
+        assert_almost_equal(out_state[0], pred[0])
+        assert_almost_equal(out_state[1], pred[1])
+
+    finally:
+        shutil.rmtree(tmp_path)
+
+
+@with_seed()
 @pytest.mark.parametrize('model_name', ['mobilenet1.0', 'inceptionv3', 'darknet53', 'resnest14'])
 def test_dynamic_shape_cv_inference_onnxruntime(tmp_path, model_name):
     tmp_path = str(tmp_path)
@@ -815,3 +867,4 @@ def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model):
 
     finally:
         shutil.rmtree(tmp_path)
+
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 0dc283b..36b687a 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -1214,13 +1214,23 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params):
 
 # onnx LSTM from opset 11 does not support float64
 @pytest.mark.parametrize('dtype', ['float32'])
-@pytest.mark.parametrize('state_size', [128, 256, 512])
-def test_onnx_export_RNN(tmp_path, dtype, state_size):
-    # the current implementation fails assertion checks for large parm/state_size. 
-    M = def_model('RNN', mode='lstm', state_size=state_size, state_outputs=True,  num_layers=1, p=0)
-    x = mx.nd.random.normal(0, 10, (38, 1, 300), dtype=dtype)
-    batch_size = np.shape(x)[1]
-    input_size = np.shape(x)[2]
-    param = mx.nd.random.normal(0, 1, [4*state_size*input_size + 4*state_size*state_size + 8*state_size], dtype=dtype)
-    state = mx.nd.random.uniform(-1, 1, [1, batch_size, state_size], dtype=dtype)
-    cell = mx.nd.random.uniform(-1, 1, [1, batch_size, state_size], dtype=dtype)
+@pytest.mark.parametrize('state_size', [32, 40])
+@pytest.mark.parametrize('input_size', [32, 40, 64])
+@pytest.mark.parametrize('num_layers', [1, 2])
+@pytest.mark.parametrize('batch_size', [1, 3, 5])
+@pytest.mark.parametrize('seq_length', [16, 32])
+def test_onnx_export_RNN(tmp_path, dtype, state_size, input_size, num_layers, batch_size, seq_length):
+    # TODO: The current implementation fails assertion checks for large parm/state_size. 
+    
+    # for num_layers >= 2, input_size must equal to state_size
+    if num_layers >= 2 and input_size != state_size:
+        return
+    
+    M = def_model('RNN', mode='lstm', state_size=state_size, state_outputs=True,  num_layers=num_layers, p=0)
+    x = mx.nd.random.normal(0, 10, (seq_length, batch_size, input_size), dtype=dtype)
+    param = mx.nd.random.normal(0, 1, [num_layers*4*state_size*input_size +
+                                       num_layers*4*state_size*state_size +
+                                       num_layers*8*state_size], dtype=dtype)
+    state = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
+    cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
+    op_export_test('rnn', M, [x, param, state, cell], tmp_path)