You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/17 01:54:34 UTC
[incubator-mxnet] branch v1.x updated: add ernie onnx test (#20030)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 66c26a0 add ernie onnx test (#20030)
66c26a0 is described below
commit 66c26a0f7399c98854dffca47e23c7ca867d776e
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Tue Mar 16 18:51:07 2021 -0700
add ernie onnx test (#20030)
---
tests/python-pytest/onnx/test_onnxruntime.py | 58 ++++++++++++++++++++++++++++
1 file changed, 58 insertions(+)
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index bf32259..fa45bb6 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -868,3 +868,61 @@ def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model):
finally:
shutil.rmtree(tmp_path)
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['ernie_12_768_12'])
+def test_ernie_inference_onnxruntime(tmp_path, model_name):
+ tmp_path = str(tmp_path)
+ try:
+ import gluonnlp as nlp
+ dataset = 'baidu_ernie_uncased'
+ ctx = mx.cpu(0)
+ model, vocab = nlp.model.get_model(
+ name=model_name,
+ ctx=ctx,
+ dataset_name=dataset,
+ pretrained=True,
+ use_pooler=True,
+ use_decoder=False,
+ num_layers = 3,
+ hparam_allow_override = True,
+ use_classifier=False)
+
+ model.hybridize(static_alloc=True)
+
+ batch = 5
+ seq_length = 16
+ # create synthetic test data
+ inputs = mx.nd.random.uniform(0, 17964, shape=(batch, seq_length), dtype='float32')
+ token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
+ valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+ seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+ prefix = "%s/ernie" % tmp_path
+ model.export(prefix)
+ sym_file = "%s-symbol.json" % prefix
+ params_file = "%s-0000.params" % prefix
+ onnx_file = "%s.onnx" % prefix
+
+ input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
+ input_types = [np.float32, np.float32, np.float32]
+ converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes,
+ input_types, onnx_file)
+
+ # create onnxruntime session using the generated onnx file
+ ses_opt = onnxruntime.SessionOptions()
+ ses_opt.log_severity_level = 3
+ session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+
+ seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+ onnx_inputs = [inputs, token_types, valid_length]
+ input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
+ pred_onx, cls_onx = session.run(None, input_dict)
+
+ assert_almost_equal(seq_encoding, pred_onx)
+ assert_almost_equal(cls_encoding, cls_onx)
+
+ finally:
+ shutil.rmtree(tmp_path)