You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/03/28 00:07:18 UTC

[incubator-mxnet] branch master updated: [MXNET-106][ONNX-MXNET] Adding ONNX Model zoo tests. (#10118)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c6f6bb  [MXNET-106][ONNX-MXNET] Adding ONNX Model zoo tests. (#10118)
8c6f6bb is described below

commit 8c6f6bb9ce66046fc50fbfa23f198dbbca3f085d
Author: Rajan Singh <33...@users.noreply.github.com>
AuthorDate: Tue Mar 27 17:07:11 2018 -0700

    [MXNET-106][ONNX-MXNET] Adding ONNX Model zoo tests. (#10118)
    
    * [MXNET-106][ONNX-MXNET] Fix Maxpool and BatchNorm operator.
    
    Description:
    Return arg params and aux params seperately from import model
    Fixes:
          1. BatchNorm
          2. Maxpool
          3. GlobalPool
    
     -Added these tests to be included in backend tests
     -Made changes to super_resolution.py to accomodate the changes.
    
    * Fix GEMM operator, add pytorch tests
    
    * Add the support for Gather ONNX operator
    
    * Fix max_pooling
    
    * Resnet fixed. Infer shape issue fixed.
    
    * enable test_gather
    
    * disabling 'gather' test, as it is partially supported
    
    * Minor refactoring for pooling and dropout
    Some comments  added.
    Added models for testing.
    
    * Revert "Fix max_pooling"
    Revert last commit to avoid a merge a conflict in the future
    
    This reverts commit e4ee423a3afe3c3522c36a8332379a5ddfbad669.
    
    * Nit: lint fixes
    
    * Added flatten default test
    
    * Fixed example test.
    Added transpose test.
    
    * Enabling 'Cast' operator test
    
    * Model tests added:
    - Googlenet
    - Caffenet
    - bvlc_rcnn_ilsvrc13
    
    *  Addressed comments on the PR
     - Enabled GlobalAveragePool and GlobalMaxPool
     - Remove 'Gather' support: as we have partial support  and we dont have unit tests.
     - Nits: Added few comments and changed variable names
    
    * - Added user warning for unsupported attributes
    
    * Updates for Imagenet tutorial
    
    * Minor changes in imagenet tutorial.
    Small change in logic for backend test.
    
    * Update tutorial: Add check for directory exits.
---
 docs/tutorials/onnx/inference_on_onnx_model.md     |  16 ++-
 example/onnx/super_resolution.py                   |  12 +-
 python/mxnet/contrib/onnx/_import/import_helper.py |   3 +-
 python/mxnet/contrib/onnx/_import/import_model.py  |   4 +-
 python/mxnet/contrib/onnx/_import/import_onnx.py   |  25 ++--
 .../mxnet/contrib/onnx/_import/op_translations.py  |  56 +++++++--
 .../contrib/onnx/_import/translation_utils.py      |  49 +++++++-
 tests/python-pytest/onnx/backend.py                |  18 ++-
 tests/python-pytest/onnx/backend_rep.py            |   7 +-
 tests/python-pytest/onnx/onnx_backend_test.py      |  83 ++++++++++---
 tests/python-pytest/onnx/onnx_test.py              | 130 ++++++++++++++++++++-
 11 files changed, 335 insertions(+), 68 deletions(-)

diff --git a/docs/tutorials/onnx/inference_on_onnx_model.md b/docs/tutorials/onnx/inference_on_onnx_model.md
index 2eb90cd..182a2ae 100644
--- a/docs/tutorials/onnx/inference_on_onnx_model.md
+++ b/docs/tutorials/onnx/inference_on_onnx_model.md
@@ -21,8 +21,8 @@ To run the tutorial you will need to have installed the following python modules
 
 ```python
 import numpy as np
-import onnx_mxnet
 import mxnet as mx
+from mxnet.contrib import onnx as onnx_mxnet
 from mxnet import gluon, nd
 %matplotlib inline
 import matplotlib.pyplot as plt
@@ -75,7 +75,8 @@ Create the model folder and download the zipped model
 
 
 ```python
-os.makedirs(model_folder, exist_ok=True)
+if not os.path.isdir(model_folder):
+    os.makedirs(model_folder)
 if not os.path.isfile(archive_file):  
     wget.download(url, model_folder)
 ```
@@ -108,7 +109,7 @@ We get the symbol and parameter objects
 
 
 ```python
-sym, params = onnx_mxnet.import_model(onnx_path)
+sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)
 ```
 
 We pick a context, CPU or GPU
@@ -124,9 +125,12 @@ And load them into a MXNet Gluon symbol block. For ONNX models the default input
 ```python
 net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('input_0'))
 net_params = net.collect_params()
