You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/06 22:33:53 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] MXNet export for ONNX 1.8 support (#20113)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 5b7826e  [v1.x] MXNet export for ONNX 1.8 support (#20113)
5b7826e is described below

commit 5b7826ee7886183f37247ff0ba7c48e0469593a1
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Tue Apr 6 15:29:49 2021 -0700

    [v1.x] MXNet export for ONNX 1.8 support (#20113)
    
    * onnx1.8 support
    
    * fix sanity
    
    * update protobuf version
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 ci/docker/install/ubuntu_onnx.sh                   |    2 +-
 ci/docker/runtime_functions.sh                     |   12 +
 python/mxnet/onnx/mx2onnx/__init__.py              |    3 +-
 python/mxnet/onnx/mx2onnx/_export_model.py         |    8 +-
 python/mxnet/onnx/mx2onnx/_export_onnx.py          |   24 +-
 .../_op_translations_opset11.py}                   |    7 +-
 .../_op_translations/_op_translations_opset13.py   | 1302 ++++++++++++++++++++
 tests/nightly/JenkinsfileForBinaries               |   12 +-
 tools/license_header.py                            |    2 +
 9 files changed, 1353 insertions(+), 19 deletions(-)

diff --git a/ci/docker/install/ubuntu_onnx.sh b/ci/docker/install/ubuntu_onnx.sh
index 81c8755..657e514 100755
--- a/ci/docker/install/ubuntu_onnx.sh
+++ b/ci/docker/install/ubuntu_onnx.sh
@@ -30,4 +30,4 @@ echo "Installing libprotobuf-dev and protobuf-compiler ..."
 apt-get update || true
 apt-get install -y libprotobuf-dev protobuf-compiler
 
-pip3 install pytest==6.2.2 pytest-cov==2.11.1 pytest-xdist==2.2.1 protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.6.0 'numpy>1.16.0,<1.19.0' gluonnlp==0.10.0 gluoncv==0.8.0
+pip3 install pytest==6.2.2 pytest-cov==2.11.1 pytest-xdist==2.2.1 protobuf==3.13.0 onnx==1.8.1 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.6.0 'numpy>1.16.0,<1.19.0' gluonnlp==0.10.0 gluoncv==0.8.0
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 2c3f367..17b6479 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1594,6 +1594,18 @@ nightly_estimator() {
     nosetests test_sentiment_rnn.py
 }
 
+nightly_onnx_operator_tests() {
+    set -ex
+    export PYTHONPATH=./python/
+    export MXNET_SUBGRAPH_VERBOSE=0
+    export DMLC_LOG_STACK_TRACE_DEPTH=10
+    COV_ARG="--cov=./ --cov-report=xml --cov-append"
+    pip3 install onnx==1.8.1
+    pytest $COV_ARG --verbose tests/python-pytest/onnx/test_operators.py
+    pip3 install onnx==1.7.0
+    pytest $COV_ARG --verbose tests/python-pytest/onnx/test_operators.py
+}
+
 nightly_onnx_cv_batch1_tests() {
     set -ex
     export PYTHONPATH=./python/
diff --git a/python/mxnet/onnx/mx2onnx/__init__.py b/python/mxnet/onnx/mx2onnx/__init__.py
index d8a6d5a..20d110e 100644
--- a/python/mxnet/onnx/mx2onnx/__init__.py
+++ b/python/mxnet/onnx/mx2onnx/__init__.py
@@ -19,4 +19,5 @@
 """ONNX Export module"""
 
 from ._export_model import export_model
-from . import _op_translations
+from ._op_translations import _op_translations_opset11
+from ._op_translations import _op_translations_opset13
diff --git a/python/mxnet/onnx/mx2onnx/_export_model.py b/python/mxnet/onnx/mx2onnx/_export_model.py
index d9be998..687899b 100644
--- a/python/mxnet/onnx/mx2onnx/_export_model.py
+++ b/python/mxnet/onnx/mx2onnx/_export_model.py
@@ -29,8 +29,8 @@ from ._export_helper import load_module
 
 
 def export_model(sym, params, in_shapes=None, in_types=np.float32,
-                 onnx_file_path='model.onnx', verbose=False, opset_version=None,
-                 dynamic=False, dynamic_input_shapes=None, run_shape_inference=False, input_type=None,
+                 onnx_file_path='model.onnx', verbose=False, dynamic=False,
+                 dynamic_input_shapes=None, run_shape_inference=False, input_type=None,
                  input_shape=None):
     """Exports the MXNet model file, passed as a parameter, into ONNX model.
     Accepts both symbol,parameter objects as well as json and params filepaths as input.
@@ -87,9 +87,7 @@ def export_model(sym, params, in_shapes=None, in_types=np.float32,
         in_shapes = input_shape
 
     converter = MXNetGraph()
-    if opset_version is None:
-        # default is to use latest opset version the onnx package supports
-        opset_version = onnx_opset_version()
+    opset_version = onnx_opset_version()
 
     if not isinstance(in_types, list):
         in_types = [in_types for _ in range(len(in_shapes))]
diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py
index 903b0cd..375e753 100644
--- a/python/mxnet/onnx/mx2onnx/_export_onnx.py
+++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py
@@ -65,13 +65,14 @@ class MXNetGraph(object):
         self.output_tensors = []
 
     @staticmethod
-    def register(op_name):
+    def register(op_name, opset_version=11):
         """Register operators"""
         def wrapper(func):
             """Helper function to map functions"""
             try:
                 import onnx as _
-                MXNetGraph.registry_[op_name] = func
+                op_map = MXNetGraph.registry_.setdefault(opset_version, {})
+                op_map[op_name] = func
             except ImportError:
                 pass
             return func
@@ -81,10 +82,23 @@ class MXNetGraph(object):
     @staticmethod
     def convert_layer(node, **kwargs):
         """Convert MXNet layer to ONNX"""
+        try:
+            from onnx.defs import onnx_opset_version
+        except ImportError:
+            raise ImportError("Onnx and protobuf need to be installed. "
+                              + "Instructions to install - https://github.com/onnx/onnx")
+
         op = str(node["op"])
-        if op not in MXNetGraph.registry_:
-            raise AttributeError("No conversion function registered for op type %s yet." % op)
-        convert_func = MXNetGraph.registry_[op]
+        opset_version = kwargs.get("opset_version", onnx_opset_version())
+        # fallback to older opset versions if op is not registered in current version
+        for op_version in range(opset_version, 10, -1):
+            if op_version not in MXNetGraph.registry_ or op not in MXNetGraph.registry_[op_version]:
+                if opset_version == 11:
+                    raise AttributeError("No conversion function registered for op type %s yet." % op)
+                continue
+            convert_func = MXNetGraph.registry_[op_version][op]
+            break
+
         ret = convert_func(node, **kwargs)
         # in case the conversion function does not specify the returned dtype, we just return None
         # as the second value
diff --git a/python/mxnet/onnx/mx2onnx/_op_translations.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset11.py
similarity index 99%
rename from python/mxnet/onnx/mx2onnx/_op_translations.py
rename to python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset11.py
index ef65fa2..d683aad 100644
--- a/python/mxnet/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset11.py
@@ -56,7 +56,7 @@ Add new functions here with a decorator.
 import re
 import logging
 import numpy as np
-from ._export_onnx import MXNetGraph as mx_op
+from .._export_onnx import MXNetGraph as mx_op
 try:
     import onnx
 except ImportError:
@@ -1856,9 +1856,6 @@ def convert_slice_channel(node, **kwargs):
     axis = int(attrs.get('axis', 1))
     squeeze_axis = attrs.get('squeeze_axis', 'False')
 
-    create_tensor([axis], name+'_axis', kwargs['initializer'])
-    create_tensor([axis+1], name+'axis_p1', kwargs['initializer'])
-
     nodes = []
     if squeeze_axis in ['True', '1']:
         nodes += [
@@ -4502,7 +4499,7 @@ def convert_RNN(node, **kwargs):
                 make_node('Squeeze', [name+'0_'], [name], axes=[1]),
             ]
         else:
-            raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1')
+            raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
 
     else:
         raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py
new file mode 100644
index 0000000..d0176d1
--- /dev/null
+++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py
@@ -0,0 +1,1302 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Based on
+#  https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/
+# mx2onnx_converter_functions.py
+#  Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions
+#  are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+#  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+#  PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+#  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+#  EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+#  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+#  PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+#  OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# coding: utf-8
+# pylint: disable=too-many-locals,no-else-return,too-many-lines
+# pylint: disable=anomalous-backslash-in-string,eval-used
+# pylint: disable=too-many-function-args
+"""
+Conversion Functions for common layers.
+Add new functions here with a decorator.
+"""
+
+import re
+import logging
+import numpy as np
+from .._export_onnx import MXNetGraph as mx_op
+try:
+    import onnx
+except ImportError:
+    onnx = None
+
+OPSET_VERSION = 13
+
+def parse_helper(attrs, attrs_name, alt_value=None):
+    """Helper function to parse operator attributes in required format."""
+    tuple_re = re.compile(r'\([0-9L|,| ]+\)')
+    if not attrs:
+        return alt_value
+    attrs_str = None if attrs.get(attrs_name) is None else str(attrs.get(attrs_name))
+    if attrs_str is None:
+        return alt_value
+    attrs_match = tuple_re.search(attrs_str)
+    if attrs_match is not None:
+        if attrs_match.span() == (0, len(attrs_str)):
+            dims = eval(attrs_str)
+            return dims
+        else:
+            raise AttributeError("Malformed %s dimensions: %s" % (attrs_name, str(attrs_str)))
+    return alt_value
+
+def transform_padding(pad_width):
+    """Helper function to convert padding format for pad operator.
+    """
+    num_pad_values = len(pad_width)
+    onnx_pad_width = [0]*num_pad_values
+
+    start_index = 0
+    # num_pad_values will always be multiple of 2
+    end_index = int(num_pad_values/2)
+    for idx in range(0, num_pad_values):
+        if idx % 2 == 0:
+            onnx_pad_width[start_index] = pad_width[idx]
+            start_index += 1
+        else:
+            onnx_pad_width[end_index] = pad_width[idx]
+            end_index += 1
+
+    return onnx_pad_width
+
+
+def convert_string_to_list(string_val):
+    """Helper function to convert string to list.
+     Used to convert shape attribute string to list format.
+    """
+    result_list = []
+
+    list_string = string_val.split(',')
+    for val in list_string:
+        val = str(val.strip())
+        val = val.replace("(", "")
+        val = val.replace(")", "")
+        val = val.replace("L", "")
+        val = val.replace("[", "")
+        val = val.replace("]", "")
+        if val == "None":
+            result_list.append(None)
+        elif val != "":
+            result_list.append(int(val))
+
+    return result_list
+
+def get_boolean_attribute_value(attrs, attr_name):
+    """ Helper function to convert a string version
+    of Boolean attributes to integer for ONNX.
+    Takes attribute dictionary and attr_name as
+    parameters.
+    """
+    return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0
+
+def get_inputs(node, kwargs):
+    """Helper function to get inputs"""
+    name = node["name"]
+    outputs_lookup = kwargs["outputs_lookup"]
+    inputs = node["inputs"]
+    attrs = node.get("attrs", {})
+
+    input_nodes = []
+    for ip in inputs:
+        input_node_name = outputs_lookup[ip[0]][ip[1]].name
+        input_nodes.append(input_node_name)
+
+    return name, input_nodes, attrs
+
+def get_input_dtypes(node, kwargs):
+    outputs_lookup = kwargs['outputs_lookup']
+    inputs = node['inputs']
+    input_dtypes = []
+    for ip in inputs:
+        input_node_dtype = outputs_lookup[ip[0]][ip[1]].dtype
+        input_dtypes.append(input_node_dtype)
+    return input_dtypes
+
+def create_basic_op_node(op_name, node, kwargs):
+    """Helper function to create a basic operator
+    node that doesn't contain op specific attrs"""
+    name, input_nodes, _ = get_inputs(node, kwargs)
+
+    node = onnx.helper.make_node(
+        op_name,
+        input_nodes,
+        [name],
+        name=name
+    )
+    return [node]
+
+def create_const_scalar_node(input_name, value, kwargs):
+    """Helper function to create a tensor value node and a
+    initializer tensor node with constant value."""
+    from onnx.helper import make_tensor
+    initializer = kwargs["initializer"]
+    input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype]
+    tensor_node = make_tensor(input_name, input_type, (), ([value]))
+    initializer.append(tensor_node)
+
+def create_const_node(input_name, value, kwargs):
+    """Helper function to create a tensor value node and a
+    initializer tensor node with constant value."""
+    from onnx.helper import make_tensor
+    initializer = kwargs["initializer"]
+    input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype]
+    input_shape = value.shape
+    tensor_node = make_tensor(input_name, input_type, input_shape, value)
+    initializer.append(tensor_node)
+
+def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'):
+    """Helper function to create a tensor value node and a
+    initializer tensor node with constant value."""
+    tensor_np = np.array(tensor_list, dtype=dtype)
+    data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[tensor_np.dtype]
+    dims = np.shape(tensor_np)
+    if dtype == np.float16:
+        tensor_np = tensor_np.view(dtype=np.uint16)
+    initializer.append(
+        onnx.helper.make_tensor(
+            name=tensor_name,
+            data_type=data_type,
+            dims=dims,
+            vals=tensor_np.flatten().tolist(),
+            raw=False
+        )
+    )
+
+
+def create_helper_trans_node(node_name, input_node):
+    """create extra transpose node for dot operator"""
+    trans_node = onnx.helper.make_node(
+        'Transpose',
+        inputs=[input_node],
+        outputs=[node_name],
+        name=node_name
+    )
+    return trans_node
+
+
+def scalar_op_helper(node, op_name, **kwargs):
+    """Helper function for scalar arithmetic operations"""
+    from onnx import numpy_helper
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    scalar_value = np.array([attrs.get("scalar", 1)],
+                            dtype=dtype)
+    initializer = kwargs["initializer"]
+    flag = True
+    # If the input value is in initializer, just multiply with scalar input
+    # and create a new initializer
+    for i in initializer:
+        if i.name == input_nodes[0]:
+            if op_name == 'Mul':
+                new_initializer = numpy_helper.to_array(i) * scalar_value[0]
+            elif op_name == 'Sub':
+                if name.startswith("_rminusscalar"):
+                    new_initializer = scalar_value[0] - numpy_helper.to_array(i)
+                else:
+                    new_initializer = numpy_helper.to_array(i) - scalar_value[0]
+            elif op_name == 'Add':
+                new_initializer = numpy_helper.to_array(i) + scalar_value[0]
+            elif op_name == 'Div':
+                if name.startswith("_rdivscalar"):
+                    new_initializer = scalar_value[0] / numpy_helper.to_array(i)
+                else:
+                    new_initializer = numpy_helper.to_array(i) / scalar_value[0]
+            elif op_name == 'Pow':
+                new_initializer = numpy_helper.to_array(i) ** scalar_value[0]
+            flag = False
+            break
+
+    # else create a new tensor of the scalar value, add it in initializer
+    if flag is True:
+        dims = np.shape(scalar_value)
+
+        scalar_op_name = "scalar_op" + str(kwargs["idx"])
+        tensor_node = onnx.helper.make_tensor_value_info(scalar_op_name, dtype_t, dims)
+
+        initializer.append(
+            onnx.helper.make_tensor(
+                name=scalar_op_name,
+                data_type=dtype_t,
+                dims=dims,
+                vals=scalar_value,
+                raw=False,
+            )
+        )
+
+        mul_node = onnx.helper.make_node(
+            op_name,
+            [input_nodes[0], scalar_op_name],
+            [name],
+            name=name
+        )
+
+        return [tensor_node, mul_node]
+    else:
+        dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[new_initializer.dtype]
+        dims = np.shape(new_initializer)
+
+        tensor_node = onnx.helper.make_tensor_value_info(name, dtype_t, dims)
+
+        initializer.append(
+            onnx.helper.make_tensor(
+                name=name,
+                data_type=dtype_t,
+                dims=dims,
+                vals=new_initializer.flatten(),
+                raw=False,
+            )
+        )
+        return [tensor_node]
+
+
+    return create_basic_op_node('Shape', node, kwargs)
+
+
+@mx_op.register("_contrib_arange_like", OPSET_VERSION)
+def convert_arange_like(node, **kwargs):
+    """Map MXNet's arange_like operator attributes to onnx's Range and Reshape operators.
+    """
+    from onnx.helper import make_node
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    opset_version = kwargs['opset_version']
+    if opset_version < 11:
+        raise AttributeError("ONNX opset 11 or greater is required to export this operator")
+    # use the same dtype as the that of the input node
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+    axis = attrs.get('axis', 'None')
+    start = attrs.get('start', 0.)
+    step = attrs.get('step', 1.)
+    repeat = int(attrs.get('repeat', 1))
+    if repeat != 1:
+        raise NotImplementedError("arange_like operator with repeat != 1 not yet implemented.")
+
+    create_const_scalar_node(name+"_start", np.dtype(dtype).type(start), kwargs)
+    create_const_scalar_node(name+"_step", np.dtype(dtype).type(step), kwargs)
+    create_const_scalar_node(name+"_half_step", np.dtype(dtype).type(float(step)*0.5), kwargs)
+    create_tensor([0], name+"_0", kwargs["initializer"], dtype='int64')
+    nodes = []
+    if axis == 'None':
+        # output will be same shape as input
+        nodes += [
+            make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]),
+            make_node("ReduceProd", [name+"_shape0_out"], [name+"_redprod0_out"]),
+            make_node("Squeeze", [name+"_redprod0_out", name+"_0"], [name+'_reshape0_out']),
+            make_node("Cast", [name+"_reshape0_out"], [name+"_cast0_out"], to=dtype_t),
+            make_node("Mul", [name+"_cast0_out", name+"_step"], [name+"_mul0_out"]),
+            make_node("Add", [name+"_mul0_out", name+"_start"], [name+"_add1_out"]),
+            make_node("Sub", [name+"_add1_out", name+"_half_step"], [name+"_sub0_out"]),
+            make_node("Range", [name+"_start", name+"_sub0_out", name+"_step"], [name+"_range0_out"]),
+            make_node("Reshape", [name+"_range0_out", name+"_shape0_out"], [name], name=name)
+        ]
+    else:
+        # determine shape of axis
+        create_tensor([int(axis)], name+"_axis_start", kwargs["initializer"], dtype='int64')
+        create_tensor([int(axis)+1], name+"_axis_end", kwargs["initializer"], dtype='int64')
+        nodes += [
+            make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]),
+            make_node("Slice", [name+"_shape0_out", name+"_axis_start", name+"_axis_end"], [name+"_slice0_out"]),
+            make_node("ReduceProd", [name+"_slice0_out"], [name+"_reprod0_out"]),
+            make_node("Squeeze", [name+"_reprod0_out", name+"_0"], [name+"_reshape0_out"]),
+            make_node("Cast", [name+"_reshape0_out"], [name+"_cast0_out"], to=dtype_t),
+            make_node("Mul", [name+"_cast0_out", name+"_step"], [name+"_mul0_out"]),
+            make_node("Add", [name+"_mul0_out", name+"_start"], [name+"_add1_out"]),
+            make_node("Sub", [name+"_add1_out", name+"_half_step"], [name+"_sub0_out"]),
+            make_node("Range", [name+"_start", name+"_sub0_out", name+"_step"], [name], name=name)
+        ]
+
+    return nodes
+
+
+@mx_op.register("LayerNorm", OPSET_VERSION)
+def convert_layer_norm(node, **kwargs):
+    """Map MXNet's LayerNorm operator attributes to onnx operators.
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+
+    axes = int(attrs.get('axis', -1))
+    eps = attrs.get('eps', 9.99999975e-06)
+
+    create_tensor([axes], name+"_axes", kwargs["initializer"])
+    create_tensor([axes+1], name+"_axes+1", kwargs["initializer"])
+    create_tensor([0], name+"_0", kwargs["initializer"], dtype='int64')
+    create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
+    create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
+    create_const_scalar_node(name+"_2_s", np.int64(2).astype(dtype), kwargs)
+    create_const_scalar_node(name+"_eps", np.float32(eps), kwargs)
+
+    nodes = [
+        make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=[axes]),
+        make_node("Sub", [input_nodes[0], name+"_rm0_out"], [name+"_sub0_out"]),
+        make_node("Pow", [name+"_sub0_out", name+"_2_s"], [name+"_pow0_out"]),
+        make_node("ReduceMean", [name+"_pow0_out"], [name+"_rm1_out"], axes=[axes]),
+        make_node("Add", [name+"_rm1_out", name+"_eps"], [name+"_add0_out"]),
+        make_node("Sqrt", [name+"_add0_out"], [name+"_sqrt0_out"]),
+        make_node("Div", [name+"_sub0_out", name+"_sqrt0_out"], [name+"_div0_out"]),
+    ]
+
+    if axes == -1:
+        nodes += [
+            make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]),
+            # make_node("Add", [name+"_mul0_out", input_nodes[2]], [name])
+            # the Add operator triggers a weird NaN issue in onnxruntime
+            # a workaround is to use Neg + Sub
+            make_node('Neg', [input_nodes[2]], [name+'_neg']),
+            make_node("Sub", [name+"_mul0_out", name+'_neg'], [name])
+        ]
+    else:
+        nodes += [
+            make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]),
+            make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]),
+            make_node("Squeeze", [name+"_in_dim", name+"_0"], [name+"_in_dim_s"]),
+            make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range"]),
+            make_node("Equal", [name+"_range", name+"_axes"], [name+"_equal"]),
+            make_node("Cast", [name+"_equal"], [name+"_one_hot"], to=int(TensorProto.INT64)),
+            make_node("Slice", [name+"_shape0_out", name+"_axes", name+"_axes+1"], [name+"_slice_out"]),
+            make_node("Squeeze", [name+"_slice_out", name+"_0"], [name+"_slice_out_s"]),
+            make_node("Sub", [name+"_slice_out_s", name+"_1_s"], [name+"_sub1_out"]),
+            make_node("Mul", [name+"_one_hot", name+"_sub1_out"], [name+"_mul0_out"]),
+            make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add1_out"]),
+            make_node('Reshape', [input_nodes[1], name+"_add1_out"], [name+"gamma_exp"]),
+            make_node('Reshape', [input_nodes[2], name+"_add1_out"], [name+"beta_exp"]),
+            make_node('Expand', [name+"gamma_exp", name+"_shape0_out"], [name+"gamma_exp1"]),
+            make_node('Expand', [name+"beta_exp", name+"_shape0_out"], [name+"beta_exp1"]),
+            make_node("Mul", [name+"_div0_out", name+"gamma_exp1"], [name+"_mul1_out"]),
+            make_node("Add", [name+"_mul1_out", name+"beta_exp1"], [name], name=name)
+        ]
+
+    return nodes
+
+
+@mx_op.register("broadcast_axis", OPSET_VERSION)
+def convert_broadcast_axis(node, **kwargs):
+    """Map MXNet's broadcast_axis
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    axis = convert_string_to_list(attrs.get('axis', '()'))
+    size = convert_string_to_list(attrs.get('size', '()'))
+    assert len(axis) == len(size)
+
+    shape_name = name+'_shape_0'
+
+    create_tensor([0], name+'_0', kwargs["initializer"])
+    create_tensor([1], name+'_1', kwargs["initializer"])
+    create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
+    create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
+
+    nodes = [
+        make_node('Shape', [input_nodes[0]], [shape_name]),
+        make_node('Shape', [shape_name], [name+'_in_dim']),
+        make_node('Squeeze', [name+'_in_dim', name+'_0'], [name+'_in_dim_s']),
+        make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], [name+'_range']),
+    ]
+
+    for i, axis in enumerate(axis):
+        if axis not in (0, 1):
+            create_tensor([axis], name+'_'+str(axis), kwargs["initializer"])
+        create_tensor([size[i]-1], name+'_size_'+str(i), kwargs["initializer"])
+        nodes += [
+            make_node('Equal', [name+'_range', name+'_'+str(axis)], [name+'_equal_'+str(i)]),
+            make_node('Cast', [name+'_equal_'+str(i)], [name+'_cast_'+str(i)], to=int(TensorProto.INT64)),
+            make_node('Mul', [name+'_size_'+str(i), name+'_cast_'+str(i)], [name+'_mul_'+str(i)]),
+            make_node('Add', [name+'_mul_'+str(i), name+'_1'], [name+'_add_'+str(i)]),
+            make_node('Mul', [name+'_add_'+str(i), shape_name], [name+'_shape_'+str(i+1)])
+        ]
+        shape_name = name+'_shape_'+str(i+1)
+
+    nodes += [
+        make_node('Expand', [input_nodes[0], shape_name], [name], name=name)
+    ]
+
+    return nodes
+
+
+@mx_op.register("SequenceMask", OPSET_VERSION)
+def convert_sequencemask(node, **kwargs):
+    """Map MXNet's SequenceMask operator
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    use_sequence_length = attrs.get('use_sequence_length', 'False')
+    mask_val = float(attrs.get('value', '0'))
+    axis = int(attrs.get('axis', '0'))
+
+    if(use_sequence_length == 'False'):
+        return [make_node('Identity', [input_nodes[0]], [name], name=name)]
+
+    create_tensor([0], name+'_0', kwargs["initializer"])
+    create_tensor([1], name+'_1', kwargs["initializer"])
+    create_tensor([2], name+'_2', kwargs["initializer"])
+    create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
+    create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
+    create_const_scalar_node(name+'_2_s', np.int64(2), kwargs)
+    create_tensor([mask_val], name+'_mask_val', kwargs["initializer"], dtype='float32')
+
+    nodes = [
+        make_node('Shape', [input_nodes[0]], [name+'_in_shape']),
+        make_node('Slice', [name+'_in_shape', name+'_0', name+'_1'], [name+'_slice_0']),
+        make_node('Slice', [name+'_in_shape', name+'_1', name+'_2'], [name+'_slice_1']),
+        make_node('Concat', [name+'_slice_0', name+'_1'], [name+'_shape_0'], axis=0),
+        make_node('Shape', [name+'_in_shape'], [name+'_in_dim']),
+        make_node('Squeeze', [name+'_in_dim', name+'_0'], [name+'_in_dim_s']),
+        make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], [name+'_range_0']),
+        make_node('Less', [name+'_range_0', name+'_2'], [name+'_less_0']),
+        make_node('Where', [name+'_less_0', name+'_in_shape', name+'_1'], [name+'_shape_1'])
+    ]
+
+    if(axis == 0):
+        nodes += [
+            make_node('Squeeze', [name+'_slice_0', name+'_0'], [name+'_max_len']),
+            make_node('Range', [name+'_0_s', name+'_max_len', name+'_1_s'], [name+'_range_1']),
+            make_node('Reshape', [name+'_range_1', name+'_shape_0'], [name+"_reshape_0"]),
+            make_node('Cast', [input_nodes[1]], [name+'_cast'], to=int(TensorProto.INT64)),
+            make_node('Less', [name+'_reshape_0', name+'_cast'], [name+'_less_1']),
+            make_node('Reshape', [name+'_less_1', name+'_shape_1'], [name+"_reshape_1"]),
+            make_node('Where', [name+'_reshape_1', input_nodes[0], name+'_mask_val'], [name], name=name),
+        ]
+    else:
+        nodes += [
+            make_node('Squeeze', [name+'_slice_1', name+'_0'], [name+'_max_len']),
+            make_node('Range', [name+'_0_s', name+'_max_len', name+'_1_s'], [name+'_range_1']),
+            make_node('Reshape', [input_nodes[1], name+'_shape_0'], [name+"_reshape_0"]),
+            make_node('Cast', [name+"_reshape_0"], [name+'_cast'], to=int(TensorProto.INT64)),
+            make_node('Less', [name+'_range_1', name+'_cast'], [name+'_less_1']),
+            make_node('Reshape', [name+'_less_1', name+'_shape_1'], [name+"_reshape_1"]),
+            make_node('Where', [name+'_reshape_1', input_nodes[0], name+'_mask_val'], [name], name=name),
+        ]
+    return nodes
+
+
+@mx_op.register("expand_dims", OPSET_VERSION)
+def convert_expand_dims(node, **kwargs):
+    """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator
+    and return the created node.
+    """
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    axis = int(attrs.get("axis"))
+    create_tensor([axis], name+"_axis", kwargs["initializer"])
+    input_nodes.append(name+"_axis")
+    node = onnx.helper.make_node(
+        "Unsqueeze",
+        input_nodes,
+        [name],
+        name=name,
+    )
+    return [node]
+
+
+@mx_op.register("stack", OPSET_VERSION)
+def convert_stack(node, **kwargs):
+    """Map MXNet's stack operator to onnx operators.
+    """
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+    axis = int(attrs.get("axis", 0))
+    create_tensor([axis], name+"_axis", kwargs["initializer"])
+    idx = 0
+    nodes = []
+    for input_node in input_nodes:
+        nodes.append(onnx.helper.make_node(
+            "Unsqueeze",
+            inputs=[input_node, name+"_axis"],
+            outputs=[name+"_unsqueeze"+str(idx)]
+        ))
+        idx += 1
+
+    nodes.append(onnx.helper.make_node(
+        "Concat",
+        inputs=[name+"_unsqueeze"+str(i) for i in range(len(nodes))],
+        outputs=[name],
+        name=name,
+        axis=axis
+    ))
+    return nodes
+
+
+@mx_op.register("softmax", OPSET_VERSION)
+def convert_softmax(node, **kwargs):
+    """Map MXNet's softmax operator attributes to onnx's Softmax operator
+    and return the created node.
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    axis = int(attrs.get("axis", -1))
+    temperature = str(attrs.get("temperature", 'None'))
+    if temperature == 'None':
+        temperature = 1.
+    else:
+        temperature = float(temperature)
+
+    use_length = str(attrs.get("use_length", 'None'))
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+    data = input_nodes[0]
+
+    create_tensor([0], name+"_0", kwargs["initializer"])
+    if axis == -1 and temperature == 1.:
+        nodes = []
+        if use_length == "True":
+            # magic number, this is fp16 min
+            create_tensor([-65500.0], name+"_mask_val", kwargs["initializer"], dtype=dtype)
+            create_tensor([1], name+"_1", kwargs["initializer"])
+            create_tensor([-1], name+"_-1", kwargs["initializer"])
+            create_const_scalar_node(name+"_0_s", np.int64(0), kwargs)
+            create_const_scalar_node(name+"_1_s", np.int64(1), kwargs)
+            nodes += [
+                make_node("Shape", [data], [name+"_shape"]),
+                make_node("Shape", [name+"_shape"], [name+"_dim"]),
+                make_node("Sub", [name+"_dim", name+"_1"], [name+"_dim_m1"]),
+                make_node("Slice", [name+"_shape", name+"_dim_m1", name+"_dim"],
+                          [name+"_dim_last_"]),
+                make_node("Squeeze", [name+"_dim_last_", name+"_0"], [name+"_dim_last"]),
+                make_node("Range", [name+"_0_s", name+"_dim_last", name+"_1_s"], [name+"_range"]),
+                make_node("Cast", [input_nodes[1]], [name+"_len"], to=int(TensorProto.INT64)),
+                make_node("Unsqueeze", [name+"_len", name+"_-1"], [name+"_len_unsqueezed"]),
+                make_node("Less", [name+"_range", name+"_len_unsqueezed"], [name+"_less"]),
+                make_node("Where", [name+'_less', data, name+"_mask_val"], [name+"_data_masked"])
+            ]
+            data = name+"_data_masked"
+
+        nodes += [
+            make_node("Softmax", [data], [name], axis=-1)
+        ]
+
+        return nodes
+
+    create_tensor([axis], name+"_axes", kwargs["initializer"])
+    create_tensor([temperature], name+"_tmp", kwargs["initializer"], dtype=dtype)
+    nodes = [
+        make_node("Div", [data, name+"_tmp"], [name+'_data']),
+        make_node("Exp", [name+'_data'], [name+"_exp_out"]),
+        make_node("ReduceSum", [name+"_exp_out", name+"_axes"], [name+"_rsum_out"], keepdims=1)
+    ]
+    if len(input_nodes) == 1:
+        nodes += [
+            make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name)
+        ]
+        return nodes
+    elif use_length == "True":
+        length = input_nodes[1]
+        create_tensor([1], name+"_1", kwargs["initializer"])
+        create_const_scalar_node(name+'_-1_s', np.int64(-1), kwargs)
+        create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
+        create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
+        nodes += [
+            # cast data type
+            make_node("Cast", [length], [name+"_length"], to=int(TensorProto.INT64)),
+            make_node("Cast", [name+"_0"], [name+"_0_itype"], to=dtype_t),
+            make_node("Cast", [name+"_1"], [name+"_1_itype"], to=dtype_t),
+            # softmax output
+            make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]),
+            # update axis
+            make_node("Shape", [data], [name+"_shape0_out"]),
+            make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]),
+            make_node("Add", [name+"_in_dim", name+"_axes"], [name+"_dim+axis"]),
+            make_node("Less", [name+"_axes", name+"_0_s"], [name+"_less0_out"]),
+            make_node("Where", [name+"_less0_out", name+"_dim+axis", name+"_axes"], [name+"_final_axis"]),
+            # data mask
+            make_node("Add", [name+"_final_axis", name+"_1_s"], [name+"_final_axis+1"]),
+            make_node("Slice", [name+"_shape0_out", name+"_final_axis", name+"_final_axis+1"], [name+"_axis_dim"]),
+            make_node("Squeeze", [name+"_axis_dim", name+"_0"], [name+"_axis_dim_s"]),
+            make_node("Range", [name+"_0_s", name+"_axis_dim_s", name+"_1_s"], [name+"_range0_out"]),
+            # one hot for axis
+            make_node("Squeeze", [name+"_in_dim", name+"_0"], [name+"_in_dim_s"]),
+            make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range1_out"]),
+            make_node("Equal", [name+"_range1_out", name+"_final_axis"], [name+"_equal_out"]),
+            make_node("Cast", [name+"_equal_out"], [name+"_one_hot"], to=int(TensorProto.INT64)),
+            # reshape data mask for less
+            make_node("Sub", [name+"_axis_dim_s", name+"_1_s"], [name+"_sub0_out"]),
+            make_node("Mul", [name+"_one_hot", name+"_sub0_out"], [name+"_mul0_out"]),
+            make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add0_out"]),
+            make_node('Reshape', [name+"_range0_out", name+"_add0_out"], [name+"_reshape0_out"]),
+            # reshape length for less
+            make_node("Mul", [name+"_one_hot", name+"_-1_s"], [name+"_mul1_out"]),
+            make_node("Add", [name+"_mul1_out", name+"_1_s"], [name+"_add1_out"]),
+            make_node("Sub", [name+"_shape0_out", name+"_1_s"], [name+"_sub1_out"]),
+            make_node("Mul", [name+"_add1_out", name+"_sub1_out"], [name+"_mul2_out"]),
+            make_node("Add", [name+"_mul2_out", name+"_1_s"], [name+"_add2_out"]),
+            make_node('Reshape', [name+"_length", name+"_add2_out"], [name+"_reshape1_out"]),
+            # mask output
+            make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]),
+            make_node("Cast", [name+"_less_out"], [name+"_mask"], to=dtype_t),
+            make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]),
+            make_node("ReduceSum", [name+"_mul3_out", name+"_axes"], [name+"_rsum1_out"], keepdims=1),
+            make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]),
+            make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]),
+            make_node("Div", [name+"_mul3_out", name+"_where_out"], [name], name=name)
+        ]
+        return nodes
+
+    else:
+        raise NotImplementedError("use_length must be true when both data and length are paased in.")
+
+
+@mx_op.register("reverse", OPSET_VERSION)
+def convert_reverse(node, **kwargs):
+    """Map MXNet's reverse operator attributes to ONNX
+    """
+    from onnx.helper import make_node
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    axis = int(attrs.get('axis', 0))
+
+    # Transpose takes perm as a parameter, so we must 'pad' the input to a known dim (10 here)
+    perm = [i for i in range(10)]
+    perm[0], perm[axis] = axis, 0
+
+    create_tensor([10], name+'_10', kwargs['initializer'])
+    create_tensor([0], name+'_0', kwargs['initializer'])
+    create_tensor([1], name+'_1', kwargs['initializer'])
+    create_tensor([-1], name+'_m1', kwargs['initializer'])
+    create_tensor([axis], name+'_axis', kwargs['initializer'])
+    create_tensor([axis+1], name+'_axis_p1', kwargs['initializer'])
+    create_const_scalar_node(name+'_m1_s', np.int64(-1), kwargs)
+
+    nodes = [
+        make_node('Shape', [input_nodes[0]], [name+'_shape']),
+        make_node('Shape', [name+'_shape'], [name+'_dim']),
+        make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']),
+        make_node('Concat', [name+'_0', name+'_sub'], [name+'_concat'], axis=0),
+        make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_10_dim']),
+        make_node('Reshape', [input_nodes[0], name+'_shape_10_dim'], [name+'_data_10_dim']),
+        make_node('Transpose', [name+'_data_10_dim'], [name+'_data_t'], perm=perm),
+        make_node('Slice', [name+'_shape', name+'_axis', name+'_axis_p1'], [name+'_axis_len']),
+        make_node('Sub', [name+'_axis_len', name+'_1'], [name+'_axis_len_m1']),
+        make_node('Squeeze', [name+'_axis_len_m1', name+'_0'], [name+'_axis_len_m1_s']),
+        make_node('Range', [name+'_axis_len_m1_s', name+'_m1_s', name+'_m1_s'], [name+'_indices']),
+        make_node('Gather', [name+'_data_t', name+'_indices'], [name+'_gather']),
+        make_node('Transpose', [name+'_gather'], [name+'_data_reversed'], perm=perm),
+        make_node('Reshape', [name+'_data_reversed', name+'_shape'], [name], name=name)
+    ]
+
+    return nodes
+
+
+@mx_op.register('repeat', OPSET_VERSION)
+def convert_repeat(node, **kwargs):
+    """Map MXNet's repeat operator attributes to onnx's Tile operator.
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    opset_version = kwargs['opset_version']
+    if opset_version < 11:
+        raise AttributeError('ONNX opset 11 or greater is required to export this operator')
+
+    repeats = int(attrs.get('repeats', 1))
+    axis = attrs.get('axis', 'None')
+
+    if repeats <= 0:
+        raise NotImplementedError('repeat operator does not support parameter repeats==0')
+
+    nodes = []
+    if axis == 'None':
+        create_tensor([-1], name+'_-1', kwargs['initializer'])
+        create_tensor([repeats], name+'_rep', kwargs['initializer'])
+        create_tensor([1, repeats], name+'_repeats', kwargs['initializer'])
+        nodes += [
+            make_node('Shape', [input_nodes[0]], [name+'_shape']),
+            make_node('ReduceProd', [name+'_shape'], [name+'_size']),
+            make_node('Reshape', [input_nodes[0], name+'_size'], [name+'_flat']),
+            make_node('Unsqueeze', [name+'_flat', name+'_-1'], [name+'_unsqueeze']),
+            make_node('Tile', [name+'_unsqueeze', name+'_repeats'], [name+'_tile']),
+            make_node('Mul', [name+'_size', name+'_rep'], [name+'_new_size']),
+            make_node('Reshape', [name+'_tile', name+'_new_size'], [name], name=name)
+        ]
+    else:
+        axis = int(axis)
+        repeats -= 1
+        create_tensor([repeats], name+'_repeats', kwargs['initializer'])
+        create_tensor([1], name+'_1', kwargs['initializer'])
+        create_tensor([0], name+'_0', kwargs['initializer'])
+        create_tensor([axis], name+'_axis', kwargs['initializer'])
+        create_const_scalar_node(name+"_0_s", np.int64(0), kwargs)
+        create_const_scalar_node(name+"_1_s", np.int64(1), kwargs)
+        nodes += [
+            make_node('Shape', [input_nodes[0]], [name+'_shape']),
+            make_node('Shape', [name+'_shape'], [name+'_dim']),
+            make_node('Squeeze', [name+'_dim', name+'_0'], [name+'_dim_s']),
+            make_node('Range', [name+'_0_s', name+'_dim_s', name+'_1_s'], [name+'_range'])
+        ]
+        if axis < 0:
+            nodes += [
+                make_node('Add', [name+'_axis', name+'_dim'], [name+'_true_axis']),
+                make_node('Equal', [name+'_range', name+'_true_axis'], [name+'_one_hot'])
+                ]
+        else:
+            nodes += [
+                make_node('Equal', [name+'_range', name+'_axis'], [name+'_one_hot'])
+                ]
+        nodes += [
+            make_node('Cast', [name+'_one_hot'], [name+'_one_hot_int'], to=int(TensorProto.INT64)),
+            make_node('Mul', [name+'_repeats', name+'_one_hot_int'], [name+'_mul']),
+            make_node('Add', [name+'_mul', name+'_1'], [name+'_add']),
+            make_node('Concat', [name+'_1', name+'_add'], [name+'_repeats_tensor'], axis=0)
+            ]
+        if axis == -1:
+            nodes += [
+                make_node('Concat', [name+'_shape', name+'_1'], [name+'_unsqueeze_shape'], axis=0),
+                make_node('Reshape', [input_nodes[0], name+'_unsqueeze_shape'],
+                          [name+'_unsqueeze'])
+                ]
+        else:
+            create_tensor([axis+1], name+'_axis+1', kwargs['initializer'])
+            nodes += [
+                make_node('Unsqueeze', [input_nodes[0], name+'_axis+1'], [name+'_unsqueeze'])
+                ]
+        nodes += [
+            make_node('Tile', [name+'_unsqueeze', name+'_repeats_tensor'], [name+'_tile']),
+            make_node('Mul', [name+'_shape', name+'_add'], [name+'_new_shape']),
+            make_node('Reshape', [name+'_tile', name+'_new_shape'], [name], name=name)
+            ]
+
+    return nodes
+
+
+@mx_op.register('_contrib_box_nms', OPSET_VERSION)
+def convert_contrib_box_nms(node, **kwargs):
+    """Map MXNet's _contrib_box_nms operator to ONNX
+    """
+    from onnx.helper import make_node
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    #dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    opset_version = kwargs['opset_version']
+    if opset_version < 11:
+        raise AttributeError('ONNX opset 11 or greater is required to export this operator')
+
+    overlap_thresh = float(attrs.get('overlap_thresh', '0.5'))
+    valid_thresh = float(attrs.get('valid_thresh', '0'))
+    topk = int(attrs.get('topk', '-1'))
+    coord_start = int(attrs.get('coord_start', '2'))
+    score_index = int(attrs.get('score_index', '1'))
+    id_index = int(attrs.get('id_index', '-1'))
+    force_suppress = attrs.get('force_suppress', 'True')
+    background_id = int(attrs.get('background_id', '-1'))
+    in_format = attrs.get('in_format', 'corner')
+    out_format = attrs.get('out_format', 'corner')
+
+    center_point_box = 0 if in_format == 'corner' else 1
+
+    if in_format != out_format:
+        raise NotImplementedError('box_nms does not currently support in_fomat != out_format')
+
+    if background_id != -1:
+        raise NotImplementedError('box_nms does not currently support background_id != -1')
+
+    if id_index != -1 or force_suppress == 'False':
+        logging.warning('box_nms: id_idex != -1 or/and force_suppress == False detected. '
+                        'However, due to ONNX limitations, boxes of different categories will NOT '
+                        'be exempted from suppression. This might lead to different behavior than '
+                        'native MXNet')
+
+    create_tensor([coord_start], name+'_cs', kwargs['initializer'])
+    create_tensor([coord_start+4], name+'_cs_p4', kwargs['initializer'])
+    create_tensor([score_index], name+'_si', kwargs['initializer'])
+    create_tensor([score_index+1], name+'_si_p1', kwargs['initializer'])
+    create_tensor([topk], name+'_topk', kwargs['initializer'])
+    create_tensor([overlap_thresh], name+'_ot', kwargs['initializer'], dtype=np.float32)
+    create_tensor([valid_thresh], name+'_vt', kwargs['initializer'], dtype=np.float32)
+    create_tensor([-1], name+'_m1', kwargs['initializer'])
+    create_tensor([-1], name+'_m1_f', kwargs['initializer'], dtype=dtype)
+    create_tensor([0], name+'_0', kwargs['initializer'])
+    create_tensor([1], name+'_1', kwargs['initializer'])
+    create_tensor([2], name+'_2', kwargs['initializer'])
+    create_tensor([3], name+'_3', kwargs['initializer'])
+    create_tensor([0, 1, -1], name+'_scores_shape', kwargs['initializer'])
+    create_tensor([0, 0, 1, 0], name+'_pad', kwargs['initializer'])
+    create_tensor([0, -1], name+'_bat_spat_helper', kwargs['initializer'])
+    create_const_scalar_node(name+"_0_s", np.int64(0), kwargs)
+    create_const_scalar_node(name+"_1_s", np.int64(1), kwargs)
+
+    nodes = [
+        make_node('Shape', [input_nodes[0]], [name+'_shape']),
+        make_node('Shape', [name+'_shape'], [name+'_dim']),
+        make_node('Sub', [name+'_dim', name+'_2'], [name+'_dim_m2']),
+        make_node('Slice', [name+'_shape', name+'_dim_m2', name+'_dim'], [name+'_shape_last2']),
+        make_node('Concat', [name+'_m1', name+'_shape_last2'], [name+'_shape_3d'], axis=0),
+        make_node('Reshape', [input_nodes[0], name+'_shape_3d'], [name+'_data_3d']),
+        make_node('Slice', [name+'_data_3d', name+'_cs', name+'_cs_p4', name+'_m1'],
+                  [name+'_boxes']),
+        make_node('Slice', [name+'_data_3d', name+'_si', name+'_si_p1', name+'_m1'],
+                  [name+'_scores_raw']),
+        make_node('Reshape', [name+'_scores_raw', name+'_scores_shape'], [name+'_scores']),
+        make_node('Shape', [name+'_scores'], [name+'_scores_shape_actual']),
+        make_node('NonMaxSuppression',
+                  [name+'_boxes', name+'_scores', name+'_topk', name+'_ot', name+'_vt'],
+                  [name+'_nms'], center_point_box=center_point_box),
+        make_node('Slice', [name+'_nms', name+'_0', name+'_3', name+'_m1', name+'_2'],
+                  [name+'_nms_sliced']),
+        make_node('GatherND', [name+'_data_3d', name+'_nms_sliced'], [name+'_candidates']),
+        make_node('Pad', [name+'_candidates', name+'_pad', name+'_m1_f'], [name+'_cand_padded']),
+        make_node('Shape', [name+'_nms'], [name+'_nms_shape']),
+        make_node('Slice', [name+'_nms_shape', name+'_0', name+'_1'], [name+'_cand_cnt']),
+        make_node('Squeeze', [name+'_cand_cnt', name+'_0'], [name+'_cc_s']),
+        make_node('Range', [name+'_0_s', name+'_cc_s', name+'_1_s'], [name+'_cand_indices']),
+        make_node('Slice', [name+'_scores_shape_actual', name+'_0', name+'_3', name+'_m1',
+                            name+'_2'], [name+'_shape_bat_spat']),
+        make_node('Slice', [name+'_shape_bat_spat', name+'_1', name+'_2'], [name+'_spat_dim']),
+        make_node('Expand', [name+'_cand_cnt', name+'_shape_bat_spat'], [name+'_base_indices']),
+        make_node('ScatterND', [name+'_base_indices', name+'_nms_sliced', name+'_cand_indices'],
+                  [name+'_indices']),
+        make_node('TopK', [name+'_indices', name+'_spat_dim'], [name+'_indices_sorted', name+'__'],
+                  largest=0, axis=-1, sorted=1),
+        make_node('Gather', [name+'_cand_padded', name+'_indices_sorted'], [name+'_gather']),
+        make_node('Reshape', [name+'_gather', name+'_shape'], [name+'0'])
+    ]
+
+    return nodes
+
+
+@mx_op.register('_contrib_ROIAlign', OPSET_VERSION)
+def convert_contrib_roialign(node, **kwargs):
+    """Map MXNet's _contrib_ROIAlign
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    pooled_size = convert_string_to_list(str(attrs.get('pooled_size')))
+    spatial_scale = float(attrs.get('spatial_scale'))
+    sample_ratio = int(attrs.get('sample_ratio', '0'))
+    position_sensitive = attrs.get('position_sensitive', 'False')
+    aligned = attrs.get('aligned', 'False')
+
+    if position_sensitive != 'False':
+        raise NotImplementedError('_contrib_ROIAlign does not currently support \
+                                   position_sensitive!=False')
+    if aligned != 'False':
+        raise NotImplementedError('_contrib_ROIAlign does not currently support \
+                                   aligned!=False')
+
+    create_tensor([0], name+'_0', kwargs['initializer'])
+    create_tensor([1], name+'_1', kwargs['initializer'])
+    create_tensor([5], name+'_5', kwargs['initializer'])
+
+    nodes = [
+        make_node('Slice', [input_nodes[1], name+'_1', name+'_5', name+'_1'], [name+'_rois']),
+        make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds__']),
+        make_node('Squeeze', [name+'_inds__', name+'_1'], [name+'_inds_']),
+        make_node('Cast', [name+'_inds_'], [name+'_inds'], to=int(TensorProto.INT64)),
+        make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name],
+                  mode='avg', output_height=pooled_size[0], output_width=pooled_size[1],
+                  sampling_ratio=sample_ratio, spatial_scale=spatial_scale)
+    ]
+
+    return nodes
+
+
+@mx_op.register("sum", OPSET_VERSION)
+def convert_sum(node, **kwargs):
+    """Map MXNet's sum operator attributes to onnx's ReduceSum operator
+    and return the created node.
+    """
+    from onnx.helper import make_node
+
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    mx_axis = attrs.get("axis", None)
+    axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
+
+    keepdims = get_boolean_attribute_value(attrs, "keepdims")
+
+    if axes:
+        create_tensor(axes, name+'_axes', kwargs['initializer'])
+        input_nodes.append(name+'_axes')
+        node = make_node(
+            'ReduceSum',
+            inputs=input_nodes,
+            outputs=[name],
+            keepdims=keepdims,
+            name=name
+        )
+        return [node]
+    else:
+        create_tensor([1], name+'_1', kwargs['initializer'])
+        nodes = [
+            onnx.helper.make_node(
+                'ReduceSum',
+                inputs=input_nodes,
+                outputs=[name+'_sum'],
+                keepdims=keepdims,
+            ),
+            make_node('Reshape', [name+'_sum', name+'_1'], [name], name=name),
+        ]
+    return nodes
+
+
+@mx_op.register("RNN", OPSET_VERSION)
+def convert_RNN(node, **kwargs):
+    """Map MXNet's RNN operator attributes to onnx's operators
+    and return the created node.
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    bidirectional = str(attrs.get('bidirectional', 'False'))
+    if bidirectional != 'False':
+        raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')
+
+    num_layers = int(attrs.get('num_layers', '1'))
+
+    p = float(attrs.get('p', '0'))
+    if p != 0:
+        raise NotImplementedError('Currently RNN onnx export only supports p equals to 0')
+
+    use_sequence_length = str(attrs.get('use_sequence_length', 'False'))
+    if use_sequence_length != 'False':
+        raise NotImplementedError('Currently RNN onnx export only supports use_sequence_length equals to False')
+
+    projection_size = str(attrs.get('projection_size', 'None'))
+    if projection_size != 'None':
+        raise NotImplementedError('Currently RNN onnx export only supports projection_size equals to None')
+
+    state_outputs = str(attrs.get('state_outputs', 'False'))
+    if state_outputs != 'True':
+        raise NotImplementedError('Currently RNN onnx export only supports state_outputs equals to True')
+
+    state_size = int(attrs.get('state_size'))
+    data = input_nodes[0]
+    param = input_nodes[1]
+    initial_h = input_nodes[2]
+
+    nodes = []
+
+    mode = str(attrs.get('mode'))
+    create_tensor([1], name+'_1', kwargs['initializer'])
+    if mode == 'lstm':
+        initial_c = input_nodes[3]
+        if num_layers == 2:
+            create_tensor([0], name+'_0', kwargs['initializer'])
+            create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+            create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+            create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
+            create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+            create_tensor([4*4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
+
+            nodes += [
+                make_node('Shape', [data], [name+'_data_shape']),
+                make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+
+                # Layer 0
+                # get W
+                make_node('Slice', [param, name+'_0', name+'_4*state_size^2'], [name+'_W0_1d']),
+                make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02', name+'_W03']),
+                make_node('Concat', [name+'_W00', name+'_W03', name+'_W01', name+'_W02'], [name+'_W0_'], axis=0),
+                make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
+                # get R
+                make_node('Add', [name+'_4*state_size^2', name+'_4*state_size^2'], [name+'_R0_offset']),
+                make_node('Slice', [param, name+'_4*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
+                make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02', name+'_R03']),
+                make_node('Concat', [name+'_R00', name+'_R03', name+'_R01', name+'_R02'], [name+'_R0_'], axis=0),
+                make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
+                # get B
+                make_node('Add', [name+'_WR_offset', name+'_8*state_size'], [name+'_B0_offset']),
+                make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
+                make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', name+'_B03',
+                                                     name+'_B04', name+'_B05', name+'_B06', name+'_B07']),
+                make_node('Concat', [name+'_B00', name+'_B03', name+'_B01', name+'_B02',
+                                     name+'_B04', name+'_B07', name+'_B05', name+'_B06'], [name+'_B0_'], axis=0),
+                make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
+                # get initial states
+                make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
+                make_node('Split', [initial_c], [name+'_initial_c0', name+'_initial_c1'], axis=0),
+                # get seq_len
+                make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+                make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+                # Layer 0 LSTM
+                make_node('LSTM', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
+                                   name+'_initial_h0', name+'_initial_c0'],
+                          [name+'_lstm0_out_', name+'_lstm0_h', name+'_lstm0_c'], hidden_size=state_size),
+                make_node('Squeeze', [name+'_lstm0_out_', name+'_1'], [name+'_lstm0_out']),
+
+                # Layer 1
+                # get W
+                make_node('Add', [name+'_R0_offset', name+'_4*state_size^2'], [name+'_W1_offset']),
+                make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
+                make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12', name+'_W13']),
+                make_node('Concat', [name+'_W10', name+'_W13', name+'_W11', name+'_W12'], [name+'_W1_'], axis=0),
+                make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
+                # get R
+                make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
+                make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12', name+'_R13']),
+                make_node('Concat', [name+'_R10', name+'_R13', name+'_R11', name+'_R12'], [name+'_R1_'], axis=0),
+                make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
+                # get B
+                make_node('Add', [name+'_B0_offset', name+'_8*state_size'], [name+'_B1_offset']),
+                make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
+                make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', name+'_B13',
+                                                     name+'_B14', name+'_B15', name+'_B16', name+'_B17']),
+                make_node('Concat', [name+'_B10', name+'_B13', name+'_B11', name+'_B12',
+                                     name+'_B14', name+'_B17', name+'_B15', name+'_B16'], [name+'_B1_'], axis=0),
+                make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
+                # Layer 1 LSTM
+                make_node('LSTM', [name+'_lstm0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
+                                   name+'_initial_h1', name+'_initial_c1'],
+                          [name+'_lstm1_out_', name+'_lstm1_h', name+'_lstm1_c'], hidden_size=state_size),
+                make_node('Squeeze', [name+'_lstm1_out_', name+'_1'], [name]),
+                make_node('Concat', [name+'_lstm0_h', name+'_lstm1_h'], [name+'1'], axis=0),
+                make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
+            ]
+        elif num_layers == 1:
+            create_tensor([0], name+'_0', kwargs['initializer'])
+            create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
+            create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+            create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+            create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
+            create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+
+            nodes += [
+                make_node('Shape', [data], [name+'_data_shape']),
+                make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+                # get W
+                make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
+                make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
+                make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
+                make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
+                make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
+                make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
+                # get R
+                make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
+                make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
+                make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
+                make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
+                make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
+                # get B
+                make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
+                make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
+                make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
+                                                    name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
+                make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
+                                     name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
+                make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
+                # get seq_len
+                make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+                make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+                # compute LSTM
+                make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
+                          [name+'0_', name+'1', name+'2'], hidden_size=state_size),
+                make_node('Squeeze', [name+'0_', name+'_1'], [name]),
+            ]
+        else:
+            raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
+
+    elif mode == 'gru':
+        if num_layers == 2:
+            create_tensor([0], name+'_0', kwargs['initializer'])
+            create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
+            create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
+            create_tensor([1, 3*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
+            create_tensor([1, 6*state_size], name+'_B_shape', kwargs['initializer'])
+            create_tensor([4*3*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
+
+            nodes += [
+                make_node('Shape', [data], [name+'_data_shape']),
+                make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+
+                # Layer 0
+                # get W
+                make_node('Slice', [param, name+'_0', name+'_3*state_size^2'], [name+'_W0_1d']),
+                make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02']),
+                make_node('Concat', [name+'_W01', name+'_W00', name+'_W02'], [name+'_W0_'], axis=0),
+                make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
+                # get R
+                make_node('Add', [name+'_3*state_size^2', name+'_3*state_size^2'], [name+'_R0_offset']),
+                make_node('Slice', [param, name+'_3*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
+                make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02']),
+                make_node('Concat', [name+'_R01', name+'_R00', name+'_R02'], [name+'_R0_'], axis=0),
+                make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
+                # get B
+                make_node('Add', [name+'_WR_offset', name+'_6*state_size'], [name+'_B0_offset']),
+                make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
+                make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02',
+                                                     name+'_B03', name+'_B04', name+'_B05']),
+                make_node('Concat', [name+'_B01', name+'_B00', name+'_B02',
+                                     name+'_B04', name+'_B03', name+'_B05'], [name+'_B0_'], axis=0),
+                make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
+                # get initial states
+                make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
+                # get seq_len
+                make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+                make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+                # Layer 0 GRU
+                make_node('GRU', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
+                                  name+'_initial_h0'],
+                          [name+'_gru0_out_', name+'_gru0_h'], hidden_size=state_size, linear_before_reset=1),
+                make_node('Squeeze', [name+'_gru0_out_', name+'_1'], [name+'_gru0_out']),
+
+                # Layer 1
+                # get W
+                make_node('Add', [name+'_R0_offset', name+'_3*state_size^2'], [name+'_W1_offset']),
+                make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
+                make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12']),
+                make_node('Concat', [name+'_W11', name+'_W10', name+'_W12'], [name+'_W1_'], axis=0),
+                make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
+                # get R
+                make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
+                make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12']),
+                make_node('Concat', [name+'_R11', name+'_R10', name+'_R12'], [name+'_R1_'], axis=0),
+                make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
+                # get B
+                make_node('Add', [name+'_B0_offset', name+'_6*state_size'], [name+'_B1_offset']),
+                make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
+                make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12',
+                                                     name+'_B13', name+'_B14', name+'_B15']),
+                make_node('Concat', [name+'_B11', name+'_B10', name+'_B12',
+                                     name+'_B14', name+'_B13', name+'_B15'], [name+'_B1_'], axis=0),
+                make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
+                # Layer 1 GRU
+                make_node('GRU', [name+'_gru0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
+                                  name+'_initial_h1'],
+                          [name+'_gru1_out_', name+'_gru1_h'], hidden_size=state_size, linear_before_reset=1),
+                make_node('Squeeze', [name+'_gru1_out_', name+'_1'], [name]),
+                make_node('Concat', [name+'_gru0_h', name+'_gru1_h'], [name+'1'], axis=0)
+            ]
+
+        elif num_layers == 1:
+            create_tensor([0], name+'_0', kwargs['initializer'])
+            create_tensor([3*state_size], name+'_3*state_size', kwargs['initializer'])
+            create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
+            create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
+            create_tensor([1, 3*state_size, state_size], name+'_R_shape', kwargs['initializer'])
+            create_tensor([1, 6*state_size], name+'_B_shape', kwargs['initializer'])
+
+            nodes += [
+                make_node('Shape', [data], [name+'_data_shape']),
+                make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+                # get W
+                make_node('Mul', [name+'_3*state_size', name+'_input_size'], [name+'_mul0']),
+                make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
+                make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2']),
+                make_node('Concat', [name+'_W1', name+'_W0', name+'_W2'], [name+'_W_'], axis=0),
+                make_node('Concat', [name+'_1', name+'_3*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
+                make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
+                # get R
+                make_node('Add', [name+'_mul0', name+'_3*state_size^2'], [name+'_add0']),
+                make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
+                make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2']),
+                make_node('Concat', [name+'_R1', name+'_R0', name+'_R2'], [name+'_R_'], axis=0),
+                make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
+                # get B
+                make_node('Add', [name+'_add0', name+'_6*state_size'], [name+'_add1']),
+                make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
+                make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2',
+                                                    name+'_B3', name+'_B4', name+'_B5']),
+                make_node('Concat', [name+'_B1', name+'_B0', name+'_B2',
+                                     name+'_B4', name+'_B3', name+'_B5'], [name+'_B_'], axis=0),
+                make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
+                # get seq_len
+                make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+                make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+                # compute LSTM
+                make_node('GRU', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
+                          [name+'0_', name+'1'], hidden_size=state_size, linear_before_reset=1),
+                make_node('Squeeze', [name+'0_', name+'_1'], [name]),
+            ]
+        else:
+            raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
+
+    else:
+        raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
+    return nodes
+
+
+@mx_op.register('SliceChannel', OPSET_VERSION)
+def convert_slice_channel(node, **kwargs):
+    """Map MXNet's SliceChannel operator attributes to onnx's Squeeze or Split
+    operator based on squeeze_axis attribute
+    and return the created node.
+    """
+    from onnx.helper import make_node
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    num_outputs = int(attrs.get('num_outputs'))
+    axis = int(attrs.get('axis', 1))
+    squeeze_axis = attrs.get('squeeze_axis', 'False')
+
+    create_tensor([axis], name+'_axis', kwargs['initializer'])
+    
+    nodes = []
+    if squeeze_axis in ['True', '1']:
+        nodes += [
+            make_node('Split', [input_nodes[0]], [name+str(i)+'_' for i in range(num_outputs)],
+                      axis=axis)
+        ]
+        for i in range(num_outputs):
+            nodes += [
+                make_node('Squeeze', [name+str(i)+'_', name+'_axis'], [name+str(i)])
+            ]
+    else:
+        nodes += [
+            make_node('Split', [input_nodes[0]], [name+str(i) for i in range(num_outputs)],
+                      axis=axis)
+        ]
+
+    return nodes
diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries
index 939cb07..e44ad41 100755
--- a/tests/nightly/JenkinsfileForBinaries
+++ b/tests/nightly/JenkinsfileForBinaries
@@ -94,9 +94,17 @@ core_logic: {
         }
       }
     },
+    'ONNX-operator: CPU': {
+      node(NODE_LINUX_CPU) {
+        ws('workspace/onnx-operator-test-cpu') {
+          utils.unpack_and_init('cpu_int64', mx_cmake_lib)
+          utils.docker_run('ubuntu_nightly_cpu', 'nightly_onnx_operator_tests', false)
+        }
+      }
+    },
     'ONNX-CV-batch1: CPU': {
       node(NODE_LINUX_CPU) {
-        ws('workspace/onnx-cv-test-cpu') {
+        ws('workspace/onnx-cv-batch1-test-cpu') {
           utils.unpack_and_init('cpu_int64', mx_cmake_lib)
           utils.docker_run('ubuntu_nightly_cpu', 'nightly_onnx_cv_batch1_tests', false)
         }
@@ -104,7 +112,7 @@ core_logic: {
     },
     'ONNX-CV-batch2: CPU': {
       node(NODE_LINUX_CPU) {
-        ws('workspace/onnx-cv-test-cpu') {
+        ws('workspace/onnx-cv-batch2-test-cpu') {
           utils.unpack_and_init('cpu_int64', mx_cmake_lib)
           utils.docker_run('ubuntu_nightly_cpu', 'nightly_onnx_cv_batch2_tests', false)
         }
diff --git a/tools/license_header.py b/tools/license_header.py
index 71b2811..0e42f79 100755
--- a/tools/license_header.py
+++ b/tools/license_header.py
@@ -125,6 +125,8 @@ _WHITE_LIST = [
                # Dual-Licensed under Apache 2.0 and Nvidia BSD-3
                'python/mxnet/onnx/mx2onnx/_export_onnx.py',
                'python/mxnet/onnx/mx2onnx/_op_translations.py',
+               'python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset11.py',
+               'python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py',
 
                # Github template
                '.github/ISSUE_TEMPLATE/bug_report.md',