You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/04/26 07:04:42 UTC

[tvm] branch main updated: allow constant value let binding in script (#11115)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4dc47df369 allow constant value let binding in script (#11115)
4dc47df369 is described below

commit 4dc47df369f3116f7674e474ea655b4c9e2e25ab
Author: wrongtest <wr...@gmail.com>
AuthorDate: Tue Apr 26 15:04:35 2022 +0800

    allow constant value let binding in script (#11115)
---
 python/tvm/script/parser.py                        | 49 +++++++++++-----------
 .../python/unittest/test_tvmscript_syntax_sugar.py | 16 +++++++
 2 files changed, 41 insertions(+), 24 deletions(-)

diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 92a730903b..b01ad383c3 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -574,32 +574,33 @@ class TVMScriptParser(Transformer):
                 arg_list = self.parse_arg_list(func, node.rhs)
                 func.handle(node, self.context, arg_list, node.rhs.func_name.span)
                 return self.parse_body(node)
-            else:
-                value = self.transform(node.rhs)
-                if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
-                    # This is a little confusing because it only is true when
-                    # we have taken this branch. We might need to clarify what
-                    # exectly is allowed in Assignments in tvmscript.
-                    self.report_error(
-                        "Left hand side of assignment must be an unqualified variable",
-                        node.span,
-                    )
-                ast_var = node.lhs[0]
+        if isinstance(node.rhs, (ast.Call, ast.Constant)):
+            # Pattern 4 of let binding
+            value = self.transform(node.rhs)
+            if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
+                # This is a little confusing because it only is true when
+                # we have taken this branch. We might need to clarify what
+                # exectly is allowed in Assignments in tvmscript.
+                self.report_error(
+                    "Left hand side of assignment must be an unqualified variable",
+                    node.span,
+                )
+            ast_var = node.lhs[0]
 
-                if node.ty is None and hasattr(value, "dtype"):
-                    var_ty = value.dtype
-                else:
-                    var_ty = self.parse_type(node.ty, ast_var)
+            if node.ty is None and hasattr(value, "dtype"):
+                var_ty = value.dtype
+            else:
+                var_ty = self.parse_type(node.ty, ast_var)
 
-                var = tvm.te.var(
-                    ast_var.id.name,
-                    var_ty,
-                    span=tvm_span_from_synr(ast_var.span),
-                )
-                self.context.update_symbol(var.name, var, node)
-                body = self.parse_body(node)
-                self.context.remove_symbol(var.name)
-                return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
+            var = tvm.te.var(
+                ast_var.id.name,
+                var_ty,
+                span=tvm_span_from_synr(ast_var.span),
+            )
+            self.context.update_symbol(var.name, var, node)
+            body = self.parse_body(node)
+            self.context.remove_symbol(var.name)
+            return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
 
         self.report_error(
             """Assignments should be either
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 1d3c8ab1f1..a0964ea4d7 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -249,5 +249,21 @@ def test_letstmt_bufferload_without_type_annotation():
         T.evaluate(x)
 
 
+def test_letstmt_bind_with_constant():
+    @T.prim_func
+    def constant_binds():
+        x = 1
+        y = 42.0
+        T.evaluate(T.cast(x, "float32") + y)
+
+    @T.prim_func
+    def constant_binds_wrapped():
+        x = T.int32(1)
+        y = T.float32(42.0)
+        T.evaluate(T.cast(x, "float32") + y)
+
+    assert_structural_equal(constant_binds, constant_binds_wrapped)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))