You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/04/13 17:24:18 UTC
[tvm] branch main updated: sort axes (#10985)
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 814e856851 sort axes (#10985)
814e856851 is described below
commit 814e856851fcd142c43f57a9bd2f93a7594d1bf2
Author: Margaret Qian <ym...@gmail.com>
AuthorDate: Wed Apr 13 10:24:11 2022 -0700
sort axes (#10985)
Co-authored-by: Margaret Qian <mq...@octoml.ai>
---
python/tvm/relay/frontend/onnx.py | 3 ++-
tests/python/frontend/onnx/test_forward.py | 4 ----
2 files changed, 2 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index ab0eeb0910..168362e229 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1505,7 +1505,8 @@ class Unsqueeze(OnnxOpConverter):
rank_input = len(infer_type(inputs[0]).checked_type.shape)
num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0])
- axes = relay.split(inputs[1], num_new_axis).astuple()
+ axes = relay.sort(inputs[1])
+ axes = relay.split(axes, num_new_axis).astuple()
result = inputs[0]
# TODO (AndrewZhaoLuo): investigate performance issues with consecutive
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 94fd0a5de4..12e02d5f29 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5125,10 +5125,6 @@ unsupported_onnx_tests = [
"test_triu_square",
"test_triu_square_neg",
"test_triu_zero",
- # These unsqueeze tests work, but take 2+ hrs to run
- "test_unsqueeze_three_axes",
- "test_unsqueeze_two_axes",
- "test_unsqueeze_unsorted_axes",
"test_unique_sorted_with_axis",
"test_unique_sorted_with_axis_3d",
"test_unique_sorted_with_negative_axis",