You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/16 06:11:47 UTC
[tvm] branch main updated: [TVMScript] Support roundtrip of LetNode (#11742)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 24010db6c0 [TVMScript] Support roundtrip of LetNode (#11742)
24010db6c0 is described below
commit 24010db6c0e90bc555f6d12e23381fa7b00cf25d
Author: wrongtest <wr...@gmail.com>
AuthorDate: Thu Jun 16 14:11:41 2022 +0800
[TVMScript] Support roundtrip of LetNode (#11742)
Just a missing support for `tir.LetNode`
---
python/tvm/script/tir/scope_handler.py | 3 +++
tests/python/unittest/test_tvmscript_roundtrip.py | 10 ++++++++++
2 files changed, 13 insertions(+)
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
index 7d3250fe87..85882055d0 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -312,6 +312,9 @@ class Let(WithScopeHandler):
super().__init__(let, concise_scope=False, def_symbol=False)
+ def __call__(self, var: tvm.tir.Var, value: tvm.tir.PrimExpr, body: tvm.tir.PrimExpr):
+ return tvm.tir.Let(var, value, body)
+
@register
class Block(WithScopeHandler):
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 93bd0707c6..306f60f1b1 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3288,6 +3288,15 @@ def buffer_ramp_access_as_slice_index():
return buffer_ramp_access
+def let_expression():
+ @T.prim_func
+ def func():
+ x = T.var("int32")
+ T.evaluate(T.let(x, 1, x + 1))
+
+ return func
+
+
ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
@@ -3325,6 +3334,7 @@ ir_generator = tvm.testing.parameter(
pointer_type,
buffer_axis_separator,
buffer_ramp_access_as_slice_index,
+ let_expression,
)