You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/12 04:01:21 UTC

[incubator-mxnet] branch v1.x updated: Onnx Dynamic Shapes Support (#20001)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 601e3b9  Onnx Dynamic Shapes Support (#20001)
601e3b9 is described below

commit 601e3b9576e718156e82c5f348d8a3409a53f0ad
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Thu Mar 11 19:58:50 2021 -0800

    Onnx Dynamic Shapes Support (#20001)
    
    * basic dynamic shape support
    
    * add shape inference
    
    * fix pylint and doc string
    
    * add cv dynamic shape test
    
    * redesign interface + fix slice operator + add bert test
    
    * add dynamic tests to ci
    
    * fix slice
    
    * revert adding test
    
    * fix lint
    
    * add tests to ci
    
    * fix api backward compatibility
---
 ci/docker/runtime_functions.sh                     |   2 +
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py |   7 +-
 python/mxnet/contrib/onnx/mx2onnx/export_model.py  |  59 +++++++---
 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py   | 127 ++++++++++++++-------
 tests/python-pytest/onnx/test_onnxruntime.py       | 111 +++++++++++++++++-
 5 files changed, 238 insertions(+), 68 deletions(-)

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index ceed020..a3484fc 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1272,6 +1272,8 @@ integrationtest_ubuntu_cpu_onnx() {
     pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_img_segmentation_model_inference_onnxruntime[deeplab_resnet50_citys]
     pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_pose_estimation_model_inference_onnxruntime[mobile_pose_mobilenet1.0]
     pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_action_recognition_model_inference_onnxruntime[inceptionv3_kinetics400]
+    pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_bert_inference_onnxruntime
+    pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_cv_inference_onnxruntime
 }
 
 integrationtest_ubuntu_gpu_python() {
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index c3aa9bc..f33e03e 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -2834,15 +2834,14 @@ def convert_slice(node, **kwargs):
         if ends[i] is None:
             ends[i] = 2**63-1
 
-    create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
-    create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
-    create_const_scalar_node(name+'_len_s', np.int64(len(starts)), kwargs)
+    axes = [i for i in range(len(starts))]
+
+    create_tensor(axes, name+'_axes', kwargs['initializer'])
     create_tensor(starts, name+'_starts', kwargs['initializer'])
     create_tensor(ends, name+'_ends', kwargs['initializer'])
     create_tensor(steps, name+'_steps', kwargs['initializer'])
 
     nodes = [
-        make_node('Range', [name+'_0_s', name+'_len_s', name+'_1_s'], [name+'_axes']),
         make_node("Slice", [input_nodes[0], name+'_starts', name+'_ends', name+'_axes',
                             name+'_steps'], [name], name=name)
     ]
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
index f4b0c90..60e6a34 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -28,8 +28,10 @@ from .export_onnx import MXNetGraph
 from ._export_helper import load_module
 
 
-def export_model(sym, params, input_shape, input_type=np.float32,
-                 onnx_file_path='model.onnx', verbose=False, opset_version=None):
+def export_model(sym, params, in_shapes=None, in_types=np.float32,
+                 onnx_file_path='model.onnx', verbose=False, opset_version=None,
+                 dynamic=False, dynamic_input_shapes=None, run_shape_inference=False, input_type=None,
+                 input_shape=None):
     """Exports the MXNet model file, passed as a parameter, into ONNX model.
     Accepts both symbol,parameter objects as well as json and params filepaths as input.
     Operator support and coverage -
@@ -41,14 +43,24 @@ def export_model(sym, params, input_shape, input_type=np.float32,
         Path to the json file or Symbol object
     params : str or symbol object
         Path to the params file or params dictionary. (Including both arg_params and aux_params)
-    input_shape : List of tuple
+    in_shapes : List of tuple
         Input shape of the model e.g [(1,3,224,224)]
-    input_type : data type or list of data types
+    in_types : data type or list of data types
         Input data type e.g. np.float32, or [np.float32, np.int32]
     onnx_file_path : str
         Path where to save the generated onnx file
     verbose : Boolean
-        If true will print logs of the model conversion
+        If True will print logs of the model conversion
+    dynamic: Boolean
+        If True will allow for dynamic input shapes to the model
+    dynamic_input_shapes: list of tuple
+        Specifies the dynamic input_shapes. If None then all dimensions are set to None
+    run_shape_inference : Boolean
+        If True will run shape inference on the model
+    input_type : data type or list of data types
+        This is the old name of in_types. We keep this parameter name for backward compatibility
+    in_shapes : List of tuple
+        This is the old name of in_shapes. We keep this parameter name for backward compatibility
 
     Returns
     -------
@@ -62,42 +74,57 @@ def export_model(sym, params, input_shape, input_type=np.float32,
     """
 
     try:
-        from onnx import helper, mapping
+        from onnx import helper, mapping, shape_inference
         from onnx.defs import onnx_opset_version
     except ImportError:
         raise ImportError("Onnx and protobuf need to be installed. "
                           + "Instructions to install - https://github.com/onnx/onnx")
 
+    if input_type is not None:
+        in_types = input_type
+
+    if input_shape is not None:
+        in_shapes = input_shape
+
     converter = MXNetGraph()
     if opset_version is None:
         # default is to use latest opset version the onnx package supports
         opset_version = onnx_opset_version()
 
-    if not isinstance(input_type, list):
-        input_type = [input_type for _ in range(len(input_shape))]
-    input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type]
+    if not isinstance(in_types, list):
+        in_types = [in_types for _ in range(len(in_shapes))]
+    in_types_t = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(i_t)] for i_t in in_types]
     # if input parameters are strings(file paths), load files and create symbol parameter objects
     if isinstance(sym, string_types) and isinstance(params, string_types):
         logging.info("Converting json and weight file to sym and params")
         sym_obj, params_obj = load_module(sym, params)
