You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by th...@apache.org on 2018/11/30 21:33:30 UTC
[incubator-mxnet] branch master updated: Add resiliency to onnx
export code (#13426)
This is an automated email from the ASF dual-hosted git repository.
thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b58de74 Add resiliency to onnx export code (#13426)
b58de74 is described below
commit b58de7494d9bf329af9730da91b5f6c21348cbff
Author: Sina Afrooze <si...@gmail.com>
AuthorDate: Fri Nov 30 13:33:16 2018 -0800
Add resiliency to onnx export code (#13426)
* Added resiliency to onnx export code
- With previous infer-shape implementation, if input shape was list instead of tuple or if extra non-existent parameters were provided, the code would still work. The fixes in this commit make sure that behavior is restored to prevent any compatibility issues with existing export code.
* Fixed name of net in unittest
* Fix pylint
---
python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 5 +++--
.../python-pytest/onnx/export/mxnet_export_test.py | 21 +++++++++++++++++++--
2 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 14c674f..84db5de 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -134,9 +134,10 @@ class MXNetGraph(object):
# remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
# by user. Also remove in_label, which is the name of the label symbol that may have been used
# as the label for loss during training.
- inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_shape)}
+ inputs = {n: tuple(s) for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label],
+ in_shape)}
# Add params and their shape to list of inputs
- inputs.update({n: v.shape for n, v in params.items()})
+ inputs.update({n: v.shape for n, v in params.items() if n in sym.list_inputs()})
# Provide input data as well as input params to infer_shape()
_, out_shapes, _ = sym.infer_shape(**inputs)
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py
index f4144fd6..964d0e7 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -286,18 +286,19 @@ def _optional_group(symbols, group=False):
return symbols
-def _check_onnx_export(net, group_outputs=False):
+def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params={}):
net.initialize()
data = nd.random.uniform(0, 1, (1, 1024))
output = _force_list(net(data)) # initialize weights
net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
net_params = {name:param._reduce() for name, param in net.collect_params().items()}
+ net_params.update(extra_params)
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
export_path = onnx_mxnet.export_model(
sym=net_sym,
params=net_params,
- input_shape=[data.shape],
+ input_shape=[shape_type(data.shape)],
onnx_file_path=onnx_file_path)
assert export_path == onnx_file_path
# Try importing the model to symbol
@@ -340,6 +341,22 @@ def test_onnx_export_multi_output():
_check_onnx_export(net, group_outputs=True)
+@with_seed()
+def test_onnx_export_list_shape():
+ net = nn.HybridSequential(prefix='list_shape_net')
+ with net.name_scope():
+ net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+ _check_onnx_export(net, shape_type=list)
+
+
+@with_seed()
+def test_onnx_export_extra_params():
+ net = nn.HybridSequential(prefix='extra_params_net')
+ with net.name_scope():
+ net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+ _check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])})
+
+
if __name__ == '__main__':
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))