You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/27 20:23:33 UTC

[GitHub] ptrendx closed pull request #12398: Add support for networks with multiple outputs in ONNX exporter

ptrendx closed pull request #12398: Add support for networks with multiple outputs in ONNX exporter
URL: https://github.com/apache/incubator-mxnet/pull/12398
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index b02d970f9c2..48e59aedf58 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -228,8 +228,14 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
         # Deriving the output_label name.
         output_label = sym.get_internals()[len(sym.get_internals()) - 1].name + "_label"
 
-        # Determine output shape
-        output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label)
+        # Determine outputs shapes
+        input_names = [n for n in sym.list_inputs() if n not in params]
+        input_pairs = {n: in_shape[i] for i, n in enumerate(input_names)}
+        _, output_shapes, _ = sym.get_internals().infer_shape(**input_pairs)
+
+        output_suffix = '_output'
+        output_names = [
+            o[:-len(output_suffix)] for o in sym.list_outputs() if o.endswith(output_suffix)]
 
         weights = MXNetGraph.convert_weights_to_numpy(params)
 
@@ -265,6 +271,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
                     weights=weights,
                     in_shape=in_shape[graph_input_idx],
                     in_type=in_type,
+                    out_shape=output_shapes[idx],
                     proc_nodes=all_processed_nodes,
                     initializer=initializer,
                     index_lookup=index_lookup)
@@ -279,6 +286,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
                     weights=weights,
                     in_shape=in_shape,
                     in_type=in_type,
+                    out_shape=output_shapes[idx],
                     proc_nodes=all_processed_nodes,
                     initializer=initializer,
                     index_lookup=index_lookup,
@@ -294,14 +302,14 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
                     # If converted node is NodeProto, add it in processed nodes list
                     elif isinstance(converted_node, NodeProto):
                         onnx_processed_nodes.append(converted_node)
-                        if idx == (len(mx_graph) - 1):
+                        if idx == (len(mx_graph) - 1) or converted_node.name in output_names:
                             # If converted node doesnt have name, use it from output field
                             if not converted_node.name:
                                 onnx_processed_outputs.append(
                                     make_tensor_value_info(
                                         name=converted_node.output[0],
                                         elem_type=in_type,
-                                        shape=output_shape
+                                        shape=output_shapes[idx]
                                     )
                                 )
                             else:
@@ -309,7 +317,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
                                     make_tensor_value_info(
                                         name=converted_node.name,
                                         elem_type=in_type,
-                                        shape=output_shape
+                                        shape=output_shapes[idx]
                                     )
                                 )
                             if verbose:
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 9f91369d667..b007d174035 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -101,10 +101,10 @@ def forward_pass(sym, arg, aux, data_names, input_data):
     batch = namedtuple('Batch', ['data'])
     mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
 
-    return mod.get_outputs()[0].asnumpy()
+    return [output.asnumpy() for output in mod.get_outputs()]
 
 
-def test_models(model_name, input_shape, output_shape):
+def test_models(model_name, input_shape, output_shape, test_extra_output=False):
     """ Tests Googlenet model for both onnx import and export"""
     model_path, inputs, outputs = get_test_files(model_name)
     logging.info("Translating model from ONNX model zoo to Mxnet")
@@ -117,6 +117,12 @@ def test_models(model_name, input_shape, output_shape):
     new_model_name = "exported_" + model_name + ".onnx"
     onnx_file = os.path.join(dir_path, new_model_name)
 
+    if test_extra_output:
+        logging.info("Adding extra output to model")
+        sym_output = sym.get_internals()[sym.list_outputs()[0]]
+        id_output = mx.sym.identity(data=sym_output)
+        sym = mx.symbol.Group([sym_output, id_output])
+
     logging.info("Translating converted model from mxnet to ONNX")
     converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file)
 
@@ -133,11 +139,11 @@ def test_models(model_name, input_shape, output_shape):
     logging.info("Running inference on onnx re-import model in mxnet")
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
-
-        # verify the results
-        npt.assert_equal(result.shape, output_data.shape)
-        npt.assert_almost_equal(output_data, result, decimal=3)
+        results = forward_pass(sym, arg_params, aux_params, data_names, input_data)
+        for result in results:
+            # verify the results
+            npt.assert_equal(result.shape, output_data.shape)
+            npt.assert_almost_equal(output_data, result, decimal=3)
     logging.info(model_name + " conversion successful")
 
 
@@ -153,7 +159,7 @@ def test_model_accuracy(model_name, input_shape):
 
     expected_result= []
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
+        result = forward_pass(sym, arg_params, aux_params, data_names, input_data)[0]
         expected_result.append(result)
 
     params = {}
@@ -175,7 +181,7 @@ def test_model_accuracy(model_name, input_shape):
 
     actual_result = []
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
+        result = forward_pass(sym, arg_params, aux_params, data_names, input_data)[0]
         actual_result.append(result)
 
     # verify the results
@@ -232,7 +238,7 @@ def test_square():
     converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx")
 
     sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
-    result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
+    result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)[0]
 
     numpy_op = np.square(input1)
 
@@ -242,6 +248,7 @@ def test_square():
     test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
     test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))
     test_models("bvlc_reference_rcnn_ilsvrc13", (1, 3, 224, 224), (1, 200))
+    test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000), test_extra_output=True)
 
     # Comparing MXNet inference result, since MXNet results don't match
     # ONNX expected results due to AveragePool issue github issue(#10194)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services