-        onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
-                                                       input_dtype,
-                                                       verbose=verbose, opset_version=opset_version)
+        onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, in_shapes,
+                                                       in_types_t,
+                                                       verbose=verbose, opset_version=opset_version,
+                                                       dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes)
     elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
-        onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
-                                                       input_dtype,
-                                                       verbose=verbose, opset_version=opset_version)
+        onnx_graph = converter.create_onnx_graph_proto(sym, params, in_shapes,
+                                                       in_types_t,
+                                                       verbose=verbose, opset_version=opset_version,
+                                                       dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes)
     else:
         raise ValueError("Input sym and params should either be files or objects")
 
     # Create the model (ModelProto)
     onnx_model = helper.make_model(onnx_graph)
 
+    # Run shape inference on the model. Due to ONNX bug/incompatibility this may or may not crash
+    if run_shape_inference:
+        try:
+            onnx_model = shape_inference.infer_shapes(onnx_model)
+        except: # pylint: disable=bare-except
+            logging.info("Shape inference failed, original export is kept.")
+
     # Save model on disk
     with open(onnx_file_path, "wb") as file_handle:
         serialized = onnx_model.SerializeToString()
         file_handle.write(serialized)
-        logging.info("Input shape of the model %s ", input_shape)
+        logging.info("Input shape of the model %s ", in_shapes)
         logging.info("Exported ONNX file %s saved to disk", onnx_file_path)
 
     return onnx_file_path
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 1790e4d..4cec698 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -122,29 +122,39 @@ class MXNetGraph(object):
         return arg_params, aux_params
 
     @staticmethod
-    def get_outputs(sym, params, in_shape, in_label, in_type):
-        """ Infer output shapes and return dictionary of output name to shape
+    def get_outputs(sym, params, in_shapes, output_label, in_types, dynamic=False,
+                    dynamic_input_shapes=None):
+        """Helper function to collect the output names, types, and shapes
 
