You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/03/19 00:59:30 UTC

[GitHub] [incubator-mxnet] Zha0q1 opened a new pull request #20048: [wip][v1.x] Onnx Support for Transformer

Zha0q1 opened a new pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048


   TODO:
   - add description
   - add operator unit tests


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] waytrue17 commented on pull request #20048: [v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
waytrue17 commented on pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#issuecomment-804319822


   LGTM, thanks


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #20048: [v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#discussion_r598365615



##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512
+        C_out = 512
+        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+        step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32')
+        src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        encoder_outputs, encoder_additional_outputs = model.encode(src,
+                                                                   valid_length=src_valid_length)
+
+        decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)
+
+        step_output, states, additional_outputs = model.decode_step(step_input, decoder_states)
+
+        # skip export of 'decoder' as it's used for training only
+        for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed',
+                         'tgt_proj']:
+
+            prefix = "%s/%s" %(tmp_path, component)
+            component = getattr(model, component)
+            component.export(prefix)
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+
+        def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+            return mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types,
+                                                onnx_file, **kwargs)
+
+        def onnx_runtime_predict(onnx_file, onnx_inputs):
+            ses_opt = onnxruntime.SessionOptions()
+            ses_opt.log_severity_level = 3
+            session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+            input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy())
+                            for i in range(len(onnx_inputs)))
+            return session.run(None, input_dict)
+
+        def verify_encoder():
+            inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32')
+            valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+            pred = model.encoder(inputs, valid_length=valid_length)
+
+            prefix = "%s/encoder" %tmp_path
+            input_shapes = [(batch, seq_length, C_in), (batch,)]
+            input_types = [np.float32, np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [inputs, valid_length]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred[0], pred_onx[0])
+
+        def verify_src_embed():
+            src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+            pred = model.src_embed(src)
+
+            prefix = "%s/src_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [src]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_embed():
+            tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+            pred = model.tgt_embed(tgt)
+
+            prefix = "%s/tgt_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [tgt]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_proj():
+            decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out),
+                                               dtype='float32')
+            pred = model.tgt_proj(decoder_out)
+
+            prefix = "%s/tgt_proj" %tmp_path
+            input_shapes = [(batch, seq_length, C_out)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [decoder_out]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)
+
+        def verify_one_step_ahead_decoder():
+            prefix = "%s/one_step_ahead_decoder" %tmp_path
+
+            # the input data order
+            perm = [2, 0, 1]

Review comment:
       I used a perm list so that the actual in_shapes an in_types list can have the same order as passed in the native model. It's just the converted onnx takes them in a different order some how. I think this is more consistent, what do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #20048: [wip][v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#issuecomment-802427953


   Hey @Zha0q1 , Thanks for submitting the PR 
   All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands: 
   - To trigger all jobs: @mxnet-bot run ci [all] 
   - To trigger specific jobs: @mxnet-bot run ci [job1, job2] 
   *** 
   **CI supported jobs**: [centos-gpu, unix-gpu, miscellaneous, edge, website, windows-cpu, unix-cpu, windows-gpu, sanity, centos-cpu, clang]
   *** 
   _Note_: 
    Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin. 
   All CI tests must pass before the PR can be merged. 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #20048: [v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#discussion_r598365388



##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512

Review comment:
       You can refer to this file https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py for C_in and C_out. Those are defined in the pretrained model thus we need to set it the same as in the pretrained model 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] waytrue17 commented on a change in pull request #20048: [v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
waytrue17 commented on a change in pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#discussion_r598031213



##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512

Review comment:
       What are C_in and C_out? Should we also test when `C_in != C_out`?

##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512
+        C_out = 512
+        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')

Review comment:
       Curious, does not `src` need to be int type?

