You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/09/11 17:19:05 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] Update onnx support to work with onnx 1.7.0 with most CV models (#19017)

This is an automated email from the ASF dual-hosted git repository.

samskalicky pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new b888d3c  [v1.x] Update onnx support to work with onnx 1.7.0 with most CV models (#19017)
b888d3c is described below

commit b888d3c71aa7d3150bfefae8beee67a234586434
Author: Joe Evans <gi...@250hacks.net>
AuthorDate: Fri Sep 11 10:17:23 2020 -0700

    [v1.x] Update onnx support to work with onnx 1.7.0 with most CV models (#19017)
    
    * fix pooling_convention warning when convert model to onnx (#18529)
    
    * fix  pooling_convention warning
    
    * fix pooling_convention warning
    
    * fix lint
    
    Co-authored-by: JackieWu <wk...@live.cn>
    
    * Prevent uninitialized variable error.
    
    * Initial work to get Dropout to work with onnx 1.7
    
    * Remove trailing whitespace for pylint.
    
    * Fix tensor initialization for Dropout operator input.
    
    * Update Clip operator to support latest ONNX opset versions by moving min/max attributes to inputs.
    
    * Fix whitespace.
    
    * Add support for importing Dropout operator in ONNX opset version >= 12.
    
    * Add support for import ONNX opsets >= 11 to clip operator.
    
    * Add optional opset_version parameter that defaults to latest opset version supported by onnx. Pass this parameter to each graph layer when exporting.
    
    * Add optional parameter to create_model() that allows user to specify which onnx opset version they want to use when exporting, defaults to latest version supported by onnx.
    
    * Use opset_version argument to determine operator format.
    
    * Add a opset_version parameter to from_onnx() so at operator conversion time, we know what opset version to use.
    
    * For Clip and Dropout operators, use opset version from passed proto_obj, which reflects what opset version the onnx model uses.
    
    * Use same tolerances that are in master.
    
    * Change Pad operator to use inputs instead of attributes for newer opset versions. Check opset version instead of ONNX version for Pooling operator.
    
    * Add documentation opset_version parameter.
    
    * Add opset_version parameters to unit tests.
    
    * Add test script for testing inference with onnxruntime on CV models from gluon model zoo.
    
    * Add license and clean up imports.
    
    * Install onnxruntime in docker container for unit tests.
    
    * Add onnxruntime to test dependencies.
    
    * Install onnxruntime into CentOS docker image.
    
    * Disable testing squeezenet models for now.
    
    * Update onnx version.
    
    * Fix typo.
    
    * Use mx.image.imread instead of PIL module.
    
    * ONNX import: use Conv pad attribute for symmetrical padding (#18675)
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Install onnx in CentOS containers when installing python.
    
    * Update import and export of some ONNX ops to support newer opset versions - this gets all ONNX unit tests to pass with onnx 1.7.
    
    * Re-enable squeezenet model testings in onnxruntime.
    
    * Run the onnxruntime inference tests in the ONNX pipeline instead of normal unittests pipelines.
    
    * Add missed return value.
    
    * Refactor code based on review comment.
    
    * Since the onnx tests are only run on ubuntu_cpu images, we don't need to install onnx and onnxruntime in the CentOS containers.
    
    Co-authored-by: Liu, Hao <ha...@hotmail.com>
    Co-authored-by: JackieWu <wk...@live.cn>
    Co-authored-by: Joe Evans <jo...@amazon.com>
    Co-authored-by: Serge Panev <sp...@nvidia.com>
---
 ci/docker/install/ubuntu_onnx.sh                   |   4 +-
 ci/docker/runtime_functions.sh                     |   4 +-
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 246 +++++++++++++++------
 python/mxnet/contrib/onnx/mx2onnx/export_model.py  |  10 +-
 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py   |  11 +-
 .../mxnet/contrib/onnx/onnx2mx/_op_translations.py | 136 +++++++++---
 python/mxnet/contrib/onnx/onnx2mx/import_model.py  |   3 +-
 python/mxnet/contrib/onnx/onnx2mx/import_onnx.py   |   8 +-
 .../mxnet/contrib/onnx/onnx2mx/import_to_gluon.py  |   3 +-
 tests/python-pytest/onnx/backend.py                |  15 +-
 tests/python-pytest/onnx/mxnet_export_test.py      |   2 +-
 tests/python-pytest/onnx/test_onnxruntime.py       | 131 +++++++++++
 tests/requirements.txt                             |   1 +
 13 files changed, 458 insertions(+), 116 deletions(-)

diff --git a/ci/docker/install/ubuntu_onnx.sh b/ci/docker/install/ubuntu_onnx.sh
index 44d6b9e..31eb5e8 100755
--- a/ci/docker/install/ubuntu_onnx.sh
+++ b/ci/docker/install/ubuntu_onnx.sh
@@ -30,5 +30,5 @@ echo "Installing libprotobuf-dev and protobuf-compiler ..."
 apt-get update || true
 apt-get install -y libprotobuf-dev protobuf-compiler
 
-echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX and tabulate ..."
-pip3 install pytest==3.6.3 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.3.0 Pillow==5.0.0 tabulate==0.7.5
+echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX, tabulate and onnxruntime..."
+pip3 install pytest==3.6.3 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.4.0
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index b5cbb9a..f8f2b57 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1228,11 +1228,13 @@ unittest_centos7_gpu() {
 integrationtest_ubuntu_cpu_onnx() {
 	set -ex
 	export PYTHONPATH=./python/
-    export DMLC_LOG_STACK_TRACE_DEPTH=10
+	export MXNET_SUBGRAPH_VERBOSE=0
+	export DMLC_LOG_STACK_TRACE_DEPTH=10
 	tests/python-pytest/onnx/backend_test.py
 	pytest tests/python-pytest/onnx/mxnet_export_test.py
 	pytest tests/python-pytest/onnx/test_models.py
 	pytest tests/python-pytest/onnx/test_node.py
+	pytest tests/python-pytest/onnx/test_onnxruntime.py
 }
 
 integrationtest_ubuntu_gpu_python() {
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index b1ab40e..f03fabb 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -175,7 +175,7 @@ def convert_weights_and_inputs(node, **kwargs):
                 data_type=data_type,
                 dims=dims,
                 vals=np_arr.flatten().tolist(),
-                raw=False,
+                raw=False
             )
         )
 
@@ -462,36 +462,73 @@ def convert_pad(node, **kwargs):
     """Map MXNet's pad operator attributes to onnx's Pad operator
     and return the created node.
     """
+    opset_version = kwargs["opset_version"]
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
     mxnet_pad_width = convert_string_to_list(attrs.get("pad_width"))
     onnx_pad_width = transform_padding(mxnet_pad_width)
 
     pad_mode = attrs.get("mode")
+    pad_value = np.float32(attrs.get("constant_value", 0.0))
 
-    if pad_mode == "constant":
-        pad_value = float(attrs.get("constant_value")) \
-            if "constant_value" in attrs else 0.0
-        node = onnx.helper.make_node(
-            'Pad',
-            inputs=input_nodes,
-            outputs=[name],
-            mode='constant',
-            value=pad_value,
-            pads=onnx_pad_width,
-            name=name
-        )
+    if opset_version >= 11:
+        # starting with opset 11, pads and constant_value are inputs instead of attributes
+        from onnx.helper import make_tensor, make_tensor_value_info
+        initializer = kwargs["initializer"]
+        pads_input_name = name + "_pads"
+        pads_input_type = onnx.TensorProto.INT64
+        pads_input_shape = np.shape(np.array(onnx_pad_width))
+        pads_value_node = make_tensor_value_info(pads_input_name, pads_input_type, pads_input_shape)
+        pads_tensor_node = make_tensor(pads_input_name, pads_input_type, pads_input_shape, onnx_pad_width)
+        initializer.append(pads_tensor_node)
+        input_nodes.append(pads_input_name)
+
+        if pad_mode == "constant":
+            const_input_name = name + "_constant"
+            const_input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[pad_value.dtype]
+            const_value_node = make_tensor_value_info(const_input_name, const_input_type, ())
+            const_tensor_node = make_tensor(const_input_name, const_input_type, (), [pad_value])
+            initializer.append(const_tensor_node)
+            input_nodes.append(const_input_name)
+            pad_node = onnx.helper.make_node(
+                "Pad",
+                input_nodes,
+                [name],
+                mode=pad_mode,
+                name=name
+            )
+            return [pads_value_node, const_value_node, pad_node]
+        else:
+            pad_node = onnx.helper.make_node(
+                "Pad",
+                input_nodes,
+                [name],
+                mode=pad_mode,
+                name=name
+            )
+            return [pads_value_node, pad_node]
     else:
-        node = onnx.helper.make_node(
-            'Pad',
-            inputs=input_nodes,
-            outputs=[name],
-            mode=pad_mode,
-            pads=onnx_pad_width,
-            name=name
-        )
-
-    return [node]
+        if pad_mode == "constant":
+            node = onnx.helper.make_node(
+                'Pad',
+                inputs=input_nodes,
+                outputs=[name],
+                mode='constant',
+                value=pad_value,
+                pads=onnx_pad_width,
+                name=name
+            )
+            return [node]
+        else:
+            node = onnx.helper.make_node(
+                'Pad',
+                inputs=input_nodes,
+                outputs=[name],
+                mode=pad_mode,
+                pads=onnx_pad_width,
+                name=name
+            )
+            return [node]
 
 
 def create_helper_trans_node(op_name, input_node, node_name):
@@ -639,6 +676,7 @@ def convert_pooling(node, **kwargs):
     MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators
     based on the input node's attributes and return the created node.
     """
+    opset_version = kwargs["opset_version"]
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
     kernel = eval(attrs["kernel"])
@@ -648,13 +686,14 @@ def convert_pooling(node, **kwargs):
     p_value = attrs.get('p_value', 'None')
 
     pooling_convention = attrs.get('pooling_convention', 'valid')
-
+    ceil_mode = False
     if pooling_convention == 'full':
-        pooling_warning = "Pooling: ONNX currently doesn't support pooling_convention. " \
-                          "This might lead to shape or accuracy issues. " \
-                          "https://github.com/onnx/onnx/issues/549"
-
-        logging.warning(pooling_warning)
+        if opset_version < 10:
+            pooling_warning = "Pooling: ONNX lower than 1.5.0 doesn't support pooling_convention. " \
+                              "This might lead to shape or accuracy issues. " \
+                              "https://github.com/onnx/onnx/issues/549"
+            logging.warning(pooling_warning)
+        ceil_mode = True
 
     pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
     pad_dims = pad_dims + pad_dims
@@ -694,15 +733,27 @@ def convert_pooling(node, **kwargs):
                 name=name
             )
         else:
-            node = onnx.helper.make_node(
-                pool_types[pool_type],
-                input_nodes,  # input
-                [name],
-                kernel_shape=kernel,
-                pads=pad_dims,
-                strides=stride,
-                name=name
-            )
+            if opset_version >= 10:
+                node = onnx.helper.make_node(
+                    pool_types[pool_type],
+                    input_nodes,  # input
+                    [name],
+                    kernel_shape=kernel,
+                    pads=pad_dims,
+                    strides=stride,
+                    name=name,
+                    ceil_mode=ceil_mode
+                )
+            else:
+                node = onnx.helper.make_node(
+                    pool_types[pool_type],
+                    input_nodes,  # input
+                    [name],
+                    kernel_shape=kernel,
+                    pads=pad_dims,
+                    strides=stride,
+                    name=name
+                )
 
     return [node]
 
@@ -945,17 +996,35 @@ def convert_dropout(node, **kwargs):
     and return the created node.
     """
     name, input_nodes, attrs = get_inputs(node, kwargs)
+    opset_version = kwargs["opset_version"]
 
     probability = float(attrs.get("p", 0.5))
 
-    dropout_node = onnx.helper.make_node(
-        "Dropout",
-        input_nodes,
-        [name],
-        ratio=probability,
-        name=name
-    )
-    return [dropout_node]
+    if opset_version >= 12:
+        # opset >= 12 requires the ratio to be an input
+        initializer = kwargs["initializer"]
+        ratio_input_name = name + "_ratio"
+        value_node = onnx.helper.make_tensor_value_info(ratio_input_name,
+                                                        onnx.TensorProto.FLOAT, ())
+        tensor_node = onnx.helper.make_tensor(ratio_input_name, onnx.TensorProto.FLOAT,
+                                              (), [probability])
+        initializer.append(tensor_node)
+        dropout_node = onnx.helper.make_node(
+            "Dropout",
+            [input_nodes[0], ratio_input_name],
+            [name],
+            name=name
+        )
+        return [value_node, dropout_node]
+    else:
+        dropout_node = onnx.helper.make_node(
+            "Dropout",
+            input_nodes,
+            [name],
+            ratio=probability,
+            name=name
+        )
+        return [dropout_node]
 
 
 @mx_op.register("Flatten")
@@ -971,19 +1040,46 @@ def convert_clip(node, **kwargs):
     and return the created node.
     """
     name, input_nodes, attrs = get_inputs(node, kwargs)
+    opset_version = kwargs["opset_version"]
 
-    a_min = np.float(attrs.get('a_min', -np.inf))
-    a_max = np.float(attrs.get('a_max', np.inf))
+    a_min = float(attrs.get('a_min', -np.inf))
+    a_max = float(attrs.get('a_max', np.inf))
 
-    clip_node = onnx.helper.make_node(
-        "Clip",
-        input_nodes,
-        [name],
-        name=name,
-        min=a_min,
-        max=a_max
-    )
-    return [clip_node]
+    if opset_version >= 11:
+        # opset >= 11 requires min/max to be inputs
+        initializer = kwargs["initializer"]
+        min_input_name = name + "_min"
+        max_input_name = name + "_max"
+        min_value_node = onnx.helper.make_tensor_value_info(min_input_name,
+                                                            onnx.TensorProto.FLOAT, ())
+        max_value_node = onnx.helper.make_tensor_value_info(max_input_name,
+                                                            onnx.TensorProto.FLOAT, ())
+        min_tensor_node = onnx.helper.make_tensor(min_input_name, onnx.TensorProto.FLOAT,
+                                                  (), [a_min])
+        max_tensor_node = onnx.helper.make_tensor(max_input_name, onnx.TensorProto.FLOAT,
+                                                  (), [a_max])
+        initializer.append(min_tensor_node)
+        initializer.append(max_tensor_node)
+        input_nodes.append(min_input_name)
+        input_nodes.append(max_input_name)
+        clip_node = onnx.helper.make_node(
+            "Clip",
+            input_nodes,
+            [name],
+            name=name
+        )
+        return [min_value_node, max_value_node, clip_node]
+
+    else:
+        clip_node = onnx.helper.make_node(
+            "Clip",
+            input_nodes,
+            [name],
+            name=name,
+            min=a_min,
+            max=a_max
+        )
+        return [clip_node]
 
 
 def scalar_op_helper(node, op_name, **kwargs):
@@ -2070,14 +2166,34 @@ def convert_topk(node, **kwargs):
     else:
         raise NotImplementedError("ONNX expects both value and indices as output")
 
-    topk_node = onnx.helper.make_node(
-        "TopK",
-        input_nodes,
-        outputs,
-        axis=axis,
-        k=k,
-        name=name
-    )
+    opset_version = kwargs['opset_version']
+    if opset_version >= 10:
+        from onnx.helper import make_tensor, make_tensor_value_info
+        initializer = kwargs["initializer"]
+        k_input_name = name + "_k"
+        k_input_type = onnx.TensorProto.INT64
+        k_value_node = make_tensor_value_info(k_input_name, k_input_type, ())
+        k_tensor_node = make_tensor(k_input_name, k_input_type, (), k)
+        initializer.append(k_tensor_node)
+        input_nodes.append(k_input_name)
+
+        topk_node = onnx.helper.make_node(
+            "TopK",
+            input_nodes,
+            outputs,
+            axis=axis,
+            name=name
+        )
+        return [k_value_node, topk_node]
+    else:
+        topk_node = onnx.helper.make_node(
+            "TopK",
+            input_nodes,
+            outputs,
+            axis=axis,
+            k=k,
+            name=name
+        )
 
     return [topk_node]
 
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
index 51a62ed..2fc7760 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -29,7 +29,7 @@ from ._export_helper import load_module
 
 
 def export_model(sym, params, input_shape, input_type=np.float32,
-                 onnx_file_path='model.onnx', verbose=False):
+                 onnx_file_path='model.onnx', verbose=False, opset_version=None):
     """Exports the MXNet model file, passed as a parameter, into ONNX model.
     Accepts both symbol,parameter objects as well as json and params filepaths as input.
     Operator support and coverage -
@@ -63,11 +63,15 @@ def export_model(sym, params, input_shape, input_type=np.float32,
 
     try:
         from onnx import helper, mapping
+        from onnx.defs import onnx_opset_version
     except ImportError:
         raise ImportError("Onnx and protobuf need to be installed. "
                           + "Instructions to install - https://github.com/onnx/onnx")
 
     converter = MXNetGraph()
+    if opset_version is None:
+        # default is to use latest opset version the onnx package supports
+        opset_version = onnx_opset_version()
 
     data_format = np.dtype(input_type)
     # if input parameters are strings(file paths), load files and create symbol parameter objects
@@ -76,11 +80,11 @@ def export_model(sym, params, input_shape, input_type=np.float32,
         sym_obj, params_obj = load_module(sym, params)
         onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
                                                        mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
-                                                       verbose=verbose)
+                                                       verbose=verbose, opset_version=opset_version)
     elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
         onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
                                                        mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
-                                                       verbose=verbose)
+                                                       verbose=verbose, opset_version=opset_version)
     else:
         raise ValueError("Input sym and params should either be files or objects")
 
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 14aa52b..bd25336 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -157,7 +157,7 @@ class MXNetGraph(object):
         return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
                      for k, v in weights_dict.items()])
 
-    def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False):
+    def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, opset_version=None):
         """Convert MXNet graph to ONNX graph
 
         Parameters
@@ -172,6 +172,8 @@ class MXNetGraph(object):
             Input data type e.g. np.float32
         verbose : Boolean
             If true will print logs of the model conversion
+        opset_version : Int
+            ONNX opset version to use for export, defaults to latest supported by onnx package
 
         Returns
         -------
@@ -181,10 +183,14 @@ class MXNetGraph(object):
         try:
             from onnx import (checker, helper, NodeProto, ValueInfoProto, TensorProto)
             from onnx.helper import make_tensor_value_info
+            from onnx.defs import onnx_opset_version
         except ImportError:
             raise ImportError("Onnx and protobuf need to be installed. "
                               + "Instructions to install - https://github.com/onnx/onnx")
 
+        if opset_version is None:
+            opset_version = onnx_opset_version()
+
         # When MXNet model is saved to json file , MXNet adds a node for label.
         # The name of this node is, name of the last node + "_label" ( i.e if last node
         # name is "Softmax", this node will have a name "Softmax_label". Also, the new node
@@ -246,7 +252,8 @@ class MXNetGraph(object):
                     proc_nodes=all_processed_nodes,
                     initializer=initializer,
                     index_lookup=index_lookup,
-                    idx=idx
+                    idx=idx,
+                    opset_version=opset_version
                 )
 
             if isinstance(converted, list):
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 311fd86..51fe418 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -240,11 +240,24 @@ def relu(attrs, inputs, proto_obj):
 
 def pad(attrs, inputs, proto_obj):
     """ Add padding to input tensor"""
-    new_attrs = translation_utils._fix_attribute_names(attrs, {'pads'  : 'pad_width',
-                                                               'value' : 'constant_value'
-                                                              })
-    new_attrs['pad_width'] = translation_utils._pad_sequence_fix(new_attrs.get('pad_width'))
-    return 'pad', new_attrs, inputs
+    opset_version = proto_obj.opset_version
+    if 'mode' not in attrs.keys():
+        attrs['mode'] = 'constant'
+    if opset_version >= 11:
+        pads = list(proto_obj._params[inputs[1].name].asnumpy())
+        pads = tuple([int(i) for i in pads])
+        new_attrs = translation_utils._add_extra_attributes(attrs, {'pad_width': pads})
+        if len(inputs) == 3:
+            const = proto_obj._params[inputs[2].name].asnumpy()[0]
+            new_attrs = translation_utils._add_extra_attributes(new_attrs, {'constant_value': const})
+        new_attrs['pad_width'] = translation_utils._pad_sequence_fix(new_attrs.get('pad_width'))
+        return 'pad', new_attrs, inputs[0]
+    else:
+        new_attrs = translation_utils._fix_attribute_names(attrs, {'pads'  : 'pad_width',
+                                                                   'value' : 'constant_value'
+                                                                  })
+        new_attrs['pad_width'] = translation_utils._pad_sequence_fix(new_attrs.get('pad_width'))
+        return 'pad', new_attrs, inputs
 
 def matrix_multiplication(attrs, inputs, proto_obj):
     """Performs general matrix multiplication"""
@@ -334,14 +347,25 @@ def conv(attrs, inputs, proto_obj):
     no_bias = new_attrs['no_bias'] if 'no_bias' in new_attrs else 0
     bias = None if no_bias is True else inputs[2]
 
-    # Unlike ONNX, MXNet's convolution operator does not support asymmetric padding, so we first
-    # use 'Pad' operator, which supports asymmetric padding. Then use the convolution operator.
-    pad_width = (0, 0, 0, 0) + translation_utils._pad_sequence_fix(padding, kernel_dim=len(kernel))
-    pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width)
+    mxnet_pad = translation_utils._pad_sequence_fix(padding, kernel_dim=len(kernel))
 
-    conv_op = symbol.Convolution(pad_op, inputs[1], bias,
-                                 kernel=kernel, stride=stride, dilate=dilations,
-                                 num_filter=num_filter, num_group=num_group, no_bias=no_bias)
+    left_pads = mxnet_pad[0::2]
+    right_pads = mxnet_pad[1::2]
+    is_pad_sym = left_pads == right_pads
+
+    if not is_pad_sym:
+        # Unlike ONNX, MXNet's convolution operator does not support asymmetric padding, so we first
+        # use 'Pad' operator, which supports asymmetric padding. Then use the convolution operator.
+        pad_width = (0, 0, 0, 0) + mxnet_pad
+        pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width)
+        conv_op = symbol.Convolution(pad_op, inputs[1], bias,
+                                     kernel=kernel, stride=stride, dilate=dilations,
+                                     num_filter=num_filter, num_group=num_group, no_bias=no_bias)
+    else:
+        pad_width = left_pads
+        conv_op = symbol.Convolution(inputs[0], inputs[1], bias,
+                                     kernel=kernel, stride=stride, dilate=dilations, pad=pad_width,
+                                     num_filter=num_filter, num_group=num_group, no_bias=no_bias)
 
     return conv_op, new_attrs, inputs
 
@@ -356,7 +380,7 @@ def deconv(attrs, inputs, proto_obj):
     new_attrs = translation_utils._fix_bias('Deconvolution', new_attrs, len(inputs))
 
     new_attrs = translation_utils._fix_channels('Deconvolution', new_attrs, inputs, proto_obj)
-    kernel = new_attrs['kernel']
+    kernel = new_attrs['kernel'] if 'kernel' in new_attrs else []
     stride = new_attrs['stride'] if 'stride' in new_attrs else []
     padding = new_attrs['pad'] if 'pad' in new_attrs else []
     dilations = new_attrs['dilate'] if 'dilate' in new_attrs else []
@@ -446,12 +470,22 @@ def local_response_norm(attrs, inputs, proto_obj):
 def dropout(attrs, inputs, proto_obj):
     """Dropout Regularization."""
     mode = 'training'
+    opset_version = proto_obj.opset_version
     if 'is_test' in attrs and attrs['is_test'] == 0:
         mode = 'always'
-    new_attrs = translation_utils._fix_attribute_names(attrs,
-                                                       {'ratio': 'p'})
-    new_attrs = translation_utils._remove_attributes(new_attrs, ['is_test'])
+    new_attrs = translation_utils._remove_attributes(attrs, ['is_test'])
     new_attrs = translation_utils._add_extra_attributes(new_attrs, {'mode': mode})
+    if opset_version >= 12:
+        new_attrs = translation_utils._remove_attributes(new_attrs, ['seed'])
+        if len(inputs) == 2:
+            ratio_float = proto_obj._params[inputs[1].name].asnumpy()[0]
+            new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
+            new_attrs = translation_utils._add_extra_attributes(new_attrs, {'p': ratio_float})
+        elif len(inputs) == 1:
+            new_attrs = translation_utils._fix_attribute_names(new_attrs, {'ratio': 'p'})
+        return 'Dropout', new_attrs, inputs[0]
+    else:
+        new_attrs = translation_utils._fix_attribute_names(new_attrs, {'ratio': 'p'})
     return 'Dropout', new_attrs, inputs
 
 # Changing shape and type.
@@ -501,15 +535,30 @@ def _slice(attrs, inputs, proto_obj):
     """Returns a slice of the input tensor along multiple axes."""
     input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0]
     input_shape = input_tensor_data[1]
-    new_attrs = translation_utils._fix_attribute_names(attrs,
-                                                       {'axes' : 'axis',
-                                                        'ends' : 'end',
-                                                        'starts' : 'begin'})
-    # onnx slice provides slicing on multiple axis. Adding multiple slice_axis operator
-    # for multiple axes from mxnet
-    begin = new_attrs.get('begin')
-    end = list(new_attrs.get('end'))
-    axes = new_attrs.get('axis', tuple(range(len(begin))))
+
+    if proto_obj.opset_version >= 10:
+        begin = proto_obj._params[inputs[1].name].asnumpy()
+        end = proto_obj._params[inputs[2].name].asnumpy()
+        if len(inputs) >= 4:
+            axes = list(proto_obj._params[inputs[3].name].asnumpy())
+            axes = tuple([int(i) for i in axes])
+        else:
+            axes = tuple(range(len(begin)))
+        new_attrs = translation_utils._add_extra_attributes(attrs, {'axes' : axes,
+                                                                    'begin' : begin,
+                                                                    'end' : end
+                                                                   })
+    else:
+        new_attrs = translation_utils._fix_attribute_names(attrs,
+                                                           {'axes' : 'axis',
+                                                            'ends' : 'end',
+                                                            'starts' : 'begin'})
+        # onnx slice provides slicing on multiple axis. Adding multiple slice_axis operator
+        # for multiple axes from mxnet
+        begin = new_attrs.get('begin')
+        end = list(new_attrs.get('end'))
+        axes = new_attrs.get('axis', tuple(range(len(begin))))
+
     for i, axis in enumerate(axes):
         end[i] = None if end[i] >= input_shape[axis] else end[i]
     slice_op = symbol.slice_axis(inputs[0], axis=axes[0], begin=begin[0], end=end[0])
@@ -549,13 +598,28 @@ def flatten(attrs, inputs, proto_obj):
 
 def clip(attrs, inputs, proto_obj):
     """Clips (limits) the values in an array."""
-    new_attrs = translation_utils._fix_attribute_names(attrs, {'min' : 'a_min',
-                                                               'max' : 'a_max'})
-    if 'a_max' not in new_attrs:
-        new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_max' : np.inf})
-    if 'a_min' not in new_attrs:
-        new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_min' : -np.inf})
-    return 'clip', new_attrs, inputs
+    opset_version = proto_obj.opset_version
+    if opset_version >= 11:
+        if len(inputs) == 1:
+            new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_max' : np.inf,
+                                                                            'a_min' : -np.inf})
+        elif len(inputs) == 2:
+            min_float = proto_obj._params[inputs[1].name].asnumpy()
+            new_attrs = translation_utils._add_extra_attributes(attrs, {'a_min': min_float[0],
+                                                                        'a_max': np.inf})
+        elif len(inputs) == 3:
+            min_float = proto_obj._params[inputs[1].name].asnumpy()
+            max_float = proto_obj._params[inputs[2].name].asnumpy()
+            new_attrs = translation_utils._add_extra_attributes(attrs, {'a_min': min_float[0],
+                                                                        'a_max': max_float[0]})
+    else:
+        new_attrs = translation_utils._fix_attribute_names(attrs, {'min' : 'a_min',
+                                                                   'max' : 'a_max'})
+        if 'a_max' not in new_attrs:
+            new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_max' : np.inf})
+        if 'a_min' not in new_attrs:
+            new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_min' : -np.inf})
+    return 'clip', new_attrs, inputs[0]
 
 def gather(attrs, inputs, proto_obj):
     """Gather elements from an input array along the given axis."""
@@ -790,4 +854,10 @@ def topk(attrs, inputs, proto_obj):
     new_attrs = translation_utils._add_extra_attributes(attrs,
                                                         {'ret_typ': 'both',
                                                          'dtype': 'int64'})
-    return 'topk', new_attrs, inputs
+    opset_version = proto_obj.opset_version
+    if opset_version >= 10:
+        k_vals = proto_obj._params[inputs[1].name].asnumpy()
+        new_attrs = translation_utils._add_extra_attributes(new_attrs, {'k': k_vals})
+        return 'topk', new_attrs, inputs[0]
+    else:
+        return 'topk', new_attrs, inputs
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_model.py b/python/mxnet/contrib/onnx/onnx2mx/import_model.py
index 1c19543..d060b08 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_model.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_model.py
@@ -56,7 +56,8 @@ def import_model(model_file):
                           + "Instructions to install - https://github.com/onnx/onnx")
     # loads model file and returns ONNX protobuf object
     model_proto = onnx.load_model(model_file)
-    sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)
+    model_opset_version = max([x.version for x in model_proto.opset_import])
+    sym, arg_params, aux_params = graph.from_onnx(model_proto.graph, opset_version=model_opset_version)
     return sym, arg_params, aux_params
 
 def get_model_metadata(model_file):
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
index 72913dd..c2be83d 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
@@ -36,6 +36,7 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
         self.aux_dict = {}
         self.arg_dict = {}
         self.model_metadata = {}
+        self.opset_version = 0
 
     def _convert_operator(self, node_name, op_name, attrs, inputs):
         """Convert from onnx operator to mxnet operator.