-        :param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer shape on
-        :param dic of (str, nd.NDArray) params:
-        :param list of tuple(int, ...) in_shape: list of all input shapes
-        :param  in_label: name of label typically used in loss that may be left in graph. This name is
+        Parameters
+        ----------
+        sym : :class:`~mxnet.symbol.Symbol`
+            MXNet symbol object
+        params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
+        in_shapes : list of tuple
+            Input shapes
+        out_label : ``str``
+            Name of label typically used in loss that may be left in graph. This name is
             removed from list of inputs required by symbol
-        :return: dictionary of output name to shape
-        :rtype: dict of (str, tuple(int, ...))
+        in_types : list of Int
+            Input ONNX data types
+        dynamic : Boolean
+            If True will allow for dynamic input shapes to the model
+        dynamic_input_shapes: list of tuple
+            Specifies the dynamic input_shapes. If None then all dimensions are set to None
+
+        Returns
+        in_shapes : list of tuple
+            Updated input shapes
+        graph_outputs : dict ``str`` to dict
+            This maps output name to {'shape':tuple, 'dtype':Int}
+        -------
         """
         from onnx import mapping
         import re
-        # remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
-        # by user. Also remove in_label, which is the name of the label symbol that may have been used
-        # as the label for loss during training.
-        inputs = {n: tuple(s) for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label],
-                                              in_shape)}
-        # Add params and their shape to list of inputs
-        inputs.update({n: v.shape for n, v in params.items() if n in sym.list_inputs()})
-        # Provide input data as well as input params to infer_shape()
-        _, out_shapes, _ = sym.infer_shape(**inputs)
 
+        # Collect graph output names
         out_names = list()
         for name in sym.list_outputs():
             if name.endswith('_state_output'): # handel special cases for RNN operator
@@ -159,24 +169,53 @@ class MXNetGraph(object):
                 logging.info("output '%s' does not end with '_output'", name)
                 out_names.append(name)
 
-        assert len(out_shapes) == len(out_names)
+        # Collect graph output shapes
+        # Remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
+        # by user. Also remove output_label, which is the name of the label symbol that may have been used
+        # as the label for loss during training.
+        inputs = {n: tuple(s) for n, s in
+                  zip([n for n in sym.list_inputs() if n not in params and n != output_label],
+                      in_shapes)}
+        # Add params and their shape to list of inputs
+        inputs.update({n: v.shape for n, v in params.items() if n in sym.list_inputs()})
+        # Provide input data as well as input params to infer_shape()
+        _, out_shapes, _ = sym.infer_shape(**inputs)
+        if dynamic:
+            # Keep the dimensionality of the output shapes but change the values to None
+            out_shapes = [tuple(None for _ in i_s) for i_s in out_shapes]
 
-        ## Infer output types
+            if dynamic_input_shapes is None:
+                # Set all dimensions to None
+                in_shapes = [tuple(None for _ in i_s) for i_s in in_shapes]
+            else:
+                assert len(in_shapes) == len(dynamic_input_shapes), "The length of " \
+                    "dynamic_input_shapes must equal to the length of in_shapes."
+                for i_s, d_i_s in zip(in_shapes, dynamic_input_shapes):
+                    assert len(i_s) == len(d_i_s), "The dimensionality " \
+                        "of each shape must match."
+                in_shapes = dynamic_input_shapes
+        else:
+            assert dynamic_input_shapes is None, "dynamic_input_shapes is specified. Please " \
+                "set dynamic_input_shapes=True to enable dynamic input shapes"
+
+        # Collect graph output types
         # Remove any input listed in params from sym.list_inputs() and bind them to the input types provided
-        # by user. Also remove in_label
-        in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t]
-                    for n, t in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_type)}
+        # by user. Also remove output_label
+        in_dtypes = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t] for n, t in
+                     zip([n for n in sym.list_inputs() if n not in params and n != output_label],
+                         in_types)}
         # Add params and their types to list of inputs
-        in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()})
-        _, out_type, _ = sym.infer_type(**in_dtype)
+        in_dtypes.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()})
+        _, out_type, _ = sym.infer_type(**in_dtypes)
         out_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[o(0).dtype] for o in out_type]
 
-        assert len(out_types) == len(out_names)
+        # Make sure the types, names, and shapes all align up
+        assert len(out_types) == len(out_names) == len(out_shapes)
 
-        # bind output shapes/types with output names
+        # Bind output shapes/types with output names
         graph_outputs = {n: {'shape': s, 'dtype': d} for n, s, d in zip(out_names, out_shapes, out_types)}
 
-        return graph_outputs
+        return in_shapes, graph_outputs
 
     @staticmethod
     def convert_weights_to_numpy(weights_dict):
@@ -184,7 +223,8 @@ class MXNetGraph(object):
         return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
                      for k, v in weights_dict.items()])
 
-    def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, opset_version=None):
+    def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=False, opset_version=None,
+                                dynamic=True, dynamic_input_shapes=None):
         """Convert MXNet graph to ONNX graph
 
         Parameters
