You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/24 19:59:23 UTC
[tvm] branch main updated: [Relay] [Pytorch] Add aten::maximum and aten::minimum (#11864)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 77d73b5b28 [Relay] [Pytorch] Add aten::maximum and aten::minimum (#11864)
77d73b5b28 is described below
commit 77d73b5b28c1b07f8d40d3339ef17adab5eb0eec
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Fri Jun 24 12:59:18 2022 -0700
[Relay] [Pytorch] Add aten::maximum and aten::minimum (#11864)
* add maximum and minimum
* cleanup
---
python/tvm/relay/frontend/pytorch.py | 10 ++++++++++
tests/python/frontend/pytorch/test_forward.py | 16 ++++++++++++++++
2 files changed, 26 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index ac7b52237a..ba0d025026 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -293,6 +293,14 @@ class PyTorchOpConverter:
def min(self, inputs, input_types):
return self.min_max_common("minimum", "min", inputs, input_types)
+ def maximum(self, inputs, input_types):
+ data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
+ return _op.maximum(data0, data1)
+
+ def minimum(self, inputs, input_types):
+ data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
+ return _op.minimum(data0, data1)
+
def make_unary(self, name):
def unary(inputs, input_types):
# this is just to ensure tensor input
@@ -3020,6 +3028,8 @@ class PyTorchOpConverter:
"aten::sub": self.sub,
"aten::max": self.max,
"aten::min": self.min,
+ "aten::maximum": self.maximum,
+ "aten::minimum": self.minimum,
"aten::amax": self.max,
"aten::amin": self.min,
"aten::stft": self.stft,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 93071839d1..9609008c99 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -363,6 +363,22 @@ def test_min_max():
verify_model(Min4(), input_data=input_data[0])
+@tvm.testing.uses_gpu
+def test_minimum_maximum():
+ class Maximum(Module):
+ def forward(self, lhs, rhs):
+ return torch.maximum(lhs, rhs)
+
+ class Minimum(Module):
+ def forward(self, lhs, rhs):
+ return torch.minimum(lhs, rhs)
+
+ input_data = [torch.rand((10, 10, 10, 10)), torch.rand((10, 10, 10, 10))]
+
+ verify_model(Maximum(), input_data=input_data)
+ verify_model(Minimum(), input_data=input_data)
+
+
@tvm.testing.uses_gpu
def test_forward_reciprocal():
torch.set_grad_enabled(False)