@@ -72,7 +73,7 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
             return mxnet_sym
         return op_name
 
-    def from_onnx(self, graph):
+    def from_onnx(self, graph, opset_version):
         """Construct symbol from onnx graph.
 
         Parameters
@@ -87,6 +88,7 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
         params : dict
             A dict of name: nd.array pairs, used as pretrained weights
         """
+        self.opset_version = opset_version
         # get input, output shapes
         self.model_metadata = self.get_graph_metadata(graph)
         # parse network inputs, aka parameters
@@ -156,7 +158,7 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
                    }
         return metadata
 
-    def graph_to_gluon(self, graph, ctx):
+    def graph_to_gluon(self, graph, ctx, opset_version):
         """Construct SymbolBlock from onnx graph.
 
         Parameters
@@ -171,7 +173,7 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
         sym_block :gluon.nn.SymbolBlock
             The returned gluon SymbolBlock
         """
-        sym, arg_params, aux_params = self.from_onnx(graph)
+        sym, arg_params, aux_params = self.from_onnx(graph, opset_version)
         metadata = self.get_graph_metadata(graph)
         data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']]
         data_inputs = [symbol.var(data_name) for data_name in data_names]
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
index 13ad5b9..f6e1036 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
@@ -49,5 +49,6 @@ def import_to_gluon(model_file, ctx):
         raise ImportError("Onnx and protobuf need to be installed. Instructions to"
                           + " install - https://github.com/onnx/onnx#installation")
     model_proto = onnx.load_model(model_file)
