You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/06 04:54:24 UTC

[GitHub] spidyDev closed pull request #10001: Add op avg pool arg max min

spidyDev closed pull request #10001: Add op avg pool arg max min
URL: https://github.com/apache/incubator-mxnet/pull/10001
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py
index 36ee21305bf..63cd8ce2664 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/python/mxnet/contrib/__init__.py
@@ -28,5 +28,5 @@
 from . import tensorboard
 
 from . import text
-
+from . import onnx
 from . import io
diff --git a/python/mxnet/contrib/onnx/__init__.py b/python/mxnet/contrib/onnx/__init__.py
new file mode 100644
index 00000000000..eff91206298
--- /dev/null
+++ b/python/mxnet/contrib/onnx/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Module for importing and exporting ONNX models."""
+
+from ._import.import_model import import_model
diff --git a/python/mxnet/contrib/onnx/_import/__init__.py b/python/mxnet/contrib/onnx/_import/__init__.py
new file mode 100644
index 00000000000..002cfa92583
--- /dev/null
+++ b/python/mxnet/contrib/onnx/_import/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""ONNX Import module"""
+from . import import_model
+from . import import_onnx
diff --git a/python/mxnet/contrib/onnx/_import/import_helper.py b/python/mxnet/contrib/onnx/_import/import_helper.py
new file mode 100644
index 00000000000..2f964bac534
--- /dev/null
+++ b/python/mxnet/contrib/onnx/_import/import_helper.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=invalid-name
+"""Operator attributes conversion"""
+from .op_translations import add, absolute, negative, reduce_max, reshape
+from .op_translations import reduce_mean, avg_pooling
+from .op_translations import sigmoid
+from .op_translations import argmax, argmin
+
+# _convert_map defines maps of name to converter functor(callable)
+_convert_map = {
+    # Arithmetic Operators
+    'Add'           : add,
+    'Abs'           : absolute,
+    'Neg'           : negative,
+    # Basic neural network functions
+    'Sigmoid'       : sigmoid,
+    # Changing shape and type.
+    'Reshape'       : reshape,
+    # Reduce Functions
+    'ReduceMax'     : reduce_max,
+    'ReduceMean'    : reduce_mean,
+    'AveragePool'   : avg_pooling,
+    # Sorting and Searching
+    'ArgMax'        : argmax,
+    'ArgMin'        : argmin
+}
diff --git a/python/mxnet/contrib/onnx/_import/import_model.py b/python/mxnet/contrib/onnx/_import/import_model.py
new file mode 100644
index 00000000000..d35a27929cd
--- /dev/null
+++ b/python/mxnet/contrib/onnx/_import/import_model.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""import function"""
+# pylint: disable=no-member
+try:
+    import onnx
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed")
+from .import_onnx import GraphProto
+
+
+def import_model(model_file):
+    """Imports the supplied ONNX model file into MXNet symbol and parameters.
+
+    Parameters
+    ----------
+    model_file : ONNX model file name
+
+    Returns
+    -------
+    sym : mx.symbol
+        Compatible mxnet symbol
+
+    params : dict of str to mx.ndarray
+        Dict of converted parameters stored in mx.ndarray format
+    """
+    graph = GraphProto()
+
+    # loads model file and returns ONNX protobuf object
+    model_proto = onnx.load(model_file)
+    sym, params = graph.from_onnx(model_proto.graph)
+    return sym, params
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/_import/import_onnx.py
new file mode 100644
index 00000000000..d35172b0649
--- /dev/null
+++ b/python/mxnet/contrib/onnx/_import/import_onnx.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=invalid-name,too-many-locals,no-self-use
+""" Support import export formats."""
+from __future__ import absolute_import as _abs
+from .... import symbol
+from .... import ndarray as nd
+from .import_helper import _convert_map
+
+def _convert_operator(node_name, op_name, attrs, inputs, convert_map=None):
+    """Convert from onnx operator to mxnet operator.
+    The converter must specify conversions explicitly for incompatible name, and
+    apply handlers to operator attributes.
+
+    Parameters
+    ----------
+    op_name : str
+        Operator name, such as Convolution, FullyConnected
+    attrs : dict
+        Dict of operator attributes
+    inputs: list
+        list of inputs to the operator
+    convert_map : dict
+        Dict of name : callable, where name is the op's name that
+        require conversion to mxnet, callable are functions which
+        take attrs and return (new_op_name, new_attrs, inputs)
+
+    Returns
+    -------
+    (op_name, attrs)
+        Converted (op_name, attrs) for mxnet.
+    """
+    convert_map = convert_map if convert_map else _convert_map
+    if op_name in convert_map:
+        op_name, new_attrs, inputs = convert_map[op_name](op_name, attrs, inputs)
+    else:
+        raise NotImplementedError("Operator {} not implemented.".format(op_name))
+    if isinstance(op_name, str):
+        new_op = getattr(symbol, op_name, None)
+        op = new_op(name=node_name, *inputs, **new_attrs)
+        if not op:
+            raise RuntimeError("Unable to map op_name {} to sym".format(op_name))
+    else:
+        op = op_name
+
+    return op
+
+
+class GraphProto(object): # pylint: disable=too-few-public-methods
+    """A helper class for handling mxnet symbol copying from pb2.GraphProto.
+    Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
+    """
+    def __init__(self):
+        self._nodes = {}
+        self._params = {}
+        self._renames = {}
+        self._num_input = 0
+        self._num_param = 0
+
+    def from_onnx(self, graph):
+        """Construct symbol from onnx graph.
+        The inputs from onnx graph is vague, only providing "1", "2"...
+        For convenience, we rename the `real` input names to "input_0",
+        "input_1"... And renaming parameters to "param_0", "param_1"...
+
+        Parameters
+        ----------
+        graph : onnx protobuf object
+            The loaded onnx graph
+
+        Returns
+        -------
+        sym :symbol.Symbol
+            The returned mxnet symbol
+        params : dict
+            A dict of name: nd.array pairs, used as pretrained weights
+        """
+        # parse network inputs, aka parameters
+        for init_tensor in graph.initializer:
+            if not init_tensor.name.strip():
+                raise ValueError("Tensor's name is required.")
+            self._params[init_tensor.name] = self._parse_array(init_tensor)
+
+        # converting GraphProto message
+        for i in graph.input:
+            if i.name in self._params:
+                # i is a param instead of input
+                name_param = 'param_{}'.format(self._num_param)
+                self._num_param += 1
+                self._params[name_param] = self._params.pop(i.name)
+                self._nodes[name_param] = symbol.Variable(name=name_param,
+                                                          shape=self._params[name_param].shape)
+                self._renames[i.name] = name_param
+            else:
+                name_input = 'input_{}'.format(self._num_input)
+                self._num_input += 1
+                self._nodes[name_input] = symbol.Variable(name=name_input)
+                self._renames[i.name] = name_input
+
+        # constructing nodes, nodes are stored as directed acyclic graph
+        # converting NodeProto message
+        for node in graph.node:
+            op_name = node.op_type
+            node_name = node.name.strip()
+            node_name = node_name if node_name else None
+            onnx_attr = self._parse_attr(node.attribute)
+            inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
+            op = _convert_operator(node_name, op_name, onnx_attr, inputs)
+
+            assert len(node.output) == len(op.list_outputs()), (
+                "Output dimension mismatch between the onnx operator and the mxnet symbol " +
+                "{} vs {} for the operator - {}.".format(
+                    len(node.output), len(op.list_outputs()), op_name))
+            for k, i in zip(list(node.output), range(len(node.output))):
+                self._nodes[k] = op[i]
+        # now return the outputs
+        out = [self._nodes[i.name] for i in graph.output]
+        if len(out) > 1:
+            out = symbol.Group(out)
+        else:
+            out = out[0]
+        return out, self._params
+
+    def _parse_array(self, tensor_proto):
+        """Grab data in TensorProto and convert to numpy array."""
+        try:
+            from onnx.numpy_helper import to_array
+        except ImportError as e:
+            raise ImportError("Unable to import onnx which is required {}".format(e))
+        np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
+        return nd.array(np_array)
+
+    def _parse_attr(self, attr_proto):
+        """Convert a list of AttributeProto to a dict, with names as keys."""
+        attrs = {}
+        for a in attr_proto:
+            for f in ['f', 'i', 's']:
+                if a.HasField(f):
+                    attrs[a.name] = getattr(a, f)
+            for f in ['floats', 'ints', 'strings']:
+                if list(getattr(a, f)):
+                    assert a.name not in attrs, "Only one type of attr is allowed"
+                    attrs[a.name] = tuple(getattr(a, f))
+            for f in ['t', 'g']:
+                if a.HasField(f):
+                    attrs[a.name] = getattr(a, f)
+            for f in ['tensors', 'graphs']:
+                if list(getattr(a, f)):
+                    raise NotImplementedError("Filed {} is not supported in mxnet.".format(f))
+            if a.name not in attrs:
+                raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
+        return attrs
diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py b/python/mxnet/contrib/onnx/_import/op_translations.py
new file mode 100644
index 00000000000..19c0e7a3dec
--- /dev/null
+++ b/python/mxnet/contrib/onnx/_import/op_translations.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+""" Module for translating ONNX operators into Mxnet operatoes"""
+# pylint: disable=unused-argument,protected-access
+from . import translation_utils
+
+
+# Arithmetic Operations
+def add(op_name, attrs, inputs):
+    """Adding two tensors"""
+    new_attr = {}
+    if 'broadcast' in attrs and attrs['broadcast'] == 1:
+        return 'broadcast_add', new_attr, inputs
+    return 'elemwise_add', new_attr, inputs
+
+
+def absolute(op_name, attrs, inputs):
+    return 'abs', attrs, inputs
+
+
+def argmax(op_name, attrs, inputs):
+    return 'argmax', attrs, inputs
+
+
+def argmin(op_name, attrs, inputs):
+    return 'argmin', attrs, inputs
+
+
+def avg_pooling(op_name, attrs, inputs):
+    new_attrs = translation_utils._fix_attribute_names(attrs,
+                                                       {'kernel_shape': 'kernel',
+                                                        'strides': 'stride',
+                                                        'pads': 'pad',
+                                                        'pool_type' : 'avg',
+                                                        'pooling_convention': 'valid'})
+    op = translation_utils._fix_pooling(op_name, inputs, new_attrs)
+    return op, new_attrs, inputs
+
+
+def negative(op_name, attrs, inputs):
+    """Negation of every element in a tensor"""
+    return "negative", attrs, inputs
+
+
+# Basic neural network functions
+def sigmoid(op_name, attrs, inputs):
+    """Computes elementwise sigmoid of the input array"""
+    return "sigmoid", attrs, inputs
+
+
+# Changing shape and type.
+def reshape(op_name, attrs, inputs):
+    """Reshape the given array by the shape attribute."""
+    return "reshape", attrs, inputs
+
+
+# Reduce Functions
+def reduce_max(op_name, attrs, inputs):
+    """Reduce the array along a given axis by maximum value"""
+    new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
+    return 'max', new_attrs, inputs
+
+
+def reduce_mean(op_name, attrs, inputs):
+    """Reduce the array along a given axis by mean value"""
+    new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
+    return 'mean', new_attrs, inputs
diff --git a/python/mxnet/contrib/onnx/_import/translation_utils.py b/python/mxnet/contrib/onnx/_import/translation_utils.py
new file mode 100644
index 00000000000..87d53ba398d
--- /dev/null
+++ b/python/mxnet/contrib/onnx/_import/translation_utils.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Utilities used for translating operators from Onnx to Mxnet."""
+# pylint: disable=
+from __future__ import absolute_import as _abs
+from .... import symbol
+
+
+def _fix_attribute_names(attrs, change_map):
+    """
+    Change attribute names as per values in change_map dictionary.
+    Parameters
+    ----------
+    attrs : dict
+        Dict of operator attributes
+    change_map : dict
+        Dict of onnx attribute name to mxnet attribute names.
+
+    Returns
+    -------
+    new_attr : dict
+        Converted dict of operator attributes.
+    """
+    new_attr = {}
+    for k in change_map:
+        if k in attrs.keys():
+            new_attr[change_map[k]] = attrs[k]
+        else:
+            new_attr[k] = change_map[k]
+
+    return new_attr
+
+
+def _pad_sequence_fix(attr, kernelDim=None):
+    """Changing onnx's pads sequence to match with mxnet's pad_width
+    mxnet: (x1_begin, x1_end, ... , xn_begin, xn_end)
+    onnx: (x1_begin, x2_begin, ... , xn_end, xn_end)"""
+    new_attr = ()
+    if len(attr) % 2 == 0:
+        for index in range(int(len(attr) / 2)):
+            new_attr = new_attr + attr[index::int(len(attr) / 2)]
+        # Making sure pad values  are in the attr for all axes.
+        if kernelDim is not None:
+            while len(new_attr) < kernelDim*2:
+                new_attr = new_attr + (0, 0)
+
+    return new_attr
+
+
+def _fix_pooling(op_name, inputs, new_attr):
+    """onnx pooling operator supports asymmetrical padding
+    Adding pad operator before pooling in mxnet to work with onnx"""
+    pool_type = 'avg' if op_name == 'AveragePool' else 'max'
+    stride = new_attr.get('stride')
+    kernel = new_attr.get('kernel')
+    padding = new_attr.get('pad')
+    pad_width = (0, 0, 0, 0) + _pad_sequence_fix(padding, len(kernel))
+    new_pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width)
+    new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type,
+                                    stride=stride, kernel=kernel)
+    return new_pooling_op
diff --git a/tests/ci_build/install/install_testdeps.sh b/tests/ci_build/install/install_testdeps.sh
index 26da186dd27..d127adb3070 100644
--- a/tests/ci_build/install/install_testdeps.sh
+++ b/tests/ci_build/install/install_testdeps.sh
@@ -19,7 +19,7 @@
 set -e
 set -x
 
