You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/12 09:09:29 UTC

[tvm] branch main updated: [PyTorch] Fix all_any_common with no default input (#12395)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a1c371f46c [PyTorch] Fix all_any_common with no default input (#12395)
a1c371f46c is described below

commit a1c371f46cf77dcdffa6f0ab55f5036bff1c5624
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Thu Aug 11 23:09:22 2022 -1000

    [PyTorch] Fix all_any_common with no default input (#12395)
    
    * fix all_any_common with no default input
    
    * work around
    
    * better naming
---
 python/tvm/relay/frontend/pytorch.py          | 10 ++++++++--
 tests/python/frontend/pytorch/test_forward.py |  5 +++++
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index ffe4b313c5..0e6d4caae0 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3253,8 +3253,14 @@ class PyTorchOpConverter:
         return (output, _op.stack(hy, 0), _op.stack(cy, 0))
 
     def all_any_common(self, op, inputs, input_types):
-        dim = inputs[1]
-        keepdim = inputs[2]
+        if len(inputs) >= 2:
+            dim = inputs[1]
+        else:
+            dim = None
+        if len(inputs) >= 3:
+            keepdim = inputs[2]
+        else:
+            keepdim = False
         if self.infer_type(inputs[0]).dtype != "bool":
             # The input dtype can be uint8.
             inp = _op.cast(inputs[0], "bool")
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 6b1eb30a56..4c78ba4b85 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4385,11 +4385,16 @@ def test_all_any():
     def test_fn(f, dim=None, keepdim=False):
         return lambda x: f(x, dim=dim, keepdim=keepdim)
 
+    def test_fn_no_arg(f):
+        return lambda x: f(x)
+
     for f in [torch.all, torch.any]:
         verify_model(test_fn(f, 0), [torch.rand(1, 2).bool()])
         verify_model(test_fn(f, 0), [torch.arange(0, 3).to(torch.uint8)])
         verify_model(test_fn(f, 1), [torch.rand(4, 2).bool()])
         verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()])
+        verify_model(test_fn_no_arg(f), [torch.rand(1, 2).bool()])
+        verify_model(test_fn_no_arg(f), [torch.arange(0, 3).to(torch.uint8)])
 
 
 @tvm.testing.uses_gpu