You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ro...@apache.org on 2019/02/12 21:21:53 UTC
[incubator-mxnet] branch master updated: ONNX export: Support equal
length splits (#14121)
This is an automated email from the ASF dual-hosted git repository.
roshrini pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ce031da ONNX export: Support equal length splits (#14121)
ce031da is described below
commit ce031daf5e4f343306e01b7b748fed2bad9df685
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Tue Feb 12 13:21:29 2019 -0800
ONNX export: Support equal length splits (#14121)
* ONNX export: Support equal length splits
* Fix lint error
* Add comment about checking for multiple outputs
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 6 ++++--
python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 23 ++++++++++++----------
.../mxnet/contrib/onnx/onnx2mx/_op_translations.py | 12 ++++++-----
tests/python-pytest/onnx/test_cases.py | 4 ++--
4 files changed, 26 insertions(+), 19 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index e077824..f9d170d 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1537,12 +1537,14 @@ def convert_slice_channel(node, **kwargs):
)
return [node]
elif squeeze_axis == 0 and num_outputs > 1:
+ in_shape = kwargs.get('in_shape')[0]
+ split = in_shape[axis] // num_outputs
node = onnx.helper.make_node(
"Split",
input_nodes,
- [name],
+ [name+'_output'+str(i) for i in range(num_outputs)],
axis=axis,
- split=[num_outputs],
+ split=[split for _ in range(num_outputs)],
name=name,
)
return [node]
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index d0d4501..a7b11fc 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -262,17 +262,20 @@ class MXNetGraph(object):
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
- node_name = converted_node.name if converted_node.name else converted_node.output[0]
- if node_name in graph_outputs:
- onnx_processed_outputs.append(
- make_tensor_value_info(
- name=node_name,
- elem_type=in_type,
- shape=graph_outputs[node_name]
+ # some operators have multiple outputs,
+ # therefore, check all output node names
+ node_names = list(converted_node.output)
+ for nodename in node_names:
+ if nodename in graph_outputs:
+ onnx_processed_outputs.append(
+ make_tensor_value_info(
+ name=nodename,
+ elem_type=in_type,
+ shape=graph_outputs[nodename]
+ )
)
- )
- if verbose:
- logging.info("Output node is: %s", converted_node.name)
+ if verbose:
+ logging.info("Output node is: %s", nodename)
elif isinstance(converted_node, TensorProto):
raise ValueError("Did not expect TensorProto")
else:
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index dc00fee..a7cef76 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -484,13 +484,15 @@ def split(attrs, inputs, proto_obj):
if not split_list:
num_outputs = len(proto_obj.model_metadata.get('output_tensor_data'))
else:
- raise NotImplementedError("Operator {} in MXNet does not support variable splits."
- "Tracking the issue to support variable split here: "
- "https://github.com/apache/incubator-mxnet/issues/11594"
- .format('split'))
+ if len(set(split_list)) == 1:
+ num_outputs = len(split_list)
+ else:
+ raise NotImplementedError("Operator {} in MXNet does not support variable splits."
+ "Tracking the issue to support variable split here: "
+ "https://github.com/apache/incubator-mxnet/issues/11594"
+ .format('split'))
new_attrs['num_outputs'] = num_outputs
-
return 'split', new_attrs, inputs
def _slice(attrs, inputs, proto_obj):
diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py
index b20db23..89b60d1 100644
--- a/tests/python-pytest/onnx/test_cases.py
+++ b/tests/python-pytest/onnx/test_cases.py
@@ -77,7 +77,8 @@ IMPLEMENTED_OPERATORS_TEST = {
'test_elu',
'test_max_',
'test_softplus',
- 'test_reduce_'
+ 'test_reduce_',
+ 'test_split_equal'
],
'import': ['test_gather',
'test_softsign',
@@ -88,7 +89,6 @@ IMPLEMENTED_OPERATORS_TEST = {
'test_averagepool_2d_precomputed_strides',
'test_averagepool_2d_strides',
'test_averagepool_3d',
- 'test_split_equal',
'test_hardmax'
],
'export': ['test_random_uniform',