You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/07 07:10:20 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export support for multiple input data types (#19796)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new b14aae2  [v1.x] ONNX export support for multiple input data types (#19796)
b14aae2 is described below

commit b14aae2822f08f0267eec38cb81adccce316bf0a
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Sat Mar 6 23:07:46 2021 -0800

    [v1.x] ONNX export support for multiple input data types (#19796)
    
    * add test
    
    * support multiple input nodes
    
    * fix sanity
    
    * update input dtype
    
    * fix typo
    
    * update export_onnx
    
    * fix sanity
    
    * fix space
    
    * update import
    
    * fix sanity
    
    * remove float64 from test_where
    
    * update test
    
    * fix bert test input type
    
    * enable defalut input_type
    
    * more default fix
    
    * fix typo
    
    * fix empty lines
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 python/mxnet/contrib/onnx/mx2onnx/export_model.py | 12 ++---
 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py  | 22 ++++++---
 python/mxnet/contrib/onnx/onnx2mx/import_onnx.py  |  3 +-
 tests/python-pytest/onnx/backend.py               |  7 +--
 tests/python-pytest/onnx/mxnet_export_test.py     |  1 +
 tests/python-pytest/onnx/test_node.py             |  6 +--
 tests/python-pytest/onnx/test_onnxruntime.py      | 55 +++++++++++++++++++++++
 tests/python-pytest/onnx/test_operators.py        |  7 ++-
 8 files changed, 91 insertions(+), 22 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
index 2fc7760..f4b0c90 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -43,8 +43,8 @@ def export_model(sym, params, input_shape, input_type=np.float32,
         Path to the params file or params dictionary. (Including both arg_params and aux_params)
     input_shape : List of tuple
         Input shape of the model e.g [(1,3,224,224)]
-    input_type : data type
-        Input data type e.g. np.float32
+    input_type : data type or list of data types
+        Input data type e.g. np.float32, or [np.float32, np.int32]
     onnx_file_path : str
         Path where to save the generated onnx file
     verbose : Boolean
@@ -73,17 +73,19 @@ def export_model(sym, params, input_shape, input_type=np.float32,
         # default is to use latest opset version the onnx package supports
         opset_version = onnx_opset_version()
 
-    data_format = np.dtype(input_type)
+    if not isinstance(input_type, list):
+        input_type = [input_type for _ in range(len(input_shape))]
+    input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type]
     # if input parameters are strings(file paths), load files and create symbol parameter objects
     if isinstance(sym, string_types) and isinstance(params, string_types):
         logging.info("Converting json and weight file to sym and params")
         sym_obj, params_obj = load_module(sym, params)
         onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
-                                                       mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+                                                       input_dtype,
                                                        verbose=verbose, opset_version=opset_version)
     elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
         onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
-                                                       mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+                                                       input_dtype,
                                                        verbose=verbose, opset_version=opset_version)
     else:
         raise ValueError("Input sym and params should either be files or objects")
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 89f061d..59f9e20 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -155,14 +155,19 @@ class MXNetGraph(object):
 
         assert len(out_shapes) == len(out_names)
 
-        # infer output types
-        args = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[in_type] for n in sym.list_inputs()}
-        _, out_type, _ = sym.infer_type(**args)
+        ## Infer output types
+        # Remove any input listed in params from sym.list_inputs() and bind them to the input types provided
+        # by user. Also remove in_label
+        in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t]
+                    for n, t in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_type)}
+        # Add params and their types to list of inputs
+        in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()})
+        _, out_type, _ = sym.infer_type(**in_dtype)
         out_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[o(0).dtype] for o in out_type]
 
         assert len(out_types) == len(out_names)
 
-        # bind output shapes with output names
+        # bind output shapes/types with output names
         graph_outputs = {n: {'shape': s, 'dtype': d} for n, s, d in zip(out_names, out_shapes, out_types)}
 
         return graph_outputs
@@ -256,13 +261,18 @@ class MXNetGraph(object):
                     mx_graph=mx_graph,
                     weights=weights,
                     in_shape=in_shape[graph_input_idx],
-                    in_type=in_type,
+                    in_type=in_type[graph_input_idx],
                     proc_nodes=all_processed_nodes,
                     initializer=initializer,
                     outputs_lookup=outputs_lookup)
                 graph_input_idx += 1
 
             else:
+                # Handle no input case
+                intype = 1  # Float32 in tensor type
+                if len(in_type) > 0:
+                    intype = in_type[0]
+
                 # Handling graph layers
                 converted = MXNetGraph.convert_layer(
                     node,
@@ -270,7 +280,7 @@ class MXNetGraph(object):
                     mx_graph=mx_graph,
                     weights=weights,
                     in_shape=in_shape,
-                    in_type=in_type,
+                    in_type=intype,
                     proc_nodes=all_processed_nodes,
                     initializer=initializer,
                     outputs_lookup=outputs_lookup,
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
index c2be83d..d51c51c 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
@@ -147,7 +147,8 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
         for graph_input in graph.input:
             if graph_input.name not in _params:
                 shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim]
-                input_data.append((graph_input.name, tuple(shape)))
+                dtype = graph_input.type.tensor_type.elem_type
+                input_data.append((graph_input.name, tuple(shape), dtype))
 
         output_data = []
         for graph_out in graph.output:
diff --git a/tests/python-pytest/onnx/backend.py b/tests/python-pytest/onnx/backend.py
index eb803f7..6d8b1af 100644
--- a/tests/python-pytest/onnx/backend.py
+++ b/tests/python-pytest/onnx/backend.py
@@ -50,7 +50,7 @@ class MXNetBackend(Backend):
         cls.operation = operation
 
     @staticmethod
