You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2023/09/16 04:09:29 UTC
[tvm] branch main updated: [Relay][Bugfix] fix the wrong calculate logic of operator flip in PyTorch frontend (#15752)
This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 61d5be0470 [Relay][Bugfix] fix the wrong calculate logic of operator flip in PyTorch frontend (#15752)
61d5be0470 is described below
commit 61d5be04705c4c5117af89d4af070f997c070e45
Author: Qingchao Shen <qi...@outlook.com>
AuthorDate: Sat Sep 16 12:09:22 2023 +0800
[Relay][Bugfix] fix the wrong calculate logic of operator flip in PyTorch frontend (#15752)
The original implementation of Flip in PyTorch converter mistaken the type of attribute `axis` in the Flip operator as an integer. Thus, It only parses the first element of the `axis` and will give a wrong calculation result when the length of `axis` is more than one. According to the PyTorch documentation [here](https://pytorch.org/docs/stable/generated/torch.flip.html), the type of `axis` is a list or tuple.
This PR corrected the incorrect implementation of the algorithm of `torch.flip` converter and added a regression test.
---
python/tvm/relay/frontend/pytorch.py | 5 ++++-
tests/python/frontend/pytorch/test_forward.py | 11 ++++++-----
2 files changed, 10 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9ddd04b5b4..89dcad03e6 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2977,7 +2977,10 @@ class PyTorchOpConverter:
def flip(self, inputs, input_types):
data = inputs[0]
axis = inputs[1]
- return _op.transform.reverse(data, axis=axis[0])
+ out = data
+ for ax in axis:
+ out = _op.reverse(out, ax)
+ return out
def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):
"""
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 9ee03512e7..6bbb9ef5cc 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4899,13 +4899,14 @@ def test_forward_flip():
self.axis = axis
def forward(self, x):
- return x.flip([self.axis])
+ return x.flip(self.axis)
input_t = torch.randn(2, 3, 4)
- verify_model(Flip(axis=0), input_data=input_t)
- verify_model(Flip(axis=1), input_data=input_t)
- verify_model(Flip(axis=2), input_data=input_t)
- verify_model(Flip(axis=-1), input_data=input_t)
+ verify_model(Flip(axis=[0]), input_data=input_t)
+ verify_model(Flip(axis=[1]), input_data=input_t)
+ verify_model(Flip(axis=[2]), input_data=input_t)
+ verify_model(Flip(axis=[-1]), input_data=input_t)
+ verify_model(Flip(axis=[0, 1]), input_data=input_t)
def test_annotate_span():