You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/16 06:11:47 UTC

[tvm] branch main updated: [TVMScript] Support roundtrip of LetNode (#11742)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 24010db6c0 [TVMScript] Support roundtrip of LetNode (#11742)
24010db6c0 is described below

commit 24010db6c0e90bc555f6d12e23381fa7b00cf25d
Author: wrongtest <wr...@gmail.com>
AuthorDate: Thu Jun 16 14:11:41 2022 +0800

    [TVMScript] Support roundtrip of LetNode (#11742)
    
    Just a missing support for `tir.LetNode`
---
 python/tvm/script/tir/scope_handler.py            |  3 +++
 tests/python/unittest/test_tvmscript_roundtrip.py | 10 ++++++++++
 2 files changed, 13 insertions(+)

diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
index 7d3250fe87..85882055d0 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -312,6 +312,9 @@ class Let(WithScopeHandler):
 
         super().__init__(let, concise_scope=False, def_symbol=False)
 
+    def __call__(self, var: tvm.tir.Var, value: tvm.tir.PrimExpr, body: tvm.tir.PrimExpr):
+        return tvm.tir.Let(var, value, body)
+
 
 @register
 class Block(WithScopeHandler):
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 93bd0707c6..306f60f1b1 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3288,6 +3288,15 @@ def buffer_ramp_access_as_slice_index():
     return buffer_ramp_access
 
 
+def let_expression():
+    @T.prim_func
+    def func():
+        x = T.var("int32")
+        T.evaluate(T.let(x, 1, x + 1))
+
+    return func
+
+
 ir_generator = tvm.testing.parameter(
     opt_gemm_normalize,
     opt_gemm_lower,
@@ -3325,6 +3334,7 @@ ir_generator = tvm.testing.parameter(
     pointer_type,
     buffer_axis_separator,
     buffer_ramp_access_as_slice_index,
+    let_expression,
 )