You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/04/26 07:04:42 UTC
[tvm] branch main updated: allow constant value let binding in script (#11115)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4dc47df369 allow constant value let binding in script (#11115)
4dc47df369 is described below
commit 4dc47df369f3116f7674e474ea655b4c9e2e25ab
Author: wrongtest <wr...@gmail.com>
AuthorDate: Tue Apr 26 15:04:35 2022 +0800
allow constant value let binding in script (#11115)
---
python/tvm/script/parser.py | 49 +++++++++++-----------
.../python/unittest/test_tvmscript_syntax_sugar.py | 16 +++++++
2 files changed, 41 insertions(+), 24 deletions(-)
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 92a730903b..b01ad383c3 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -574,32 +574,33 @@ class TVMScriptParser(Transformer):
arg_list = self.parse_arg_list(func, node.rhs)
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
return self.parse_body(node)
- else:
- value = self.transform(node.rhs)
- if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
- # This is a little confusing because it only is true when
- # we have taken this branch. We might need to clarify what
- # exectly is allowed in Assignments in tvmscript.
- self.report_error(
- "Left hand side of assignment must be an unqualified variable",
- node.span,
- )
- ast_var = node.lhs[0]
+ if isinstance(node.rhs, (ast.Call, ast.Constant)):
+ # Pattern 4 of let binding
+ value = self.transform(node.rhs)
+ if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
+ # This is a little confusing because it only is true when
+ # we have taken this branch. We might need to clarify what
+ # exectly is allowed in Assignments in tvmscript.
+ self.report_error(
+ "Left hand side of assignment must be an unqualified variable",
+ node.span,
+ )
+ ast_var = node.lhs[0]
- if node.ty is None and hasattr(value, "dtype"):
- var_ty = value.dtype
- else:
- var_ty = self.parse_type(node.ty, ast_var)
+ if node.ty is None and hasattr(value, "dtype"):
+ var_ty = value.dtype
+ else:
+ var_ty = self.parse_type(node.ty, ast_var)
- var = tvm.te.var(
- ast_var.id.name,
- var_ty,
- span=tvm_span_from_synr(ast_var.span),
- )
- self.context.update_symbol(var.name, var, node)
- body = self.parse_body(node)
- self.context.remove_symbol(var.name)
- return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
+ var = tvm.te.var(
+ ast_var.id.name,
+ var_ty,
+ span=tvm_span_from_synr(ast_var.span),
+ )
+ self.context.update_symbol(var.name, var, node)
+ body = self.parse_body(node)
+ self.context.remove_symbol(var.name)
+ return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
self.report_error(
"""Assignments should be either
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 1d3c8ab1f1..a0964ea4d7 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -249,5 +249,21 @@ def test_letstmt_bufferload_without_type_annotation():
T.evaluate(x)
+def test_letstmt_bind_with_constant():
+ @T.prim_func
+ def constant_binds():
+ x = 1
+ y = 42.0
+ T.evaluate(T.cast(x, "float32") + y)
+
+ @T.prim_func
+ def constant_binds_wrapped():
+ x = T.int32(1)
+ y = T.float32(42.0)
+ T.evaluate(T.cast(x, "float32") + y)
+
+ assert_structural_equal(constant_binds, constant_binds_wrapped)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))