You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/09/21 13:05:01 UTC

[tvm] branch main updated: [TVMScript][Fix] Correct round-trip of explicit root block (#12673)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fdc6894b7d [TVMScript][Fix] Correct round-trip of explicit root block (#12673)
fdc6894b7d is described below

commit fdc6894b7dae096d0ec983292aa0a2a475843f56
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Sep 21 08:04:53 2022 -0500

    [TVMScript][Fix] Correct round-trip of explicit root block (#12673)
    
    * [TVMScript][Fix] Correct round-trip of explicit root block
    
    Prior to this commit, when converting TIR to TVMScript, the root
    `tir::Block` is typically hidden.  When parsing, however,
    `tvm::tir::ScriptComplete` will wrap the function body in a root block
    if the primfunc if the contains at least one block and does not
    already have a root block.  As a result, if the root block is the only
    block present, it would be stripped by a round-trip.
    
    This commit tightens the condition for hiding the root `tir::Block`
    when converting to TVMScript, so that it is printed in cases where
    the autocompleter would reinsert it when parsing.
---
 include/tvm/tir/stmt_functor.h                    | 32 +++++++++++++++
 src/printer/tvmscript_printer.cc                  | 50 +++++++++++++++++++----
 src/tir/ir/script/script_complete.cc              | 37 +++++++++++++----
 tests/python/unittest/test_tvmscript_roundtrip.py | 21 ++++++++++
 4 files changed, 123 insertions(+), 17 deletions(-)

diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index 49b1f28e5d..2fc3b9678b 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -427,6 +427,38 @@ TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
  * \return The renewed func.
  */
 TVM_DLL PrimFunc RenewDefs(const PrimFunc& func);
+
+/*!
+ * \brief Check if the statement contains the specified node type.
+ *
+ * This utility potentially walks the entire statement, and should
+ * therefore not be used if it could otherwise be merged with another
+ * pass.
+ *
+ * \param stmt The statement to be searched
+ * \return Whether stmt contains Node
+ */
+template <typename Node, typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
+bool ContainsNode(const Stmt& stmt) {
+  struct Visitor : StmtVisitor {
+    // Early bail-out, if we already found the node.
+    void VisitStmt(const Stmt& stmt) {
+      if (contains_node) {
+        return;
+      }
+      StmtVisitor::VisitStmt(stmt);
+    }
+
+    void VisitStmt_(const Node* block) override { contains_node = true; }
+
+    bool contains_node{false};
+  };
+
+  Visitor visitor;
+  visitor(stmt);
+  return visitor.contains_node;
+}
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 2072037358..936ac7580f 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1664,19 +1664,53 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
   }
   // print body
   body << "# body" << Doc::NewLine();
-  if (op->body->IsInstance<BlockRealizeNode>() &&
-      op->body.as<BlockRealizeNode>()->iter_values.empty()) {
-    const BlockNode* block = op->body.as<BlockRealizeNode>()->block.get();
-    if (block->annotations.empty() && !ContainsOptionalInfo(GetRef<Stmt>(block))) {
-      // Skip print root block
-      body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
-      body << PrintBlockBody(block);
+
+  Optional<Block> elided_root_block_body = [&]() -> Optional<Block> {
+    auto block_realize = op->body.as<BlockRealizeNode>();
+    if (!block_realize || block_realize->iter_values.size()) {
+      return NullOpt;
+    }
+
+    const auto& block = block_realize->block;
+    if (block->annotations.size() || ContainsOptionalInfo(block)) {
+      return NullOpt;
+    }
+
+    // The autocomplete might recognize the body itself as being a
+    // root block, and fail to insert it.
+    bool autocomplete_would_insert_root_block = [&]() -> bool {
+      if (block->alloc_buffers.size()) {
+        return true;
+      }
+
+      auto* block_realize = block->body.as<BlockRealizeNode>();
+      if (block_realize && block_realize->block->iter_vars.size()) {
+        return true;
+      }
+      if (!block_realize && ContainsNode<BlockRealizeNode>(block->body)) {
+        return true;
+      }
+      return false;
+    }();
+
+    if (autocomplete_would_insert_root_block) {
+      return block;
     } else {
-      body << PrintBody(op->body);
+      return NullOpt;
     }
+  }();
+
+  if (elided_root_block_body) {
+    // Skip printing of root block in cases where tvm::tir::ScriptComplete
+    // would re-insert it.
+    body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
+    body << PrintBlockBody(elided_root_block_body.value().get());
   } else {
+    // If this is a non-root block, or is an unskippable root block,
+    // just print it without skipping.
     body << PrintBody(op->body);
   }
+
   // print func attrs
   Doc header_attr;
   if (primFunc->attrs.defined()) {
diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index b11ca6650a..c44083108d 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -105,16 +105,35 @@ PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
   for (const auto& alloc : root_allocates) {
     buffer_var_map.Set(alloc->data, alloc);
   }
-  bool contain_root = root_allocates.empty() && func->body->IsInstance<BlockRealizeNode>() &&
-                      Downcast<BlockRealize>(func->body)->block->iter_vars.empty();
-  ScriptCompleter script_completer(&buffer_var_map);
-  // generate surrounding loops automatically
-  Stmt res = script_completer(func->body);
-  // generate root block automatically
-  if ((script_completer.contains_block || root_allocates.size()) && !contain_root) {
-    res = Block({}, {}, {}, "root", res, NullOpt, root_allocates);
-    res = BlockRealize({}, Bool(true), Downcast<Block>(res));
+
+  Stmt res = func->body;
+
+  // Generate root block automatically.  This is done before
+  // ScriptCompleter, in order to fill the root block's T.reads() and
+  // T.writes() annotations, as if it had been explicitly written.
+  bool should_insert_root = [&]() -> bool {
+    if (root_allocates.size()) {
+      return true;
+    }
+    auto* block_realize = func->body.as<BlockRealizeNode>();
+    if (block_realize && block_realize->block->iter_vars.size()) {
+      return true;
+    }
+    if (!block_realize && ContainsNode<BlockRealizeNode>(func->body)) {
+      return true;
+    }
+    return false;
+  }();
+
+  if (should_insert_root) {
+    Block root_block({}, {}, {}, "root", std::move(res), NullOpt, root_allocates);
+    res = BlockRealize({}, Bool(true), std::move(root_block));
   }
+
+  // generate surrounding loops automatically
+  ScriptCompleter script_completer(&buffer_var_map);
+  res = script_completer(std::move(res));
+
   if (func->body.same_as(res)) {
     return func;
   } else {
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 1f5871b488..e139d2111b 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3142,6 +3142,25 @@ def func_root_attr():
     return func_root_attr
 
 
+def func_trivial_root_block():
+    @T.prim_func
+    def func(A: T.Buffer[1, "int32"]):
+        with T.block("root"):
+            A[0] = 0
+
+    return func
+
+
+def func_nested_root_block():
+    @T.prim_func
+    def func(A: T.Buffer[1, "int32"]):
+        with T.block("root"):
+            with T.block("block"):
+                A[0] = 0
+
+    return func
+
+
 def func_T_ptr_let_statement():
     @T.prim_func
     def func_T_ptr_let_statement(
@@ -3418,6 +3437,8 @@ ir_generator = tvm.testing.parameter(
     func_with_target_spec_by_config,
     func_with_target_spec_by_str,
     func_root_attr,
+    func_trivial_root_block,
+    func_nested_root_block,
     func_T_ptr_let_statement,
     func_T_ptr_allocate,
     llvm_intrin_call,