You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/05/15 00:50:21 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] onnx fix rnn (#20272)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 5fa2234  [v1.x] onnx fix rnn (#20272)
5fa2234 is described below

commit 5fa22343e3fc0ba119dc58399f8044af02fd48da
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Fri May 14 17:48:07 2021 -0700

    [v1.x] onnx fix rnn (#20272)
    
    * fix rnn
    
    * fix whitespace
---
 .../mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py | 2 +-
 .../mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py | 6 +++---
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
index c1eadf4..b73c5bf 100644
--- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
+++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
@@ -4525,7 +4525,7 @@ def convert_RNN(node, **kwargs):
     create_tensor([state_size], name+'_state_size', kwargs['initializer'])
     create_tensor([direction], name+'_direction', kwargs['initializer'])
 
-    tensor_1 = make_tensor(name+'_1_f', dtype, [1], [1])
+    tensor_1 = make_tensor(name+'_1_f', onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype], [1], [1])
 
     nodes = [
         make_node('Shape', [data], [name+'_data_shape']),
diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py
index 1a20de4..812ade6 100644
--- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py
+++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py
@@ -208,7 +208,6 @@ def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'):
         raw=False
     )
     initializer.append(tensor)
-    return tensor
 
 
 def create_helper_trans_node(node_name, input_node):
@@ -1009,7 +1008,7 @@ def convert_RNN(node, **kwargs):
     """Map MXNet's RNN operator attributes to onnx's operators
     and return the created node.
     """
-    from onnx.helper import make_node
+    from onnx.helper import make_node, make_tensor
     from onnx import TensorProto
 
     name, input_nodes, attrs = get_inputs(node, kwargs)
@@ -1047,7 +1046,8 @@ def convert_RNN(node, **kwargs):
     create_tensor([1], name+'_1', kwargs['initializer'])
     create_tensor([state_size], name+'_state_size', kwargs['initializer'])
     create_tensor([direction], name+'_direction', kwargs['initializer'])
-    tensor_1 = create_tensor([1], name+'_1_f', kwargs['initializer'], dtype)
+
+    tensor_1 = make_tensor(name+'_1_f', onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype], [1], [1])
 
     nodes = [
         make_node('Shape', [data], [name+'_data_shape']),