You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/26 08:00:11 UTC

[GitHub] ThomasDelteil closed pull request #13390: Onnx multi output

ThomasDelteil closed pull request #13390: Onnx multi output
URL: https://github.com/apache/incubator-mxnet/pull/13390
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index b02d970f9c2..14c674f56f2 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -53,12 +53,8 @@
 from __future__ import unicode_literals
 import logging
 import json
-import numpy as np
 
-from .... import context
 from .... import ndarray as nd
-from .... import io
-from .... import module as mod
 
 
 class MXNetGraph(object):
@@ -95,60 +91,6 @@ def convert_layer(node, **kwargs):
         convert_func = MXNetGraph.registry_[op]
         return convert_func(node, **kwargs)
 
-    @staticmethod
-    def forward_pass(inputs, sym, arg_params, aux_params, output_label):
-        """Do a forward pass based on the sym and params to get the shape
-        of the output using dummy data
-
-        Parameters
-        ----------
-        inputs   : json string
-
-        sym : :class:`~mxnet.symbol.Symbol`
-            MXNet symbol object
-        arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
-            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
-        aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
-            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
-
-        Returns
-        -------
-        shape : Shape
-            Output shape
-        """
-        # if label is not provided, MXNet adds label "softmax_label" by default
-        # while running load_checkpoint which is not actually a graph input. So ignoring it here
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in aux_params
-                      and graph_input != output_label]
-
-        data_shapes = []
-        # Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
-        for idx, input_name in enumerate(data_names):
-            data_shapes.append((input_name, inputs[idx].shape))
-
-        # create module, passing cpu context
-        ctx = context.cpu()
-        test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
-        test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
-
-        # initializing parameters for calculating result of each individual node
-        if arg_params is None and aux_params is None:
-            test_mod.init_params()
-        else:
-            test_mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True)
-
-        data_forward = []
-        for idx, input_name in enumerate(data_names):
-            val = inputs[idx]
-            data_forward.append(nd.array(val))
-
-        test_mod.forward(io.DataBatch(data_forward))
-        result = test_mod.get_outputs()[0].asnumpy()
-
-        return result.shape
-
-
     @staticmethod
     def split_params(sym, params):
         """Helper function to split params dictionary into args and aux params
@@ -177,15 +119,40 @@ def split_params(sym, params):
                 aux_params.update({aux: nd.array(params[aux])})
         return arg_params, aux_params
 
-
     @staticmethod
-    def infer_output_shape(sym, params, in_shape, output_label):
-        """Infer output shape by doing a forward pass using dummy inputs """
-        # create dummy input
-        inputs = [np.random.randn(*input_shape) for input_shape in in_shape]
-        arg, aux = MXNetGraph.split_params(sym, params)
-        return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label)
+    def get_outputs(sym, params, in_shape, in_label):
+        """ Infer output shapes and return dictionary of output name to shape
+
+        :param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer shape on
+        :param dic of (str, nd.NDArray) params:
+        :param list of tuple(int, ...) in_shape: list of all input shapes
+        :param  in_label: name of label typically used in loss that may be left in graph. This name is
+            removed from list of inputs required by symbol
+        :return: dictionary of output name to shape
+        :rtype: dict of (str, tuple(int, ...))
+        """
+        # remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
+        # by user. Also remove in_label, which is the name of the label symbol that may have been used
+        # as the label for loss during training.
+        inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_shape)}
+        # Add params and their shape to list of inputs
+        inputs.update({n: v.shape for n, v in params.items()})
+        # Provide input data as well as input params to infer_shape()
+        _, out_shapes, _ = sym.infer_shape(**inputs)
+
+        out_names = list()
+        for name in sym.list_outputs():
+            if name.endswith('_output'):
+                out_names.append(name[:-len('_output')])
+            else:
+                logging.warning("output '%s' does not end with '_output'", name)
+                out_names.append(name)
 
+        assert len(out_shapes) == len(out_names)
+        # bind output shapes with output names
+        graph_outputs = {n: s for n, s in zip(out_names, out_shapes)}
+
+        return graph_outputs
 
     @staticmethod
     def convert_weights_to_numpy(weights_dict):
@@ -228,9 +195,6 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
         # Deriving the output_label name.
         output_label = sym.get_internals()[len(sym.get_internals()) - 1].name + "_label"
 
-        # Determine output shape
-        output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label)
-
         weights = MXNetGraph.convert_weights_to_numpy(params)
 
         mx_graph = json.loads(sym.tojson())["nodes"]
@@ -242,6 +206,9 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
         onnx_processed_outputs = []
         index_lookup = []
 
+        # Determine output shape
+        graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label)
+
         graph_input_idx = 0
         for idx, node in enumerate(mx_graph):
             op = node["op"]
@@ -294,24 +261,15 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
                     # If converted node is NodeProto, add it in processed nodes list
                     elif isinstance(converted_node, NodeProto):
                         onnx_processed_nodes.append(converted_node)
-                        if idx == (len(mx_graph) - 1):
-                            # If converted node doesnt have name, use it from output field
-                            if not converted_node.name:
-                                onnx_processed_outputs.append(
-                                    make_tensor_value_info(
-                                        name=converted_node.output[0],
-                                        elem_type=in_type,
-                                        shape=output_shape
-                                    )
-                                )
-                            else:
-                                onnx_processed_outputs.append(
-                                    make_tensor_value_info(
-                                        name=converted_node.name,
-                                        elem_type=in_type,
-                                        shape=output_shape
-                                    )
+                        node_name = converted_node.name if converted_node.name else converted_node.output[0]
+                        if node_name in graph_outputs:
+                            onnx_processed_outputs.append(
+                                make_tensor_value_info(
+                                    name=node_name,
+                                    elem_type=in_type,
+                                    shape=graph_outputs[node_name]
                                 )
+                            )
                             if verbose:
                                 logging.info("Output node is: %s", converted_node.name)
                     elif isinstance(converted_node, TensorProto):
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 9f91369d667..bbff7833fe2 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -28,11 +28,14 @@
 import unittest
 import logging
 import tarfile
+import tempfile
 from collections import namedtuple
 import numpy as np
 import numpy.testing as npt
 from onnx import numpy_helper, helper
 from onnx import TensorProto
+from mxnet import nd, sym
+from mxnet.gluon import nn
 from mxnet.test_utils import download
 from mxnet.contrib import onnx as onnx_mxnet
 import mxnet as mx
@@ -238,6 +241,79 @@ def test_square():
 
     npt.assert_almost_equal(result, numpy_op)
 
+
+def _assert_sym_equal(lhs, rhs):
+    assert lhs.list_inputs() == rhs.list_inputs()  # input names must be identical
+    assert len(lhs.list_outputs()) == len(rhs.list_outputs())  # number of outputs must be identical
+
+
+def _force_list(output):
+    if isinstance(output, nd.NDArray):
+        return [output]
+    return list(output)
+
+
+def _optional_group(symbols, group=False):
+    if group:
+        return sym.Group(symbols)
+    else:
+        return symbols
+
+
+def _check_onnx_export(net, group_outputs=False):
+    net.initialize()
+    data = nd.random.uniform(0, 1, (1, 1024))
+    output = _force_list(net(data))  # initialize weights
+    net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
+    net_params = {name:param._reduce() for name, param in net.collect_params().items()}
+    with tempfile.TemporaryDirectory() as tmpdirname:
+        onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
+        export_path = onnx_mxnet.export_model(
+            sym=net_sym,
+            params=net_params,
+            input_shape=[data.shape],
+            onnx_file_path=onnx_file_path)
+        assert export_path == onnx_file_path
+        # Try importing the model to symbol
+        _assert_sym_equal(net_sym, onnx_mxnet.import_model(export_path)[0])
+
+        # Try importing the model to gluon
+        imported_net = onnx_mxnet.import_to_gluon(export_path, ctx=None)
+        _assert_sym_equal(net_sym, _optional_group(imported_net(sym.Variable('data')), group_outputs))
+
+        # Confirm network outputs are the same
+        imported_net_output = _force_list(imported_net(data))
+        for out, imp_out in zip(output, imported_net_output):
+            mx.test_utils.assert_almost_equal(out.asnumpy(), imp_out.asnumpy())
+
+
+@with_seed()
+def test_onnx_export_single_output():
+    net = nn.HybridSequential(prefix='single_output_net')
+    with net.name_scope():
+        net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+    _check_onnx_export(net)
+
+
+@with_seed()
+def test_onnx_export_multi_output():
+    class MultiOutputBlock(nn.HybridBlock):
+        def __init__(self):
+            super(MultiOutputBlock, self).__init__()
+            with self.name_scope():
+                self.net = nn.HybridSequential()
+                for i in range(10):
+                    self.net.add(nn.Dense(100 + i * 10, activation='relu'))
+
+        def hybrid_forward(self, F, x):
+            out = tuple(block(x) for block in self.net._children.values())
+            return out
+
+    net = MultiOutputBlock()
+    assert len(sym.Group(net(sym.Variable('data'))).list_outputs()) == 10
+    _check_onnx_export(net, group_outputs=True)
+
+
 if __name__ == '__main__':
     test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
     test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services