You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/04 18:05:17 UTC

[GitHub] anirudh2290 closed pull request #10605: [MXNET-310] [ONNX-MXNet] API to import ONNX models into Gluon.

anirudh2290 closed pull request #10605: [MXNET-310] [ONNX-MXNet] API to import ONNX models into Gluon.
URL: https://github.com/apache/incubator-mxnet/pull/10605
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/ci/docker/install/ubuntu_onnx.sh b/ci/docker/install/ubuntu_onnx.sh
index 07acba01908..737c333afb6 100755
--- a/ci/docker/install/ubuntu_onnx.sh
+++ b/ci/docker/install/ubuntu_onnx.sh
@@ -30,5 +30,5 @@ echo "Installing libprotobuf-dev and protobuf-compiler ..."
 apt-get install -y libprotobuf-dev protobuf-compiler
 
 echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX and tabulate ..."
-pip2 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.1.1 Pillow==5.0.0 tabulate==0.7.5
-pip3 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.1.1 Pillow==5.0.0 tabulate==0.7.5
+pip2 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.2.1 Pillow==5.0.0 tabulate==0.7.5
+pip3 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.2.1 Pillow==5.0.0 tabulate==0.7.5
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 10bca17b5ff..fa9de6112ff 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -514,8 +514,9 @@ integrationtest_ubuntu_cpu_onnx() {
 	set -ex
 	export PYTHONPATH=./python/
 	python example/onnx/super_resolution.py
-	pytest tests/python-pytest/onnx/onnx_backend_test.py
-	pytest tests/python-pytest/onnx/onnx_test.py
+	pytest tests/python-pytest/onnx/import/mxnet_backend_test.py
+	pytest tests/python-pytest/onnx/import/onnx_import_test.py
+	pytest tests/python-pytest/onnx/import/gluon_backend_test.py
 }
 
 integrationtest_ubuntu_gpu_python() {
diff --git a/python/mxnet/contrib/onnx/__init__.py b/python/mxnet/contrib/onnx/__init__.py
index fb8488ca4f2..4f9296d3c56 100644
--- a/python/mxnet/contrib/onnx/__init__.py
+++ b/python/mxnet/contrib/onnx/__init__.py
@@ -17,3 +17,4 @@
 """Module for ONNX model format support for Apache MXNet."""
 
 from ._import.import_model import import_model, get_model_metadata
+from ._import.import_to_gluon import import_to_gluon
diff --git a/python/mxnet/contrib/onnx/_import/__init__.py b/python/mxnet/contrib/onnx/_import/__init__.py
index 002cfa92583..d0411df739b 100644
--- a/python/mxnet/contrib/onnx/_import/__init__.py
+++ b/python/mxnet/contrib/onnx/_import/__init__.py
@@ -19,3 +19,4 @@
 """ONNX Import module"""
 from . import import_model
 from . import import_onnx
+from . import import_to_gluon
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/_import/import_onnx.py
index db233578ff9..d81ec96537f 100644
--- a/python/mxnet/contrib/onnx/_import/import_onnx.py
+++ b/python/mxnet/contrib/onnx/_import/import_onnx.py
@@ -20,6 +20,7 @@
 """ Support import export formats."""
 from __future__ import absolute_import as _abs
 from .... import symbol
+from .... import cpu, gpu
 from .... import ndarray as nd
 from ....base import string_types
 from .import_helper import _convert_map as convert_map
@@ -33,6 +34,9 @@ def __init__(self):
         self._params = {}
         self._num_input = 0
         self._num_param = 0
+        self.aux_dict = {}
+        self.arg_dict = {}
+        self.model_metadata = {}
 
     def _convert_operator(self, node_name, op_name, attrs, inputs):
         """Convert from onnx operator to mxnet operator.
@@ -84,6 +88,8 @@ def from_onnx(self, graph):
         params : dict
             A dict of name: nd.array pairs, used as pretrained weights
         """
+        #get input, output shapes
+        self.model_metadata = self.get_graph_metadata(graph)
         # parse network inputs, aka parameters
         for init_tensor in graph.initializer:
             if not init_tensor.name.strip():
@@ -99,10 +105,6 @@ def from_onnx(self, graph):
             else:
                 self._nodes[i.name] = symbol.Variable(name=i.name)
 
-        # For storing arg  and aux params for the graph.
-        auxDict = {}
-        argDict = {}
-
         # constructing nodes, nodes are stored as directed acyclic graph
         # converting NodeProto message
         for node in graph.node:
@@ -119,10 +121,10 @@ def from_onnx(self, graph):
             # splitting params into args and aux params
             for args in mxnet_sym.list_arguments():
                 if args in self._params:
-                    argDict.update({args: nd.array(self._params[args])})
+                    self.arg_dict.update({args: nd.array(self._params[args])})
             for aux in mxnet_sym.list_auxiliary_states():
                 if aux in self._params:
-                    auxDict.update({aux: nd.array(self._params[aux])})
+                    self.aux_dict.update({aux: nd.array(self._params[aux])})
 
         # now return the outputs
         out = [self._nodes[i.name] for i in graph.output]
@@ -130,7 +132,7 @@ def from_onnx(self, graph):
             out = symbol.Group(out)
         else:
             out = out[0]
-        return out, argDict, auxDict
+        return out, self.arg_dict, self.aux_dict
 
     def get_graph_metadata(self, graph):
         """
@@ -155,6 +157,40 @@ def get_graph_metadata(self, graph):
                    }
         return metadata
 
+    def graph_to_gluon(self, graph, context):
+        """Construct SymbolBlock from onnx graph.
+
+        Parameters
+        ----------
+        graph : onnx protobuf object
+            The loaded onnx graph
+        context : str
+            context for mxnet module object. Should be 'CPU' or 'GPU'
+
+        Returns
+        -------
+        sym_block :gluon.nn.SymbolBlock
+            The returned gluon SymbolBlock
+        """
+        sym, arg_params, aux_params = self.from_onnx(graph)
+        metadata = self.get_graph_metadata(graph)
+        data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']]
+        data_inputs = [symbol.var(data_name) for data_name in data_names]
+
+        ctx = gpu() if context == 'GPU' else cpu()
+        from ....gluon import SymbolBlock
+        net = SymbolBlock(outputs=sym, inputs=data_inputs)
+        net_params = net.collect_params()
+        for param in arg_params:
+            if param in net_params:
+                net_params[param].shape = arg_params[param].shape
+                net_params[param]._load_init(arg_params[param], ctx=ctx)
+        for param in aux_params:
+            if param in net_params:
+                net_params[param].shape = aux_params[param].shape
+                net_params[param]._load_init(aux_params[param], ctx=ctx)
+        return net
+
     def _parse_array(self, tensor_proto):
         """Grab data in TensorProto and convert to numpy array."""
         try:
diff --git a/python/mxnet/contrib/onnx/_import/import_to_gluon.py b/python/mxnet/contrib/onnx/_import/import_to_gluon.py
new file mode 100644
index 00000000000..eee968b32cd
--- /dev/null
+++ b/python/mxnet/contrib/onnx/_import/import_to_gluon.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Import ONNX model to gluon interface"""
+# pylint: disable=no-member
+
+from .import_onnx import GraphProto
+
+def import_to_gluon(model_file, context):
+    """
+    Imports the ONNX model files, passed as a parameter, into Gluon SymbolBlock object.
+
+    Parameters
+    ----------
+    model_file : str
+        ONNX model file name
+    context : str
+        context. Should be 'CPU' or 'GPU'
+
+    Returns
+    -------
+    sym_block : :class:`~mxnet.gluon.SymbolBlock`
+        A SymbolBlock object representing the given model file.
+    """
+    graph = GraphProto()
+    try:
+        import onnx
+    except ImportError:
+        raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+                          + " install - https://github.com/onnx/onnx#installation")
+    model_proto = onnx.load(model_file)
+    net = graph.graph_to_gluon(model_proto.graph, context)
+    return net
diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py b/python/mxnet/contrib/onnx/_import/op_translations.py
index 2fa517a8c7c..5df9d913f11 100644
--- a/python/mxnet/contrib/onnx/_import/op_translations.py
+++ b/python/mxnet/contrib/onnx/_import/op_translations.py
@@ -23,76 +23,86 @@
 
 # Method definitions for the callable objects mapped in the import_helper module
 
