You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/28 04:59:37 UTC
[tvm] branch main updated: [Relay][PyTorch] Add aten::lerp (#12167)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c35c9fd3a5 [Relay][PyTorch] Add aten::lerp (#12167)
c35c9fd3a5 is described below
commit c35c9fd3a5249cfb01093b08b35979db846dfa33
Author: xndcn <xn...@gmail.com>
AuthorDate: Thu Jul 28 12:59:30 2022 +0800
[Relay][PyTorch] Add aten::lerp (#12167)
---
python/tvm/relay/frontend/pytorch.py | 11 +++++++++++
tests/python/frontend/pytorch/test_forward.py | 15 +++++++++++++++
2 files changed, 26 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index b88e08b719..1bd3232871 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -343,6 +343,16 @@ class PyTorchOpConverter:
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))
+ def lerp(self, inputs, input_types):
+ if len(inputs) != 3:
+ msg = "Wrong number of arguments (%d) to parse." % (len(inputs))
+ raise AssertionError(msg)
+
+ start = inputs[0]
+ end = inputs[1]
+ weight = inputs[2]
+ return start + weight * (end - start)
+
def arange(self, inputs, input_types):
def _get_value(val, dtype):
# dtype is a tvm dtype
@@ -3412,6 +3422,7 @@ class PyTorchOpConverter:
"aten::stft": self.stft,
"aten::mul": self.make_elemwise("multiply"),
"aten::pow": self.make_elemwise("power"),
+ "aten::lerp": self.lerp,
"aten::arange": self.arange,
"aten::meshgrid": self.meshgrid,
"aten::div": self.make_elemwise("divide"),
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 6d7926396a..4332f3efe5 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4596,5 +4596,20 @@ def test_softmax_fuse():
tvm.testing.assert_allclose(out, output_torch, rtol=1e-5, atol=1e-5)
+@tvm.testing.uses_gpu
+def test_lerp():
+ def test_fn(x, y, w):
+ return torch.lerp(x, y, w)
+
+ input_shape = [16]
+ x = torch.rand(input_shape).float()
+ y = torch.rand(input_shape).float()
+ w = torch.rand(input_shape).float()
+
+ # weight can be tensor or scalar
+ verify_model(test_fn, [x, y, w])
+ verify_model(test_fn, [x, y, w[0]])
+
+
if __name__ == "__main__":
pytest.main([__file__])