You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/25 03:42:41 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX support for SequenceReverse (#19954)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new b31c5c8  [v1.x] ONNX support for SequenceReverse (#19954)
b31c5c8 is described below

commit b31c5c88dd6dc49543d36735edc31fc995821414
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Wed Feb 24 19:41:12 2021 -0800

    [v1.x] ONNX support for SequenceReverse (#19954)
    
    * add support for sequencereverse
    
    * fix sanity
    
    * fix spelling
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 31 ++++++++++++++++++++++
 tests/python-pytest/onnx/test_operators.py         | 13 +++++++++
 2 files changed, 44 insertions(+)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 3e1bcd2..2ce0a6b 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -4033,3 +4033,34 @@ def convert_argsort(node, **kwargs):
         ]
 
     return nodes
+
+
+@mx_op.register('SequenceReverse')
+def convert_sequence_reverse(node, **kwargs):
+    """Map MXNet's SequenceReverse op
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    batch_axis = 1
+    time_axis = 0
+    use_sequence_length = attrs.get('use_sequence_length', 'False')
+
+    nodes = []
+    if use_sequence_length == 'False':
+        nodes += [
+            make_node('Shape', [input_nodes[0]], [name+'_shape']),
+            make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2']),
+            make_node('Expand', [name+'_dim0', name+'_dim1'], [name+'_seq_len']),
+            make_node('ReverseSequence', [input_nodes[0], name+'_seq_len'], [name],
+                      batch_axis=batch_axis, time_axis=time_axis)
+        ]
+    else:
+        nodes += [
+            make_node('Cast', [input_nodes[1]], [name+'_seq_len'], to=int(TensorProto.INT64)),
+            make_node('ReverseSequence', [input_nodes[0], name+'_seq_len'], [name],
+                      batch_axis=batch_axis, time_axis=time_axis)
+        ]
+
+    return nodes
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 21932b4..9931628 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -1181,3 +1181,16 @@ def test_onnx_export_take_raise(tmp_path, dtype, axis):
     y = mx.random.randint(0, 3, (6, 7)).astype(dtype)
     M = def_model('take', axis=axis, mode='raise')
     op_export_test('take', M, [x, y], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
+@pytest.mark.parametrize('params', [((6, 5, 4), [1, 2, 4, 5, 6]),
+                                     ((7, 3, 5), [1, 7, 4]),
+                                     ((3, 2, 1), [1, 2])])
+def test_onnx_export_sequence_reverse(tmp_path, dtype, params):
+    x = mx.nd.random.uniform(0, 10, params[0]).astype(dtype)
+    M1 = def_model('SequenceReverse')
+    op_export_test('SequenceReverse1', M1, [x], tmp_path)
+    seq_len = mx.nd.array(params[1])
+    M1 = def_model('SequenceReverse', use_sequence_length=True)
+    op_export_test('SequenceReverse1', M1, [x, seq_len], tmp_path)