-def identity(attrs, inputs, cls):
+def identity(attrs, inputs, proto_obj):
     """Returns the identity function of the the input."""
     return 'identity', attrs, inputs
 
-def random_uniform(attrs, inputs, cls):
+def random_uniform(attrs, inputs, proto_obj):
     """Draw random samples from a uniform distribtuion."""
     new_attr = translation_utils._remove_attributes(attrs, ['seed'])
     return 'random_uniform', new_attr, inputs
 
-def random_normal(attrs, inputs, cls):
+def random_normal(attrs, inputs, proto_obj):
     """Draw random samples from a Gaussian distribution."""
     new_attr = translation_utils._remove_attributes(attrs, ['seed'])
     new_attr = translation_utils._fix_attribute_names(new_attr, {'mean' : 'loc'})
     return 'random_uniform', new_attr, inputs
 
 # Arithmetic Operations
-def add(attrs, inputs, cls):
+def add(attrs, inputs, proto_obj):
     """Adding two tensors"""
     new_attr = {}
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        op_value = translation_utils._fix_bias_shape('broadcast_add', inputs, cls)
+        broadcast_axis = attrs['axis']
+        op_value = translation_utils._fix_broadcast('broadcast_add', inputs,
+                                                    broadcast_axis, proto_obj)
         return op_value, new_attr, inputs
-    return 'elemwise_add', new_attr, inputs
+    return 'broadcast_add', new_attr, inputs
 
-def subtract(attrs, inputs, cls):
+def subtract(attrs, inputs, proto_obj):
     """Subtracting two tensors"""
     new_attr = {}
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        return 'broadcast_sub', new_attr, inputs
-    return 'elemwise_sub', new_attr, inputs
+        broadcast_axis = attrs['axis']
+        op_value = translation_utils._fix_broadcast('broadcast_sub', inputs,
+                                                    broadcast_axis, proto_obj)
+        return op_value, new_attr, inputs
+    return 'broadcast_sub', new_attr, inputs
 
 
-def multiply(attrs, inputs, cls):
+def multiply(attrs, inputs, proto_obj):
     """Multiply two tensors"""
     new_attr = {}
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        op_value = translation_utils._fix_bias_shape('broadcast_mul', inputs, cls)
+        broadcast_axis = attrs['axis']
+        op_value = translation_utils._fix_broadcast('broadcast_mul', inputs,
+                                                    broadcast_axis, proto_obj)
         return op_value, new_attr, inputs
-    return 'elemwise_mul', new_attr, inputs
+    return 'broadcast_mul', new_attr, inputs
 
-def divide(attrs, inputs, cls):
+def divide(attrs, inputs, proto_obj):
     """Divide two tensors"""
     new_attr = {}
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        return 'broadcast_div', new_attr, inputs
-    return 'elemwise_div', new_attr, inputs
+        broadcast_axis = attrs['axis']
+        op_value = translation_utils._fix_broadcast('broadcast_div', inputs,
+                                                    broadcast_axis, proto_obj)
+        return op_value, new_attr, inputs
+    return 'broadcast_div', new_attr, inputs
 
-def absolute(attrs, inputs, cls):
+def absolute(attrs, inputs, proto_obj):
     """Returns element-wise absolute value of the input."""
     return 'abs', attrs, inputs
 
-def negative(attrs, inputs, cls):
+def negative(attrs, inputs, proto_obj):
     """Negation of every element in a tensor"""
     return 'negative', attrs, inputs
 
-def add_n(attrs, inputs, cls):
+def add_n(attrs, inputs, proto_obj):
     """Elementwise sum of arrays"""
     return 'add_n', attrs, inputs
 
 # Sorting and Searching
-def argmax(attrs, inputs, cls):
+def argmax(attrs, inputs, proto_obj):
     """Returns indices of the maximum values along an axis"""
     return 'argmax', attrs, inputs
 
 
-def argmin(attrs, inputs, cls):
+def argmin(attrs, inputs, proto_obj):
     """Returns indices of the minimum values along an axis."""
     return 'argmin', attrs, inputs
 