-    net = graph.graph_to_gluon(model_proto.graph, ctx)
+    model_opset_version = max([x.version for x in model_proto.opset_import])
+    net = graph.graph_to_gluon(model_proto.graph, ctx, model_opset_version)
     return net
diff --git a/tests/python-pytest/onnx/backend.py b/tests/python-pytest/onnx/backend.py
index 2f9e247..eb803f7 100644
--- a/tests/python-pytest/onnx/backend.py
+++ b/tests/python-pytest/onnx/backend.py
@@ -26,6 +26,7 @@ import numpy as np
 try:
     from onnx import helper, TensorProto, mapping
     from onnx.backend.base import Backend
+    from onnx.defs import onnx_opset_version
 except ImportError:
     raise ImportError("Onnx and protobuf need to be installed. Instructions to"
                       + " install - https://github.com/onnx/onnx#installation")
@@ -57,13 +58,16 @@ class MXNetBackend(Backend):
         params = {}
         params.update(arg_params)
         params.update(aux_params)
+        # use the latest opset version supported by the onnx library
+        opset_version = onnx_opset_version()
         # exporting to onnx graph proto format
         converter = MXNetGraph()
         graph_proto = converter.create_onnx_graph_proto(sym, params, in_shape=input_shape,
-                                                        in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')])
+                                                        in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')],
+                                                        opset_version=opset_version)
 
         # importing back to MXNET for verifying result.
