You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/04/14 03:29:30 UTC

[GitHub] [incubator-mxnet] TriLoo opened a new issue #18048: convert mxnet op `split` to onnx `Split` output names error

TriLoo opened a new issue #18048: convert mxnet op `split` to onnx `Split` output names error
URL: https://github.com/apache/incubator-mxnet/issues/18048
 
 
   ## Description
   if a network contains `mx.nd.split()`, then converting it to `onnx` may raise an error `the input of followed layer is not an output of previous layer's output`. The cause is that the output is renamed during converting `mx.nd.split(...)` to `onnx.Split()` node, *i.e.* the `_output + str(i)` is added after the layer name of `mx.nd.split()`, meanwhile, the layer after `mx.nd.split()` which use the outputs of `mx.nd.split()` still expert the layer name of `mx.nd.split()` as inputs.  Then the `onnx` would complain that the input is not an output of previous layers!
   
   ## To Reproduce
   ``` python
   class TmpMulScalar(gluon.HybridBlock):
       def __init__(self, **kwargs):
           super(TmpMulScalar, self).__init__(**kwargs)
   
           with self.name_scope():
               self.conv = gluon.nn.Conv2D(1, 3, 1, 1, use_bias=False)
   
           self.val = 0.1
   
       def hybrid_forward(self, F, x):
           r,g,b = F.split(x, axis=1, num_outputs=3)           # split can cause error!!!
           r = g + self.val
           feat = self.conv(r)
           output = F.concat(feat, g, b, dim=1)
   
           return output
   
   def try_mul_scalar():
       net = TmpMulScalar()
       net.initialize()
       data = nd.random.uniform(0.0, 1.0, (1, 3, 10, 10))
   
       net.hybridize()
       net(data)
   
       net.export('./temp')
   
       import onnx
       from mxnet.contrib import onnx as mx_onnx
   
       sym_file = './temp-symbol.json'
       param_file = './temp-0000.params'
   
       converted_file = mx_onnx.export_model(sym_file, param_file, [(1, 3, 10, 10)], onnx_file_path='./temp.onnx')
       print('converted_file: ', converted_file)
   
       from onnx import checker
       model_onnx = onnx.load_model(converted_file)
       checker.check_graph(model_onnx.graph)
   ```
   
   ### Steps to reproduce
   just run above code can reproduce this error.
   
   ## Possible Solutions
   1. add a check in [get_inputs() - op_translation](https://github.com/apache/incubator-mxnet/blob/e3d7866e6854a5c11ab2b2c8bfb63de66f79e132/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py#L129), `i.e.` if the input names contain `split`, then should change the input nodes name to `proc_nodes[input_node_id].name + '_output'+str(ip[1])`, where `ip[1]` means the output index of the `mx.nd.split()`.
   
   The complete function change is shown as below:
   
   ``` python
   def get_inputs(node, kwargs):
       """Helper function to get inputs"""
       name = node["name"]
       proc_nodes = kwargs["proc_nodes"]
       index_lookup = kwargs["index_lookup"]
       inputs = node["inputs"]
       attrs = node.get("attrs", {})
   
       input_nodes = []
       for ip in inputs:
           input_node_id = index_lookup[ip[0]]
           input_node_name = proc_nodes[input_node_id].name
           if 'split' in input_node_name:
               input_node_name = input_node_name + '_output' + str(ip[1])
           # input_nodes.append(proc_nodes[input_node_id].name)
           input_nodes.append(input_node_name)
   
       return name, input_nodes, attrs
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services