You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/01 20:59:51 UTC
[tvm] branch main updated: [TensorIR] introduce Block and
BlockRealize (#312) (#7553)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 057a673 [TensorIR] introduce Block and BlockRealize (#312) (#7553)
057a673 is described below
commit 057a673986f0ab50c6be8335339b2beb01e3a1f4
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Tue Mar 2 04:59:32 2021 +0800
[TensorIR] introduce Block and BlockRealize (#312) (#7553)
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Tianqi Chen <tq...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Tianqi Chen <tq...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
---
include/tvm/tir/stmt.h | 248 +++++++++++++++++++++++++++++++-
include/tvm/tir/stmt_functor.h | 8 ++
python/tvm/tir/__init__.py | 1 +
python/tvm/tir/stmt.py | 162 +++++++++++++++++++++
src/tir/ir/stmt.cc | 219 ++++++++++++++++++++++++++++
src/tir/ir/stmt_functor.cc | 109 ++++++++++++++
tests/cpp/ir_functor_test.cc | 41 ++++++
tests/python/unittest/test_tir_nodes.py | 82 +++++++++++
8 files changed, 869 insertions(+), 1 deletion(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 093d49c..074bcdd 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -862,7 +862,7 @@ class For : public Stmt {
};
/*!
- * \brief A prefetch hint for abuffer
+ * \brief A prefetch hint for a buffer
*/
class PrefetchNode : public StmtNode {
public:
@@ -905,6 +905,252 @@ class Prefetch : public Stmt {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};
+/*!
+ * \brief Representing the region of multi-dimensional buffer access.
+ */
+class BufferRegionNode : public Object {
+ public:
+ /*! \brief The buffer of the buffer region. */
+ Buffer buffer;
+ /*! \brief The region array of the buffer region. */
+ Array<Range> region;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("buffer", &buffer);
+ v->Visit("region", ®ion);
+ }
+
+ bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
+ return equal(buffer, other->buffer) && equal(region, other->region);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(buffer);
+ hash_reduce(region);
+ }
+
+ static constexpr const char* _type_key = "tir.BufferRegion";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
+ static constexpr const bool _type_has_method_shash_reduce = true;
+ TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object);
+};
+
+/*!
+ * \brief Managed reference to BufferRegionNode.
+ * \sa BufferRegionNode
+ */
+class BufferRegion : public ObjectRef {
+ public:
+ TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
+
+ /*!
+ * \brief Create a BufferRegion which is full region of the given buffer..
+ * \param buffer The buffer to generate full BufferRegion.
+ * \return The BufferRegion which covers all region of the given buffer
+ */
+ TVM_DLL static BufferRegion FullRegion(Buffer buffer);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
+};
+
+/*!
+ * \brief Match introduces a constraint that the source buffer region can be remapped to the data
+ * layout specified by the buffer field. The constraint can be checked in later part of lowering (or
+ * optionally during runtime).
+ *
+ * MatchBufferRegion provides a mechanism to represent data layout and compactness constraints in
+ * low-level hardware primitives in the IR and defer the check after the sequence of
+ * transformations.
+ */
+class MatchBufferRegionNode : public Object {
+ public:
+ /*! \brief The target buffer. */
+ Buffer buffer;
+ /*! \brief The source buffer region. */
+ BufferRegion source;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("buffer", &buffer);
+ v->Visit("source", &source);
+ }
+
+ bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const {
+ return equal(buffer, other->buffer) && equal(source, other->source);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(buffer);
+ hash_reduce(source);
+ }
+
+ static constexpr const char* _type_key = "tir.MatchBufferRegion";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
+ static constexpr const bool _type_has_method_shash_reduce = true;
+ TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MatchBufferRegionNode.
+ * \sa MatchBufferRegionNode
+ */
+class MatchBufferRegion : public ObjectRef {
+ public:
+ TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
+};
+
+/*!
+ * \brief A block is a basic schedule unit in TIR.
+ * \note Block's body is parameterized by iter vars.
+ * \code
+ *
+ * with tir.block([extent0, extent1, ...], name) as [v0, v1, ...]:
+ * tir.bind(v0, value0)
+ * tir.bind(v1, value1)
+ * ...
+ * tir.reads([buffer0[start:end, ...], ...])
+ * tir.writes([buffer1[start:end, ...], ...])
+ * tir.where(predicate)
+ * buffer2 = tir.alloc_buffer(shape, dtype)
+ * buffer3 = tir.match_buffer(source_buffer[start:end, ...])
+ * tir.attr({attr_key: attr_value, ...})
+ * with tir.init():
+ * // init body
+ * // body
+ *
+ * \endcode
+ */
+class BlockNode : public StmtNode {
+ public:
+ /*! \brief The variables of the block. */
+ Array<IterVar> iter_vars;
+ /*! \brief The read buffer regions of the block. */
+ Array<BufferRegion> reads;
+ /*! \brief The write buffer regions of the block. */
+ Array<BufferRegion> writes;
+ /*! \brief The name_hint of the block. */
+ String name_hint;
+ /*! \brief The body of the block. */
+ Stmt body;
+ /*!
+ * \brief The init statement is executed during the first iteration of reduction loops in a
+ * reduction block. The optional init field allows us to represent initialization and
+ * reduction update in a single block and transform them collectively.
+ * We also provide primitives to decompose the init into a separate block during scheduling.
+ * Init field is `NullOpt` if there is no reduction iter_vars
+ */
+ Optional<Stmt> init;
+ /*! \brief The buffer allocated in the block. */
+ Array<Buffer> alloc_buffers;
+ /*! \brief The match buffer regions. */
+ Array<MatchBufferRegion> match_buffers;
+ /*! \brief The annotation of the block. */
+ Map<String, ObjectRef> annotations;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("iter_vars", &iter_vars);
+ v->Visit("reads", &reads);
+ v->Visit("writes", &writes);
+ v->Visit("name_hint", &name_hint);
+ v->Visit("body", &body);
+ v->Visit("init", &init);
+ v->Visit("alloc_buffers", &alloc_buffers);
+ v->Visit("match_buffers", &match_buffers);
+ v->Visit("annotations", &annotations);
+ }
+
+ bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
+ // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
+ return equal.DefEqual(iter_vars, other->iter_vars) &&
+ equal(alloc_buffers, other->alloc_buffers) &&
+ equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
+ equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
+ equal(annotations, other->annotations);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce.DefHash(iter_vars);
+ hash_reduce(alloc_buffers);
+ hash_reduce(match_buffers);
+ hash_reduce(reads);
+ hash_reduce(writes);
+ hash_reduce(body);
+ hash_reduce(init);
+ hash_reduce(annotations);
+ }
+
+ static constexpr const char* _type_key = "tir.Block";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode);
+};
+
+/*!
+ * \brief Managed reference to BlockNode.
+ * \sa BlockNode
+ */
+class Block : public Stmt {
+ public:
+ TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
+ Array<BufferRegion> writes, String name_hint, Stmt body,
+ Optional<Stmt> init = NullOpt,
+ Array<Buffer> alloc_buffers = Array<Buffer>(),
+ Array<MatchBufferRegion> match_buffers = Array<MatchBufferRegion>(),
+ Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
+ Span span = Span());
+
+ TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode);
+};
+
+/*!
+ * \brief A block realization node represents execution of the block at the binding values.
+ */
+class BlockRealizeNode : public StmtNode {
+ public:
+ /*! \brief The corresponding values of the iter vars. */
+ Array<PrimExpr> iter_values;
+ /*!
+ * \brief The predicate of the block realization, the block will only be executed when the
+ * predicate is true.
+ */
+ PrimExpr predicate;
+ /*! \brief The block to be realized. */
+ Block block;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("iter_values", &iter_values);
+ v->Visit("predicate", &predicate);
+ v->Visit("block", &block);
+ }
+
+ bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
+ return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
+ equal(block, other->block);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(iter_values);
+ hash_reduce(predicate);
+ hash_reduce(block);
+ }
+
+ static constexpr const char* _type_key = "tir.BlockRealize";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode);
+};
+
+/*!
+ * \brief Managed reference to BlockRealizeNode
+ * \sa BlockRealizeNode
+ */
+class BlockRealize : public Stmt {
+ public:
+ TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
+ Span span = Span());
+
+ TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
+};
+
/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index 0f4238d..e53b02d 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -96,6 +96,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
@@ -119,6 +121,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
+ IR_STMT_FUNCTOR_DISPATCH(BlockNode);
+ IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
return vtable;
}
};
@@ -158,6 +162,8 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const PrefetchNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
+ void VisitStmt_(const BlockNode* op) override;
+ void VisitStmt_(const BlockRealizeNode* op) override;
};
/*!
@@ -249,6 +255,8 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const PrefetchNode* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const EvaluateNode* op) override;
+ Stmt VisitStmt_(const BlockNode* op) override;
+ Stmt VisitStmt_(const BlockRealizeNode* op) override;
/*!
* \brief Alternative advance method for SeqStmtNode.
*
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 324c4da..ad91eab 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -31,6 +31,7 @@ from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
+from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
from .function import PrimFunc
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 5882dca..e4f1ac9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -26,11 +26,15 @@ Each statement node have subfields that can be visited from python side.
assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a)
"""
+from typing import List, Optional, Mapping
from enum import IntEnum
import tvm._ffi
from tvm.runtime import Object
+from tvm.ir import Span, PrimExpr, Range
from . import _ffi_api
+from .buffer import Buffer
+from .expr import IterVar
class Stmt(Object):
@@ -429,6 +433,164 @@ class Prefetch(Stmt):
self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span)
+@tvm._ffi.register_object("tir.BufferRegion")
+class BufferRegion(Object):
+ """BufferRegion node.
+
+ Parameters
+ ----------
+ buffer : Buffer
+ The buffer of the buffer region
+
+ region : List[Range]
+ The region array of the buffer region
+ """
+
+ buffer: Buffer
+ region: List[Range]
+
+ def __init__(self, buffer: Buffer, region: List[Range]):
+ self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region)
+
+
+@tvm._ffi.register_object("tir.MatchBufferRegion")
+class MatchBufferRegion(Object):
+ """MatchBufferRegion node.
+
+ Parameters
+ ----------
+ buffer : Buffer
+ The target buffer
+
+ source : BufferRegion
+ The region of source buffer
+ """
+
+ buffer: Buffer
+ source: BufferRegion
+
+ def __init__(self, buffer: Buffer, source: BufferRegion):
+ self.__init_handle_by_constructor__(_ffi_api.MatchBufferRegion, buffer, source)
+
+
+@tvm._ffi.register_object("tir.Block")
+class Block(Stmt):
+ """Block node.
+
+ Parameters
+ ----------
+ iter_vars : List[IterVar]
+ The block Variable.
+
+ reads : List[BufferRegion]
+ The read buffer regions of the block.
+
+ writes: List[BufferRegion]
+ The write buffer regions of the block.
+
+ name_hint: str
+ the name_hint of the block.
+
+ body: Stmt
+ The body of the block.
+
+ init: Optional[Stmt]
+ The init block of the reduction block
+
+ alloc_buffers: Optional[list[Buffer]]
+ The buffer allocations
+
+ match_buffers: Optional[List[MatchBufferRegion]]
+ The subregion buffer match
+
+ annotations: Optional[Mapping[str, Object]]
+ Additional annotation hints.
+
+ span : Optional[Span]
+ The location of this block in the source code.
+ """
+
+ iter_vars: List[IterVar]
+ reads: List[BufferRegion]
+ writes: List[BufferRegion]
+ name_hint: str
+ body: Stmt
+ init: Optional[Stmt]
+ alloc_buffers: Optional[List[Buffer]]
+ match_buffers: Optional[List[MatchBufferRegion]]
+ annotations: Optional[Mapping[str, Object]]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ iter_vars: List[IterVar],
+ reads: List[BufferRegion],
+ writes: List[BufferRegion],
+ name_hint: str,
+ body: Stmt,
+ init: Optional[Stmt] = None,
+ alloc_buffers: Optional[List[Buffer]] = None,
+ match_buffers: Optional[List[MatchBufferRegion]] = None,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ if alloc_buffers is None:
+ alloc_buffers = []
+ if match_buffers is None:
+ match_buffers = []
+ if annotations is None:
+ annotations = {}
+ self.__init_handle_by_constructor__(
+ _ffi_api.Block,
+ iter_vars,
+ reads,
+ writes,
+ name_hint,
+ body,
+ init,
+ alloc_buffers,
+ match_buffers,
+ annotations,
+ span,
+ )
+
+
+@tvm._ffi.register_object("tir.BlockRealize")
+class BlockRealize(Stmt):
+ """BlockRealize node.
+
+ Parameters
+ ----------
+ iter_values : List[PrimExpr]
+ The binding values of the block var.
+
+ predicate : PrimExpr
+ The predicate of the block.
+
+ block : Block
+ The block to realize
+
+ span : Optional[Span]
+ The location of this block_realize in the source code.
+ """
+
+ iter_values: List[PrimExpr]
+ predicate: PrimExpr
+ block: Block
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ iter_values: List[PrimExpr],
+ predicate: PrimExpr,
+ block: Block,
+ span: Optional[Span] = None,
+ ):
+ self.__init_handle_by_constructor__(
+ _ffi_api.BlockRealize, iter_values, predicate, block, span
+ )
+
+
def stmt_seq(*args):
"""Make sequence of statements
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 92dc387..e54be43 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -598,6 +598,225 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "}\n";
});
+// BufferRegion
+BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
+ ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>();
+ node->buffer = std::move(buffer);
+ node->region = std::move(region);
+ data_ = std::move(node);
+}
+
+BufferRegion BufferRegion::FullRegion(Buffer buffer) {
+ Array<Range> region;
+ for (PrimExpr extent : buffer->shape) {
+ region.push_back(Range::FromMinExtent(0, extent));
+ }
+ return BufferRegion(buffer, region);
+}
+
+TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<Range> region) {
+ return BufferRegion(buffer, region);
+});
+
+TVM_REGISTER_NODE_TYPE(BufferRegionNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<BufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferRegionNode*>(node.get());
+ p->stream << op->buffer->name;
+ p->stream << "[";
+ for (size_t i = 0; i < op->region.size(); ++i) {
+ const auto& range = op->region[i];
+ p->Print(range->min);
+ if (!is_one(range->extent)) {
+ p->stream << ":";
+ p->Print(range->min + range->extent);
+ }
+ if (i != op->region.size() - 1) p->stream << ", ";
+ }
+ p->stream << "]";
+ });
+
+// MatchBufferRegion
+MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
+ ObjectPtr<MatchBufferRegionNode> node = make_object<MatchBufferRegionNode>();
+ node->buffer = std::move(buffer);
+ node->source = std::move(source);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) {
+ return MatchBufferRegion(buffer, source);
+});
+
+TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
+ p->PrintIndent();
+ p->stream << op->buffer->name << " = match_buffer_region(";
+ p->Print(op->source);
+ p->stream << ")\n";
+ });
+
+// Block
+Block::Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, Array<BufferRegion> writes,
+ String name_hint, Stmt body, Optional<Stmt> init, Array<Buffer> alloc_buffers,
+ Array<MatchBufferRegion> match_buffers, Map<String, ObjectRef> annotations,
+ Span span) {
+ ObjectPtr<BlockNode> node = make_object<BlockNode>();
+ node->iter_vars = std::move(iter_vars);
+ node->reads = std::move(reads);
+ node->writes = std::move(writes);
+ node->name_hint = std::move(name_hint);
+ node->body = std::move(body);
+ node->init = std::move(init);
+ node->alloc_buffers = std::move(alloc_buffers);
+ node->match_buffers = std::move(match_buffers);
+ node->annotations = std::move(annotations);
+ node->span = std::move(span);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.Block")
+ .set_body_typed([](Array<IterVar> iter_vars, Array<BufferRegion> reads,
+ Array<BufferRegion> writes, String name_hint, Stmt body, Optional<Stmt> init,
+ Array<Buffer> alloc_buffers, Array<MatchBufferRegion> match_buffers,
+ Map<String, ObjectRef> annotations, Span span) {
+ return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers,
+ annotations, span);
+ });
+
+TVM_REGISTER_NODE_TYPE(BlockNode);
+
+void PrintBlockTitle(const BlockNode* op, ReprPrinter* p) {
+ p->stream << "block " << op->name_hint << "(";
+ for (size_t i = 0; i < op->iter_vars.size(); i++) {
+ p->Print(op->iter_vars[i]);
+ if (i < op->iter_vars.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+}
+
+void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) {
+ // print read/write regions
+ p->PrintIndent();
+ p->stream << "reads(";
+ p->Print(op->reads);
+ p->stream << ")\n";
+ p->PrintIndent();
+ p->stream << "writes(";
+ p->Print(op->writes);
+ p->stream << ")\n";
+ // Print alloc_buffers
+ for (const auto& alloc_buf : op->alloc_buffers) {
+ p->PrintIndent();
+ p->stream << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "[";
+ for (size_t i = 0; i < alloc_buf->shape.size(); ++i) {
+ if (i > 0) p->stream << ", ";
+ p->Print(alloc_buf->shape[i]);
+ }
+ p->stream << "])\n";
+ }
+ // Print match_buffer_regions
+ for (const auto& match_buf : op->match_buffers) {
+ p->Print(match_buf);
+ }
+ if (!op->annotations.empty()) {
+ p->PrintIndent();
+ p->stream << "annotations(" << op->annotations << ")\n";
+ }
+}
+
+void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
+ // Print init
+ if (op->init.defined()) {
+ p->PrintIndent();
+ p->stream << "with init() {\n";
+ p->indent += 2;
+ p->Print(op->init.value());
+ p->indent -= 2;
+ p->PrintIndent();
+ p->stream << "}\n";
+ }
+ // Print body
+ p->Print(op->body);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<BlockNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BlockNode*>(node.get());
+ p->PrintIndent();
+ PrintBlockTitle(op, p);
+ p->stream << "{\n";
+ p->indent += 2;
+
+ // Print block elements (e.g. reads/writes, etc)
+ PrintBlockSignature(op, p);
+ // Print block init and body
+ PrintBlockBody(op, p);
+
+ p->indent -= 2;
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
+
+// BlockRealize
+BlockRealize::BlockRealize(Array<PrimExpr> values, PrimExpr predicate, Block block, Span span) {
+ CHECK_EQ(block->iter_vars.size(), values.size())
+ << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values";
+ CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression";
+ ObjectPtr<BlockRealizeNode> node = make_object<BlockRealizeNode>();
+ node->iter_values = std::move(values);
+ node->predicate = std::move(predicate);
+ node->block = std::move(block);
+ node->span = std::move(span);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.BlockRealize")
+ .set_body_typed([](Array<PrimExpr> iter_values, PrimExpr predicate, Block block, Span span) {
+ return BlockRealize(iter_values, predicate, block, span);
+ });
+
+TVM_REGISTER_NODE_TYPE(BlockRealizeNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<BlockRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BlockRealizeNode*>(node.get());
+ auto* block_op = op->block.get();
+ p->PrintIndent();
+ PrintBlockTitle(block_op, p);
+ p->stream << "{\n";
+ p->indent += 2;
+
+ // Print binding iter_values
+ for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
+ p->PrintIndent();
+ p->stream << "bind(";
+ p->Print(block_op->iter_vars[i]->var);
+ p->stream << ", ";
+ p->Print(op->iter_values[i]);
+ p->stream << ")\n";
+ }
+ // Print predicate
+ if (!is_one(op->predicate)) {
+ p->PrintIndent();
+ p->stream << "where(";
+ p->Print(op->predicate);
+ p->stream << ")\n";
+ }
+ // Print block elements (e.g. reads/writes, etc)
+ PrintBlockSignature(block_op, p);
+ // Print block init and body
+ PrintBlockBody(block_op, p);
+
+ p->indent -= 2;
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
+
PrimExpr TypeAnnotation(DataType dtype, Span span) {
static auto op = Op::Get("tir.type_annotation");
return tir::Call(dtype, op, {}, span);
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index e4cc1b7..f05dc71 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -112,6 +112,35 @@ void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); }
+void StmtVisitor::VisitStmt_(const BlockNode* op) {
+ auto fvisit_buffer_region = [this](const BufferRegion& s) {
+ for (const auto& range : s->region) {
+ this->VisitExpr(range->min);
+ this->VisitExpr(range->extent);
+ }
+ };
+ VisitArray(op->iter_vars, [this](const IterVar& iter_var) {
+ this->VisitExpr(iter_var->dom->min);
+ this->VisitExpr(iter_var->dom->extent);
+ });
+ VisitArray(op->reads, fvisit_buffer_region);
+ VisitArray(op->writes, fvisit_buffer_region);
+ VisitArray(op->match_buffers,
+ [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
+ fvisit_buffer_region(match_buffer_region->source);
+ });
+ if (op->init.defined()) {
+ this->VisitStmt(op->init.value());
+ }
+ this->VisitStmt(op->body);
+}
+
+void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) {
+ VisitArray(op->iter_values, [this](const PrimExpr& e) { this->VisitExpr(e); });
+ this->VisitExpr(op->predicate);
+ this->VisitStmt(op->block);
+}
+
class StmtMutator::Internal {
public:
/*!
@@ -150,6 +179,20 @@ class StmtMutator::Internal {
}
}
+ static Array<IterVar> Mutate(StmtMutator* self, const Array<IterVar>& arr) {
+ auto fmutate = [self](const IterVar& iter_var) {
+ PrimExpr min = self->VisitExpr(iter_var->dom->min);
+ PrimExpr extent = self->VisitExpr(iter_var->dom->extent);
+ if (min.same_as(iter_var->dom->min) && extent.same_as(iter_var->dom->extent)) {
+ return iter_var;
+ } else {
+ return IterVar(Range(min, extent), iter_var->var, iter_var->iter_type,
+ iter_var->thread_tag);
+ }
+ };
+ return MutateArray(self, arr, fmutate);
+ }
+
static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
return MutateArray(self, arr, fmutate);
@@ -172,6 +215,31 @@ class StmtMutator::Internal {
};
return MutateArray(self, arr, fmutate);
}
+
+ static Array<BufferRegion> Mutate(StmtMutator* self, const Array<BufferRegion>& arr) {
+ auto fmutate = [self](const BufferRegion& buffer_region) {
+ Array<Range> region = Mutate(self, buffer_region->region);
+ if (region.same_as(buffer_region->region)) {
+ return buffer_region;
+ } else {
+ return BufferRegion(buffer_region->buffer, region);
+ }
+ };
+ return MutateArray(self, arr, fmutate);
+ }
+
+ static Array<MatchBufferRegion> Mutate(StmtMutator* self, const Array<MatchBufferRegion>& arr) {
+ auto fmutate = [self](const MatchBufferRegion& match_buffer_region) {
+ Array<Range> region = Mutate(self, match_buffer_region->source->region);
+ if (region.same_as(match_buffer_region->source->region)) {
+ return match_buffer_region;
+ } else {
+ return MatchBufferRegion(match_buffer_region->buffer,
+ BufferRegion(match_buffer_region->source->buffer, region));
+ }
+ };
+ return MutateArray(self, arr, fmutate);
+ }
};
Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
@@ -415,6 +483,47 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
}
}
+Stmt StmtMutator::VisitStmt_(const BlockNode* op) {
+ Array<IterVar> iter_vars = Internal::Mutate(this, op->iter_vars);
+ Array<BufferRegion> reads = Internal::Mutate(this, op->reads);
+ Array<BufferRegion> writes = Internal::Mutate(this, op->writes);
+ Array<MatchBufferRegion> match_buffers = Internal::Mutate(this, op->match_buffers);
+ Optional<Stmt> init = NullOpt;
+ if (op->init.defined()) {
+ init = VisitStmt(op->init.value());
+ }
+ Stmt body = VisitStmt(op->body);
+ if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) &&
+ body.same_as(op->body) && init.same_as(op->init) &&
+ match_buffers.same_as(op->match_buffers)) {
+ return GetRef<Block>(op);
+ } else {
+ auto n = CopyOnWrite(op);
+ n->iter_vars = std::move(iter_vars);
+ n->reads = std::move(reads);
+ n->writes = std::move(writes);
+ n->body = std::move(body);
+ n->init = std::move(init);
+ n->match_buffers = std::move(match_buffers);
+ return Stmt(n);
+ }
+}
+
+Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) {
+ Array<PrimExpr> v = Internal::Mutate(this, op->iter_values);
+ PrimExpr pred = this->VisitExpr(op->predicate);
+ Stmt block = this->VisitStmt(op->block);
+ if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) {
+ return GetRef<Stmt>(op);
+ } else {
+ auto n = CopyOnWrite(op);
+ n->iter_values = std::move(v);
+ n->predicate = std::move(pred);
+ n->block = Downcast<Block>(block);
+ return Stmt(n);
+ }
+}
+
// Implementations of IRTransform, PostOrderVisit and Substitute
class IRApplyVisit : public StmtExprVisitor {
public:
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index d242b20..237dc46 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -120,6 +120,25 @@ TEST(IRF, StmtVisitor) {
};
v(fmaketest());
ICHECK_EQ(v.count, 3);
+
+ {
+ // tests for block and block_realize
+ Stmt body = fmaketest();
+ DataType dtype = DataType::Float(32);
+ Var buf_var("b", PointerType(PrimType(dtype)));
+ Buffer buffer = decl_buffer({16});
+ BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
+ MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
+
+ // construct block and block_realize
+ Block block =
+ Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region});
+ Stmt block_realize = BlockRealize({}, const_true(), block);
+
+ v.count = 0;
+ v(block_realize);
+ ICHECK_EQ(v.count, 9);
+ }
}
TEST(IRF, StmtMutator) {
@@ -229,6 +248,28 @@ TEST(IRF, StmtMutator) {
// the seq get flattened
ICHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() != extentptr);
}
+
+ {
+ // tests for block and block_realize
+ Stmt body = fmakealloc();
+ DataType dtype = DataType::Float(32);
+ Var buf_var("b", PointerType(PrimType(dtype)));
+ Buffer buffer = decl_buffer({16});
+ BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
+ MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
+ // construct block and block_realize
+ Block block =
+ Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region});
+ Stmt block_realize = BlockRealize({}, const_true(), block);
+ body = v(std::move(block_realize));
+ // the body should be changed
+ Block new_block = body.as<BlockRealizeNode>()->block;
+ ICHECK(new_block->body.as<AllocateNode>()->extents[1].same_as(x));
+ ICHECK(new_block->init.as<AllocateNode>()->extents[1].same_as(x));
+ ICHECK(new_block->reads[0]->region[0]->min.same_as(x));
+ ICHECK(new_block->writes[0]->region[0]->min.same_as(x));
+ ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x));
+ }
}
int main(int argc, char** argv) {
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index bff60f7..6e338d6 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -364,6 +364,87 @@ def test_intimm_cond():
assert x == 1
+def test_block_blockrealize():
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ vx = tvm.tir.IterVar((16, 16), "vx", 0)
+ vx_var = vx.var
+ vy = tvm.tir.IterVar((16, 16), "vy", 2)
+ vy_var = vy.var
+ A = tvm.tir.decl_buffer((16), "float32")
+ B = tvm.tir.decl_buffer((16, 16), "float32")
+ alloc_buffer = tvm.tir.decl_buffer((16, 16), "float32")
+ match_buffer = tvm.tir.decl_buffer((16, 16), "float32")
+ init_body = tvm.tir.BufferStore(A, 0.0, [vx_var])
+ body = tvm.tir.BufferStore(
+ A,
+ tvm.tir.BufferLoad(A, [vx_var]) + tvm.tir.BufferLoad(B, [vx_var, vy_var]),
+ [vx_var],
+ )
+ reads = [
+ tvm.tir.BufferRegion(
+ B, [tvm.ir.Range.from_min_extent(vx_var, 1), tvm.ir.Range.from_min_extent(vy_var, 1)]
+ )
+ ]
+ writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])]
+ match_buffer_region = tvm.tir.MatchBufferRegion(
+ match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)])
+ )
+
+ block = tvm.tir.Block(
+ [vx, vy],
+ reads,
+ writes,
+ "block",
+ body,
+ init=init_body,
+ alloc_buffers=[alloc_buffer],
+ match_buffers=[match_buffer_region],
+ annotations={"attr_key": "attr_value"},
+ )
+
+ # Checking Block
+ assert isinstance(block, tvm.tir.Block)
+ # Checking iter_vars
+ assert block.iter_vars[0] == vx
+ assert block.iter_vars[1] == vy
+ # Checking reads/writes region
+ assert isinstance(block.reads[0], tvm.tir.BufferRegion)
+ assert block.reads[0].buffer == B
+ assert block.reads[0].region[0].min == vx_var
+ assert block.reads[0].region[1].min == vy_var
+ assert isinstance(block.writes[0], tvm.tir.BufferRegion)
+ assert block.writes[0].buffer == A
+ assert block.writes[0].region[0].min == vx_var
+ assert block.writes[0].region[0].extent == 1
+ # Checking name_hint
+ assert block.name_hint == "block"
+ # Checking body
+ assert block.body == body
+ # Checking init
+ assert block.init == init_body
+ # Checking alloc_buffers
+ assert block.alloc_buffers[0] == alloc_buffer
+ # Checking match_buffers
+ assert block.match_buffers[0].buffer == match_buffer
+ assert isinstance(block.match_buffers[0].source, tvm.tir.BufferRegion)
+ assert block.match_buffers[0].source.buffer == B
+ assert block.match_buffers[0].source.region[0].min == 0
+ assert block.match_buffers[0].source.region[0].extent == 16
+
+ # Checking BlockRealize
+ block_realize = tvm.tir.BlockRealize([x, y], tvm.tir.const(True, "bool"), block)
+ assert isinstance(block_realize, tvm.tir.BlockRealize)
+ assert block_realize.iter_values[0] == x
+ assert block_realize.iter_values[1] == y
+ assert block_realize.predicate == tvm.tir.const(True, "bool")
+ assert block_realize.block == block
+
+ # make sure we can print
+ str(block)
+ str(block_realize)
+
+
if __name__ == "__main__":
test_intimm_cond()
test_buffer_load_store()
@@ -389,3 +470,4 @@ if __name__ == "__main__":
test_isnan()
test_equality()
test_equality_string_imm()
+ test_block_blockrealize()