You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/17 01:54:34 UTC

[incubator-mxnet] branch v1.x updated: add ernie onnx test (#20030)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 66c26a0  add ernie onnx test (#20030)
66c26a0 is described below

commit 66c26a0f7399c98854dffca47e23c7ca867d776e
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Tue Mar 16 18:51:07 2021 -0700

    add ernie onnx test (#20030)
---
 tests/python-pytest/onnx/test_onnxruntime.py | 58 ++++++++++++++++++++++++++++
 1 file changed, 58 insertions(+)

diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index bf32259..fa45bb6 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -868,3 +868,61 @@ def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model):
     finally:
         shutil.rmtree(tmp_path)
 
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['ernie_12_768_12'])
+def test_ernie_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'baidu_ernie_uncased'
+        ctx = mx.cpu(0)
+        model, vocab = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            dataset_name=dataset,
+            pretrained=True,
+            use_pooler=True,
+            use_decoder=False,
+            num_layers = 3,
+            hparam_allow_override = True,
+            use_classifier=False)
+
+        model.hybridize(static_alloc=True)
+
+        batch = 5
+        seq_length = 16
+        # create synthetic test data
+        inputs = mx.nd.random.uniform(0, 17964, shape=(batch, seq_length), dtype='float32')
+        token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
+        valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+        prefix = "%s/ernie" % tmp_path
+        model.export(prefix)
+        sym_file = "%s-symbol.json" % prefix
+        params_file = "%s-0000.params" % prefix
+        onnx_file = "%s.onnx" % prefix
+
+        input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
+        input_types = [np.float32, np.float32, np.float32]
+        converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes,
+                                                            input_types, onnx_file)
+
+        # create onnxruntime session using the generated onnx file
+        ses_opt = onnxruntime.SessionOptions()
+        ses_opt.log_severity_level = 3
+        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+
+        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+        onnx_inputs = [inputs, token_types, valid_length]
+        input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
+        pred_onx, cls_onx = session.run(None, input_dict)
+
+        assert_almost_equal(seq_encoding, pred_onx)
+        assert_almost_equal(cls_encoding, cls_onx)
+
+    finally:
+        shutil.rmtree(tmp_path)