-        sym, arg_params, aux_params = graph.from_onnx(graph_proto)
+        sym, arg_params, aux_params = graph.from_onnx(graph_proto, opset_version)
 
         return sym, arg_params, aux_params
 
@@ -95,8 +99,11 @@ class MXNetBackend(Backend):
         else:
             raise NotImplementedError("ONNX tests are run only for CPU context.")
 
+        # determine opset version model uses
+        model_opset_version = max([x.version for x in model.opset_import])
+
         if backend == 'mxnet':
-            sym, arg_params, aux_params = graph.from_onnx(model.graph)
+            sym, arg_params, aux_params = graph.from_onnx(model.graph, model_opset_version)
             if operation == 'export':
                 metadata = graph.get_graph_metadata(model.graph)
                 input_data = metadata['input_tensor_data']
@@ -107,7 +114,7 @@ class MXNetBackend(Backend):
             return MXNetBackendRep(sym, arg_params, aux_params, device)
         elif backend == 'gluon':
             if operation == 'import':
-                net = graph.graph_to_gluon(model.graph, ctx)
+                net = graph.graph_to_gluon(model.graph, ctx, model_opset_version)
                 return GluonBackendRep(net, device)
             elif operation == 'export':
                 raise NotImplementedError("Gluon->ONNX export not implemented.")