-pip install cpplint 'pylint==1.4.4' 'astroid==1.3.6'
+pip install cpplint 'pylint==1.4.4' 'astroid==1.3.6' 'onnx==1.0.1' 'protobuf==3.0.0'
 
 pip3 install nose
 ln -s -f /opt/bin/nosetests /usr/local/bin/nosetests3
diff --git a/tests/python/onnx_test_utils/backend.py b/tests/python/onnx_test_utils/backend.py
new file mode 100644
index 00000000000..e18036e54e1
--- /dev/null
+++ b/tests/python/onnx_test_utils/backend.py
@@ -0,0 +1,187 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=too-many-locals,invalid-name,no-member
+"""backend wrapper for onnx test infrastructure"""
+from collections import namedtuple
+import mxnet as mx
+from mxnet.contrib.onnx._import.import_onnx import GraphProto
+try:
+    from onnx import helper, TensorProto
+    from onnx.backend.base import Backend
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed")
+from backend_rep import MXNetBackendRep
+
+# Using these functions for onnx test infrastructure.
+# Implemented by following onnx docs guide:
+# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
+# MXNetBackend class will take an ONNX model with inputs, perform a computation,
+# and then return the output.
+
+
+class MXNetBackend(Backend):
+    """MXNet backend for ONNX"""
+
+    @staticmethod
+    def make_graph(node, inputs):
+        """ Created ONNX GraphProto from node"""
+        initializer = []
+        tensor_input_info = []
+        tensor_output_info = []
+
+        # Adding input tensor info.
+        for index in range(len(node.input)):
+            tensor_input_info.append(
+                helper.make_tensor_value_info(str(node.input[index]), TensorProto.FLOAT, [1]))
+
+            # Creating an initializer for Weight params.
+            # Assumes that weight params is named as 'W'.
+            # TODO: Handle multiple weight params.
+            # TODO: Add for "bias" if needed
+            if node.input[index] == 'W':
+                dim = inputs[index].shape
+                param_tensor = helper.make_tensor(
+                    name=node.input[index],
+                    data_type=TensorProto.FLOAT,
+                    dims=dim,
+                    vals=inputs[index].flatten())
+
+                initializer.append(param_tensor)
+
+        # Adding output tensor info.
+        for index in range(len(node.output)):
+            tensor_output_info.append(
+                helper.make_tensor_value_info(str(node.output[index]), TensorProto.FLOAT, [1]))
+
+        # creating graph proto object.
+        graph_proto = helper.make_graph(
+            [node],
+            "test",
+            tensor_input_info,
+            tensor_output_info,
+            initializer=initializer)
+
+        return graph_proto
+
+    @classmethod
+    def run_node(cls, node, inputs, device='CPU'):
+        """Running individual node inference on mxnet engine and
+        return the result to onnx test infrastructure.
+
+        Parameters
+        ----------
+        node   : onnx node object
+            loaded onnx node (individual layer)
+        inputs : numpy array
+            input to run a node on
+        device : 'CPU'
+            device to run a node on
+
+        Returns
+        -------
+        params : numpy array
+            result obtained after running the operator
+        """
+        graph = GraphProto()
+        sym, params = graph.from_onnx(MXNetBackend.make_graph(node, inputs))
+        data_names = [i for i in sym.get_internals().list_inputs() if i[:-1] == "input_"]
+        data_shapes = []
+        dim_change_op_types = set(['ReduceMin', 'ReduceMax', 'ReduceMean',
+                                   'ReduceProd', 'ReduceSum', 'Slice', 'Pad',
+                                   'Squeeze', 'Upsample', 'Reshape', 'Conv'])
+
+        # Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
+        for idx, input_name in enumerate(data_names):
+            batch_size = 1
+            if len(inputs[idx].shape) < 4 and len(inputs) > 1 and \
+                            len(set(x.shape[0] for x in inputs)) != 1:
+                tuples = ((batch_size,), inputs[idx].shape)
+                new_shape = sum(tuples, ())
+                data_shapes.append((input_name, new_shape))
+            else:
+                data_shapes.append((input_name, inputs[idx].shape))
+
+        # create module, passing cpu context
+        if device == 'CPU':
+            ctx = mx.cpu()
+        else:
+            raise NotImplementedError("Only CPU context is supported for now")
+
+        # create a module
+        mod = mx.mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
+        mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
+
+        # initializing parameters for calculating result of each individual node
+        if int(len(params)) > 0:
+            mod.set_params(arg_params=params, aux_params=params)
+        else:
+            mod.init_params()
+
+        batch = namedtuple('Batch', ['data'])
+
+        data_forward = []
+        for idx, input_name in enumerate(data_names):
+            # slice and pad operator tests needs 1 less dimension in forward pass
+            # otherwise it will throw an error.
+            # for squeeze operator, need to retain shape of input as provided
+            val = inputs[idx]
+            if node.op_type in dim_change_op_types:
+                data_forward.append(mx.nd.array(val))
+            else:
+                data_forward.append(mx.nd.array([val]))
+
+        mod.forward(batch(data_forward))
+        result = mod.get_outputs()[0].asnumpy()
+        if node.op_type in dim_change_op_types:
+            return [result]
+        return result
+
+    @classmethod
+    def prepare(cls, model, device='CPU', **kwargs):
+        """For running end to end model(used for onnx test backend)
+
+        Parameters
+        ----------
+        model  : onnx ModelProto object
+            loaded onnx graph
+        device : 'CPU'
+            specifying device to run test on
+        kwargs :
+            other arguments
+
+        Returns
+        -------
+        MXNetBackendRep : object
+            Returns object of MXNetBackendRep class which will be in turn
+            used to run inference on the input model and return the result for comparison.
+        """
+        graph = GraphProto()
+        sym, params = graph.from_onnx(model.graph)
+        return MXNetBackendRep(sym, params, device)
+
+    @classmethod
+    def supports_device(cls, device):
+        """Supports only CPU for testing"""
+        return device == 'CPU'
+
+prepare = MXNetBackend.prepare
+
+run_node = MXNetBackend.run_node
+
+supports_device = MXNetBackend.supports_device
diff --git a/tests/python/onnx_test_utils/backend_rep.py b/tests/python/onnx_test_utils/backend_rep.py
new file mode 100644
index 00000000000..c6ceabc825f
--- /dev/null
+++ b/tests/python/onnx_test_utils/backend_rep.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=too-few-public-methods
+"""backend rep for onnx test infrastructure"""
+from collections import namedtuple
+import numpy as np
+try:
+    from onnx.backend.base import BackendRep
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed")
+import mxnet as mx
+
+# Using these functions for onnx test infrastructure.
+# Implemented by following onnx docs guide:
+# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
+# MXNetBackendRep object will be returned by MXNetBackend's prepare method which is used to
+# execute a model repeatedly.
+# Inputs will be passed to the run method of MXNetBackendRep class, it will perform computation and
+# retrieve the corresponding results for comparison to the onnx backend.
+# https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py.
+
+class MXNetBackendRep(BackendRep):
+    """Running model inference on mxnet engine and return the result
+     to onnx test infrastructure for comparison."""
+    def __init__(self, symbol, params, device):
+        self.symbol = symbol
+        self.params = params
+        self.device = device
+
+    def run(self, inputs, **kwargs):
+        """Run model inference and return the result
+
+        Parameters
+        ----------
+        inputs : numpy array
+            input to run a layer on
+
+        Returns
+        -------
+        params : numpy array
+            result obtained after running the inference on mxnet
+        """
+        input_data = np.asarray(inputs[0], dtype='f')
+
+        # create module, passing cpu context
+        if self.device == 'CPU':
+            ctx = mx.cpu()
+        else:
+            raise NotImplementedError("Only CPU context is supported for now")
+
+        mod = mx.mod.Module(symbol=self.symbol, data_names=['input_0'], context=ctx,
+                            label_names=None)
+        mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)],
+                 label_shapes=None)
+        mod.set_params(arg_params=self.params, aux_params=None)
+
+        # run inference
+        batch = namedtuple('Batch', ['data'])
+
+        mod.forward(batch([mx.nd.array(input_data)]))
+        result = mod.get_outputs()[0].asnumpy()
+        return [result]
diff --git a/tests/python/unittest/onnx_backend_test.py b/tests/python/unittest/onnx_backend_test.py
new file mode 100644
index 00000000000..35a52429984
--- /dev/null
+++ b/tests/python/unittest/onnx_backend_test.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""ONNX test backend wrapper"""
+# pylint: disable=invalid-name,import-error,wrong-import-position
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+try:
+    import onnx.backend.test
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed")
+
+from os import sys
+sys.path.append('../onnx_test_utils')
+import backend as mxnet_backend
+
+# This is a pytest magic variable to load extra plugins
+pytest_plugins = "onnx.backend.test.report",
+
+backend_test = onnx.backend.test.BackendTest(mxnet_backend, __name__)
+
+implemented_operators = [
+    #Arithmetic Operators
+    'test_add*',
+    'test_neg*',
+    'test_abs*',
+    'test_argmax*',
+    'test_argmin*',
+    #Basic neural network functions
+    'test_sigmoid*',
+    #Changing shape and type.
+    'test_reshape_*',
+    'test_AvgPool2D*'
+    ]
+
+for op_test in implemented_operators:
+    backend_test.include(op_test)
+
+# import all test cases at global scope to make them visible to python.unittest
+globals().update(backend_test
+                 .enable_report()
+                 .test_cases)
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/python/unittest/test_layers.py b/tests/python/unittest/test_layers.py
new file mode 100644
index 00000000000..d2272d06c76
--- /dev/null
+++ b/tests/python/unittest/test_layers.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for individual operators"""
+# pylint: disable=import-error,no-self-use
+
+from __future__ import absolute_import
+import unittest
+import numpy as np
+import numpy.testing as npt
+from onnx import helper
+import backend as mxnet_backend
+
+class TestLayers(unittest.TestCase):
+    """Tests for different layers comparing output with numpy operators.
+    Temporary file until we have a corresponding test in onnx-backend_test
+    for these operators."""
+
+    def _random_array(self, shape):
+        """Generate random array according to input shape"""
+        return np.random.ranf(shape).astype("float32")
+
+    def test_reduce_max(self):
+        """Test for ReduceMax operator"""
+        node_def = helper.make_node("ReduceMax", ["input1"], ["output"], axes=[1, 0], keepdims=1)
+        input1 = self._random_array([3, 10])
+        output = mxnet_backend.run_node(node_def, [input1])[0]
+        numpy_op = np.max(input1, axis=(1, 0), keepdims=True)
+        npt.assert_almost_equal(output, numpy_op)
+
+    def test_reduce_mean(self):
+        """Test for ReduceMean operator"""
+        node_def = helper.make_node("ReduceMean", ["input1"], ["output"], axes=[1, 0], keepdims=1)
+        input1 = self._random_array([3, 10])
+        output = mxnet_backend.run_node(node_def, [input1])[0]
+        numpy_op = np.mean(input1, axis=(1, 0), keepdims=True)
+        npt.assert_almost_equal(output, numpy_op, decimal=5)
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/python/unittest/test_super_resolution.py b/tests/python/unittest/test_super_resolution.py
new file mode 100644
index 00000000000..80c55ec32e7
--- /dev/null
+++ b/tests/python/unittest/test_super_resolution.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Testing super_resolution model conversion"""
+# pylint: disable=invalid-name
+from __future__ import absolute_import as _abs
+from __future__ import print_function
+from collections import namedtuple
+import mxnet as mx
+from mxnet.test_utils import download
+import mxnet.contrib.onnx as onnx_mxnet
+import numpy as np
+from PIL import Image
+
+model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx'
+
+download(model_url, 'super_resolution.onnx')
+
+print("Converting onnx format to mxnet's symbol and params...")
+sym, params = onnx_mxnet.import_model('super_resolution.onnx')
+
+# Load test image
+input_image_dim = 224
+output_image_dim = 672
+img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
+download(img_url, 'super_res_input.jpg')
+img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim))
+img_ycbcr = img.convert("YCbCr")
+img_y, img_cb, img_cr = img_ycbcr.split()
+x = np.array(img_y)[np.newaxis, np.newaxis, :, :]
+
+# create module
+mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None)
+mod.bind(for_training=False, data_shapes=[('input_0', x.shape)])
+mod.set_params(arg_params=params, aux_params=None)
+
+# run inference
+Batch = namedtuple('Batch', ['data'])
+mod.forward(Batch([mx.nd.array(x)]))
+
+# Save the result
+img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0].asnumpy().clip(0, 255)), mode='L')
+
+result_img = Image.merge(
+    "YCbCr", [img_out_y,
+              img_cb.resize(img_out_y.size, Image.BICUBIC),
+              img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB")
+result_img.save("super_res_output.jpg")


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services