You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/11/11 06:20:21 UTC

[tvm] branch main updated: Add default for split op (#9489)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e09bb2  Add default for split op (#9489)
1e09bb2 is described below

commit 1e09bb290d0a7ef91045b11df932c6ddf5bb58ce
Author: anwang2009 <an...@gmail.com>
AuthorDate: Wed Nov 10 22:19:37 2021 -0800

    Add default for split op (#9489)
    
    * split fix
    
    * add default split test case
---
 python/tvm/relay/frontend/onnx.py          | 3 +--
 tests/python/frontend/onnx/test_forward.py | 3 +++
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 3c88f65..5813f63 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1461,9 +1461,8 @@ class Split(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         splits = attr.get("split", None)
-        if splits is not None:
+        if splits is not None and len(splits) > 1:
             indices = []
-            attr["indices_or_sections"] = []
             index = 0
             for i in splits[:-1]:
                 index += i
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index dd1c773..f8870ed 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1966,6 +1966,9 @@ def test_split(target, dev):
     verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False)
     # Split a single value to a single value
     verify_split([1], [[1]], [1], pass_split=True)
+    # Test that the default case modifies nothing when split list has length one
+    verify_split([[1.0, 2.0]], [[1.0, 2.0]], [2], 1)
+    verify_split([[1.0, 2.0]], [[1.0, 2.0]], [1], 0)
 
 
 @tvm.testing.parametrize_targets