You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by th...@apache.org on 2018/11/30 21:33:30 UTC

[incubator-mxnet] branch master updated: Add resiliency to onnx export code (#13426)

This is an automated email from the ASF dual-hosted git repository.

thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b58de74  Add resiliency to onnx export code (#13426)
b58de74 is described below

commit b58de7494d9bf329af9730da91b5f6c21348cbff
Author: Sina Afrooze <si...@gmail.com>
AuthorDate: Fri Nov 30 13:33:16 2018 -0800

    Add resiliency to onnx export code (#13426)
    
    * Added resiliency to onnx export code
    
    - With previous infer-shape implementation, if input shape was list instead of tuple or if extra non-existent parameters were provided, the code would still work. The fixes in this commit make sure that behavior is restored to prevent any compatibility issues with existing export code.
    
    * Fixed name of net in unittest
    
    * Fix pylint
---
 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py    |  5 +++--
 .../python-pytest/onnx/export/mxnet_export_test.py  | 21 +++++++++++++++++++--
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 14c674f..84db5de 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -134,9 +134,10 @@ class MXNetGraph(object):
         # remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
         # by user. Also remove in_label, which is the name of the label symbol that may have been used
         # as the label for loss during training.
-        inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_shape)}
+        inputs = {n: tuple(s) for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label],
+                                              in_shape)}
         # Add params and their shape to list of inputs
-        inputs.update({n: v.shape for n, v in params.items()})
+        inputs.update({n: v.shape for n, v in params.items() if n in sym.list_inputs()})
         # Provide input data as well as input params to infer_shape()
         _, out_shapes, _ = sym.infer_shape(**inputs)
 
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py
index f4144fd6..964d0e7 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -286,18 +286,19 @@ def _optional_group(symbols, group=False):
         return symbols
 
 
-def _check_onnx_export(net, group_outputs=False):
+def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params={}):
     net.initialize()
     data = nd.random.uniform(0, 1, (1, 1024))
     output = _force_list(net(data))  # initialize weights
     net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
     net_params = {name:param._reduce() for name, param in net.collect_params().items()}
+    net_params.update(extra_params)
     with tempfile.TemporaryDirectory() as tmpdirname:
         onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
         export_path = onnx_mxnet.export_model(
             sym=net_sym,
             params=net_params,
-            input_shape=[data.shape],
+            input_shape=[shape_type(data.shape)],
             onnx_file_path=onnx_file_path)
         assert export_path == onnx_file_path
         # Try importing the model to symbol
@@ -340,6 +341,22 @@ def test_onnx_export_multi_output():
     _check_onnx_export(net, group_outputs=True)
 
 
+@with_seed()
+def test_onnx_export_list_shape():
+    net = nn.HybridSequential(prefix='list_shape_net')
+    with net.name_scope():
+        net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+    _check_onnx_export(net, shape_type=list)
+
+
+@with_seed()
+def test_onnx_export_extra_params():
+    net = nn.HybridSequential(prefix='extra_params_net')
+    with net.name_scope():
+        net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+    _check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])})
+
+
 if __name__ == '__main__':
     test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
     test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))