-def maximum(attrs, inputs, cls):
+def maximum(attrs, inputs, proto_obj):
     """
     Elementwise maximum of arrays.
     MXNet maximum compares only two symbols at a time.
@@ -107,7 +117,7 @@ def maximum(attrs, inputs, cls):
         mxnet_op = inputs[0]
     return mxnet_op, attrs, inputs
 
-def minimum(attrs, inputs, cls):
+def minimum(attrs, inputs, proto_obj):
     """Elementwise minimum of arrays."""
     # MXNet minimum compares only two symbols at a time.
     # ONNX can send more than two to compare.
@@ -121,36 +131,35 @@ def minimum(attrs, inputs, cls):
     return mxnet_op, attrs, inputs
 
 #Hyperbolic functions
-def tanh(attrs, inputs, cls):
+def tanh(attrs, inputs, proto_obj):
     """Returns the hyperbolic tangent of the input array."""
     return 'tanh', attrs, inputs
 
 # Rounding
-def ceil(attrs, inputs, cls):
+def ceil(attrs, inputs, proto_obj):
     """ Calculate ceil value for input """
     return 'ceil', attrs, inputs
 
-def floor(attrs, inputs, cls):
+def floor(attrs, inputs, proto_obj):
     """ Calculate floor value for input """
     return 'floor', attrs, inputs
 
 # Joining and spliting
-def concat(attrs, inputs, cls):
+def concat(attrs, inputs, proto_obj):
     """ Joins input arrays along a given axis. """
     new_attrs = translation_utils._fix_attribute_names(attrs, {'axis': 'dim'})
     return 'concat', new_attrs, inputs
 
-
 # Basic neural network functions
-def sigmoid(attrs, inputs, cls):
+def sigmoid(attrs, inputs, proto_obj):
     """Computes elementwise sigmoid of the input array"""
     return 'sigmoid', attrs, inputs
 
-def relu(attrs, inputs, cls):
+def relu(attrs, inputs, proto_obj):
     """Computes rectified linear function."""
     return 'relu', attrs, inputs
 
-def pad(attrs, inputs, cls):
+def pad(attrs, inputs, proto_obj):
     """ Add padding to input tensor"""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'pads'  : 'pad_width',
                                                                'value' : 'constant_value'
@@ -158,24 +167,23 @@ def pad(attrs, inputs, cls):
     new_attrs['pad_width'] = translation_utils._pad_sequence_fix(new_attrs.get('pad_width'))
     return 'pad', new_attrs, inputs
 
-def matrix_multiplication(attrs, inputs, cls):
+def matrix_multiplication(attrs, inputs, proto_obj):
     """Performs general matrix multiplication"""
     return 'linalg_gemm2', attrs, inputs
 
-def batch_norm(attrs, inputs, cls):
+def batch_norm(attrs, inputs, proto_obj):
     """Batch normalization."""
-    new_attrs = translation_utils._fix_attribute_names(attrs, {'epsilon' : 'eps',
-                                                               'is_test':'fix_gamma'})
+    new_attrs = translation_utils._fix_attribute_names(attrs, {'epsilon': 'eps',
+                                                               'is_test': 'fix_gamma'})
     new_attrs = translation_utils._remove_attributes(new_attrs,
                                                      ['spatial', 'consumed_inputs'])
     new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': 1})
 
     # in test mode "fix_gamma" should be unset.
-    new_attrs['fix_gamma'] = 0 if new_attrs['fix_gamma'] == 1 else 1
+    new_attrs['fix_gamma'] = not attrs.get('is_test', 1)
     return 'BatchNorm', new_attrs, inputs
 
-
-def leaky_relu(attrs, inputs, cls):
+def leaky_relu(attrs, inputs, proto_obj):
     """Leaky Relu function"""
     if 'alpha' in attrs:
         new_attrs = translation_utils._fix_attribute_names(attrs, {'alpha' : 'slope'})
@@ -183,7 +191,7 @@ def leaky_relu(attrs, inputs, cls):
         new_attrs = translation_utils._add_extra_attributes(attrs, {'slope': 0.01})
     return 'LeakyReLU', new_attrs, inputs
 
-def _elu(attrs, inputs, cls):
+def _elu(attrs, inputs, proto_obj):
     """Elu function"""
     if 'alpha' in attrs:
         new_attrs = translation_utils._fix_attribute_names(attrs, {'alpha' : 'slope'})
@@ -192,18 +200,18 @@ def _elu(attrs, inputs, cls):
     new_attrs = translation_utils._add_extra_attributes(new_attrs, {'act_type': 'elu'})
     return 'LeakyReLU', new_attrs, inputs
 
-def _prelu(attrs, inputs, cls):
+def _prelu(attrs, inputs, proto_obj):
     """PRelu function"""
     new_attrs = translation_utils._add_extra_attributes(attrs, {'act_type': 'prelu'})
     return 'LeakyReLU', new_attrs, inputs
 
-def softmax(attrs, inputs, cls):
+def softmax(attrs, inputs, proto_obj):
     """Softmax function."""
     if 'axis' not in attrs:
         attrs = translation_utils._add_extra_attributes(attrs, {'axis': 1})
     return 'softmax', attrs, inputs
 
-def conv(attrs, inputs, cls):
+def conv(attrs, inputs, proto_obj):
     """Compute N-D convolution on (N+2)-D input."""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'kernel_shape' : 'kernel',
                                                                'strides' : 'stride',
@@ -213,7 +221,7 @@ def conv(attrs, inputs, cls):
     new_attrs = translation_utils._add_extra_attributes(new_attrs, {'num_group' : 1})
     new_attrs = translation_utils._fix_bias('Convolution', new_attrs, len(inputs))
 
-    new_attrs = translation_utils._fix_channels('Convolution', new_attrs, inputs, cls)
+    new_attrs = translation_utils._fix_channels('Convolution', new_attrs, inputs, proto_obj)
     kernel = new_attrs['kernel']
     stride = new_attrs['stride'] if 'stride' in new_attrs else []
     padding = new_attrs['pad'] if 'pad' in new_attrs else []
@@ -234,7 +242,7 @@ def conv(attrs, inputs, cls):
 
     return conv_op, new_attrs, inputs
 
-def deconv(attrs, inputs, cls):
+def deconv(attrs, inputs, proto_obj):
     """Computes transposed convolution of the input tensor."""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'kernel_shape' : 'kernel',
                                                                'strides' : 'stride',
@@ -244,7 +252,7 @@ def deconv(attrs, inputs, cls):
     new_attrs = translation_utils._add_extra_attributes(new_attrs, {'num_group' : 1})
     new_attrs = translation_utils._fix_bias('Deconvolution', new_attrs, len(inputs))
 
-    new_attrs = translation_utils._fix_channels('Deconvolution', new_attrs, inputs, cls)
+    new_attrs = translation_utils._fix_channels('Deconvolution', new_attrs, inputs, proto_obj)
     kernel = new_attrs['kernel']
     stride = new_attrs['stride'] if 'stride' in new_attrs else []
     padding = new_attrs['pad'] if 'pad' in new_attrs else []
@@ -265,18 +273,18 @@ def deconv(attrs, inputs, cls):
 
     return deconv_op, new_attrs, inputs
 
-def fully_connected(attrs, inputs, cls):
+def fully_connected(attrs, inputs, proto_obj):
     """Applies a linear transformation: Y=XWT+b."""
     new_attrs = translation_utils._remove_attributes(attrs, ['axis'])
 
     new_attrs = translation_utils._fix_bias('FullyConnected', new_attrs, len(inputs))
 
-    new_attrs = translation_utils._fix_channels('FullyConnected', new_attrs, inputs, cls)
+    new_attrs = translation_utils._fix_channels('FullyConnected', new_attrs, inputs, proto_obj)
 
     return 'FullyConnected', new_attrs, inputs
 
 
-def global_maxpooling(attrs, inputs, cls):
+def global_maxpooling(attrs, inputs, proto_obj):
     """Performs max pooling on the input."""
     new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True,
                                                                 'kernel': (1, 1),
@@ -284,7 +292,7 @@ def global_maxpooling(attrs, inputs, cls):
     return 'Pooling', new_attrs, inputs
 
 
-def global_avgpooling(attrs, inputs, cls):
+def global_avgpooling(attrs, inputs, proto_obj):
     """Performs avg pooling on the input."""
     new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True,
                                                                 'kernel': (1, 1),
@@ -292,7 +300,7 @@ def global_avgpooling(attrs, inputs, cls):
     return 'Pooling', new_attrs, inputs
 
 
-def linalg_gemm(attrs, inputs, cls):
+def linalg_gemm(attrs, inputs, proto_obj):
     """Performs general matrix multiplication and accumulation"""
     trans_a = 0
     trans_b = 0
@@ -316,17 +324,17 @@ def linalg_gemm(attrs, inputs, cls):
     new_attrs = translation_utils._remove_attributes(new_attrs, ['broadcast'])
     return gemm_op, new_attrs, inputs
 
-def local_response_norm(attrs, inputs, cls):
+def local_response_norm(attrs, inputs, proto_obj):
     """Local Response Normalization."""
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'bias': 'knorm',
                                                         'size' : 'nsize'})
     return 'LRN', new_attrs, inputs
 
-def dropout(attrs, inputs, cls):
+def dropout(attrs, inputs, proto_obj):
     """Dropout Regularization."""
     mode = 'training'
-    if attrs['is_test'] == 0:
+    if 'is_test' in attrs and attrs['is_test'] == 0:
         mode = 'always'
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'ratio': 'p'})
@@ -335,23 +343,28 @@ def dropout(attrs, inputs, cls):
     return 'Dropout', new_attrs, inputs
 
 # Changing shape and type.
-def reshape(attrs, inputs, cls):
+def reshape(attrs, inputs, proto_obj):
     """Reshape the given array by the shape attribute."""
-    return 'reshape', attrs, inputs
-
-def cast(attrs, inputs, cls):
+    if len(inputs) == 1:
+        return 'reshape', attrs, inputs[0]
+    reshape_shape = list(proto_obj._params[inputs[1].name].asnumpy())
+    reshape_shape = [int(i) for i in reshape_shape]
+    new_attrs = {'shape': reshape_shape}
+    return 'reshape', new_attrs, inputs[:1]
+
+def cast(attrs, inputs, proto_obj):
     """ Cast input to a given dtype"""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'to' : 'dtype'})
     new_attrs['dtype'] = new_attrs['dtype'].lower()
     return 'cast', new_attrs, inputs
 
-def split(attrs, inputs, cls):
+def split(attrs, inputs, proto_obj):
     """Splits an array along a particular axis into multiple sub-arrays."""
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'split' : 'num_outputs'})
     return 'split', new_attrs, inputs
 
-def _slice(attrs, inputs, cls):
+def _slice(attrs, inputs, proto_obj):
     """Returns a slice of the input tensor along multiple axes."""
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'axes' : 'axis',
@@ -368,13 +381,13 @@ def _slice(attrs, inputs, cls):
             slice_op = symbol.slice_axis(slice_op, axis=axis, begin=begin[i], end=end[i])
     return slice_op, new_attrs, inputs
 
-def transpose(attrs, inputs, cls):
+def transpose(attrs, inputs, proto_obj):
     """Transpose the input array."""
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'perm' : 'axes'})
     return 'transpose', new_attrs, inputs
 
-def squeeze(attrs, inputs, cls):
+def squeeze(attrs, inputs, proto_obj):
     """Remove single-dimensional entries from the shape of a tensor."""
     # MXNet doesnt have a squeeze operator.
     # Using "split" to perform similar operation.
@@ -387,7 +400,7 @@ def squeeze(attrs, inputs, cls):
     return mxnet_op, new_attrs, inputs
 
 
-def flatten(attrs, inputs, cls):
+def flatten(attrs, inputs, proto_obj):
     """Flattens the input array into a 2-D array by collapsing the higher dimensions."""
     #Mxnet does not have axis support. By default uses axis=1
     if 'axis' in attrs and attrs['axis'] != 1:
@@ -396,15 +409,15 @@ def flatten(attrs, inputs, cls):
     return 'Flatten', new_attrs, inputs
 
 #Powers
-def reciprocal(attrs, inputs, cls):
+def reciprocal(attrs, inputs, proto_obj):
     """Returns the reciprocal of the argument, element-wise."""
     return 'reciprocal', attrs, inputs
 
-def squareroot(attrs, inputs, cls):
+def squareroot(attrs, inputs, proto_obj):
     """Returns element-wise square-root value of the input."""
     return 'sqrt', attrs, inputs
 
-def power(attrs, inputs, cls):
+def power(attrs, inputs, proto_obj):
     """Returns element-wise result of base element raised to powers from exp element."""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'exponent':'exp'})
     if 'broadcast' in attrs and attrs['broadcast'] == 1:
@@ -412,41 +425,41 @@ def power(attrs, inputs, cls):
         return 'broadcast_power', new_attrs, inputs
     return 'pow', new_attrs, inputs
 
-def exponent(attrs, inputs, cls):
+def exponent(attrs, inputs, proto_obj):
     """Elementwise exponent of input array."""
     return 'exp', attrs, inputs
 
-def _log(attrs, inputs, cls):
+def _log(attrs, inputs, proto_obj):
     """Elementwise log of input array."""
     return 'log', attrs, inputs
 
 # Reduce Functions
-def reduce_max(attrs, inputs, cls):
+def reduce_max(attrs, inputs, proto_obj):
     """Reduce the array along a given axis by maximum value"""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
     return 'max', new_attrs, inputs
 
-def reduce_mean(attrs, inputs, cls):
+def reduce_mean(attrs, inputs, proto_obj):
     """Reduce the array along a given axis by mean value"""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
     return 'mean', new_attrs, inputs
 
-def reduce_min(attrs, inputs, cls):
+def reduce_min(attrs, inputs, proto_obj):
     """Reduce the array along a given axis by mean value"""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
     return 'min', new_attrs, inputs
 
-def reduce_sum(attrs, inputs, cls):
+def reduce_sum(attrs, inputs, proto_obj):
     """Reduce the array along a given axis by mean value"""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
     return 'sum', new_attrs, inputs
 
-def reduce_prod(attrs, inputs, cls):
+def reduce_prod(attrs, inputs, proto_obj):
     """Reduce the array along a given axis by mean value"""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
     return 'prod', new_attrs, inputs
 
-def avg_pooling(attrs, inputs, cls):
+def avg_pooling(attrs, inputs, proto_obj):
     """ Average pooling"""
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'kernel_shape': 'kernel',
@@ -461,7 +474,7 @@ def avg_pooling(attrs, inputs, cls):
     return new_op, new_attrs, inputs
 
 
-def max_pooling(attrs, inputs, cls):
+def max_pooling(attrs, inputs, proto_obj):
     """ Average pooling"""
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'kernel_shape': 'kernel',
diff --git a/python/mxnet/contrib/onnx/_import/translation_utils.py b/python/mxnet/contrib/onnx/_import/translation_utils.py
index 1d84bd70d7e..fe25a94baa7 100644
--- a/python/mxnet/contrib/onnx/_import/translation_utils.py
+++ b/python/mxnet/contrib/onnx/_import/translation_utils.py
@@ -20,6 +20,10 @@
 # pylint: disable=protected-access
 from __future__ import absolute_import as _abs
 from .... import symbol
+from .... import  module
+from .... import  context
+from .... import  ndarray as nd
+from .... import  io
 
 
 def _fix_attribute_names(attrs, change_map):
@@ -148,30 +152,31 @@ def _fix_bias(op_name, attrs, num_inputs):
         raise ValueError("Unexpected number of inputs for: {}".format(op_name))
     return attrs
 
-def _fix_bias_shape(op_name, inputs, cls):
+def _fix_broadcast(op_name, inputs, broadcast_axis, proto_obj):
     """A workaround to reshape bias term to (1, num_channel)."""
-    if int(len(cls._params)) > 0:
+    if int(len(proto_obj._params)) > 0:
         assert len(list(inputs)) == 2
 
-        op_sym = symbol.reshape(inputs[1], shape=(1, -1, 1, 1))
-        if op_name == 'broadcast_add':
-            op_sym = symbol.broadcast_add(op_sym, inputs[0])
-        elif op_name == 'broadcast_mul':
-            op_sym = symbol.broadcast_mul(op_sym, inputs[0])
+        input0_shape = get_input_shape(inputs[0], proto_obj)
+        #creating reshape shape
+        reshape_shape = list(len(input0_shape) * (1,))
+        reshape_shape[broadcast_axis] = -1
+        reshape_shape = tuple(reshape_shape)
+        reshape_op_sym = symbol.reshape(inputs[1], shape=reshape_shape)
+        op_sym = getattr(symbol, op_name)(inputs[0], reshape_op_sym)
     else:
         op_sym = op_name
     return op_sym
 
-
-def _fix_channels(op_name, attrs, inputs, cls):
+def _fix_channels(op_name, attrs, inputs, proto_obj):
     """A workaround for getting 'channels' or 'units' since onnx don't provide
     these attributes. We check the shape of weights provided to get the number.
     """
     weight_name = inputs[1].name
-    if not weight_name in cls._params:
+    if not weight_name in proto_obj._params:
         raise ValueError("Unable to get channels/units attr from onnx graph.")
     else:
-        wshape = cls._params[weight_name].shape
+        wshape = proto_obj._params[weight_name].shape
         assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape)
 
         if op_name == 'FullyConnected':
@@ -188,7 +193,7 @@ def _fix_channels(op_name, attrs, inputs, cls):
     return attrs
 
 
-def _fix_gemm(op_name, inputs, old_attr, cls):
+def _fix_gemm(op_name, inputs, old_attr, proto_obj):
     """Using FullyConnected operator in place of linalg_gemm to perform same operation"""
     op_sym = getattr(symbol, op_name, None)
     alpha = float(old_attr.get('alpha', 1.0))
@@ -200,5 +205,38 @@ def _fix_gemm(op_name, inputs, old_attr, cls):
     if not trans_b:
         inputs[1] = symbol.transpose(inputs[1], axes=(1, 0))
     new_inputs = [alpha*inputs[0], inputs[1], beta*inputs[2]]
-    new_attr = {'num_hidden' : cls._params[inputs[2].name].shape[0]}
+    new_attr = {'num_hidden' : proto_obj._params[inputs[2].name].shape[0]}
     return op_sym, new_attr, new_inputs
+
+def get_input_shape(sym, proto_obj):
+    """Helper function to obtain the shape of an array"""
+    arg_params = proto_obj.arg_dict
+    aux_params = proto_obj.aux_dict
+
+    model_input_shape = [data[1] for data  in proto_obj.model_metadata.get('input_tensor_data')]
+    data_names = [data[0] for data  in proto_obj.model_metadata.get('input_tensor_data')]
+
+    #creating dummy inputs
+    inputs = []
+    for  in_shape in model_input_shape:
+        inputs.append(nd.ones(shape=in_shape))
+
+    data_shapes = []
+    for idx, input_name in enumerate(data_names):
+        data_shapes.append((input_name, inputs[idx].shape))
+
+    ctx = context.cpu()
+    # create a module
+    mod = module.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
+    mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
+    mod.set_params(arg_params=arg_params, aux_params=aux_params)
+
+    data_forward = []
+    for idx, input_name in enumerate(data_names):
+        val = inputs[idx]
+        data_forward.append(val)
+
+    mod.forward(io.DataBatch(data_forward))
+    result = mod.get_outputs()[0].asnumpy()
+
+    return result.shape
diff --git a/tests/python-pytest/onnx/backend.py b/tests/python-pytest/onnx/backend.py
deleted file mode 100644
index 0e0a6a680b7..00000000000
--- a/tests/python-pytest/onnx/backend.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# coding: utf-8
-"""backend wrapper for onnx test infrastructure"""
-import mxnet as mx
-from mxnet.contrib.onnx._import.import_onnx import GraphProto
-try:
-    from onnx import helper, TensorProto
-    from onnx.backend.base import Backend
-except ImportError:
-    raise ImportError("Onnx and protobuf need to be installed")
-from backend_rep import MXNetBackendRep
-
-# Using these functions for onnx test infrastructure.
-# Implemented by following onnx docs guide:
-# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
-# MXNetBackend class will take an ONNX model with inputs, perform a computation,
-# and then return the output.
-
-class MXNetBackend(Backend):
-    """MXNet backend for ONNX"""
-
-    @staticmethod
-    def make_graph(node, inputs):
-        """ Created ONNX GraphProto from node"""
-        initializer = []
-        tensor_input_info = []
-        tensor_output_info = []
-
-        # Adding input tensor info.
-        for index in range(len(node.input)):
-            tensor_input_info.append(
-                helper.make_tensor_value_info(str(node.input[index]), TensorProto.FLOAT, [1]))
-
-            # Creating an initializer for Weight params.
-            # Assumes that weight params is named as 'W'.
-            if node.input[index] == 'W':
-                dim = inputs[index].shape
-                param_tensor = helper.make_tensor(
-                    name=node.input[index],
-                    data_type=TensorProto.FLOAT,
-                    dims=dim,
-                    vals=inputs[index].flatten())
-
-                initializer.append(param_tensor)
-
-        # Adding output tensor info.
-        for index in range(len(node.output)):
-            tensor_output_info.append(
-                helper.make_tensor_value_info(str(node.output[index]), TensorProto.FLOAT, [1]))
-
-        # creating graph proto object.
-        graph_proto = helper.make_graph(
-            [node],
-            "test",
-            tensor_input_info,
-            tensor_output_info,
-            initializer=initializer)
-
-        return graph_proto
-
-    @classmethod
-    def run_node(cls, node, inputs, device='CPU'):
-        """Running individual node inference on mxnet engine and
-        return the result to onnx test infrastructure.
-
-        Parameters
-        ----------
-        node   : onnx node object
-            loaded onnx node (individual layer)
-        inputs : numpy array
-            input to run a node on
-        device : 'CPU'
-            device to run a node on
-
-        Returns
-        -------
-        params : numpy array
-            result obtained after running the operator
-        """
-        graph = GraphProto()
-        sym, arg_params, aux_params = graph.from_onnx(MXNetBackend.make_graph(node, inputs))
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in aux_params]
-        data_shapes = []
-        dim_change_op_types = set(['ReduceMin', 'ReduceMax', 'ReduceMean',
-                                   'ReduceProd', 'ReduceSum', 'Slice', 'Pad',
-                                   'Squeeze', 'Upsample', 'Reshape', 'Conv',
-                                   'Concat', 'Softmax', 'Flatten', 'Transpose',
-                                   'GlobalAveragePool', 'GlobalMaxPool'])
-
-        # Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
-        for idx, input_name in enumerate(data_names):
-            batch_size = 1
-            if len(inputs) > 1 and len(inputs[idx].shape) < 4 and  \
-                            len(set(x.shape[0] for x in inputs)) != 1:
-                tuples = ((batch_size,), inputs[idx].shape)
-                new_shape = sum(tuples, ())
-                data_shapes.append((input_name, new_shape))
-            else:
-                data_shapes.append((input_name, inputs[idx].shape))
-
-        # create module, passing cpu context
-        if device == 'CPU':
-            ctx = mx.cpu()
-        else:
-            raise NotImplementedError("Only CPU context is supported for now")
-
-        # create a module
-        mod = mx.mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
-        mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
-
-        # initializing parameters for calculating result of each individual node
-        if arg_params is None and aux_params is None:
-            mod.init_params()
-        else:
-            mod.set_params(arg_params=arg_params, aux_params=aux_params)
-
-        data_forward = []
-        for idx, input_name in enumerate(data_names):
-            # slice and pad operator tests needs 1 less dimension in forward pass
-            # otherwise it will throw an error.
-            # for squeeze operator, need to retain shape of input as provided
-            val = inputs[idx]
-            if node.op_type in dim_change_op_types:
-                data_forward.append(mx.nd.array(val))
-            else:
-                data_forward.append(mx.nd.array([val]))
-
-        mod.forward(mx.io.DataBatch(data_forward))
-        result = mod.get_outputs()[0].asnumpy()
-        if node.op_type in dim_change_op_types:
-            return [result]
-        return result
-
-    @classmethod
-    def prepare(cls, model, device='CPU', **kwargs):
-        """For running end to end model(used for onnx test backend)
-
-        Parameters
-        ----------
-        model  : onnx ModelProto object
-            loaded onnx graph
-        device : 'CPU'
-            specifying device to run test on
-        kwargs :
-            other arguments
-
-        Returns
-        -------
-        MXNetBackendRep : object
-            Returns object of MXNetBackendRep class which will be in turn
-            used to run inference on the input model and return the result for comparison.
-        """
-        graph = GraphProto()
-        sym, arg_params, aux_params = graph.from_onnx(model.graph)
-        return MXNetBackendRep(sym, arg_params, aux_params, device)
-
-    @classmethod
-    def supports_device(cls, device):
-        """Supports only CPU for testing"""
-        return device == 'CPU'
-
-prepare = MXNetBackend.prepare
-
-run_node = MXNetBackend.run_node
-
-supports_device = MXNetBackend.supports_device
diff --git a/tests/python-pytest/onnx/import/gluon_backend.py b/tests/python-pytest/onnx/import/gluon_backend.py
new file mode 100644
index 00000000000..d2946f7bb54
--- /dev/null
+++ b/tests/python-pytest/onnx/import/gluon_backend.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Gluon backend wrapper for onnx test infrastructure"""
+import mxnet as mx
+from mxnet import nd
+from mxnet.contrib.onnx._import.import_onnx import GraphProto
+import numpy as np
+try:
+    from onnx import helper, TensorProto
+    from onnx.backend.base import Backend
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+                      + " install - https://github.com/onnx/onnx#installation")
+from gluon_backend_rep import GluonBackendRep
+
+# GluonBackend class will take an ONNX model with inputs, perform a computation,
+# and then return the output.
+# Implemented by following onnx docs guide:
+# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
+
+class GluonBackend(Backend):
+    """Gluon backend for ONNX"""
+
+    @classmethod
+    def prepare(cls, model, device='CPU', **kwargs):
+        """For running end to end model(used for onnx test backend)
+
+        Parameters
+        ----------
+        model  : onnx ModelProto object
+            loaded onnx graph
+        device : 'CPU'
+            specifying device to run test on
+        kwargs :
+            other arguments
+
+        Returns
+        -------
+        GluonBackendRep : object
+            Returns object of GluonBackendRep class which will be in turn
+            used to run inference on the input model and return the result for comparison.
+        """
+        graph = GraphProto()
+        net = graph.graph_to_gluon(model.graph, device)
+        return GluonBackendRep(net, device)
+
+    @classmethod
+    def supports_device(cls, device):
+        """Supports only CPU for testing"""
+        return device == 'CPU'
+
+prepare = GluonBackend.prepare
+
+supports_device = GluonBackend.supports_device
diff --git a/tests/python-pytest/onnx/import/gluon_backend_rep.py b/tests/python-pytest/onnx/import/gluon_backend_rep.py
new file mode 100644
index 00000000000..a90d350c8cd
--- /dev/null
+++ b/tests/python-pytest/onnx/import/gluon_backend_rep.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""gluon backend rep for onnx test infrastructure"""
+import numpy as np
+try:
+    from onnx.backend.base import BackendRep
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+                      + " install - https://github.com/onnx/onnx#installation")
+import mxnet as mx
+from mxnet import nd
+
+# GluonBackendRep object will be returned by GluonBackend's prepare method which is used to
+# execute a model repeatedly.
+# Inputs will be passed to the run method of MXNetBackendRep class, it will perform computation and
+# retrieve the corresponding results for comparison to the onnx backend.
+# https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py.
+# Implemented by following onnx docs guide:
+# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
+
+class GluonBackendRep(BackendRep):
+    """Running model inference on gluon backend and return the result
+     to onnx test infrastructure for comparison."""
+    def __init__(self, net, device):
+        self.net = net
+        self.device = device
+
+    def run(self, inputs, **kwargs):
+        """Run model inference and return the result
+
+        Parameters
+        ----------
+        inputs : numpy array
+            input to run a layer on
+
+        Returns
+        -------
+        params : numpy array
+            result obtained after running the inference on mxnet
+        """
+        # create module, passing cpu context
+        if self.device == 'CPU':
+            ctx = mx.cpu()
+        else:
+            raise NotImplementedError("ONNX tests are run only for CPU context.")
+
+        # run inference
+        net_inputs = [nd.array(input_data, ctx=ctx) for input_data in inputs]
+        net_outputs = self.net(*net_inputs)
+        results = []
+        results.extend([o for o in net_outputs.asnumpy()])
+        result = np.array(results)
+
+        return [result]
diff --git a/tests/python-pytest/onnx/import/gluon_backend_test.py b/tests/python-pytest/onnx/import/gluon_backend_test.py
new file mode 100644
index 00000000000..6dd5f8a071c
--- /dev/null
+++ b/tests/python-pytest/onnx/import/gluon_backend_test.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""ONNX test backend wrapper"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+try:
+    import onnx.backend.test
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+                      + " install - https://github.com/onnx/onnx#installation")
+
+import gluon_backend
+import test_cases
+
+# This is a pytest magic variable to load extra plugins
+pytest_plugins = "onnx.backend.test.report",
+
+BACKEND_TESTS = onnx.backend.test.BackendTest(gluon_backend, __name__)
+
+for op_tests in test_cases.IMPLEMENTED_OPERATORS_TEST:
+    BACKEND_TESTS.include(op_tests)
+
+for std_model_test in test_cases.STANDARD_MODEL:
+    BACKEND_TESTS.include(std_model_test)
+
+for basic_model_test in test_cases.BASIC_MODEL_TESTS:
+    BACKEND_TESTS.include(basic_model_test)
+
+BACKEND_TESTS.exclude('.*broadcast.*')
+BACKEND_TESTS.exclude('.*bcast.*')
+
+# import all test cases at global scope to make them visible to python.unittest
+globals().update(BACKEND_TESTS.enable_report().test_cases)
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/python-pytest/onnx/import/mxnet_backend.py b/tests/python-pytest/onnx/import/mxnet_backend.py
new file mode 100644
index 00000000000..bbe8899dee1
--- /dev/null
+++ b/tests/python-pytest/onnx/import/mxnet_backend.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""MXNet backend wrapper for onnx test infrastructure"""
+import mxnet as mx
+from mxnet.contrib.onnx._import.import_onnx import GraphProto
+try:
+    from onnx import helper, TensorProto
+    from onnx.backend.base import Backend
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+                      + " install - https://github.com/onnx/onnx#installation")
+from mxnet_backend_rep import MXNetBackendRep
+
+# MXNetBackend class will take an ONNX model with inputs, perform a computation,
+# and then return the output.
+# Implemented by following onnx docs guide:
+# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
+
+class MXNetBackend(Backend):
+    """MXNet backend for ONNX"""
+
+    @classmethod
+    def prepare(cls, model, device='CPU', **kwargs):
+        """For running end to end model(used for onnx test backend)
+
+        Parameters
+        ----------
+        model  : onnx ModelProto object
+            loaded onnx graph
+        device : 'CPU'
+            specifying device to run test on
+        kwargs :
+            other arguments
+
+        Returns
+        -------
+        MXNetBackendRep : object
+            Returns object of MXNetBackendRep class which will be in turn
+            used to run inference on the input model and return the result for comparison.
+        """
+        graph = GraphProto()
+        sym, arg_params, aux_params = graph.from_onnx(model.graph)
+        return MXNetBackendRep(sym, arg_params, aux_params, device)
+
+    @classmethod
+    def supports_device(cls, device):
+        """Supports only CPU for testing"""
+        return device == 'CPU'
+
+prepare = MXNetBackend.prepare
+
+supports_device = MXNetBackend.supports_device
diff --git a/tests/python-pytest/onnx/backend_rep.py b/tests/python-pytest/onnx/import/mxnet_backend_rep.py
similarity index 85%
rename from tests/python-pytest/onnx/backend_rep.py
rename to tests/python-pytest/onnx/import/mxnet_backend_rep.py
index 114a2eb7990..5ce29f54150 100644
--- a/tests/python-pytest/onnx/backend_rep.py
+++ b/tests/python-pytest/onnx/import/mxnet_backend_rep.py
@@ -16,18 +16,18 @@
 # under the License.
 
 # coding: utf-8