-    def perform_import_export(sym, arg_params, aux_params, input_shape):
+    def perform_import_export(sym, arg_params, aux_params, input_shape, input_dtype):
         """ Import ONNX model to mxnet model and then export to ONNX model
             and then import it back to mxnet for verifying the result"""
         graph = GraphProto()
@@ -63,7 +63,7 @@ class MXNetBackend(Backend):
         # exporting to onnx graph proto format
         converter = MXNetGraph()
         graph_proto = converter.create_onnx_graph_proto(sym, params, in_shape=input_shape,
-                                                        in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')],
+                                                        in_type=input_dtype,
                                                         opset_version=opset_version)
 
         # importing back to MXNET for verifying result.
@@ -108,8 +108,9 @@ class MXNetBackend(Backend):
                 metadata = graph.get_graph_metadata(model.graph)
                 input_data = metadata['input_tensor_data']
                 input_shape = [data[1] for data in input_data]
+                input_dtype = [data[2] for data in input_data]
                 sym, arg_params, aux_params = MXNetBackend.perform_import_export(sym, arg_params, aux_params,
-                                                                                 input_shape)
+                                                                                 input_shape, input_dtype)
 
             return MXNetBackendRep(sym, arg_params, aux_params, device)
         elif backend == 'gluon':
diff --git a/tests/python-pytest/onnx/mxnet_export_test.py b/tests/python-pytest/onnx/mxnet_export_test.py
index 947fa2f..82e628b 100644
--- a/tests/python-pytest/onnx/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/mxnet_export_test.py
@@ -62,6 +62,7 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params=
             sym=net_sym,
             params=net_params,
             input_shape=[shape_type(data.shape)],
+            input_type=[data.dtype],
             onnx_file_path=onnx_file_path)
         assert export_path == onnx_file_path
         # Try importing the model to symbol
diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py
index 0b7fd9c..686ea4c 100644
--- a/tests/python-pytest/onnx/test_node.py
+++ b/tests/python-pytest/onnx/test_node.py
@@ -154,7 +154,7 @@ class TestNode(unittest.TestCase):
 
                 if mxnet_specific:
                     onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs],
-                                                            np.float32,
+                                                            [ip.dtype for ip in inputs],
                                                             onnx_name + ".onnx")
                     onnxmodel = load_model(onnxmodelfile)
                 else:
@@ -190,9 +190,9 @@ class TestNode(unittest.TestCase):
                                                       onnx_file_path=outsym.name + ".onnx")
 
             sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
-        result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
+            result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
 
-        npt.assert_almost_equal(result, forward_op)
+            npt.assert_almost_equal(result, forward_op)
 
     def test_imports(self):
         for test in import_test_cases:
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index 86e19fa..48a8a43 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -555,3 +555,58 @@ def test_action_recognition_model_inference_onnxruntime(tmp_path, model, act_rec
     finally:
         shutil.rmtree(tmp_path)
 
+@with_seed()
+@pytest.mark.parametrize('model', ['bert_12_768_12'])
+def test_bert_inference_onnxruntime(tmp_path, model):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'book_corpus_wiki_en_uncased'
+        ctx = mx.cpu(0)
+        model, vocab = nlp.model.get_model(
+            name=model,
+            ctx=ctx,
+            dataset_name=dataset,
+            pretrained=False,
+            use_pooler=True,
+            use_decoder=False,
+            use_classifier=False)
+        model.initialize(ctx=ctx)
+        model.hybridize(static_alloc=True)
+
+        batch = 5
+        seq_length = 16
+        # create synthetic test data
+        inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
+        token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
+        valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+        prefix = "%s/bert" % tmp_path
+        model.export(prefix)
+        sym_file = "%s-symbol.json" % prefix
+        params_file = "%s-0000.params" % prefix
+        onnx_file = "%s.onnx" % prefix
+
+
+        input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
+        input_types = [np.float32, np.float32, np.float32]
+        converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types, onnx_file)
+
+
+        # create onnxruntime session using the generated onnx file
+        ses_opt = onnxruntime.SessionOptions()
+        ses_opt.log_severity_level = 3
+        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+        onnx_inputs = [inputs, token_types, valid_length]
+        input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
+        pred_onx, cls_onx = session.run(None, input_dict)
+
+        assert_almost_equal(seq_encoding, pred_onx)
+        assert_almost_equal(cls_encoding, cls_onx)
+
+    finally:
+        shutil.rmtree(tmp_path)
+
+
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 8ebfbab..3bf3cf6 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -45,16 +45,15 @@ def op_export_test(model_name, Model, inputs, tmp_path, dummy_input=False, onnx_
         model.export(model_path, epoch=0)
         sym_file = '{}-symbol.json'.format(model_path)
         params_file = '{}-0000.params'.format(model_path)
-        dtype = inputs[0].dtype
         onnx_file = '{}/{}.onnx'.format(tmp_path, model_name)
         mx.contrib.onnx.export_model(sym_file, params_file, [inp.shape for inp in inputs],
-                                     dtype, onnx_file)
+                                     [inp.dtype for inp in inputs], onnx_file)
         return onnx_file
 
     def onnx_rt(onnx_file, inputs):
         sess = rt.InferenceSession(onnx_file)
         dtype_0 = inputs[0].asnumpy().dtype
-        input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs)))
+        input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
         pred = sess.run(None, input_dict)
         return pred
 
@@ -560,7 +559,7 @@ def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
     op_export_test('_internal._equal_scalar', M, [x], tmp_path)
 
 
-@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
+@pytest.mark.parametrize("dtype", ["float16", "float32", "int32", "int64"])
 @pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
 def test_onnx_export_where(tmp_path, dtype, shape):
     M = def_model('where')