You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2023/09/16 04:09:29 UTC

[tvm] branch main updated: [Relay][Bugfix] fix the wrong calculate logic of operator flip in PyTorch frontend (#15752)

This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 61d5be0470 [Relay][Bugfix] fix the wrong calculate logic of operator flip in PyTorch frontend (#15752)
61d5be0470 is described below

commit 61d5be04705c4c5117af89d4af070f997c070e45
Author: Qingchao Shen <qi...@outlook.com>
AuthorDate: Sat Sep 16 12:09:22 2023 +0800

    [Relay][Bugfix] fix the wrong calculate logic of operator flip in PyTorch frontend (#15752)
    
    The original implementation of Flip in PyTorch converter mistaken the type of attribute `axis` in the Flip operator as an integer. Thus, It only parses the first element of the `axis` and will give a wrong calculation result when the length of `axis` is more than one.  According to the PyTorch documentation [here](https://pytorch.org/docs/stable/generated/torch.flip.html), the type of `axis` is a list or tuple.
    
    This PR corrected the incorrect implementation of the algorithm of `torch.flip` converter and added a regression test.
---
 python/tvm/relay/frontend/pytorch.py          |  5 ++++-
 tests/python/frontend/pytorch/test_forward.py | 11 ++++++-----
 2 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9ddd04b5b4..89dcad03e6 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2977,7 +2977,10 @@ class PyTorchOpConverter:
     def flip(self, inputs, input_types):
         data = inputs[0]
         axis = inputs[1]
-        return _op.transform.reverse(data, axis=axis[0])
+        out = data
+        for ax in axis:
+            out = _op.reverse(out, ax)
+        return out
 
     def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):
         """
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 9ee03512e7..6bbb9ef5cc 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4899,13 +4899,14 @@ def test_forward_flip():
             self.axis = axis
 
         def forward(self, x):
-            return x.flip([self.axis])
+            return x.flip(self.axis)
 
     input_t = torch.randn(2, 3, 4)
-    verify_model(Flip(axis=0), input_data=input_t)
-    verify_model(Flip(axis=1), input_data=input_t)
-    verify_model(Flip(axis=2), input_data=input_t)
-    verify_model(Flip(axis=-1), input_data=input_t)
+    verify_model(Flip(axis=[0]), input_data=input_t)
+    verify_model(Flip(axis=[1]), input_data=input_t)
+    verify_model(Flip(axis=[2]), input_data=input_t)
+    verify_model(Flip(axis=[-1]), input_data=input_t)
+    verify_model(Flip(axis=[0, 1]), input_data=input_t)
 
 
 def test_annotate_span():