-"""backend rep for onnx test infrastructure"""
-from collections import namedtuple
+"""MXNet backend rep for onnx test infrastructure"""
 import numpy as np
 try:
     from onnx.backend.base import BackendRep
 except ImportError:
-    raise ImportError("Onnx and protobuf need to be installed")
+    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+                      + " install - https://github.com/onnx/onnx#installation")
 import mxnet as mx
 
 # Using these functions for onnx test infrastructure.
 # Implemented by following onnx docs guide:
-# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
+# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
 # MXNetBackendRep object will be returned by MXNetBackend's prepare method which is used to
 # execute a model repeatedly.
 # Inputs will be passed to the run method of MXNetBackendRep class, it will perform computation and
@@ -56,13 +56,14 @@ def run(self, inputs, **kwargs):
         params : numpy array
             result obtained after running the inference on mxnet
         """
-        input_data = np.asarray(inputs[0], dtype='f')
-
+        data_forward = []
+        for val in inputs:
+            data_forward.append(mx.nd.array(val))
         # create module, passing cpu context
         if self.device == 'CPU':
             ctx = mx.cpu()
         else:
-            raise NotImplementedError("Only CPU context is supported for now")
+            raise NotImplementedError("ONNX tests are run only for CPU context.")
 
         # To fetch the data names of the input to the model we list the inputs of the symbol graph
         # and exclude the argument and auxiliary parameters from the list
@@ -80,8 +81,6 @@ def run(self, inputs, **kwargs):
         mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
 
         # run inference
-        batch = namedtuple('Batch', ['data'])
-
-        mod.forward(batch([mx.nd.array(input_data)]))
+        mod.forward(mx.io.DataBatch(data_forward))
         result = mod.get_outputs()[0].asnumpy()
         return [result]
diff --git a/tests/python-pytest/onnx/import/mxnet_backend_test.py b/tests/python-pytest/onnx/import/mxnet_backend_test.py
new file mode 100644
index 00000000000..06ce681907f
--- /dev/null
+++ b/tests/python-pytest/onnx/import/mxnet_backend_test.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""ONNX test backend wrapper"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+try:
+    import onnx.backend.test
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+                      + " install - https://github.com/onnx/onnx#installation")
+
+import mxnet_backend
+import test_cases
+
+# This is a pytest magic variable to load extra plugins
+pytest_plugins = "onnx.backend.test.report",
+
+BACKEND_TESTS = onnx.backend.test.BackendTest(mxnet_backend, __name__)
+
+for op_tests in test_cases.IMPLEMENTED_OPERATORS_TEST:
+    BACKEND_TESTS.include(op_tests)
+
+for std_model_test in test_cases.STANDARD_MODEL:
+    BACKEND_TESTS.include(std_model_test)
+
+for basic_model_test in test_cases.BASIC_MODEL_TESTS:
+    BACKEND_TESTS.include(basic_model_test)
+
+BACKEND_TESTS.exclude('.*broadcast.*')
+BACKEND_TESTS.exclude('.*bcast.*')
+
+# import all test cases at global scope to make them visible to python.unittest
+globals().update(BACKEND_TESTS.enable_report().test_cases)
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/python-pytest/onnx/onnx_test.py b/tests/python-pytest/onnx/import/onnx_import_test.py
similarity index 68%
rename from tests/python-pytest/onnx/onnx_test.py
rename to tests/python-pytest/onnx/import/onnx_import_test.py
index b3718c9beb8..741ae1febb1 100644
--- a/tests/python-pytest/onnx/onnx_test.py
+++ b/tests/python-pytest/onnx/import/onnx_import_test.py
@@ -39,107 +39,52 @@
 from mxnet.contrib import onnx as onnx_mxnet
 import mxnet as mx
 CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.insert(0, os.path.join(CURR_PATH, '../../python/unittest'))
+sys.path.insert(0, os.path.join(CURR_PATH, '../../../python/unittest'))
 from common import with_seed
-import backend as mxnet_backend
+import mxnet_backend
 
 
 URLS = {
     'bvlc_googlenet' :
-        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_googlenet.tar.gz',
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/opset7/bvlc_googlenet.tar.gz',
     'bvlc_reference_caffenet' :
-        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_caffenet.tar.gz',
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/opset7/bvlc_reference_caffenet.tar.gz',
     'bvlc_reference_rcnn_ilsvrc13' :
-        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_rcnn_ilsvrc13.tar.gz',
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/opset7/bvlc_reference_rcnn_ilsvrc13.tar.gz',
 }
 
 @with_seed()
-def test_reduce_max():
-    """Test for ReduceMax operator"""
-    node_def = helper.make_node("ReduceMax", ["input1"], ["output"], axes=[1, 0], keepdims=1)
-    input1 = np.random.ranf([3, 10]).astype("float32")
-    output = mxnet_backend.run_node(node_def, [input1])[0]
-    numpy_op = np.max(input1, axis=(1, 0), keepdims=True)
-    npt.assert_almost_equal(output, numpy_op)
+def test_broadcast():
+    """Test for broadcasting in onnx operators."""
+    input1 = np.random.rand(1, 3, 4, 5).astype("float32")
+    input2 = np.random.rand(1, 5).astype("float32")
+    inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, shape=(1, 3, 4, 5)),
+              helper.make_tensor_value_info("input2", TensorProto.FLOAT, shape=(1, 5))]
 
-@with_seed()
-def test_reduce_mean():
-    """Test for ReduceMean operator"""
-    node_def = helper.make_node("ReduceMean", ["input1"], ["output"], axes=[1, 0], keepdims=1)
-    input1 = np.random.ranf([3, 10]).astype("float32")
-    output = mxnet_backend.run_node(node_def, [input1])[0]
-    numpy_op = np.mean(input1, axis=(1, 0), keepdims=True)
-    npt.assert_almost_equal(output, numpy_op, decimal=5)
-
-@with_seed()
-def test_reduce_min():
-    """Test for ReduceMin operator"""
-    node_def = helper.make_node("ReduceMin", ["input1"], ["output"], axes=[1, 0], keepdims=1)
-    input1 = np.random.ranf([3, 10]).astype("float32")
-    output = mxnet_backend.run_node(node_def, [input1])[0]
-    numpy_op = np.min(input1, axis=(1, 0), keepdims=True)
-    npt.assert_almost_equal(output, numpy_op)
+    outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=(1, 3, 4, 5))]
 
-@with_seed()
-def test_reduce_sum():
-    """Test for ReduceSum operator"""
-    node_def = helper.make_node("ReduceSum", ["input1"], ["output"], axes=[1, 0], keepdims=1)
-    input1 = np.random.ranf([3, 10]).astype("float32")
-    output = mxnet_backend.run_node(node_def, [input1])[0]
-    numpy_op = np.sum(input1, axis=(1, 0), keepdims=True)
-    npt.assert_almost_equal(output, numpy_op, decimal=5)
+    nodes = [helper.make_node("Add", ["input1", "input2"], ["output"])]
 
-@with_seed()
-def test_reduce_prod():
-    """Test for ReduceProd operator"""
-    node_def = helper.make_node("ReduceProd", ["input1"], ["output"], axes=[1, 0], keepdims=1)
-    input1 = np.random.ranf([3, 10]).astype("float32")
-    output = mxnet_backend.run_node(node_def, [input1])[0]
-    numpy_op = np.prod(input1, axis=(1, 0), keepdims=True)
-    npt.assert_almost_equal(output, numpy_op, decimal=5)
+    graph = helper.make_graph(nodes,
+                              "bcast_test",
+                              inputs,
+                              outputs)
 
-@with_seed()
-def test_squeeze():
-    """Test for Squeeze operator"""
-    node_def = helper.make_node("Squeeze", ["input1"], ["output"], axes=[1, 3])
-    input1 = np.random.ranf([3, 1, 2, 1, 4]).astype("float32")
-    output = mxnet_backend.run_node(node_def, [input1])[0]
-    npt.assert_almost_equal(output, np.squeeze(input1, axis=[1, 3]))
+    bcast_model = helper.make_model(graph)
+    
+    bkd_rep = mxnet_backend.prepare(bcast_model)
+    numpy_op = input1 + input2
+    output = bkd_rep.run([input1, input2])
+    npt.assert_almost_equal(output[0], numpy_op)
 
 def test_super_resolution_example():
     """Test the super resolution example in the example/onnx folder"""
-    sys.path.insert(0, os.path.join(CURR_PATH, '../../../example/onnx/'))
+    sys.path.insert(0, os.path.join(CURR_PATH, '../../../../example/onnx/'))
     import super_resolution
 
     sym, arg_params, aux_params = super_resolution.import_onnx()
-    assert sym is not None
-    assert arg_params is not None
-
-    inputs = sym.list_inputs()
-    assert len(inputs) == 9
-    for i, input_param in enumerate(['9', '7', '5', '3', '1', '2', '4', '6', '8']):
-        assert inputs[i] == input_param
-
-    assert len(sym.list_outputs()) == 1
-    assert sym.list_outputs()[0] == 'reshape5_output'
-
-    attrs_keys = sym.attr_dict().keys()
-    assert len(attrs_keys) == 23
-    for i, key_item in enumerate(['reshape4', 'convolution2', 'convolution0',
-                                  'transpose0', '6', 'reshape0', 'reshape2',
-                                  'reshape3', '3', 'reshape1', '5', '4', '7',
-                                  'convolution1', '9', '2', 'convolution3',
-                                  'reshape5', '8', 'pad1', 'pad0', 'pad3',
-                                  'pad2']):
-        assert key_item in attrs_keys
-
-    param_keys = arg_params.keys()
-    assert len(param_keys) == 8
-    for i, param_item in enumerate(['3', '2', '5', '4', '7', '6', '9', '8']):
-        assert param_item in param_keys
 
     logging.info("Asserted the result of the onnx model conversion")
-
     output_img_dim = 672
     input_image, img_cb, img_cr = super_resolution.get_test_image()
     result_img = super_resolution.perform_inference(sym, arg_params, aux_params,
diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py
new file mode 100644
index 00000000000..d408930970b
--- /dev/null
+++ b/tests/python-pytest/onnx/import/test_cases.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test Cases to be run for the import module"""
+
+IMPLEMENTED_OPERATORS_TEST = [
+    'test_random_uniform',
+    'test_random_normal',
+    'test_add',
+    'test_sub',
+    'test_mul',
+    'test_div',
+    'test_neg',
+    'test_abs',
+    'test_sum',
+    'test_tanh',
+    'test_ceil',
+    'test_floor',
+    'test_concat',
+    'test_sigmoid',
+    'test_relu',
+    'test_constant_pad',
+    'test_edge_pad',
+    'test_reflect_pad',
+    'test_reduce_min',
+    'test_reduce_max',
+    'test_reduce_mean',
+    'test_reduce_prod',
+    'test_squeeze',
+    'test_softmax_example',
+    'test_softmax_large_number',
+    'test_softmax_axis_2',
+    'test_transpose',
+    'test_globalmaxpool',
+    'test_globalaveragepool',
+    'test_slice_cpu',
+    'test_slice_neg',
+    'test_squeeze_',
+    'test_reciprocal',
+    'test_sqrt',
+    'test_pow',
+    'test_exp',
+    'test_argmax',
+    'test_argmin',
+    'test_min',
+    #pytorch operator tests
+    'test_operator_exp',
+    'test_operator_maxpool',
+    'test_operator_params',
+    'test_operator_permute2'
+    ]
+
+BASIC_MODEL_TESTS = [
+    'test_AvgPool2D',
+    'test_BatchNorm',
+    'test_ConstantPad2d'
+    'test_Conv2d',
+    'test_ELU',
+    'test_LeakyReLU',
+    'test_MaxPool',
+    'test_PReLU',
+    'test_ReLU',
+    'test_Sigmoid',
+    'test_Softmax',
+    'test_softmax_functional',
+    'test_softmax_lastdim',
+    'test_Tanh'
+    ]
+
+STANDARD_MODEL = [
+    'test_bvlc_alexnet',
+    'test_densenet121',
+    #'test_inception_v1',
+    #'test_inception_v2',
+    'test_resnet50',
+    #'test_shufflenet',
+    'test_squeezenet',
+    'test_zfnet512',
+    'test_vgg19'
+    ]
diff --git a/tests/python-pytest/onnx/onnx_backend_test.py b/tests/python-pytest/onnx/onnx_backend_test.py
deleted file mode 100644
index 4ea31e5aac9..00000000000
--- a/tests/python-pytest/onnx/onnx_backend_test.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""ONNX test backend wrapper"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-import unittest
-try:
-    import onnx.backend.test
-except ImportError:
-    raise ImportError("Onnx and protobuf need to be installed")
-
-import backend as mxnet_backend
-
-# This is a pytest magic variable to load extra plugins
-pytest_plugins = "onnx.backend.test.report",
-
-BACKEND_TEST = onnx.backend.test.BackendTest(mxnet_backend, __name__)
-
-IMPLEMENTED_OPERATORS_TEST = [
-    #Generator Functions
-    #'test_constant*', # Identity Function
-    #'test_random_uniform',
-    #'test_random_normal',
-
-    #Arithmetic Operators
-    'test_add',
-    'test_sub',
-    'test_mul',
-    'test_div',
-    'test_neg',
-    'test_abs',
-    'test_sum',
-
-    #Hyperbolic functions
-    'test_tanh',
-
-    #Rounding
-    'test_ceil',
-    'test_floor',
-
-    ## Joining and spliting
-    'test_concat',
-
-    #Basic neural network functions
-    'test_sigmoid',
-    'test_relu',
-    'test_constant_pad',
-    'test_edge_pad',
-    'test_reflect_pad',
-    'test_matmul',
-    'test_leakyrelu',
-    'test_elu',
-    'test_softmax_example',
-    'test_softmax_large_number',
-    'test_softmax_axis_2',
-    'test_conv',
-    'test_basic_conv',
-    'test_transpose',
-    'test_globalmaxpool',
-    'test_globalaveragepool',
-    #'test_batch_norm', - tests to be added
-    #'test_gather',
-
-    #Changing shape and type.
-    'test_reshape_',
-    'test_cast',
-    #'test_split',
-    'test_slice_cpu',
-    'test_default_axes', #make PR against onnx to fix the test name(grep-able)
-    'test_slice_neg',
-    #'test_slice_start_out_of_bounds',
-    #'test_slice_end_out_of_bounds',
-    #'test_transpose',
-    'test_squeeze_',
-    'test_flatten_default',
-
-    #Powers
-    'test_reciprocal',
-    'test_sqrt',
-    'test_pow_example',
-    'test_pow_cpu',
-    'test_pow_bcast_cpu',
-    #'test_pow_bcast_axis0',
-    'test_log_',
-    'test_exp',
-
-    # Sorting and Searching
-    'test_argmax',
-    'test_argmin',
-    'test_max',
-    'test_min',
-
-    #pytorch operator tests
-    #'test_operator_chunk',
-    #'test_operator_clip',
-    'test_operator_conv',
-    #'test_operator_equal',
-    'test_operator_exp',
-    #'test_operator_flatten',
-    #'test_operator_max',
-    'test_operator_maxpool',
-    'test_operator_non_float_params',
-    'test_operator_params',
-    'test_operator_permute2',
-    #'test_operator_transpose',
-    #'test_operator_view'
-    ]
-
-BASIC_MODEL_TESTS = [
-    'test_AvgPool2D',
-    'test_BatchNorm',
-    'test_ConstantPad2d'
-    'test_Conv2d',
-    'test_ELU',
-    'test_LeakyReLU',
-    'test_MaxPool',
-    'test_PReLU',
-    'test_ReLU',
-    'test_Sigmoid',
-    'test_Softmax',
-    'test_softmax_functional',
-    'test_softmax_lastdim',
-    'test_Tanh'
-    ]
-
-STANDARD_MODEL = [
-    'test_bvlc_alexnet',
-    'test_densenet121',
-    #'test_inception_v1',
-    #'test_inception_v2',
-    'test_resnet50',
-    #'test_shufflenet',
-    'test_squeezenet',
-    'test_vgg16',
-    'test_vgg19'
-    ]
-
-for op_test in IMPLEMENTED_OPERATORS_TEST:
-    BACKEND_TEST.include(op_test)
-
-for std_model_test in STANDARD_MODEL:
-    BACKEND_TEST.include(std_model_test)
-
-for basic_model_test in BASIC_MODEL_TESTS:
-    BACKEND_TEST.include(basic_model_test)
-
-# import all test cases at global scope to make them visible to python.unittest
-globals().update(BACKEND_TEST.enable_report().test_cases)
-
-if __name__ == '__main__':
-    unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services