You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/05/10 15:58:52 UTC

[tvm] branch main updated: add onnx reverse sequence op (#7771)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2077ee2  add onnx reverse sequence op (#7771)
2077ee2 is described below

commit 2077ee2f9a5b3c1c4f9793742387fdbc8e848494
Author: alter-xp <xp...@linux.alibaba.com>
AuthorDate: Mon May 10 23:58:26 2021 +0800

    add onnx reverse sequence op (#7771)
    
    Co-authored-by: xp224797 <xp...@alibaba-inc.com>
---
 python/tvm/relay/frontend/onnx.py          | 10 +++++++
 tests/python/frontend/onnx/test_forward.py | 42 +++++++++++++++++++++++++++---
 2 files changed, 49 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index a62e505..e4a6885 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2269,6 +2269,15 @@ class NonZero(OnnxOpConverter):
         return _op.transpose(output, axes=(1, 0))
 
 
+class ReverseSequence(OnnxOpConverter):
+    """Operator converter for ReverseSequence"""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+
+        return _op.reverse_sequence(inputs[0], inputs[1], attr["time_axis"], attr["batch_axis"])
+
+
 class TopK(OnnxOpConverter):
     """Operator converter for TopK"""
 
@@ -3007,6 +3016,7 @@ def _get_convert_map(opset):
         "QuantizeLinear": QuantizeLinear.get_converter(opset),
         "DequantizeLinear": DequantizeLinear.get_converter(opset),
         "DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset),
+        "ReverseSequence": ReverseSequence.get_converter(opset),
     }
 
 
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index f878fa9..8965584 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4225,9 +4225,6 @@ unsupported_onnx_tests = [
     "test_resize_upsample_sizes_nearest_ceil_half_pixel/",
     "test_resize_upsample_sizes_nearest_floor_align_corners/",
     "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/",
-    # ----
-    "test_reversesequence_batch/",
-    "test_reversesequence_time/",
     "test_rnn_seq_length/",
     "test_round/",
     "test_scan9_sum/",
@@ -4350,6 +4347,44 @@ def test_aten():
     verify_embedding_bag(32, 2, [3, 3])
 
 
+def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis):
+    node = onnx.helper.make_node(
+        "ReverseSequence",
+        inputs=["x", "sequence_lens"],
+        outputs=["y"],
+        time_axis=time_axis,
+        batch_axis=batch_axis,
+    )
+
+    graph = helper.make_graph(
+        [node],
+        "reverse_sequence_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)),
+            helper.make_tensor_value_info(
+                "sequence_lens", TensorProto.INT64, list(sequence_lens.shape)
+            ),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="reverse_sequence_test")
+    verify_with_ort_with_inputs(model, [x, sequence_lens], [x.shape])
+
+
+@tvm.testing.uses_gpu
+def test_reverse_sequence():
+    x = np.array(
+        [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],
+        dtype=np.float32,
+    )
+    sequence_lens = np.array([1, 2, 3, 4], dtype=np.int64)
+    verify_reverse_sequence(x, sequence_lens, 0, 1)
+
+    sequence_lens = np.array([4, 3, 2, 1], dtype=np.int64)
+    verify_reverse_sequence(x, sequence_lens, 1, 0)
+
+
 if __name__ == "__main__":
     test_flatten()
     test_reshape()
@@ -4430,3 +4465,4 @@ if __name__ == "__main__":
     test_cumsum()
     test_wrong_input()
     test_aten()
+    test_reverse_sequence()