##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512
+        C_out = 512
+        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+        step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32')
+        src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        encoder_outputs, encoder_additional_outputs = model.encode(src,
+                                                                   valid_length=src_valid_length)
+
+        decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)
+
+        step_output, states, additional_outputs = model.decode_step(step_input, decoder_states)
+
+        # skip export of 'decoder' as it's used for training only
+        for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed',
+                         'tgt_proj']:
+
+            prefix = "%s/%s" %(tmp_path, component)
+            component = getattr(model, component)
+            component.export(prefix)
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+
+        def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+            return mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types,
+                                                onnx_file, **kwargs)
+
+        def onnx_runtime_predict(onnx_file, onnx_inputs):
+            ses_opt = onnxruntime.SessionOptions()
+            ses_opt.log_severity_level = 3
+            session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+            input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy())
+                            for i in range(len(onnx_inputs)))
+            return session.run(None, input_dict)
+
+        def verify_encoder():
+            inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32')
+            valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+            pred = model.encoder(inputs, valid_length=valid_length)
+
+            prefix = "%s/encoder" %tmp_path
+            input_shapes = [(batch, seq_length, C_in), (batch,)]
+            input_types = [np.float32, np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [inputs, valid_length]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred[0], pred_onx[0])
+
+        def verify_src_embed():
+            src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+            pred = model.src_embed(src)
+
+            prefix = "%s/src_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [src]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_embed():
+            tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+            pred = model.tgt_embed(tgt)
+
+            prefix = "%s/tgt_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [tgt]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_proj():
+            decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out),
+                                               dtype='float32')
+            pred = model.tgt_proj(decoder_out)
+
+            prefix = "%s/tgt_proj" %tmp_path
+            input_shapes = [(batch, seq_length, C_out)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [decoder_out]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)
+
+        def verify_one_step_ahead_decoder():
+            prefix = "%s/one_step_ahead_decoder" %tmp_path
+
+            # the input data order
+            perm = [2, 0, 1]

Review comment:
       Could we put the correct order when instantiating the list instead of using perm?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 merged pull request #20048: [v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
Zha0q1 merged pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #20048: [wip][v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#discussion_r597887392



##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512
+        C_out = 512
+        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+        step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32')
+        src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        encoder_outputs, encoder_additional_outputs = model.encode(src,
+                                                                   valid_length=src_valid_length)
+
+        decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)
+
+        step_output, states, additional_outputs = model.decode_step(step_input, decoder_states)
+
+        # skip export of 'decoder' as it's used for training only
+        for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed',
+                         'tgt_proj']:
+
+            prefix = "%s/%s" %(tmp_path, component)
+            component = getattr(model, component)
+            component.export(prefix)
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+
+        def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+            return mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types,
+                                                onnx_file, **kwargs)
+
+        def onnx_runtime_predict(onnx_file, onnx_inputs):
+            ses_opt = onnxruntime.SessionOptions()
+            ses_opt.log_severity_level = 3
+            session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+            input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy())
+                            for i in range(len(onnx_inputs)))
+            return session.run(None, input_dict)
+
+        def verify_encoder():
+            inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32')
+            valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+            pred = model.encoder(inputs, valid_length=valid_length)
+
+            prefix = "%s/encoder" %tmp_path
+            input_shapes = [(batch, seq_length, C_in), (batch,)]
+            input_types = [np.float32, np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [inputs, valid_length]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred[0], pred_onx[0])
+
+        def verify_src_embed():
+            src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+            pred = model.src_embed(src)
+
+            prefix = "%s/src_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [src]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_embed():
+            tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+            pred = model.tgt_embed(tgt)
+
+            prefix = "%s/tgt_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [tgt]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_proj():
+            decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out),
+                                               dtype='float32')
+            pred = model.tgt_proj(decoder_out)
+
+            prefix = "%s/tgt_proj" %tmp_path
+            input_shapes = [(batch, seq_length, C_out)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [decoder_out]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)
+
+        def verify_one_step_ahead_decoder():

Review comment:
       @sxjscience would you help take a quick look at this func thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #20048: [wip][v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#discussion_r597339143



##########
File path: python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
##########
@@ -2064,14 +2064,80 @@ def convert_broadcast_lesser(node, **kwargs):
     """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator
     and return the created node.
     """
-    return create_basic_op_node('Less', node, kwargs)
+    from onnx.helper import make_node, make_tensor
+    name, input_nodes, _ = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    nodes = [
+        make_node('Less', [input_nodes[0], input_nodes[1]], [name+'_lt']),
+        make_node('Cast', [name+'_lt'], [name], to=dtype_t)
+    ]
+
+    return nodes
+
+
+@mx_op.register("broadcast_lesser_equal")
+def convert_broadcast_lesser(node, **kwargs):
+    """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator
+    and return the created node.
+    """
+    from onnx.helper import make_node, make_tensor
+    name, input_nodes, _ = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    nodes = [
+        make_node('LessOrEqual', [input_nodes[0], input_nodes[1]], [name+'_lt']),
+        make_node('Cast', [name+'_lt'], [name], to=dtype_t)
+    ]
+
+    return nodes
+
+
+@mx_op.register("broadcast_greater_equal")
+def convert_broadcast_lesser(node, **kwargs):
+    """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator

Review comment:
       TODO: fix doc




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #20048: [v1.x] Onnx Support for Transformer

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#discussion_r598365876



##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512
+        C_out = 512
+        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')

Review comment:
       No it's float in the original mxnet model too. This should not matter I think because the operator will apply ceiling/flooring




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org