@@ -193,14 +233,18 @@ class MXNetGraph(object):
             MXNet symbol object
         params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
             Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
-        in_shape : List of tuple
+        in_shapes : List of tuple
             Input shape of the model e.g [(1,3,224,224)]
-        in_type : data type
-            Input data type e.g. np.float32
+        in_types : List of Int
+            Input ONNX data types
         verbose : Boolean
             If true will print logs of the model conversion
         opset_version : Int
             ONNX opset version to use for export, defaults to latest supported by onnx package
+        dynamic: Boolean
+            If True will allow for dynamic input shapes to the model
+        dynamic_input_shapes: list of tuple
+            Specifies the dynamic input_shapes. If None then all dimensions are set to None
 
         Returns
         -------
@@ -241,9 +285,9 @@ class MXNetGraph(object):
         onnx_processed_outputs = []
         outputs_lookup = []
 
-        # Determine output shape
-        graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label, in_type)
-
+        # Determine graph output names, shapes, and dtypes. Also update in_shapes
+        in_shapes, graph_outputs = MXNetGraph.get_outputs(sym, params, in_shapes, output_label,
+                                                          in_types, dynamic, dynamic_input_shapes)
         appeared_names = set()
         graph_input_idx = 0
         for idx, node in enumerate(mx_graph):
@@ -260,26 +304,26 @@ class MXNetGraph(object):
             # A node is an input node if its op_name is "null" and is not
             # in params dict
             if op == "null" and name not in params:
-                # Handling graph input
+                # Handle graph input
 
-                # Skipping output_label node, as this node is not part of graph
-                # Refer "output_label" assignment above for more details.
+                # Skip output_label node, as this node is not part of graph
+                # Refer to "output_label" assignment above for more details.
                 if name == output_label:
                     continue
+
                 converted, dtypes = MXNetGraph.convert_layer(
                     node,
                     is_input=True,
                     mx_graph=mx_graph,
                     weights=weights,
-                    in_shape=in_shape[graph_input_idx],
-                    in_type=in_type[graph_input_idx],
+                    in_shape=in_shapes[graph_input_idx],
+                    in_type=in_types[graph_input_idx],
                     proc_nodes=all_processed_nodes,
                     initializer=initializer,
                     outputs_lookup=outputs_lookup)
                 graph_input_idx += 1
-
             else:
-                # Handling graph layers
+                # Handle graph layers
                 converted, dtypes = MXNetGraph.convert_layer(
                     node,
                     is_input=False,
@@ -291,7 +335,6 @@ class MXNetGraph(object):
                     idx=idx,
                     opset_version=opset_version
                 )
-
             if isinstance(converted, list):
                 # Collect all the node's output names
                 node_possible_names = [name] + [name + str(i) for i in range(10)]
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index 3bcf27f..2fba092 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -55,6 +55,13 @@ class GluonModel():
                                      [self.input_shape], self.input_dtype, onnx_file)
         return onnx_file
 
+    def export_onnx_dynamic(self, dynamic_input_shapes):
+        onnx_file = self.modelpath + ".onnx"
+        mx.contrib.onnx.export_model(self.modelpath + "-symbol.json", self.modelpath + "-0000.params",
+                                     [self.input_shape], self.input_dtype, onnx_file, dynamic=True,
+                                     dynamic_input_shapes=dynamic_input_shapes)
+        return onnx_file
+
     def predict(self, data):
         return self.model(data)
 
