You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/28 04:59:37 UTC

[tvm] branch main updated: [Relay][PyTorch] Add aten::lerp (#12167)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c35c9fd3a5 [Relay][PyTorch] Add aten::lerp (#12167)
c35c9fd3a5 is described below

commit c35c9fd3a5249cfb01093b08b35979db846dfa33
Author: xndcn <xn...@gmail.com>
AuthorDate: Thu Jul 28 12:59:30 2022 +0800

    [Relay][PyTorch] Add aten::lerp (#12167)
---
 python/tvm/relay/frontend/pytorch.py          | 11 +++++++++++
 tests/python/frontend/pytorch/test_forward.py | 15 +++++++++++++++
 2 files changed, 26 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index b88e08b719..1bd3232871 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -343,6 +343,16 @@ class PyTorchOpConverter:
         diag_input = _op.zeros(input_shape, dtype=input_types[0])
         return _op.matrix_set_diag(data, diag_input, k=(k1, k2))
 
+    def lerp(self, inputs, input_types):
+        if len(inputs) != 3:
+            msg = "Wrong number of arguments (%d) to parse." % (len(inputs))
+            raise AssertionError(msg)
+
+        start = inputs[0]
+        end = inputs[1]
+        weight = inputs[2]
+        return start + weight * (end - start)
+
     def arange(self, inputs, input_types):
         def _get_value(val, dtype):
             # dtype is a tvm dtype
@@ -3412,6 +3422,7 @@ class PyTorchOpConverter:
             "aten::stft": self.stft,
             "aten::mul": self.make_elemwise("multiply"),
             "aten::pow": self.make_elemwise("power"),
+            "aten::lerp": self.lerp,
             "aten::arange": self.arange,
             "aten::meshgrid": self.meshgrid,
             "aten::div": self.make_elemwise("divide"),
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 6d7926396a..4332f3efe5 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4596,5 +4596,20 @@ def test_softmax_fuse():
         tvm.testing.assert_allclose(out, output_torch, rtol=1e-5, atol=1e-5)
 
 
+@tvm.testing.uses_gpu
+def test_lerp():
+    def test_fn(x, y, w):
+        return torch.lerp(x, y, w)
+
+    input_shape = [16]
+    x = torch.rand(input_shape).float()
+    y = torch.rand(input_shape).float()
+    w = torch.rand(input_shape).float()
+
+    # weight can be tensor or scalar
+    verify_model(test_fn, [x, y, w])
+    verify_model(test_fn, [x, y, w[0]])
+
+
 if __name__ == "__main__":
     pytest.main([__file__])