You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/11/16 01:07:26 UTC

[tvm] branch main updated: [TVMScript] Use tir::Evaluate if expression is in statement context (#13396)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 86a5ceec27 [TVMScript] Use tir::Evaluate if expression is in statement context (#13396)
86a5ceec27 is described below

commit 86a5ceec271f241451b641d10b4c27e0cdeb1e89
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Nov 15 19:07:20 2022 -0600

    [TVMScript] Use tir::Evaluate if expression is in statement context (#13396)
    
    * [TVMScript] Use tir::Evaluate if expression is in statement context
    
    For the previous version of the parser, this was special-cased for
    some intrinsic operators.  After the new TVMScript was enabled in
    https://github.com/apache/tvm/pull/12496, any `PrimExpr` that appears
    in the body of a statement is silently ignored.  This commit updates
    the parser to instead wrap the bare `PrimExpr` in a `tir::Evaluate`
    node.
    
    This change effectively allows [expression
    statements](https://docs.python.org/3/reference/simple_stmts.html#expression-statements)
    in TVMScript, which are converted to `tir::Evaluate` nodes during
    parsing.
    
    * Update to print T.evaluate() for readability, except for CallNode
---
 python/tvm/script/parser/tir/parser.py             |  5 +++++
 src/printer/tvmscript_printer.cc                   | 19 ++++++++--------
 tests/python/unittest/test_tvmscript_roundtrip.py  | 10 +++++++++
 .../python/unittest/test_tvmscript_syntax_sugar.py | 26 ++++++++++++++++++++++
 4 files changed, 51 insertions(+), 9 deletions(-)

diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
index 1370758f5a..0e74114ba2 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -20,6 +20,7 @@ import contextlib
 from functools import partial
 from typing import Any
 
+import tvm
 from tvm.ir import PrimType
 from tvm.tir import Buffer, IterVar, PrimExpr, Var
 
@@ -411,6 +412,10 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
     if isinstance(res, Frame):
         res.add_callback(partial(res.__exit__, None, None, None))
         res.__enter__()
+    elif isinstance(res, PrimExpr):
+        T.evaluate(res)
+    elif isinstance(res, (int, bool)):
+        T.evaluate(tvm.tir.const(res))
 
 
 @dispatch.register(token="tir", type_name="If")
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index d7a3a406e3..f1d68ee438 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1275,16 +1275,17 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
 }
 
 Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
-  if (auto* call = op->value.as<CallNode>()) {
-    if (call->op.same_as(builtin::assume())) {
-      Doc doc;
-      doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")";
-      return doc;
-    }
-  }
-
+  // When parsing TVMScript, a PrimExpr that occurs as a statement is
+  // automatically wrapped in `tir::Evaluate`.  Therefore, when
+  // printing, it's only necessary to print the value.  For
+  // readability, though, we still print T.evaluate() when the
+  // expression is something other than a call node.
   Doc doc;
-  doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
+  if (op->value.as<CallNode>()) {
+    doc << Print(op->value);
+  } else {
+    doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
+  }
   return doc;
 }
 
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index f22e61e183..b8c8379c8a 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3458,6 +3458,15 @@ def bool_cast():
     return func
 
 
+def implicit_evaluate():
+    @T.prim_func
+    def func(A: T.Buffer[1, "int32"]):
+        T.evaluate(T.assume(A[0] == 5))
+        A[0] = 10
+
+    return func
+
+
 ir_generator = tvm.testing.parameter(
     opt_gemm_normalize,
     opt_gemm_lower,
@@ -3509,6 +3518,7 @@ ir_generator = tvm.testing.parameter(
     bool_primitive,
     bool_cast,
     return_none,
+    implicit_evaluate,
 )
 
 
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 16f1cb0494..a39354b955 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -402,5 +402,31 @@ def test_int64_loop():
     assert_structural_equal(int64_grid, int64_grid_expanded)
 
 
+def test_implicit_evaluate_assume():
+    @T.prim_func
+    def explicit(A: T.Buffer[1, "int32"]):
+        T.evaluate(T.assume(A[0] == 5))
+        A[0] = 10
+
+    @T.prim_func
+    def implicit(A: T.Buffer[1, "int32"]):
+        T.assume(A[0] == 5)
+        A[0] = 10
+
+    assert_structural_equal(implicit, explicit)
+
+
+def test_implicit_evaluate_call_extern():
+    @T.prim_func
+    def explicit(A: T.Buffer[1, "int32"]):
+        T.evaluate(T.call_extern("extern_func", A.data, dtype="int32"))
+
+    @T.prim_func
+    def implicit(A: T.Buffer[1, "int32"]):
+        T.call_extern("extern_func", A.data, dtype="int32")
+
+    assert_structural_equal(implicit, explicit)
+
+
 if __name__ == "__main__":
     tvm.testing.main()