@@ -584,11 +591,7 @@ def test_roberta_inference_onnxruntime(tmp_path, model_name):
 
         sequence_outputs, attention_outputs= model(inputs, valid_length, masked_positions)    
 
-        model_dir = f'roberta_model'
-        if not os.path.isdir(model_dir):
-            os.mkdir(model_dir)
-
-        prefix = '%s/%s' % (model_dir, model_name)
+        prefix = "%s/roberta" % tmp_path
         model.export(prefix)
 
         sym_file = "%s-symbol.json" % prefix
@@ -608,7 +611,7 @@ def test_roberta_inference_onnxruntime(tmp_path, model_name):
         pred = sess.run(None, input_dict)
 
         assert_almost_equal(sequence_outputs, pred[0])
-        assert_almost_equal(attension_outputs, pred[1])
+        assert_almost_equal(attention_outputs, pred[1])
 
     finally:
         shutil.rmtree(tmp_path)
@@ -716,3 +719,99 @@ def test_distilbert_inference_onnxruntime(tmp_path, model_name):
     finally:
         shutil.rmtree(tmp_path)
 
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['mobilenet1.0', 'inceptionv3', 'darknet53', 'resnest14'])
+def test_dynamic_shape_cv_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        M = GluonModel(model_name, (1, 3, 512, 512), 'float32', tmp_path)
+        dynamic_input_shapes = [(None, 3, 512, 512)]
+        onnx_file = M.export_onnx_dynamic(dynamic_input_shapes)
+
+        # create onnxruntime session using the generated onnx file
+        ses_opt = onnxruntime.SessionOptions()
+        ses_opt.log_severity_level = 3
+        sess = onnxruntime.InferenceSession(onnx_file, ses_opt)
+
+        # test on a different batch size
+        x = mx.random.uniform(0, 10, (5, 3, 512, 512))
+        in_tensors = [x]
+        input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
+        pred_on = sess.run(None, input_dict)
+
+        pred_mx = M.predict(x)
+
+        assert_almost_equal(pred_mx, pred_on[0])
+
+    finally:
+        shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model', ['bert_12_768_12'])
+def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'book_corpus_wiki_en_uncased'
+        ctx = mx.cpu(0)
+        model, vocab = nlp.model.get_model(
+            name=model,
+            ctx=ctx,
+            dataset_name=dataset,
+            pretrained=True,
+            use_pooler=True,
+            use_decoder=False,
+            num_layers = 3,
+            hparam_allow_override = True,
+            use_classifier=False)
+
+        model.hybridize(static_alloc=True)
+
+        batch = 5
+        seq_length = 16
+        # create synthetic test data
+        inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
+        token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
+        valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+        prefix = "%s/bert" % tmp_path
+        model.export(prefix)
+        sym_file = "%s-symbol.json" % prefix
+        params_file = "%s-0000.params" % prefix
+        onnx_file = "%s.onnx" % prefix
+
+        dynamic_input_shapes = [(None, seq_length), (None, seq_length), (None,)]
+        input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
+        input_types = [np.float32, np.float32, np.float32]
+        converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes,
+                                                            input_types, onnx_file,
+                                                            dynamic=True,
+                                                            dynamic_input_shapes=dynamic_input_shapes)
+
+        # create onnxruntime session using the generated onnx file
+        ses_opt = onnxruntime.SessionOptions()
+        ses_opt.log_severity_level = 3
+        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+
+        # test on a different batch size
+        batch = 7
+        seq_length = 16
+        inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
+        token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
+        valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+        onnx_inputs = [inputs, token_types, valid_length]
+        input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
+        pred_onx, cls_onx = session.run(None, input_dict)
+
+        assert_almost_equal(seq_encoding, pred_onx)
+        assert_almost_equal(cls_encoding, cls_onx)
+
+    finally:
+        shutil.rmtree(tmp_path)