You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/12 09:09:29 UTC
[tvm] branch main updated: [PyTorch] Fix all_any_common with no default input (#12395)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a1c371f46c [PyTorch] Fix all_any_common with no default input (#12395)
a1c371f46c is described below
commit a1c371f46cf77dcdffa6f0ab55f5036bff1c5624
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Thu Aug 11 23:09:22 2022 -1000
[PyTorch] Fix all_any_common with no default input (#12395)
* fix all_any_common with no default input
* work around
* better naming
---
python/tvm/relay/frontend/pytorch.py | 10 ++++++++--
tests/python/frontend/pytorch/test_forward.py | 5 +++++
2 files changed, 13 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index ffe4b313c5..0e6d4caae0 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3253,8 +3253,14 @@ class PyTorchOpConverter:
return (output, _op.stack(hy, 0), _op.stack(cy, 0))
def all_any_common(self, op, inputs, input_types):
- dim = inputs[1]
- keepdim = inputs[2]
+ if len(inputs) >= 2:
+ dim = inputs[1]
+ else:
+ dim = None
+ if len(inputs) >= 3:
+ keepdim = inputs[2]
+ else:
+ keepdim = False
if self.infer_type(inputs[0]).dtype != "bool":
# The input dtype can be uint8.
inp = _op.cast(inputs[0], "bool")
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 6b1eb30a56..4c78ba4b85 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4385,11 +4385,16 @@ def test_all_any():
def test_fn(f, dim=None, keepdim=False):
return lambda x: f(x, dim=dim, keepdim=keepdim)
+ def test_fn_no_arg(f):
+ return lambda x: f(x)
+
for f in [torch.all, torch.any]:
verify_model(test_fn(f, 0), [torch.rand(1, 2).bool()])
verify_model(test_fn(f, 0), [torch.arange(0, 3).to(torch.uint8)])
verify_model(test_fn(f, 1), [torch.rand(4, 2).bool()])
verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()])
+ verify_model(test_fn_no_arg(f), [torch.rand(1, 2).bool()])
+ verify_model(test_fn_no_arg(f), [torch.arange(0, 3).to(torch.uint8)])
@tvm.testing.uses_gpu