You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ro...@apache.org on 2019/02/12 21:21:53 UTC

[incubator-mxnet] branch master updated: ONNX export: Support equal length splits (#14121)

This is an automated email from the ASF dual-hosted git repository.

roshrini pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ce031da  ONNX export: Support equal length splits (#14121)
ce031da is described below

commit ce031daf5e4f343306e01b7b748fed2bad9df685
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Tue Feb 12 13:21:29 2019 -0800

    ONNX export: Support equal length splits (#14121)
    
    * ONNX export: Support equal length splits
    
    * Fix lint error
    
    * Add comment about checking for multiple outputs
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py |  6 ++++--
 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py   | 23 ++++++++++++----------
 .../mxnet/contrib/onnx/onnx2mx/_op_translations.py | 12 ++++++-----
 tests/python-pytest/onnx/test_cases.py             |  4 ++--
 4 files changed, 26 insertions(+), 19 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index e077824..f9d170d 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1537,12 +1537,14 @@ def convert_slice_channel(node, **kwargs):
         )
         return [node]
     elif squeeze_axis == 0 and num_outputs > 1:
+        in_shape = kwargs.get('in_shape')[0]
+        split = in_shape[axis] // num_outputs
         node = onnx.helper.make_node(
             "Split",
             input_nodes,
-            [name],
+            [name+'_output'+str(i) for i in range(num_outputs)],
             axis=axis,
-            split=[num_outputs],
+            split=[split for _ in range(num_outputs)],
             name=name,
         )
         return [node]
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index d0d4501..a7b11fc 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -262,17 +262,20 @@ class MXNetGraph(object):
                     # If converted node is NodeProto, add it in processed nodes list
                     elif isinstance(converted_node, NodeProto):
                         onnx_processed_nodes.append(converted_node)
-                        node_name = converted_node.name if converted_node.name else converted_node.output[0]
-                        if node_name in graph_outputs:
-                            onnx_processed_outputs.append(
-                                make_tensor_value_info(
-                                    name=node_name,
-                                    elem_type=in_type,
-                                    shape=graph_outputs[node_name]
+                        # some operators have multiple outputs,
+                        # therefore, check all output node names
+                        node_names = list(converted_node.output)
+                        for nodename in node_names:
+                            if nodename in graph_outputs:
+                                onnx_processed_outputs.append(
+                                    make_tensor_value_info(
+                                        name=nodename,
+                                        elem_type=in_type,
+                                        shape=graph_outputs[nodename]
+                                    )
                                 )
-                            )
-                            if verbose:
-                                logging.info("Output node is: %s", converted_node.name)
+                                if verbose:
+                                    logging.info("Output node is: %s", nodename)
                     elif isinstance(converted_node, TensorProto):
                         raise ValueError("Did not expect TensorProto")
                     else:
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index dc00fee..a7cef76 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -484,13 +484,15 @@ def split(attrs, inputs, proto_obj):
     if not split_list:
         num_outputs = len(proto_obj.model_metadata.get('output_tensor_data'))
     else:
-        raise NotImplementedError("Operator {} in MXNet does not support variable splits."
-                                  "Tracking the issue to support variable split here: "
-                                  "https://github.com/apache/incubator-mxnet/issues/11594"
-                                  .format('split'))
+        if len(set(split_list)) == 1:
+            num_outputs = len(split_list)
+        else:
+            raise NotImplementedError("Operator {} in MXNet does not support variable splits."
+                                      "Tracking the issue to support variable split here: "
+                                      "https://github.com/apache/incubator-mxnet/issues/11594"
+                                      .format('split'))
 
     new_attrs['num_outputs'] = num_outputs
-
     return 'split', new_attrs, inputs
 
 def _slice(attrs, inputs, proto_obj):
diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py
index b20db23..89b60d1 100644
--- a/tests/python-pytest/onnx/test_cases.py
+++ b/tests/python-pytest/onnx/test_cases.py
@@ -77,7 +77,8 @@ IMPLEMENTED_OPERATORS_TEST = {
              'test_elu',
              'test_max_',
              'test_softplus',
-             'test_reduce_'
+             'test_reduce_',
+             'test_split_equal'
              ],
     'import': ['test_gather',
                'test_softsign',
@@ -88,7 +89,6 @@ IMPLEMENTED_OPERATORS_TEST = {
                'test_averagepool_2d_precomputed_strides',
                'test_averagepool_2d_strides',
                'test_averagepool_3d',
-               'test_split_equal',
                'test_hardmax'
                ],
     'export': ['test_random_uniform',