You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/12 04:01:21 UTC
[incubator-mxnet] branch v1.x updated: Onnx Dynamic Shapes Support
(#20001)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 601e3b9 Onnx Dynamic Shapes Support (#20001)
601e3b9 is described below
commit 601e3b9576e718156e82c5f348d8a3409a53f0ad
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Thu Mar 11 19:58:50 2021 -0800
Onnx Dynamic Shapes Support (#20001)
* basic dynamic shape support
* add shape inference
* fix pylint and doc string
* add cv dynamic shape test
* redesign interface + fix slice operator + add bert test
* add dynamic tests to ci
* fix slice
* revert adding test
* fix lint
* add tests to ci
* fix api backward compatibility
---
ci/docker/runtime_functions.sh | 2 +
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 7 +-
python/mxnet/contrib/onnx/mx2onnx/export_model.py | 59 +++++++---
python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 127 ++++++++++++++-------
tests/python-pytest/onnx/test_onnxruntime.py | 111 +++++++++++++++++-
5 files changed, 238 insertions(+), 68 deletions(-)
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index ceed020..a3484fc 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1272,6 +1272,8 @@ integrationtest_ubuntu_cpu_onnx() {
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_img_segmentation_model_inference_onnxruntime[deeplab_resnet50_citys]
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_pose_estimation_model_inference_onnxruntime[mobile_pose_mobilenet1.0]
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_action_recognition_model_inference_onnxruntime[inceptionv3_kinetics400]
+ pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_bert_inference_onnxruntime
+ pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_cv_inference_onnxruntime
}
integrationtest_ubuntu_gpu_python() {
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index c3aa9bc..f33e03e 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -2834,15 +2834,14 @@ def convert_slice(node, **kwargs):
if ends[i] is None:
ends[i] = 2**63-1
- create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
- create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
- create_const_scalar_node(name+'_len_s', np.int64(len(starts)), kwargs)
+ axes = [i for i in range(len(starts))]
+
+ create_tensor(axes, name+'_axes', kwargs['initializer'])
create_tensor(starts, name+'_starts', kwargs['initializer'])
create_tensor(ends, name+'_ends', kwargs['initializer'])
create_tensor(steps, name+'_steps', kwargs['initializer'])
nodes = [
- make_node('Range', [name+'_0_s', name+'_len_s', name+'_1_s'], [name+'_axes']),
make_node("Slice", [input_nodes[0], name+'_starts', name+'_ends', name+'_axes',
name+'_steps'], [name], name=name)
]
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
index f4b0c90..60e6a34 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -28,8 +28,10 @@ from .export_onnx import MXNetGraph
from ._export_helper import load_module
-def export_model(sym, params, input_shape, input_type=np.float32,
- onnx_file_path='model.onnx', verbose=False, opset_version=None):
+def export_model(sym, params, in_shapes=None, in_types=np.float32,
+ onnx_file_path='model.onnx', verbose=False, opset_version=None,
+ dynamic=False, dynamic_input_shapes=None, run_shape_inference=False, input_type=None,
+ input_shape=None):
"""Exports the MXNet model file, passed as a parameter, into ONNX model.
Accepts both symbol,parameter objects as well as json and params filepaths as input.
Operator support and coverage -
@@ -41,14 +43,24 @@ def export_model(sym, params, input_shape, input_type=np.float32,
Path to the json file or Symbol object
params : str or symbol object
Path to the params file or params dictionary. (Including both arg_params and aux_params)
- input_shape : List of tuple
+ in_shapes : List of tuple
Input shape of the model e.g [(1,3,224,224)]
- input_type : data type or list of data types
+ in_types : data type or list of data types
Input data type e.g. np.float32, or [np.float32, np.int32]
onnx_file_path : str
Path where to save the generated onnx file
verbose : Boolean
- If true will print logs of the model conversion
+ If True will print logs of the model conversion
+ dynamic: Boolean
+ If True will allow for dynamic input shapes to the model
+ dynamic_input_shapes: list of tuple
+ Specifies the dynamic input_shapes. If None then all dimensions are set to None
+ run_shape_inference : Boolean
+ If True will run shape inference on the model
+ input_type : data type or list of data types
+ This is the old name of in_types. We keep this parameter name for backward compatibility
+ in_shapes : List of tuple
+ This is the old name of in_shapes. We keep this parameter name for backward compatibility
Returns
-------
@@ -62,42 +74,57 @@ def export_model(sym, params, input_shape, input_type=np.float32,
"""
try:
- from onnx import helper, mapping
+ from onnx import helper, mapping, shape_inference
from onnx.defs import onnx_opset_version
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - https://github.com/onnx/onnx")
+ if input_type is not None:
+ in_types = input_type
+
+ if input_shape is not None:
+ in_shapes = input_shape
+
converter = MXNetGraph()
if opset_version is None:
# default is to use latest opset version the onnx package supports
opset_version = onnx_opset_version()
- if not isinstance(input_type, list):
- input_type = [input_type for _ in range(len(input_shape))]
- input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type]
+ if not isinstance(in_types, list):
+ in_types = [in_types for _ in range(len(in_shapes))]
+ in_types_t = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(i_t)] for i_t in in_types]
# if input parameters are strings(file paths), load files and create symbol parameter objects
if isinstance(sym, string_types) and isinstance(params, string_types):
logging.info("Converting json and weight file to sym and params")
sym_obj, params_obj = load_module(sym, params)
- onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
- input_dtype,
- verbose=verbose, opset_version=opset_version)
+ onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, in_shapes,
+ in_types_t,
+ verbose=verbose, opset_version=opset_version,
+ dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes)
elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
- onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
- input_dtype,
- verbose=verbose, opset_version=opset_version)
+ onnx_graph = converter.create_onnx_graph_proto(sym, params, in_shapes,
+ in_types_t,
+ verbose=verbose, opset_version=opset_version,
+ dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes)
else:
raise ValueError("Input sym and params should either be files or objects")
# Create the model (ModelProto)
onnx_model = helper.make_model(onnx_graph)
+ # Run shape inference on the model. Due to ONNX bug/incompatibility this may or may not crash
+ if run_shape_inference:
+ try:
+ onnx_model = shape_inference.infer_shapes(onnx_model)
+ except: # pylint: disable=bare-except
+ logging.info("Shape inference failed, original export is kept.")
+
# Save model on disk
with open(onnx_file_path, "wb") as file_handle:
serialized = onnx_model.SerializeToString()
file_handle.write(serialized)
- logging.info("Input shape of the model %s ", input_shape)
+ logging.info("Input shape of the model %s ", in_shapes)
logging.info("Exported ONNX file %s saved to disk", onnx_file_path)
return onnx_file_path
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 1790e4d..4cec698 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -122,29 +122,39 @@ class MXNetGraph(object):
return arg_params, aux_params
@staticmethod
- def get_outputs(sym, params, in_shape, in_label, in_type):
- """ Infer output shapes and return dictionary of output name to shape
+ def get_outputs(sym, params, in_shapes, output_label, in_types, dynamic=False,
+ dynamic_input_shapes=None):
+ """Helper function to collect the output names, types, and shapes
- :param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer shape on
- :param dic of (str, nd.NDArray) params:
- :param list of tuple(int, ...) in_shape: list of all input shapes
- :param in_label: name of label typically used in loss that may be left in graph. This name is
+ Parameters
+ ----------
+ sym : :class:`~mxnet.symbol.Symbol`
+ MXNet symbol object
+ params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+ Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
+ in_shapes : list of tuple
+ Input shapes
+ out_label : ``str``
+ Name of label typically used in loss that may be left in graph. This name is
removed from list of inputs required by symbol
- :return: dictionary of output name to shape
- :rtype: dict of (str, tuple(int, ...))
+ in_types : list of Int
+ Input ONNX data types
+ dynamic : Boolean
+ If True will allow for dynamic input shapes to the model
+ dynamic_input_shapes: list of tuple
+ Specifies the dynamic input_shapes. If None then all dimensions are set to None
+
+ Returns
+ in_shapes : list of tuple
+ Updated input shapes
+ graph_outputs : dict ``str`` to dict
+ This maps output name to {'shape':tuple, 'dtype':Int}
+ -------
"""
from onnx import mapping
import re
- # remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
- # by user. Also remove in_label, which is the name of the label symbol that may have been used
- # as the label for loss during training.
- inputs = {n: tuple(s) for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label],
- in_shape)}
- # Add params and their shape to list of inputs
- inputs.update({n: v.shape for n, v in params.items() if n in sym.list_inputs()})
- # Provide input data as well as input params to infer_shape()
- _, out_shapes, _ = sym.infer_shape(**inputs)
+ # Collect graph output names
out_names = list()
for name in sym.list_outputs():
if name.endswith('_state_output'): # handel special cases for RNN operator
@@ -159,24 +169,53 @@ class MXNetGraph(object):
logging.info("output '%s' does not end with '_output'", name)
out_names.append(name)
- assert len(out_shapes) == len(out_names)
+ # Collect graph output shapes
+ # Remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
+ # by user. Also remove output_label, which is the name of the label symbol that may have been used
+ # as the label for loss during training.
+ inputs = {n: tuple(s) for n, s in
+ zip([n for n in sym.list_inputs() if n not in params and n != output_label],
+ in_shapes)}
+ # Add params and their shape to list of inputs
+ inputs.update({n: v.shape for n, v in params.items() if n in sym.list_inputs()})
+ # Provide input data as well as input params to infer_shape()
+ _, out_shapes, _ = sym.infer_shape(**inputs)
+ if dynamic:
+ # Keep the dimensionality of the output shapes but change the values to None
+ out_shapes = [tuple(None for _ in i_s) for i_s in out_shapes]
- ## Infer output types
+ if dynamic_input_shapes is None:
+ # Set all dimensions to None
+ in_shapes = [tuple(None for _ in i_s) for i_s in in_shapes]
+ else:
+ assert len(in_shapes) == len(dynamic_input_shapes), "The length of " \
+ "dynamic_input_shapes must equal to the length of in_shapes."
+ for i_s, d_i_s in zip(in_shapes, dynamic_input_shapes):
+ assert len(i_s) == len(d_i_s), "The dimensionality " \
+ "of each shape must match."
+ in_shapes = dynamic_input_shapes
+ else:
+ assert dynamic_input_shapes is None, "dynamic_input_shapes is specified. Please " \
+ "set dynamic_input_shapes=True to enable dynamic input shapes"
+
+ # Collect graph output types
# Remove any input listed in params from sym.list_inputs() and bind them to the input types provided
- # by user. Also remove in_label
- in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t]
- for n, t in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_type)}
+ # by user. Also remove output_label
+ in_dtypes = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t] for n, t in
+ zip([n for n in sym.list_inputs() if n not in params and n != output_label],
+ in_types)}
# Add params and their types to list of inputs
- in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()})
- _, out_type, _ = sym.infer_type(**in_dtype)
+ in_dtypes.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()})
+ _, out_type, _ = sym.infer_type(**in_dtypes)
out_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[o(0).dtype] for o in out_type]
- assert len(out_types) == len(out_names)
+ # Make sure the types, names, and shapes all align up
+ assert len(out_types) == len(out_names) == len(out_shapes)
- # bind output shapes/types with output names
+ # Bind output shapes/types with output names
graph_outputs = {n: {'shape': s, 'dtype': d} for n, s, d in zip(out_names, out_shapes, out_types)}
- return graph_outputs
+ return in_shapes, graph_outputs
@staticmethod
def convert_weights_to_numpy(weights_dict):
@@ -184,7 +223,8 @@ class MXNetGraph(object):
return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
for k, v in weights_dict.items()])
- def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, opset_version=None):
+ def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=False, opset_version=None,
+ dynamic=True, dynamic_input_shapes=None):
"""Convert MXNet graph to ONNX graph
Parameters
@@ -193,14 +233,18 @@ class MXNetGraph(object):
MXNet symbol object
params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
- in_shape : List of tuple
+ in_shapes : List of tuple
Input shape of the model e.g [(1,3,224,224)]
- in_type : data type
- Input data type e.g. np.float32
+ in_types : List of Int
+ Input ONNX data types
verbose : Boolean
If true will print logs of the model conversion
opset_version : Int
ONNX opset version to use for export, defaults to latest supported by onnx package
+ dynamic: Boolean
+ If True will allow for dynamic input shapes to the model
+ dynamic_input_shapes: list of tuple
+ Specifies the dynamic input_shapes. If None then all dimensions are set to None
Returns
-------
@@ -241,9 +285,9 @@ class MXNetGraph(object):
onnx_processed_outputs = []
outputs_lookup = []
- # Determine output shape
- graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label, in_type)
-
+ # Determine graph output names, shapes, and dtypes. Also update in_shapes
+ in_shapes, graph_outputs = MXNetGraph.get_outputs(sym, params, in_shapes, output_label,
+ in_types, dynamic, dynamic_input_shapes)
appeared_names = set()
graph_input_idx = 0
for idx, node in enumerate(mx_graph):
@@ -260,26 +304,26 @@ class MXNetGraph(object):
# A node is an input node if its op_name is "null" and is not
# in params dict
if op == "null" and name not in params:
- # Handling graph input
+ # Handle graph input
- # Skipping output_label node, as this node is not part of graph
- # Refer "output_label" assignment above for more details.
+ # Skip output_label node, as this node is not part of graph
+ # Refer to "output_label" assignment above for more details.
if name == output_label:
continue
+
converted, dtypes = MXNetGraph.convert_layer(
node,
is_input=True,
mx_graph=mx_graph,
weights=weights,
- in_shape=in_shape[graph_input_idx],
- in_type=in_type[graph_input_idx],
+ in_shape=in_shapes[graph_input_idx],
+ in_type=in_types[graph_input_idx],
proc_nodes=all_processed_nodes,
initializer=initializer,
outputs_lookup=outputs_lookup)
graph_input_idx += 1
-
else:
- # Handling graph layers
+ # Handle graph layers
converted, dtypes = MXNetGraph.convert_layer(
node,
is_input=False,
@@ -291,7 +335,6 @@ class MXNetGraph(object):
idx=idx,
opset_version=opset_version
)
-
if isinstance(converted, list):
# Collect all the node's output names
node_possible_names = [name] + [name + str(i) for i in range(10)]
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index 3bcf27f..2fba092 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -55,6 +55,13 @@ class GluonModel():
[self.input_shape], self.input_dtype, onnx_file)
return onnx_file
+ def export_onnx_dynamic(self, dynamic_input_shapes):
+ onnx_file = self.modelpath + ".onnx"
+ mx.contrib.onnx.export_model(self.modelpath + "-symbol.json", self.modelpath + "-0000.params",
+ [self.input_shape], self.input_dtype, onnx_file, dynamic=True,
+ dynamic_input_shapes=dynamic_input_shapes)
+ return onnx_file
+
def predict(self, data):
return self.model(data)
@@ -584,11 +591,7 @@ def test_roberta_inference_onnxruntime(tmp_path, model_name):
sequence_outputs, attention_outputs= model(inputs, valid_length, masked_positions)
- model_dir = f'roberta_model'
- if not os.path.isdir(model_dir):
- os.mkdir(model_dir)
-
- prefix = '%s/%s' % (model_dir, model_name)
+ prefix = "%s/roberta" % tmp_path
model.export(prefix)
sym_file = "%s-symbol.json" % prefix
@@ -608,7 +611,7 @@ def test_roberta_inference_onnxruntime(tmp_path, model_name):
pred = sess.run(None, input_dict)
assert_almost_equal(sequence_outputs, pred[0])
- assert_almost_equal(attension_outputs, pred[1])
+ assert_almost_equal(attention_outputs, pred[1])
finally:
shutil.rmtree(tmp_path)
@@ -716,3 +719,99 @@ def test_distilbert_inference_onnxruntime(tmp_path, model_name):
finally:
shutil.rmtree(tmp_path)
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['mobilenet1.0', 'inceptionv3', 'darknet53', 'resnest14'])
+def test_dynamic_shape_cv_inference_onnxruntime(tmp_path, model_name):
+ tmp_path = str(tmp_path)
+ try:
+ M = GluonModel(model_name, (1, 3, 512, 512), 'float32', tmp_path)
+ dynamic_input_shapes = [(None, 3, 512, 512)]
+ onnx_file = M.export_onnx_dynamic(dynamic_input_shapes)
+
+ # create onnxruntime session using the generated onnx file
+ ses_opt = onnxruntime.SessionOptions()
+ ses_opt.log_severity_level = 3
+ sess = onnxruntime.InferenceSession(onnx_file, ses_opt)
+
+ # test on a different batch size
+ x = mx.random.uniform(0, 10, (5, 3, 512, 512))
+ in_tensors = [x]
+ input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
+ pred_on = sess.run(None, input_dict)
+
+ pred_mx = M.predict(x)
+
+ assert_almost_equal(pred_mx, pred_on[0])
+
+ finally:
+ shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model', ['bert_12_768_12'])
+def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model):
+ tmp_path = str(tmp_path)
+ try:
+ import gluonnlp as nlp
+ dataset = 'book_corpus_wiki_en_uncased'
+ ctx = mx.cpu(0)
+ model, vocab = nlp.model.get_model(
+ name=model,
+ ctx=ctx,
+ dataset_name=dataset,
+ pretrained=True,
+ use_pooler=True,
+ use_decoder=False,
+ num_layers = 3,
+ hparam_allow_override = True,
+ use_classifier=False)
+
+ model.hybridize(static_alloc=True)
+
+ batch = 5
+ seq_length = 16
+ # create synthetic test data
+ inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
+ token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
+ valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+ seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+ prefix = "%s/bert" % tmp_path
+ model.export(prefix)
+ sym_file = "%s-symbol.json" % prefix
+ params_file = "%s-0000.params" % prefix
+ onnx_file = "%s.onnx" % prefix
+
+ dynamic_input_shapes = [(None, seq_length), (None, seq_length), (None,)]
+ input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
+ input_types = [np.float32, np.float32, np.float32]
+ converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes,
+ input_types, onnx_file,
+ dynamic=True,
+ dynamic_input_shapes=dynamic_input_shapes)
+
+ # create onnxruntime session using the generated onnx file
+ ses_opt = onnxruntime.SessionOptions()
+ ses_opt.log_severity_level = 3
+ session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+
+ # test on a different batch size
+ batch = 7
+ seq_length = 16
+ inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
+ token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
+ valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+ seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
+
+ onnx_inputs = [inputs, token_types, valid_length]
+ input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
+ pred_onx, cls_onx = session.run(None, input_dict)
+
+ assert_almost_equal(seq_encoding, pred_onx)
+ assert_almost_equal(cls_encoding, cls_onx)
+
+ finally:
+ shutil.rmtree(tmp_path)