You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/04/14 03:29:30 UTC
[GitHub] [incubator-mxnet] TriLoo opened a new issue #18048: convert mxnet
op `split` to onnx `Split` output names error
TriLoo opened a new issue #18048: convert mxnet op `split` to onnx `Split` output names error
URL: https://github.com/apache/incubator-mxnet/issues/18048
## Description
if a network contains `mx.nd.split()`, then converting it to `onnx` may raise an error `the input of followed layer is not an output of previous layer's output`. The cause is that the output is renamed during converting `mx.nd.split(...)` to `onnx.Split()` node, *i.e.* the `_output + str(i)` is added after the layer name of `mx.nd.split()`, meanwhile, the layer after `mx.nd.split()` which use the outputs of `mx.nd.split()` still expert the layer name of `mx.nd.split()` as inputs. Then the `onnx` would complain that the input is not an output of previous layers!
## To Reproduce
``` python
class TmpMulScalar(gluon.HybridBlock):
def __init__(self, **kwargs):
super(TmpMulScalar, self).__init__(**kwargs)
with self.name_scope():
self.conv = gluon.nn.Conv2D(1, 3, 1, 1, use_bias=False)
self.val = 0.1
def hybrid_forward(self, F, x):
r,g,b = F.split(x, axis=1, num_outputs=3) # split can cause error!!!
r = g + self.val
feat = self.conv(r)
output = F.concat(feat, g, b, dim=1)
return output
def try_mul_scalar():
net = TmpMulScalar()
net.initialize()
data = nd.random.uniform(0.0, 1.0, (1, 3, 10, 10))
net.hybridize()
net(data)
net.export('./temp')
import onnx
from mxnet.contrib import onnx as mx_onnx
sym_file = './temp-symbol.json'
param_file = './temp-0000.params'
converted_file = mx_onnx.export_model(sym_file, param_file, [(1, 3, 10, 10)], onnx_file_path='./temp.onnx')
print('converted_file: ', converted_file)
from onnx import checker
model_onnx = onnx.load_model(converted_file)
checker.check_graph(model_onnx.graph)
```
### Steps to reproduce
just run above code can reproduce this error.
## Possible Solutions
1. add a check in [get_inputs() - op_translation](https://github.com/apache/incubator-mxnet/blob/e3d7866e6854a5c11ab2b2c8bfb63de66f79e132/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py#L129), `i.e.` if the input names contain `split`, then should change the input nodes name to `proc_nodes[input_node_id].name + '_output'+str(ip[1])`, where `ip[1]` means the output index of the `mx.nd.split()`.
The complete function change is shown as below:
``` python
def get_inputs(node, kwargs):
"""Helper function to get inputs"""
name = node["name"]
proc_nodes = kwargs["proc_nodes"]
index_lookup = kwargs["index_lookup"]
inputs = node["inputs"]
attrs = node.get("attrs", {})
input_nodes = []
for ip in inputs:
input_node_id = index_lookup[ip[0]]
input_node_name = proc_nodes[input_node_id].name
if 'split' in input_node_name:
input_node_name = input_node_name + '_output' + str(ip[1])
# input_nodes.append(proc_nodes[input_node_id].name)
input_nodes.append(input_node_name)
return name, input_nodes, attrs
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services