diff --git a/tests/python-pytest/onnx/mxnet_export_test.py b/tests/python-pytest/onnx/mxnet_export_test.py
index 90e92cc..947fa2f 100644
--- a/tests/python-pytest/onnx/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/mxnet_export_test.py
@@ -74,7 +74,7 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params=
         # Confirm network outputs are the same
         imported_net_output = _force_list(imported_net(data))
         for out, imp_out in zip(output, imported_net_output):
-            mx.test_utils.assert_almost_equal(out, imp_out)
+            mx.test_utils.assert_almost_equal(out, imp_out, atol=1e-5, rtol=1e-5)
 
 
 class TestExport(unittest.TestCase):
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
new file mode 100644
index 0000000..052b241
--- /dev/null
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+import onnxruntime
+
+import json
+import os
+import shutil
+import tempfile
+
+
+def test_cv_model_inference_onnxruntime():
+    def get_gluon_cv_model(model_name, tmp):
+        tmpfile = os.path.join(tmp, model_name)
+        ctx = mx.cpu(0)
+        net_fp32 = mx.gluon.model_zoo.vision.get_model(model_name, pretrained=True, ctx=ctx, root=tmp)
+        net_fp32.hybridize()
+        data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx)
+        net_fp32.forward(data)
+        net_fp32.export(tmpfile, 0)
+        sym_file = tmpfile + '-symbol.json'
+        params_file = tmpfile + '-0000.params'
+        return sym_file, params_file
+
+    def export_model_to_onnx(sym_file, params_file):
+        input_shape = (1,3,224,224)
+        onnx_file = os.path.join(os.path.dirname(sym_file), "model.onnx")
+        converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, [input_shape],
+                                                            np.float32, onnx_file)
+        return onnx_file
+
+    def normalize_image(imgfile):
+        image = mx.image.imread(imgfile).asnumpy()
+        image_data = np.array(image).transpose(2, 0, 1)
+        img_data = image_data.astype('float32')
+        mean_vec = np.array([0.485, 0.456, 0.406])
+        stddev_vec = np.array([0.229, 0.224, 0.225])
+        norm_img_data = np.zeros(img_data.shape).astype('float32')
+        for i in range(img_data.shape[0]):
+            norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
+        return norm_img_data.reshape(1, 3, 224, 224).astype('float32')
+
+    def get_prediction(model, image):
+        pass
+
+    def softmax(x):
+        x = x.reshape(-1)
+        e_x = np.exp(x - np.max(x))
+        return e_x / e_x.sum(axis=0)
+
+    def load_imgnet_labels():
+        mx.test_utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/image_net_labels.json')
+        return np.array(json.load(open('image_net_labels.json', 'r')))
+
+    def download_test_images():
+        test_images = [
+            ['dog.jpg',['boxer']],
+            ['apron.jpg', ['apron', 'maillot']],
+            ['dolphin.jpg', ['great white shark','grey whale']],
+            ['hammerheadshark.jpg', ['tiger shark']],
+            ['lotus.jpg', ['pinwheel','pot']]
+        ]
+        for f,_ in test_images:
+            mx.test_utils.download('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/onnx/images/'+f+'?raw=true',
+                                   fname=f)
+        return test_images
+
+
+    test_models = [
+        'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25',
+        'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25',
+        'resnet18_v1', 'resnet18_v2', 'resnet34_v1', 'resnet34_v2', 'resnet50_v1', 'resnet50_v2',
+        'resnet101_v1', 'resnet101_v2', 'resnet152_v1', 'resnet152_v2',
+        'squeezenet1.0', 'squeezenet1.1', 
+        'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn'
+    ]
+    labels = load_imgnet_labels()
+    test_images = download_test_images()
+
+    for model in test_models:
+        tmpdir = tempfile.mkdtemp()
+        sym_file, params_file = get_gluon_cv_model(model, tmpdir)
+        onnx_file = export_model_to_onnx(sym_file, params_file)
+        #print("exported onnx file: ",onnx_file)
+
+        # create onnxruntime session using the generated onnx file
+        ses_opt = onnxruntime.SessionOptions()
+        ses_opt.log_severity_level = 3
+        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+        input_name = session.get_inputs()[0].name
+
+        for img,classes in test_images:
+            img_data = normalize_image(img)
+            raw_result = session.run([], {input_name: img_data})
+            res = softmax(np.array(raw_result)).tolist()
+            class_idx = np.argmax(res)
+            #print("Image top classification:",labels[class_idx])
+            sort_idx = np.flip(np.squeeze(np.argsort(res)))
+            #print("\tTop labels: " + ",".join(labels[sort_idx[:5]]))
+            correct_classification = False
+            for label in labels[sort_idx[:5]]:
+                for c in classes:
+                    if c in label:
+                        correct_classification = True
+            assert correct_classification == True
+
+        # cleanup
+        shutil.rmtree(tmpdir)
+
+
+
+
+if __name__ == "__main__":
+    test_cv_model_inference_onnxruntime()
+
diff --git a/tests/requirements.txt b/tests/requirements.txt
index dde62c5..24764fb 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -6,3 +6,4 @@ nose-timer
 ipython
 numpy>1.16.0,<1.19.0  # Restrict numpy version to < 1.19.0 due to https://github.com/apache/incubator-mxnet/issues/18600
 scipy
+onnxruntime