You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/09/21 13:05:01 UTC
[tvm] branch main updated: [TVMScript][Fix] Correct round-trip of explicit root block (#12673)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fdc6894b7d [TVMScript][Fix] Correct round-trip of explicit root block (#12673)
fdc6894b7d is described below
commit fdc6894b7dae096d0ec983292aa0a2a475843f56
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Sep 21 08:04:53 2022 -0500
[TVMScript][Fix] Correct round-trip of explicit root block (#12673)
* [TVMScript][Fix] Correct round-trip of explicit root block
Prior to this commit, when converting TIR to TVMScript, the root
`tir::Block` is typically hidden. When parsing, however,
`tvm::tir::ScriptComplete` will wrap the function body in a root block
if the primfunc if the contains at least one block and does not
already have a root block. As a result, if the root block is the only
block present, it would be stripped by a round-trip.
This commit tightens the condition for hiding the root `tir::Block`
when converting to TVMScript, so that it is printed in cases where
the autocompleter would reinsert it when parsing.
---
include/tvm/tir/stmt_functor.h | 32 +++++++++++++++
src/printer/tvmscript_printer.cc | 50 +++++++++++++++++++----
src/tir/ir/script/script_complete.cc | 37 +++++++++++++----
tests/python/unittest/test_tvmscript_roundtrip.py | 21 ++++++++++
4 files changed, 123 insertions(+), 17 deletions(-)
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index 49b1f28e5d..2fc3b9678b 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -427,6 +427,38 @@ TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
* \return The renewed func.
*/
TVM_DLL PrimFunc RenewDefs(const PrimFunc& func);
+
+/*!
+ * \brief Check if the statement contains the specified node type.
+ *
+ * This utility potentially walks the entire statement, and should
+ * therefore not be used if it could otherwise be merged with another
+ * pass.
+ *
+ * \param stmt The statement to be searched
+ * \return Whether stmt contains Node
+ */
+template <typename Node, typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
+bool ContainsNode(const Stmt& stmt) {
+ struct Visitor : StmtVisitor {
+ // Early bail-out, if we already found the node.
+ void VisitStmt(const Stmt& stmt) {
+ if (contains_node) {
+ return;
+ }
+ StmtVisitor::VisitStmt(stmt);
+ }
+
+ void VisitStmt_(const Node* block) override { contains_node = true; }
+
+ bool contains_node{false};
+ };
+
+ Visitor visitor;
+ visitor(stmt);
+ return visitor.contains_node;
+}
+
} // namespace tir
} // namespace tvm
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 2072037358..936ac7580f 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1664,19 +1664,53 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
}
// print body
body << "# body" << Doc::NewLine();
- if (op->body->IsInstance<BlockRealizeNode>() &&
- op->body.as<BlockRealizeNode>()->iter_values.empty()) {
- const BlockNode* block = op->body.as<BlockRealizeNode>()->block.get();
- if (block->annotations.empty() && !ContainsOptionalInfo(GetRef<Stmt>(block))) {
- // Skip print root block
- body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
- body << PrintBlockBody(block);
+
+ Optional<Block> elided_root_block_body = [&]() -> Optional<Block> {
+ auto block_realize = op->body.as<BlockRealizeNode>();
+ if (!block_realize || block_realize->iter_values.size()) {
+ return NullOpt;
+ }
+
+ const auto& block = block_realize->block;
+ if (block->annotations.size() || ContainsOptionalInfo(block)) {
+ return NullOpt;
+ }
+
+ // The autocomplete might recognize the body itself as being a
+ // root block, and fail to insert it.
+ bool autocomplete_would_insert_root_block = [&]() -> bool {
+ if (block->alloc_buffers.size()) {
+ return true;
+ }
+
+ auto* block_realize = block->body.as<BlockRealizeNode>();
+ if (block_realize && block_realize->block->iter_vars.size()) {
+ return true;
+ }
+ if (!block_realize && ContainsNode<BlockRealizeNode>(block->body)) {
+ return true;
+ }
+ return false;
+ }();
+
+ if (autocomplete_would_insert_root_block) {
+ return block;
} else {
- body << PrintBody(op->body);
+ return NullOpt;
}
+ }();
+
+ if (elided_root_block_body) {
+ // Skip printing of root block in cases where tvm::tir::ScriptComplete
+ // would re-insert it.
+ body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
+ body << PrintBlockBody(elided_root_block_body.value().get());
} else {
+ // If this is a non-root block, or is an unskippable root block,
+ // just print it without skipping.
body << PrintBody(op->body);
}
+
// print func attrs
Doc header_attr;
if (primFunc->attrs.defined()) {
diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index b11ca6650a..c44083108d 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -105,16 +105,35 @@ PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
for (const auto& alloc : root_allocates) {
buffer_var_map.Set(alloc->data, alloc);
}
- bool contain_root = root_allocates.empty() && func->body->IsInstance<BlockRealizeNode>() &&
- Downcast<BlockRealize>(func->body)->block->iter_vars.empty();
- ScriptCompleter script_completer(&buffer_var_map);
- // generate surrounding loops automatically
- Stmt res = script_completer(func->body);
- // generate root block automatically
- if ((script_completer.contains_block || root_allocates.size()) && !contain_root) {
- res = Block({}, {}, {}, "root", res, NullOpt, root_allocates);
- res = BlockRealize({}, Bool(true), Downcast<Block>(res));
+
+ Stmt res = func->body;
+
+ // Generate root block automatically. This is done before
+ // ScriptCompleter, in order to fill the root block's T.reads() and
+ // T.writes() annotations, as if it had been explicitly written.
+ bool should_insert_root = [&]() -> bool {
+ if (root_allocates.size()) {
+ return true;
+ }
+ auto* block_realize = func->body.as<BlockRealizeNode>();
+ if (block_realize && block_realize->block->iter_vars.size()) {
+ return true;
+ }
+ if (!block_realize && ContainsNode<BlockRealizeNode>(func->body)) {
+ return true;
+ }
+ return false;
+ }();
+
+ if (should_insert_root) {
+ Block root_block({}, {}, {}, "root", std::move(res), NullOpt, root_allocates);
+ res = BlockRealize({}, Bool(true), std::move(root_block));
}
+
+ // generate surrounding loops automatically
+ ScriptCompleter script_completer(&buffer_var_map);
+ res = script_completer(std::move(res));
+
if (func->body.same_as(res)) {
return func;
} else {
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 1f5871b488..e139d2111b 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3142,6 +3142,25 @@ def func_root_attr():
return func_root_attr
+def func_trivial_root_block():
+ @T.prim_func
+ def func(A: T.Buffer[1, "int32"]):
+ with T.block("root"):
+ A[0] = 0
+
+ return func
+
+
+def func_nested_root_block():
+ @T.prim_func
+ def func(A: T.Buffer[1, "int32"]):
+ with T.block("root"):
+ with T.block("block"):
+ A[0] = 0
+
+ return func
+
+
def func_T_ptr_let_statement():
@T.prim_func
def func_T_ptr_let_statement(
@@ -3418,6 +3437,8 @@ ir_generator = tvm.testing.parameter(
func_with_target_spec_by_config,
func_with_target_spec_by_str,
func_root_attr,
+ func_trivial_root_block,
+ func_nested_root_block,
func_T_ptr_let_statement,
func_T_ptr_allocate,
llvm_intrin_call,