-for param in params:
+for param in arg_params:
     if param in net_params:
-        net_params[param]._load_init(params[param], ctx=ctx)
+        net_params[param]._load_init(arg_params[param], ctx=ctx)
+for param in aux_params:
+    if param in net_params:
+        net_params[param]._load_init(aux_params[param], ctx=ctx)
 ```
 
 We can now cache the computational graph through [hybridization](https://mxnet.incubator.apache.org/tutorials/gluon/hybrid.html) to gain some performance
@@ -165,7 +169,7 @@ We can visualize the network (requires graphviz installed)
 
 
 ```python
-mx.visualization.plot_network(sym, shape={"input_0":inputs[0].shape}, node_attrs={"shape":"oval","fixedsize":"false"})
+mx.visualization.plot_network(sym,  node_attrs={"shape":"oval","fixedsize":"false"})
 ```
 
 
diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py
index 1392b77..f7c7886 100644
--- a/example/onnx/super_resolution.py
+++ b/example/onnx/super_resolution.py
@@ -37,9 +37,9 @@ def import_onnx():
     download(model_url, 'super_resolution.onnx')
 
     LOGGER.info("Converting onnx format to mxnet's symbol and params...")
-    sym, params = onnx_mxnet.import_model('super_resolution.onnx')
+    sym, arg_params, aux_params = onnx_mxnet.import_model('super_resolution.onnx')
     LOGGER.info("Successfully Converted onnx format to mxnet's symbol and params...")
-    return sym, params
+    return sym, arg_params, aux_params
 
 def get_test_image():
     """Download and process the test image"""
@@ -53,12 +53,12 @@ def get_test_image():
     input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
     return input_image, img_cb, img_cr
 
-def perform_inference(sym, params, input_img, img_cb, img_cr):
+def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
     """Perform inference on image using mxnet"""
     # create module
     mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None)
     mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)])
-    mod.set_params(arg_params=params, aux_params=None)
+    mod.set_params(arg_params=arg_params, aux_params=aux_params)
 
     # run inference
     batch = namedtuple('Batch', ['data'])
@@ -79,6 +79,6 @@ def perform_inference(sym, params, input_img, img_cb, img_cr):
     return result_img
 
 if __name__ == '__main__':
-    MX_SYM, MX_PARAM = import_onnx()
+    MX_SYM, MX_ARG_PARAM, MX_AUX_PARAM = import_onnx()
     INPUT_IMG, IMG_CB, IMG_CR = get_test_image()
-    perform_inference(MX_SYM, MX_PARAM, INPUT_IMG, IMG_CB, IMG_CR)
+    perform_inference(MX_SYM, MX_ARG_PARAM, MX_AUX_PARAM, INPUT_IMG, IMG_CB, IMG_CR)
diff --git a/python/mxnet/contrib/onnx/_import/import_helper.py b/python/mxnet/contrib/onnx/_import/import_helper.py
index 80541ec..175c2fb 100644
--- a/python/mxnet/contrib/onnx/_import/import_helper.py
+++ b/python/mxnet/contrib/onnx/_import/import_helper.py
@@ -27,7 +27,7 @@ from .op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected
 from .op_translations import global_avgpooling, global_maxpooling, linalg_gemm
 from .op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
 from .op_translations import dropout, local_response_norm, conv, deconv
-from .op_translations import reshape, cast, split, _slice, transpose, squeeze
+from .op_translations import reshape, cast, split, _slice, transpose, squeeze, flatten
 from .op_translations import reciprocal, squareroot, power, exponent, _log
 from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum
 from .op_translations import reduce_prod, avg_pooling, max_pooling
@@ -83,6 +83,7 @@ _convert_map = {
     'Slice'             : _slice,
     'Transpose'         : transpose,
     'Squeeze'           : squeeze,
+    'Flatten'           : flatten,
     #Powers
     'Reciprocal'        : reciprocal,
     'Sqrt'              : squareroot,
diff --git a/python/mxnet/contrib/onnx/_import/import_model.py b/python/mxnet/contrib/onnx/_import/import_model.py
index 1df429b..d8d32a9 100644
--- a/python/mxnet/contrib/onnx/_import/import_model.py
+++ b/python/mxnet/contrib/onnx/_import/import_model.py
@@ -46,5 +46,5 @@ def import_model(model_file):
     except ImportError:
         raise ImportError("Onnx and protobuf need to be installed")
     model_proto = onnx.load(model_file)
-    sym, params = graph.from_onnx(model_proto.graph)
-    return sym, params
+    sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)
+    return sym, arg_params, aux_params
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/_import/import_onnx.py
index 56181c7..037790c 100644
--- a/python/mxnet/contrib/onnx/_import/import_onnx.py
+++ b/python/mxnet/contrib/onnx/_import/import_onnx.py
@@ -61,12 +61,12 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
             raise NotImplementedError("Operator {} not implemented.".format(op_name))
         if isinstance(op_name, string_types):
             new_op = getattr(symbol, op_name, None)
+            if not new_op:
+                raise RuntimeError("Unable to map op_name {} to sym".format(op_name))
             if node_name is None:
                 mxnet_sym = new_op(*inputs, **new_attrs)
             else:
                 mxnet_sym = new_op(name=node_name, *inputs, **new_attrs)
-            if not mxnet_sym:
-                raise RuntimeError("Unable to map op_name {} to sym".format(op_name))
             return mxnet_sym
         return op_name
 
@@ -110,6 +110,10 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
                 self._nodes[name_input] = symbol.Variable(name=name_input)
                 self._renames[i.name] = name_input
 
+        # For storing arg  and aux params for the graph.
+        auxDict = {}
+        argDict = {}
+
         # constructing nodes, nodes are stored as directed acyclic graph
         # converting NodeProto message
         for node in graph.node:
@@ -120,19 +124,24 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
             inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
             mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs)
 
-            assert len(node.output) == len(mxnet_sym.list_outputs()), (
-                "Output dimension mismatch between the onnx operator and the mxnet symbol " +
-                "{} vs {} for the operator - {}.".format(
-                    len(node.output), len(mxnet_sym.list_outputs()), op_name))
-            for k, i in zip(list(node.output), range(len(node.output))):
+            for k, i in zip(list(node.output), range(len(mxnet_sym.list_outputs()))):
                 self._nodes[k] = mxnet_sym[i]
+
+            # splitting params into args and aux params
+            for args in mxnet_sym.list_arguments():
+                if args in self._params:
+                    argDict.update({args: nd.array(self._params[args])})
+            for aux in mxnet_sym.list_auxiliary_states():
+                if aux in self._params:
+                    auxDict.update({aux: nd.array(self._params[aux])})
+
         # now return the outputs
         out = [self._nodes[i.name] for i in graph.output]
         if len(out) > 1:
             out = symbol.Group(out)
         else:
             out = out[0]
-        return out, self._params
+        return out, argDict, auxDict
 
     def _parse_array(self, tensor_proto):
         """Grab data in TensorProto and convert to numpy array."""
diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py b/python/mxnet/contrib/onnx/_import/op_translations.py
index a67c181..de34132 100644
--- a/python/mxnet/contrib/onnx/_import/op_translations.py
+++ b/python/mxnet/contrib/onnx/_import/op_translations.py
@@ -164,10 +164,14 @@ def matrix_multiplication(attrs, inputs, cls):
 
 def batch_norm(attrs, inputs, cls):
     """Batch normalization."""
-    new_attrs = translation_utils._fix_attribute_names(attrs, {'epsilon' : 'eps'})
+    new_attrs = translation_utils._fix_attribute_names(attrs, {'epsilon' : 'eps',
+                                                               'is_test':'fix_gamma'})
     new_attrs = translation_utils._remove_attributes(new_attrs,
-                                                     ['spatial', 'is_test', 'consumed_inputs'])
+                                                     ['spatial', 'consumed_inputs'])
     new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': 1})
+
+    # in test mode "fix_gamma" should be unset.
+    new_attrs['fix_gamma'] = 0 if new_attrs['fix_gamma'] == 1 else 1
     return 'BatchNorm', new_attrs, inputs
 
 
@@ -245,7 +249,7 @@ def global_maxpooling(attrs, inputs, cls):
     new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True,
                                                                 'kernel': (1, 1),
                                                                 'pool_type': 'max'})
-    return 'pooling', new_attrs, inputs
+    return 'Pooling', new_attrs, inputs
 
 
 def global_avgpooling(attrs, inputs, cls):
@@ -253,28 +257,49 @@ def global_avgpooling(attrs, inputs, cls):
     new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True,
                                                                 'kernel': (1, 1),
                                                                 'pool_type': 'avg'})
-    return 'pooling', new_attrs, inputs
+    return 'Pooling', new_attrs, inputs
 
 
 def linalg_gemm(attrs, inputs, cls):
     """Performs general matrix multiplication and accumulation"""
+    trans_a = 0
+    trans_b = 0
+    alpha = 1
+    beta = 1
+    if 'transA' in attrs:
+        trans_a = attrs['transA']
+    if 'transB' in attrs:
+        trans_b = attrs['transB']
+    if 'alpha' in attrs:
+        alpha = attrs['alpha']
+    if 'beta' in attrs:
+        beta = attrs['beta']
+    flatten_a = symbol.flatten(inputs[0])
+    matmul_op = symbol.linalg_gemm2(A=flatten_a, B=inputs[1],
+                                    transpose_a=trans_a, transpose_b=trans_b,
+                                    alpha=alpha)
+    gemm_op = symbol.broadcast_add(matmul_op, beta*inputs[2])
     new_attrs = translation_utils._fix_attribute_names(attrs, {'transA': 'transpose_a',
                                                                'transB': 'transpose_b'})
     new_attrs = translation_utils._remove_attributes(new_attrs, ['broadcast'])
-    return translation_utils._fix_gemm('FullyConnected', inputs, new_attrs, cls)
+    return gemm_op, new_attrs, inputs
 
-def local_response_norm(op_name, attrs, inputs):
+def local_response_norm(attrs, inputs, cls):
     """Local Response Normalization."""
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'bias': 'knorm',
                                                         'size' : 'nsize'})
     return 'LRN', new_attrs, inputs
 
-def dropout(op_name, attrs, inputs):
+def dropout(attrs, inputs, cls):
     """Dropout Regularization."""
+    mode = 'training'
+    if attrs['is_test'] == 0:
+        mode = 'always'
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'ratio': 'p'})
     new_attrs = translation_utils._remove_attributes(new_attrs, ['is_test'])
+    new_attrs = translation_utils._add_extra_attributes(new_attrs, {'mode': mode})
     return 'Dropout', new_attrs, inputs
 
 # Changing shape and type.
@@ -285,6 +310,7 @@ def reshape(attrs, inputs, cls):
 def cast(attrs, inputs, cls):
     """ Cast input to a given dtype"""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'to' : 'dtype'})
+    new_attrs['dtype'] = new_attrs['dtype'].lower()
     return 'cast', new_attrs, inputs
 
 def split(attrs, inputs, cls):
@@ -328,6 +354,15 @@ def squeeze(attrs, inputs, cls):
         mxnet_op = symbol.split(mxnet_op, axis=i-1, num_outputs=1, squeeze_axis=1)
     return mxnet_op, new_attrs, inputs
 
+
+def flatten(attrs, inputs, cls):
+    """Flattens the input array into a 2-D array by collapsing the higher dimensions."""
+    #Mxnet does not have axis support. By default uses axis=1
+    if 'axis' in attrs and attrs['axis'] != 1:
+        raise RuntimeError("Flatten operator only supports axis=1")
+    new_attrs = translation_utils._remove_attributes(attrs, ['axis'])
+    return 'Flatten', new_attrs, inputs
+
 #Powers
 def reciprocal(attrs, inputs, cls):
     """Returns the reciprocal of the argument, element-wise."""
@@ -387,8 +422,7 @@ def avg_pooling(attrs, inputs, cls):
                                                         'pads': 'pad',
                                                        })
     new_attrs = translation_utils._add_extra_attributes(new_attrs,
-                                                        {'pool_type': 'avg',
-                                                         'pooling_convention': 'valid'
+                                                        {'pooling_convention': 'valid'
                                                         })
     new_op = translation_utils._fix_pooling('avg', inputs, new_attrs)
 
@@ -402,9 +436,9 @@ def max_pooling(attrs, inputs, cls):
                                                         'strides': 'stride',
                                                         'pads': 'pad',
                                                        })
+
     new_attrs = translation_utils._add_extra_attributes(new_attrs,
-                                                        {'pool_type': 'avg',
-                                                         'pooling_convention': 'valid'
+                                                        {'pooling_convention': 'valid'
                                                         })
     new_op = translation_utils._fix_pooling('max', inputs, new_attrs)
 
diff --git a/python/mxnet/contrib/onnx/_import/translation_utils.py b/python/mxnet/contrib/onnx/_import/translation_utils.py
index 0fdef64..1d84bd7 100644
--- a/python/mxnet/contrib/onnx/_import/translation_utils.py
+++ b/python/mxnet/contrib/onnx/_import/translation_utils.py
@@ -90,10 +90,51 @@ def _fix_pooling(pool_type, inputs, new_attr):
     stride = new_attr.get('stride')
     kernel = new_attr.get('kernel')
     padding = new_attr.get('pad')
-    pad_width = (0, 0, 0, 0) + _pad_sequence_fix(padding, len(kernel))
-    new_pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width)
-    new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type,
-                                    stride=stride, kernel=kernel)
+
+    # Adding default stride.
+    if stride is None:
+        stride = (1,) * len(kernel)
+
+    # Add padding attr if not provided.
+    if padding is None:
+        padding = (0,) * len(kernel) * 2
+
+    # Mxnet Pad operator supports only 4D/5D tensors.
+    # For 1D case, these are the steps:
+    #    Step 1. Add extra dummy dimension to make it 4D. Adding to  axis = 2
+    #    Step 2. Apply padding to this changed tensor
+    #    Step 3. Remove the extra dimension added in step 1.
+    if len(kernel) == 1:
+        dummy_axis = 2
+        # setting 0 padding to the new dim to be added.
+        padding = (0, padding[0], 0, padding[1])
+        pad_width = (0, 0, 0, 0) + _pad_sequence_fix(padding, kernel_dim=2)
+
+        # Step 1.
+        curr_sym = symbol.expand_dims(inputs[0], axis=dummy_axis)
+
+        # Step 2. Common for all tensor sizes
+        new_pad_op = symbol.pad(curr_sym, mode='edge', pad_width=pad_width)
+
+        # Step 3: Removing extra dim added.
+        new_pad_op = symbol.split(new_pad_op, axis=dummy_axis, num_outputs=1, squeeze_axis=1)
+    else:
+        # For 2D/3D cases:
+        # Apply padding
+        pad_width = (0, 0, 0, 0) + _pad_sequence_fix(padding, kernel_dim=len(kernel))
+        curr_sym = inputs[0]
+
+        if pool_type == 'max':
+            # For max pool : mode = 'edge', we should replicate the
+            # edge values to pad, so that we only include  input data values
+            # for calculating 'max'
+            new_pad_op = symbol.pad(curr_sym, mode='edge', pad_width=pad_width)
+        else:
+            # For avg pool, we should add 'zeros' for padding  so mode='constant'
+            new_pad_op = symbol.pad(curr_sym, mode='constant', pad_width=pad_width)
+
+    # Apply pooling without pads.
+    new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel)
     return new_pooling_op
 
 def _fix_bias(op_name, attrs, num_inputs):
diff --git a/tests/python-pytest/onnx/backend.py b/tests/python-pytest/onnx/backend.py
index 3b99563..0e0a6a6 100644
--- a/tests/python-pytest/onnx/backend.py
+++ b/tests/python-pytest/onnx/backend.py
@@ -94,12 +94,15 @@ class MXNetBackend(Backend):
             result obtained after running the operator
         """
         graph = GraphProto()
-        sym, _ = graph.from_onnx(MXNetBackend.make_graph(node, inputs))
-        data_names = [i for i in sym.get_internals().list_inputs()]
+        sym, arg_params, aux_params = graph.from_onnx(MXNetBackend.make_graph(node, inputs))
+        data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in aux_params]
         data_shapes = []
         dim_change_op_types = set(['ReduceMin', 'ReduceMax', 'ReduceMean',
                                    'ReduceProd', 'ReduceSum', 'Slice', 'Pad',
-                                   'Squeeze', 'Upsample', 'Reshape', 'Conv'])
+                                   'Squeeze', 'Upsample', 'Reshape', 'Conv',
+                                   'Concat', 'Softmax', 'Flatten', 'Transpose',
+                                   'GlobalAveragePool', 'GlobalMaxPool'])
 
         # Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
         for idx, input_name in enumerate(data_names):
@@ -123,7 +126,10 @@ class MXNetBackend(Backend):
         mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
 
         # initializing parameters for calculating result of each individual node
-        mod.init_params()
+        if arg_params is None and aux_params is None:
+            mod.init_params()
+        else:
+            mod.set_params(arg_params=arg_params, aux_params=aux_params)
 
         data_forward = []
         for idx, input_name in enumerate(data_names):
@@ -162,8 +168,8 @@ class MXNetBackend(Backend):
             used to run inference on the input model and return the result for comparison.
         """
         graph = GraphProto()
-        sym, params = graph.from_onnx(model.graph)
-        return MXNetBackendRep(sym, params, device)
+        sym, arg_params, aux_params = graph.from_onnx(model.graph)
+        return MXNetBackendRep(sym, arg_params, aux_params, device)
 
     @classmethod
     def supports_device(cls, device):
diff --git a/tests/python-pytest/onnx/backend_rep.py b/tests/python-pytest/onnx/backend_rep.py
index a125086..47ea6c1 100644
--- a/tests/python-pytest/onnx/backend_rep.py
+++ b/tests/python-pytest/onnx/backend_rep.py
@@ -37,9 +37,10 @@ import mxnet as mx
 class MXNetBackendRep(BackendRep):
     """Running model inference on mxnet engine and return the result
      to onnx test infrastructure for comparison."""
-    def __init__(self, symbol, params, device):
+    def __init__(self, symbol, arg_params, aux_params, device):
         self.symbol = symbol
-        self.params = params
+        self.arg_params = arg_params
+        self.aux_params = aux_params
         self.device = device
 
     def run(self, inputs, **kwargs):
@@ -67,7 +68,7 @@ class MXNetBackendRep(BackendRep):
                             label_names=None)
         mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)],
                  label_shapes=None)
-        mod.set_params(arg_params=self.params, aux_params=None)
+        mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
 
         # run inference
         batch = namedtuple('Batch', ['data'])
diff --git a/tests/python-pytest/onnx/onnx_backend_test.py b/tests/python-pytest/onnx/onnx_backend_test.py
index 28e2aae..4ea31e5 100644
--- a/tests/python-pytest/onnx/onnx_backend_test.py
+++ b/tests/python-pytest/onnx/onnx_backend_test.py
@@ -34,7 +34,7 @@ pytest_plugins = "onnx.backend.test.report",
 
 BACKEND_TEST = onnx.backend.test.BackendTest(mxnet_backend, __name__)
 
-IMPLEMENTED_OPERATORS = [
+IMPLEMENTED_OPERATORS_TEST = [
     #Generator Functions
     #'test_constant*', # Identity Function
     #'test_random_uniform',
@@ -57,37 +57,40 @@ IMPLEMENTED_OPERATORS = [
     'test_floor',
 
     ## Joining and spliting
-    #'test_concat.*',  #---Failing test
+    'test_concat',
 
     #Basic neural network functions
     'test_sigmoid',
     'test_relu',
-    #'test_constant_pad',
-    #'test_edge_pad',
-    #'test_reflect_pad',
+    'test_constant_pad',
+    'test_edge_pad',
+    'test_reflect_pad',
     'test_matmul',
     'test_leakyrelu',
     'test_elu',
-    #'test_softmax*',
+    'test_softmax_example',
+    'test_softmax_large_number',
+    'test_softmax_axis_2',
     'test_conv',
     'test_basic_conv',
-    #'test_globalmaxpool',
-    #'test_globalaveragepool',
-    #'test_batch_norm',
+    'test_transpose',
+    'test_globalmaxpool',
+    'test_globalaveragepool',
+    #'test_batch_norm', - tests to be added
+    #'test_gather',
 
     #Changing shape and type.
     'test_reshape_',
-    #'test_AvgPool2D*',
-    #'test_MaxPool2D*',
-    #'test_cast',
+    'test_cast',
     #'test_split',
     'test_slice_cpu',
     'test_default_axes', #make PR against onnx to fix the test name(grep-able)
     'test_slice_neg',
     #'test_slice_start_out_of_bounds',
     #'test_slice_end_out_of_bounds',
-    #'test_transpose*',
+    #'test_transpose',
     'test_squeeze_',
+    'test_flatten_default',
 
     #Powers
     'test_reciprocal',
@@ -103,12 +106,62 @@ IMPLEMENTED_OPERATORS = [
     'test_argmax',
     'test_argmin',
     'test_max',
-    'test_min'
+    'test_min',
+
+    #pytorch operator tests
+    #'test_operator_chunk',
+    #'test_operator_clip',
+    'test_operator_conv',
+    #'test_operator_equal',
+    'test_operator_exp',
+    #'test_operator_flatten',
+    #'test_operator_max',
+    'test_operator_maxpool',
+    'test_operator_non_float_params',
+    'test_operator_params',
+    'test_operator_permute2',
+    #'test_operator_transpose',
+    #'test_operator_view'
     ]
 
-for op_test in IMPLEMENTED_OPERATORS:
+BASIC_MODEL_TESTS = [
+    'test_AvgPool2D',
+    'test_BatchNorm',
+    'test_ConstantPad2d'
+    'test_Conv2d',
+    'test_ELU',
+    'test_LeakyReLU',
+    'test_MaxPool',
+    'test_PReLU',
+    'test_ReLU',
+    'test_Sigmoid',
+    'test_Softmax',
+    'test_softmax_functional',
+    'test_softmax_lastdim',
+    'test_Tanh'
+    ]
+
+STANDARD_MODEL = [
+    'test_bvlc_alexnet',
+    'test_densenet121',
+    #'test_inception_v1',
+    #'test_inception_v2',
+    'test_resnet50',
+    #'test_shufflenet',
+    'test_squeezenet',
+    'test_vgg16',
+    'test_vgg19'
+    ]
+
+for op_test in IMPLEMENTED_OPERATORS_TEST:
     BACKEND_TEST.include(op_test)
 
+for std_model_test in STANDARD_MODEL:
+    BACKEND_TEST.include(std_model_test)
+
+for basic_model_test in BASIC_MODEL_TESTS:
+    BACKEND_TEST.include(basic_model_test)
+
 # import all test cases at global scope to make them visible to python.unittest
 globals().update(BACKEND_TEST.enable_report().test_cases)
 
diff --git a/tests/python-pytest/onnx/onnx_test.py b/tests/python-pytest/onnx/onnx_test.py
index 016490a..ddc633e 100644
--- a/tests/python-pytest/onnx/onnx_test.py
+++ b/tests/python-pytest/onnx/onnx_test.py
@@ -21,19 +21,37 @@ This module contains operator tests which currently do not exist on
 ONNX backend test framework. Once we have PRs on the ONNX repo and get
 those PRs merged, this file will get EOL'ed.
 """
+# pylint: disable=too-many-locals,wrong-import-position,import-error
 from __future__ import absolute_import
 import sys
 import os
 import unittest
 import logging
 import hashlib
+import tarfile
+from collections import namedtuple
 import numpy as np
 import numpy.testing as npt
 from onnx import helper
-import backend as mxnet_backend
+from onnx import numpy_helper
+from onnx import TensorProto
+from mxnet.test_utils import download
+from mxnet.contrib import onnx as onnx_mxnet
+import mxnet as mx
 CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(CURR_PATH, '../../python/unittest'))
 from common import with_seed
+import backend as mxnet_backend
+
+
+URLS = {
+    'bvlc_googlenet' :
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_googlenet.tar.gz',
+    'bvlc_reference_caffenet' :
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_caffenet.tar.gz',
+    'bvlc_reference_rcnn_ilsvrc13' :
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_rcnn_ilsvrc13.tar.gz',
+}
 
 @with_seed()
 def test_reduce_max():
@@ -93,9 +111,9 @@ def test_super_resolution_example():
     sys.path.insert(0, os.path.join(CURR_PATH, '../../../example/onnx/'))
     import super_resolution
 
-    sym, params = super_resolution.import_onnx()
+    sym, arg_params, aux_params = super_resolution.import_onnx()
     assert sym is not None
-    assert params is not None
+    assert arg_params is not None
 
     inputs = sym.list_inputs()
     assert len(inputs) == 9
@@ -116,7 +134,7 @@ def test_super_resolution_example():
                                   'transpose0']):
         assert key_item in attrs_keys
 
-    param_keys = params.keys()
+    param_keys = arg_params.keys()
     assert len(param_keys) == 8
     for i, param_item in enumerate(['param_5', 'param_4', 'param_7', 'param_6',
                                     'param_1', 'param_0', 'param_3', 'param_2']):
@@ -126,11 +144,111 @@ def test_super_resolution_example():
 
     output_img_dim = 672
     input_image, img_cb, img_cr = super_resolution.get_test_image()
-    result_img = super_resolution.perform_inference(sym, params, input_image,
-                                                    img_cb, img_cr)
+    result_img = super_resolution.perform_inference(sym, arg_params, aux_params,
+                                                    input_image, img_cb, img_cr)
 
     assert hashlib.md5(result_img.tobytes()).hexdigest() == '0d98393a49b1d9942106a2ed89d1e854'
     assert result_img.size == (output_img_dim, output_img_dim)
 
+def get_test_files(name):
+    """Extract tar file and returns model path and input, output data"""
+    tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__())
+    # extract tar file
+    tar_path = os.path.join(CURR_PATH, tar_name)
+    tar = tarfile.open(tar_path.__str__(), "r:*")
+    tar.extractall(path=CURR_PATH.__str__())
+    tar.close()
+    data_dir = os.path.join(CURR_PATH, name)
+    model_path = os.path.join(data_dir, 'model.onnx')
+
+    inputs = []
+    outputs = []
+    # get test files
+    for test_file in os.listdir(data_dir):
+        case_dir = os.path.join(data_dir, test_file)
+        # skip the non-dir files
+        if not os.path.isdir(case_dir):
+            continue
+        input_file = os.path.join(case_dir, 'input_0.pb')
+        input_tensor = TensorProto()
+        with open(input_file, 'rb') as proto_file:
+            input_tensor.ParseFromString(proto_file.read())
+        inputs.append(numpy_helper.to_array(input_tensor))
+
+        output_tensor = TensorProto()
+        output_file = os.path.join(case_dir, 'output_0.pb')
+        with open(output_file, 'rb') as proto_file:
+            output_tensor.ParseFromString(proto_file.read())
+        outputs.append(numpy_helper.to_array(output_tensor))
+
+    return model_path, inputs, outputs
+
+def test_bvlc_googlenet():
+    """ Tests Googlenet model"""
+    model_path, inputs, outputs = get_test_files('bvlc_googlenet')
+    logging.info("Translating Googlenet model from ONNX to Mxnet")
+    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+
+    # run test for each test file
+    for input_data, output_data in zip(inputs, outputs):
+        # create module
+        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None)
+        mod.set_params(arg_params=arg_params, aux_params=aux_params,
+                       allow_missing=True, allow_extra=True)
+        # run inference
+        batch = namedtuple('Batch', ['data'])
+        mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
+
+        # verify the results
+        npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape)
+        npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3)
+    logging.info("Googlenet model conversion Successful")
+
+def test_bvlc_reference_caffenet():
+    """Tests the bvlc cafenet model"""
+    model_path, inputs, outputs = get_test_files('bvlc_reference_caffenet')
+    logging.info("Translating Caffenet model from ONNX to Mxnet")
+    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+
+    # run test for each test file
+    for input_data, output_data in zip(inputs, outputs):
+        # create module
+        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None)
+        mod.set_params(arg_params=arg_params, aux_params=aux_params,
+                       allow_missing=True, allow_extra=True)
+        # run inference
+        batch = namedtuple('Batch', ['data'])
+        mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
+
+        # verify the results
+        npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape)
+        npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3)
+    logging.info("Caffenet model conversion Successful")
+
+def test_bvlc_rcnn_ilsvrc13():
+    """Tests the bvlc rcnn model"""
+    model_path, inputs, outputs = get_test_files('bvlc_reference_rcnn_ilsvrc13')
+    logging.info("Translating rcnn_ilsvrc13 model from ONNX to Mxnet")
+    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+
+    # run test for each test file
+    for input_data, output_data in zip(inputs, outputs):
+        # create module
+        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None)
+        mod.set_params(arg_params=arg_params, aux_params=aux_params,
+                       allow_missing=True, allow_extra=True)
+        # run inference
+        batch = namedtuple('Batch', ['data'])
+        mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
+
+        # verify the results
+        npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape)
+        npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3)
+    logging.info("rcnn_ilsvrc13 model conversion Successful")
+
+
 if __name__ == '__main__':
     unittest.main()

-- 
To stop receiving notification emails like this one, please contact
skm@apache.org.