You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/23 01:36:03 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] Onnx Support for Transformer (#20048)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 833cb89  [v1.x] Onnx Support for Transformer (#20048)
833cb89 is described below

commit 833cb89e3e7a5262151a3b512d18a82d6de917be
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Mon Mar 22 18:34:10 2021 -0700

    [v1.x] Onnx Support for Transformer (#20048)
    
    * add ops
    
    * add transformer test
    
    * fix test
    
    * add unit test
    
    * fix sanity
    
    * add to ci
---
 ci/docker/runtime_functions.sh                     |   1 +
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 105 ++++++++++++-
 tests/python-pytest/onnx/test_onnxruntime.py       | 169 ++++++++++++++++++++-
 tests/python-pytest/onnx/test_operators.py         |  26 ++++
 4 files changed, 295 insertions(+), 6 deletions(-)

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index a3484fc..9bfc841 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1274,6 +1274,7 @@ integrationtest_ubuntu_cpu_onnx() {
     pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_action_recognition_model_inference_onnxruntime[inceptionv3_kinetics400]
     pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_bert_inference_onnxruntime
     pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_cv_inference_onnxruntime
+    pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_transformer_pretrained_inference_onnxruntime
 }
 
 integrationtest_ubuntu_gpu_python() {
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 462564a..0903376 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -2064,14 +2064,78 @@ def convert_broadcast_lesser(node, **kwargs):
     """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator
     and return the created node.
     """
-    return create_basic_op_node('Less', node, kwargs)
+    from onnx.helper import make_node
+    name, input_nodes, _ = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    nodes = [
+        make_node('Less', [input_nodes[0], input_nodes[1]], [name+'_lt']),
+        make_node('Cast', [name+'_lt'], [name], to=dtype_t)
+    ]
+
+    return nodes
+
+
+@mx_op.register("broadcast_lesser_equal")
+def convert_broadcast_lesser_equal(node, **kwargs):
+    """Map MXNet's broadcast_lesser_equal operator
+    """
+    from onnx.helper import make_node
+    name, input_nodes, _ = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    nodes = [
+        make_node('LessOrEqual', [input_nodes[0], input_nodes[1]], [name+'_lt']),
+        make_node('Cast', [name+'_lt'], [name], to=dtype_t)
+    ]
+
+    return nodes
+
+
+@mx_op.register("broadcast_greater_equal")
+def convert_broadcast_greater_equal(node, **kwargs):
+    """Map MXNet's broadcast_greater_equal operator
+    """
+    from onnx.helper import make_node
+    name, input_nodes, _ = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    nodes = [
+        make_node('GreaterOrEqual', [input_nodes[0], input_nodes[1]], [name+'_gt']),
+        make_node('Cast', [name+'_gt'], [name], to=dtype_t)
+    ]
+
+    return nodes
+
 
 @mx_op.register("broadcast_greater")
 def convert_broadcast_greater(node, **kwargs):
     """Map MXNet's broadcast_greater operator attributes to onnx's Greater operator
     and return the created node.
     """
-    return create_basic_op_node('Greater', node, kwargs)
+    from onnx.helper import make_node
+    name, input_nodes, _ = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    nodes = [
+        make_node('Greater', [input_nodes[0], input_nodes[1]], [name+'_gt']),
+        make_node('Cast', [name+'_gt'], [name], to=dtype_t)
+    ]
+
+    return nodes
+
 
 @mx_op.register("broadcast_equal")
 def convert_broadcast_equal(node, **kwargs):
@@ -2498,7 +2562,6 @@ def convert_layer_norm(node, **kwargs):
     axes = int(attrs.get('axis', -1))
     eps = attrs.get('eps', 9.99999975e-06)
 
-
     create_tensor([axes], name+"_axes", kwargs["initializer"])
     create_tensor([axes+1], name+"_axes+1", kwargs["initializer"])
     create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
@@ -2519,7 +2582,11 @@ def convert_layer_norm(node, **kwargs):
     if axes == -1:
         nodes += [
             make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]),
-            make_node("Add", [name+"_mul0_out", input_nodes[2]], [name], name=name)
+            # make_node("Add", [name+"_mul0_out", input_nodes[2]], [name])
+            # the Add operator triggers a weird NaN issue in onnxruntime
+            # a workaround is to use Neg + Sub
+            make_node('Neg', [input_nodes[2]], [name+'_neg']),
+            make_node("Sub", [name+"_mul0_out", name+'_neg'], [name])
         ]
     else:
         nodes += [
@@ -4301,7 +4368,7 @@ def convert_RNN(node, **kwargs):
 
 @mx_op.register('_rnn_param_concat')
 def convert_rnn_param_concat(node, **kwargs):
-    """Map MXNet’s _rnn_param_concat operator
+    """Map MXNet's _rnn_param_concat operator
     """
     from onnx.helper import make_node
     name, input_nodes, attrs = get_inputs(node, kwargs)
@@ -4313,3 +4380,31 @@ def convert_rnn_param_concat(node, **kwargs):
     ]
 
     return nodes
+
+
+@mx_op.register('_contrib_div_sqrt_dim')
+def convert_contrib_div_sqrt_dim(node, **kwargs):
+    """Map MXNet's _contrib_div_sqrt_dim operator
+    """
+    from onnx.helper import make_node
+    name, input_nodes, _ = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    create_tensor([0], name+'_0', kwargs['initializer'])
+    create_tensor([1], name+'_1', kwargs['initializer'])
+    create_tensor([1], name+'_1_f', kwargs['initializer'], dtype=dtype)
+    nodes = [
+        make_node('Shape', [input_nodes[0]], [name+'_shape']),
+        make_node('Shape', [name+'_shape'], [name+'_dim']),
+        make_node('Sub', [name+'_dim', name+'_1'], [name+'_dim_m1']),
+        make_node('Slice', [name+'_shape', name+'_dim_m1', name+'_dim', name+'_0'], [name+'_c_']),
+        make_node('Cast', [name+'_c_'], [name+'_c'], to=dtype_t),
+        make_node('Sqrt', [name+'_c'], [name+'_c_sqrt']),
+        make_node('Div', [name+'_1_f', name+'_c_sqrt'], [name+'_1_over_c_sqrt']),
+        make_node('Mul', [input_nodes[0], name+'_1_over_c_sqrt'], [name])
+    ]
+
+    return nodes
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index e2a8329..ca4c0fe 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -723,7 +723,7 @@ def test_distilbert_inference_onnxruntime(tmp_path, model_name):
 @with_seed()
 @pytest.mark.parametrize('model_name', [('standard_lstm_lm_200', 200), ('standard_lstm_lm_650', 650),
                                         ('standard_lstm_lm_1500', 1500)])
-@pytest.mark.parametrize('seq_length', [16, 32])
+@pytest.mark.parametrize('seq_length', [64, 128])
 def test_standard_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length):
     try:
         import gluonnlp as nlp
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512
+        C_out = 512
+        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+        step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32')
+        src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        encoder_outputs, encoder_additional_outputs = model.encode(src,
+                                                                   valid_length=src_valid_length)
+
+        decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)
+
+        step_output, states, additional_outputs = model.decode_step(step_input, decoder_states)
+
+        # skip export of 'decoder' as it's used for training only
+        for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed',
+                         'tgt_proj']:
+
+            prefix = "%s/%s" %(tmp_path, component)
+            component = getattr(model, component)
+            component.export(prefix)
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+
+        def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+            return mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types,
+                                                onnx_file, **kwargs)
+
+        def onnx_runtime_predict(onnx_file, onnx_inputs):
+            ses_opt = onnxruntime.SessionOptions()
+            ses_opt.log_severity_level = 3
+            session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+            input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy())
+                            for i in range(len(onnx_inputs)))
+            return session.run(None, input_dict)
+
+        def verify_encoder():
+            inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32')
+            valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+            pred = model.encoder(inputs, valid_length=valid_length)
+
+            prefix = "%s/encoder" %tmp_path
+            input_shapes = [(batch, seq_length, C_in), (batch,)]
+            input_types = [np.float32, np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [inputs, valid_length]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred[0], pred_onx[0])
+
+        def verify_src_embed():
+            src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+            pred = model.src_embed(src)
+
+            prefix = "%s/src_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [src]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_embed():
+            tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+            pred = model.tgt_embed(tgt)
+
+            prefix = "%s/tgt_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [tgt]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_proj():
+            decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out),
+                                               dtype='float32')
+            pred = model.tgt_proj(decoder_out)
+
+            prefix = "%s/tgt_proj" %tmp_path
+            input_shapes = [(batch, seq_length, C_out)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [decoder_out]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)
+
+        def verify_one_step_ahead_decoder():
+            prefix = "%s/one_step_ahead_decoder" %tmp_path
+
+            # the input data order
+            perm = [2, 0, 1]
+            input_shapes = [(batch, seq_length, C_in), (batch, seq_length, C_out),
+                            (batch, seq_length)]
+            input_shapes = [input_shapes[i] for i in perm]
+            dynamic_input_shapes = [(batch, 'seq_length', C_in), (batch, 'seq_length', C_out),
+                                    (batch, 'seq_length')]
+            dynamic_input_shapes = [dynamic_input_shapes[i] for i in perm]
+            input_types = [np.float32, np.float32, np.float32]
+            # do a dynamic export
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True,
+                                       dynamic_input_shapes=dynamic_input_shapes)
+
+            # step 0
+            step_input = mx.nd.random.uniform(-1, 1, shape=(batch, C_in), dtype='float32')
+            # mxnet
+            pred, step_states, _ = model.one_step_ahead_decoder(step_input, decoder_states)
+            # onnx
+            # note that we need to expand the sequence axis just like in here:
+            # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L831
+            input_onx = mx.nd.expand_dims(step_input, axis=1)
+            onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]]
+            onnx_inputs = [onnx_inputs[i] for i in perm]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+            # step >= 1
+            for i in range(20):
+                step_input = mx.nd.random.uniform(-10*i, 10*i, shape=(batch, C_in), dtype='float32')
+                # mxnet
+                pred, step_states, _ = model.one_step_ahead_decoder(step_input, step_states)
+                # onnx
+                # note that we need to concat the step_input with the previous inpus
+                # just like in here:
+                # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L828
+                input_onx = mx.nd.concat(input_onx, mx.nd.expand_dims(step_input, axis=1), dim=1)
+                onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]]
+                onnx_inputs = [onnx_inputs[i] for i in perm]
+                pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+                assert_almost_equal(pred, pred_onx[0])
+
+        verify_encoder()
+        verify_src_embed()
+        verify_tgt_embed()
+        verify_tgt_proj()
+        verify_one_step_ahead_decoder()
+
+    finally:
+        shutil.rmtree(tmp_path)
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 36b687a..220f259 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -1234,3 +1234,29 @@ def test_onnx_export_RNN(tmp_path, dtype, state_size, input_size, num_layers, ba
     state = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
     cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
     op_export_test('rnn', M, [x, param, state, cell], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
+@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))])
+def test_onnx_export_broadcast_lesser_equal(tmp_path, dtype, shapes):
+    A = mx.nd.random.uniform(0, 5, shapes[0]).astype('int32').astype(dtype)
+    B = mx.nd.random.uniform(0, 5, shapes[1]).astype('int32').astype(dtype)
+    M = def_model('broadcast_lesser_equal')
+    op_export_test('broadcast_lesser_equal', M, [A, B], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
+@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))])
+def test_onnx_export_broadcast_greater_equal(tmp_path, dtype, shapes):
+    A = mx.nd.random.uniform(0, 5, shapes[0]).astype('int32').astype(dtype)
+    B = mx.nd.random.uniform(0, 5, shapes[1]).astype('int32').astype(dtype)
+    M = def_model('broadcast_greater_equal')
+    op_export_test('broadcast_greater_equal', M, [A, B], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
+@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)])
+def test_onnx_export_contrib_div_sqrt_dim(tmp_path, dtype, shape):
+    A = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
+    M = def_model('contrib.div_sqrt_dim')
+    op_export_test('contrib_div_sqrt_dim', M, [A], tmp_path)