You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/24 19:59:23 UTC

[tvm] branch main updated: [Relay] [Pytorch] Add aten::maximum and aten::minimum (#11864)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 77d73b5b28 [Relay] [Pytorch] Add aten::maximum and aten::minimum (#11864)
77d73b5b28 is described below

commit 77d73b5b28c1b07f8d40d3339ef17adab5eb0eec
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Fri Jun 24 12:59:18 2022 -0700

    [Relay] [Pytorch] Add aten::maximum and aten::minimum (#11864)
    
    * add maximum and minimum
    
    * cleanup
---
 python/tvm/relay/frontend/pytorch.py          | 10 ++++++++++
 tests/python/frontend/pytorch/test_forward.py | 16 ++++++++++++++++
 2 files changed, 26 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index ac7b52237a..ba0d025026 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -293,6 +293,14 @@ class PyTorchOpConverter:
     def min(self, inputs, input_types):
         return self.min_max_common("minimum", "min", inputs, input_types)
 
+    def maximum(self, inputs, input_types):
+        data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
+        return _op.maximum(data0, data1)
+
+    def minimum(self, inputs, input_types):
+        data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
+        return _op.minimum(data0, data1)
+
     def make_unary(self, name):
         def unary(inputs, input_types):
             # this is just to ensure tensor input
@@ -3020,6 +3028,8 @@ class PyTorchOpConverter:
             "aten::sub": self.sub,
             "aten::max": self.max,
             "aten::min": self.min,
+            "aten::maximum": self.maximum,
+            "aten::minimum": self.minimum,
             "aten::amax": self.max,
             "aten::amin": self.min,
             "aten::stft": self.stft,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 93071839d1..9609008c99 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -363,6 +363,22 @@ def test_min_max():
     verify_model(Min4(), input_data=input_data[0])
 
 
+@tvm.testing.uses_gpu
+def test_minimum_maximum():
+    class Maximum(Module):
+        def forward(self, lhs, rhs):
+            return torch.maximum(lhs, rhs)
+
+    class Minimum(Module):
+        def forward(self, lhs, rhs):
+            return torch.minimum(lhs, rhs)
+
+    input_data = [torch.rand((10, 10, 10, 10)), torch.rand((10, 10, 10, 10))]
+
+    verify_model(Maximum(), input_data=input_data)
+    verify_model(Minimum(), input_data=input_data)
+
+
 @tvm.testing.uses_gpu
 def test_forward_reciprocal():
     torch.set_grad_enabled(False)