You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/04/13 17:24:18 UTC

[tvm] branch main updated: sort axes (#10985)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 814e856851 sort axes (#10985)
814e856851 is described below

commit 814e856851fcd142c43f57a9bd2f93a7594d1bf2
Author: Margaret Qian <ym...@gmail.com>
AuthorDate: Wed Apr 13 10:24:11 2022 -0700

    sort axes (#10985)
    
    Co-authored-by: Margaret Qian <mq...@octoml.ai>
---
 python/tvm/relay/frontend/onnx.py          | 3 ++-
 tests/python/frontend/onnx/test_forward.py | 4 ----
 2 files changed, 2 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index ab0eeb0910..168362e229 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1505,7 +1505,8 @@ class Unsqueeze(OnnxOpConverter):
 
         rank_input = len(infer_type(inputs[0]).checked_type.shape)
         num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0])
-        axes = relay.split(inputs[1], num_new_axis).astuple()
+        axes = relay.sort(inputs[1])
+        axes = relay.split(axes, num_new_axis).astuple()
         result = inputs[0]
 
         # TODO (AndrewZhaoLuo): investigate performance issues with consecutive
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 94fd0a5de4..12e02d5f29 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5125,10 +5125,6 @@ unsupported_onnx_tests = [
     "test_triu_square",
     "test_triu_square_neg",
     "test_triu_zero",
-    # These unsqueeze tests work, but take 2+ hrs to run
-    "test_unsqueeze_three_axes",
-    "test_unsqueeze_two_axes",
-    "test_unsqueeze_unsorted_axes",
     "test_unique_sorted_with_axis",
     "test_unique_sorted_with_axis_3d",
     "test_unique_sorted_with_negative_axis",