You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/01 20:59:51 UTC

[tvm] branch main updated: [TensorIR] introduce Block and BlockRealize (#312) (#7553)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 057a673  [TensorIR] introduce Block and BlockRealize (#312) (#7553)
057a673 is described below

commit 057a673986f0ab50c6be8335339b2beb01e3a1f4
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Tue Mar 2 04:59:32 2021 +0800

    [TensorIR] introduce Block and BlockRealize (#312) (#7553)
    
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Tianqi Chen <tq...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Tianqi Chen <tq...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
---
 include/tvm/tir/stmt.h                  | 248 +++++++++++++++++++++++++++++++-
 include/tvm/tir/stmt_functor.h          |   8 ++
 python/tvm/tir/__init__.py              |   1 +
 python/tvm/tir/stmt.py                  | 162 +++++++++++++++++++++
 src/tir/ir/stmt.cc                      | 219 ++++++++++++++++++++++++++++
 src/tir/ir/stmt_functor.cc              | 109 ++++++++++++++
 tests/cpp/ir_functor_test.cc            |  41 ++++++
 tests/python/unittest/test_tir_nodes.py |  82 +++++++++++
 8 files changed, 869 insertions(+), 1 deletion(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 093d49c..074bcdd 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -862,7 +862,7 @@ class For : public Stmt {
 };
 
 /*!
- * \brief A prefetch hint for abuffer
+ * \brief A prefetch hint for a buffer
  */
 class PrefetchNode : public StmtNode {
  public:
@@ -905,6 +905,252 @@ class Prefetch : public Stmt {
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
 };
 
+/*!
+ * \brief Representing the region of multi-dimensional buffer access.
+ */
+class BufferRegionNode : public Object {
+ public:
+  /*! \brief The buffer of the buffer region. */
+  Buffer buffer;
+  /*! \brief The region array of the buffer region. */
+  Array<Range> region;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("buffer", &buffer);
+    v->Visit("region", &region);
+  }
+
+  bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
+    return equal(buffer, other->buffer) && equal(region, other->region);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(buffer);
+    hash_reduce(region);
+  }
+
+  static constexpr const char* _type_key = "tir.BufferRegion";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const bool _type_has_method_shash_reduce = true;
+  TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object);
+};
+
+/*!
+ * \brief Managed reference to BufferRegionNode.
+ * \sa BufferRegionNode
+ */
+class BufferRegion : public ObjectRef {
+ public:
+  TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
+
+  /*!
+   * \brief Create a BufferRegion which is full region of the given buffer..
+   * \param buffer The buffer to generate full BufferRegion.
+   * \return The BufferRegion which covers all region of the given buffer
+   */
+  TVM_DLL static BufferRegion FullRegion(Buffer buffer);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
+};
+
+/*!
+ * \brief Match introduces a constraint that the source buffer region can be remapped to the data
+ * layout specified by the buffer field. The constraint can be checked in later part of lowering (or
+ * optionally during runtime).
+ *
+ * MatchBufferRegion provides a mechanism to represent data layout and compactness constraints in
+ * low-level hardware primitives in the IR and defer the check after the sequence of
+ * transformations.
+ */
+class MatchBufferRegionNode : public Object {
+ public:
+  /*! \brief The target buffer. */
+  Buffer buffer;
+  /*! \brief The source buffer region. */
+  BufferRegion source;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("buffer", &buffer);
+    v->Visit("source", &source);
+  }
+
+  bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const {
+    return equal(buffer, other->buffer) && equal(source, other->source);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(buffer);
+    hash_reduce(source);
+  }
+
+  static constexpr const char* _type_key = "tir.MatchBufferRegion";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const bool _type_has_method_shash_reduce = true;
+  TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MatchBufferRegionNode.
+ * \sa MatchBufferRegionNode
+ */
+class MatchBufferRegion : public ObjectRef {
+ public:
+  TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
+};
+
+/*!
+ * \brief A block is a basic schedule unit in TIR.
+ * \note Block's body is parameterized by iter vars.
+ * \code
+ *
+ *  with tir.block([extent0, extent1, ...], name) as [v0, v1, ...]:
+ *      tir.bind(v0, value0)
+ *      tir.bind(v1, value1)
+ *      ...
+ *      tir.reads([buffer0[start:end, ...], ...])
+ *      tir.writes([buffer1[start:end, ...], ...])
+ *      tir.where(predicate)
+ *      buffer2 = tir.alloc_buffer(shape, dtype)
+ *      buffer3 = tir.match_buffer(source_buffer[start:end, ...])
+ *      tir.attr({attr_key: attr_value, ...})
+ *      with tir.init():
+ *          // init body
+ *      // body
+ *
+ * \endcode
+ */
+class BlockNode : public StmtNode {
+ public:
+  /*! \brief The variables of the block. */
+  Array<IterVar> iter_vars;
+  /*! \brief The read buffer regions of the block. */
+  Array<BufferRegion> reads;
+  /*! \brief The write buffer regions of the block. */
+  Array<BufferRegion> writes;
+  /*! \brief The name_hint of the block. */
+  String name_hint;
+  /*! \brief The body of the block. */
+  Stmt body;
+  /*!
+   * \brief The init statement is executed during the first iteration of reduction loops in a
+   *  reduction block. The optional init field allows us to represent initialization and
+   *  reduction update in a single block and transform them collectively.
+   *  We also provide primitives to decompose the init into a separate block during scheduling.
+   *  Init field is `NullOpt` if there is no reduction iter_vars
+   */
+  Optional<Stmt> init;
+  /*! \brief The buffer allocated in the block. */
+  Array<Buffer> alloc_buffers;
+  /*! \brief The match buffer regions. */
+  Array<MatchBufferRegion> match_buffers;
+  /*! \brief The annotation of the block. */
+  Map<String, ObjectRef> annotations;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("iter_vars", &iter_vars);
+    v->Visit("reads", &reads);
+    v->Visit("writes", &writes);
+    v->Visit("name_hint", &name_hint);
+    v->Visit("body", &body);
+    v->Visit("init", &init);
+    v->Visit("alloc_buffers", &alloc_buffers);
+    v->Visit("match_buffers", &match_buffers);
+    v->Visit("annotations", &annotations);
+  }
+
+  bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
+    // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
+    return equal.DefEqual(iter_vars, other->iter_vars) &&
+           equal(alloc_buffers, other->alloc_buffers) &&
+           equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
+           equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
+           equal(annotations, other->annotations);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce.DefHash(iter_vars);
+    hash_reduce(alloc_buffers);
+    hash_reduce(match_buffers);
+    hash_reduce(reads);
+    hash_reduce(writes);
+    hash_reduce(body);
+    hash_reduce(init);
+    hash_reduce(annotations);
+  }
+
+  static constexpr const char* _type_key = "tir.Block";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode);
+};
+
+/*!
+ * \brief Managed reference to BlockNode.
+ * \sa BlockNode
+ */
+class Block : public Stmt {
+ public:
+  TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
+                         Array<BufferRegion> writes, String name_hint, Stmt body,
+                         Optional<Stmt> init = NullOpt,
+                         Array<Buffer> alloc_buffers = Array<Buffer>(),
+                         Array<MatchBufferRegion> match_buffers = Array<MatchBufferRegion>(),
+                         Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
+                         Span span = Span());
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode);
+};
+
+/*!
+ * \brief A block realization node represents execution of the block at the binding values.
+ */
+class BlockRealizeNode : public StmtNode {
+ public:
+  /*! \brief The corresponding values of the iter vars. */
+  Array<PrimExpr> iter_values;
+  /*!
+   * \brief The predicate of the block realization, the block will only be executed when the
+   * predicate is true.
+   */
+  PrimExpr predicate;
+  /*! \brief The block to be realized. */
+  Block block;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("iter_values", &iter_values);
+    v->Visit("predicate", &predicate);
+    v->Visit("block", &block);
+  }
+
+  bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
+    return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
+           equal(block, other->block);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(iter_values);
+    hash_reduce(predicate);
+    hash_reduce(block);
+  }
+
+  static constexpr const char* _type_key = "tir.BlockRealize";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode);
+};
+
+/*!
+ * \brief Managed reference to BlockRealizeNode
+ * \sa BlockRealizeNode
+ */
+class BlockRealize : public Stmt {
+ public:
+  TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
+                                Span span = Span());
+
+  TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
+};
+
 /*! \brief namespace of possible attribute sin AttrStmt.attr_key */
 namespace attr {
 // The above attr does not pass to ir stage.
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index 0f4238d..e53b02d 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -96,6 +96,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
   virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmtDefault_(const Object* op, Args...) {
     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
     return R();
@@ -119,6 +121,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
     IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
     IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
     IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
+    IR_STMT_FUNCTOR_DISPATCH(BlockNode);
+    IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
     return vtable;
   }
 };
@@ -158,6 +162,8 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
   void VisitStmt_(const PrefetchNode* op) override;
   void VisitStmt_(const SeqStmtNode* op) override;
   void VisitStmt_(const EvaluateNode* op) override;
+  void VisitStmt_(const BlockNode* op) override;
+  void VisitStmt_(const BlockRealizeNode* op) override;
 };
 
 /*!
@@ -249,6 +255,8 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
   Stmt VisitStmt_(const PrefetchNode* op) override;
   Stmt VisitStmt_(const SeqStmtNode* op) override;
   Stmt VisitStmt_(const EvaluateNode* op) override;
+  Stmt VisitStmt_(const BlockNode* op) override;
+  Stmt VisitStmt_(const BlockRealizeNode* op) override;
   /*!
    * \brief Alternative advance method for SeqStmtNode.
    *
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 324c4da..ad91eab 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -31,6 +31,7 @@ from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
 from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
 from .stmt import ProducerRealize, SeqStmt
 from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
+from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
 
 from .function import PrimFunc
 
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 5882dca..e4f1ac9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -26,11 +26,15 @@ Each statement node have subfields that can be visited from python side.
     assert isinstance(st, tvm.tir.stmt.Store)
     assert(st.buffer_var == a)
 """
+from typing import List, Optional, Mapping
 from enum import IntEnum
 import tvm._ffi
 
 from tvm.runtime import Object
+from tvm.ir import Span, PrimExpr, Range
 from . import _ffi_api
+from .buffer import Buffer
+from .expr import IterVar
 
 
 class Stmt(Object):
@@ -429,6 +433,164 @@ class Prefetch(Stmt):
         self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span)
 
 
+@tvm._ffi.register_object("tir.BufferRegion")
+class BufferRegion(Object):
+    """BufferRegion node.
+
+    Parameters
+    ----------
+    buffer : Buffer
+        The buffer of the buffer region
+
+    region : List[Range]
+        The region array of the buffer region
+    """
+
+    buffer: Buffer
+    region: List[Range]
+
+    def __init__(self, buffer: Buffer, region: List[Range]):
+        self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region)
+
+
+@tvm._ffi.register_object("tir.MatchBufferRegion")
+class MatchBufferRegion(Object):
+    """MatchBufferRegion node.
+
+    Parameters
+    ----------
+    buffer : Buffer
+        The target buffer
+
+    source : BufferRegion
+        The region of source buffer
+    """
+
+    buffer: Buffer
+    source: BufferRegion
+
+    def __init__(self, buffer: Buffer, source: BufferRegion):
+        self.__init_handle_by_constructor__(_ffi_api.MatchBufferRegion, buffer, source)
+
+
+@tvm._ffi.register_object("tir.Block")
+class Block(Stmt):
+    """Block node.
+
+    Parameters
+    ----------
+    iter_vars : List[IterVar]
+        The block Variable.
+
+    reads : List[BufferRegion]
+        The read buffer regions of the block.
+
+    writes: List[BufferRegion]
+        The write buffer regions of the block.
+
+    name_hint: str
+        the name_hint of the block.
+
+    body: Stmt
+        The body of the block.
+
+    init: Optional[Stmt]
+        The init block of the reduction block
+
+    alloc_buffers: Optional[list[Buffer]]
+        The buffer allocations
+
+    match_buffers: Optional[List[MatchBufferRegion]]
+        The subregion buffer match
+
+    annotations: Optional[Mapping[str, Object]]
+        Additional annotation hints.
+
+    span : Optional[Span]
+        The location of this block in the source code.
+    """
+
+    iter_vars: List[IterVar]
+    reads: List[BufferRegion]
+    writes: List[BufferRegion]
+    name_hint: str
+    body: Stmt
+    init: Optional[Stmt]
+    alloc_buffers: Optional[List[Buffer]]
+    match_buffers: Optional[List[MatchBufferRegion]]
+    annotations: Optional[Mapping[str, Object]]
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        iter_vars: List[IterVar],
+        reads: List[BufferRegion],
+        writes: List[BufferRegion],
+        name_hint: str,
+        body: Stmt,
+        init: Optional[Stmt] = None,
+        alloc_buffers: Optional[List[Buffer]] = None,
+        match_buffers: Optional[List[MatchBufferRegion]] = None,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ):
+        if alloc_buffers is None:
+            alloc_buffers = []
+        if match_buffers is None:
+            match_buffers = []
+        if annotations is None:
+            annotations = {}
+        self.__init_handle_by_constructor__(
+            _ffi_api.Block,
+            iter_vars,
+            reads,
+            writes,
+            name_hint,
+            body,
+            init,
+            alloc_buffers,
+            match_buffers,
+            annotations,
+            span,
+        )
+
+
+@tvm._ffi.register_object("tir.BlockRealize")
+class BlockRealize(Stmt):
+    """BlockRealize node.
+
+    Parameters
+    ----------
+    iter_values : List[PrimExpr]
+        The binding values of the block var.
+
+    predicate : PrimExpr
+        The predicate of the block.
+
+    block : Block
+        The block to realize
+
+    span : Optional[Span]
+        The location of this block_realize in the source code.
+    """
+
+    iter_values: List[PrimExpr]
+    predicate: PrimExpr
+    block: Block
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        iter_values: List[PrimExpr],
+        predicate: PrimExpr,
+        block: Block,
+        span: Optional[Span] = None,
+    ):
+        self.__init_handle_by_constructor__(
+            _ffi_api.BlockRealize, iter_values, predicate, block, span
+        )
+
+
 def stmt_seq(*args):
     """Make sequence of statements
 
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 92dc387..e54be43 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -598,6 +598,225 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "}\n";
     });
 
+// BufferRegion
+BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
+  ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>();
+  node->buffer = std::move(buffer);
+  node->region = std::move(region);
+  data_ = std::move(node);
+}
+
+BufferRegion BufferRegion::FullRegion(Buffer buffer) {
+  Array<Range> region;
+  for (PrimExpr extent : buffer->shape) {
+    region.push_back(Range::FromMinExtent(0, extent));
+  }
+  return BufferRegion(buffer, region);
+}
+
+TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<Range> region) {
+  return BufferRegion(buffer, region);
+});
+
+TVM_REGISTER_NODE_TYPE(BufferRegionNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<BufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const BufferRegionNode*>(node.get());
+      p->stream << op->buffer->name;
+      p->stream << "[";
+      for (size_t i = 0; i < op->region.size(); ++i) {
+        const auto& range = op->region[i];
+        p->Print(range->min);
+        if (!is_one(range->extent)) {
+          p->stream << ":";
+          p->Print(range->min + range->extent);
+        }
+        if (i != op->region.size() - 1) p->stream << ", ";
+      }
+      p->stream << "]";
+    });
+
+// MatchBufferRegion
+MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
+  ObjectPtr<MatchBufferRegionNode> node = make_object<MatchBufferRegionNode>();
+  node->buffer = std::move(buffer);
+  node->source = std::move(source);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) {
+  return MatchBufferRegion(buffer, source);
+});
+
+TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
+      p->PrintIndent();
+      p->stream << op->buffer->name << " = match_buffer_region(";
+      p->Print(op->source);
+      p->stream << ")\n";
+    });
+
+// Block
+Block::Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, Array<BufferRegion> writes,
+             String name_hint, Stmt body, Optional<Stmt> init, Array<Buffer> alloc_buffers,
+             Array<MatchBufferRegion> match_buffers, Map<String, ObjectRef> annotations,
+             Span span) {
+  ObjectPtr<BlockNode> node = make_object<BlockNode>();
+  node->iter_vars = std::move(iter_vars);
+  node->reads = std::move(reads);
+  node->writes = std::move(writes);
+  node->name_hint = std::move(name_hint);
+  node->body = std::move(body);
+  node->init = std::move(init);
+  node->alloc_buffers = std::move(alloc_buffers);
+  node->match_buffers = std::move(match_buffers);
+  node->annotations = std::move(annotations);
+  node->span = std::move(span);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.Block")
+    .set_body_typed([](Array<IterVar> iter_vars, Array<BufferRegion> reads,
+                       Array<BufferRegion> writes, String name_hint, Stmt body, Optional<Stmt> init,
+                       Array<Buffer> alloc_buffers, Array<MatchBufferRegion> match_buffers,
+                       Map<String, ObjectRef> annotations, Span span) {
+      return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers,
+                   annotations, span);
+    });
+
+TVM_REGISTER_NODE_TYPE(BlockNode);
+
+void PrintBlockTitle(const BlockNode* op, ReprPrinter* p) {
+  p->stream << "block " << op->name_hint << "(";
+  for (size_t i = 0; i < op->iter_vars.size(); i++) {
+    p->Print(op->iter_vars[i]);
+    if (i < op->iter_vars.size() - 1) p->stream << ", ";
+  }
+  p->stream << ")";
+}
+
+void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) {
+  // print read/write regions
+  p->PrintIndent();
+  p->stream << "reads(";
+  p->Print(op->reads);
+  p->stream << ")\n";
+  p->PrintIndent();
+  p->stream << "writes(";
+  p->Print(op->writes);
+  p->stream << ")\n";
+  // Print alloc_buffers
+  for (const auto& alloc_buf : op->alloc_buffers) {
+    p->PrintIndent();
+    p->stream << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "[";
+    for (size_t i = 0; i < alloc_buf->shape.size(); ++i) {
+      if (i > 0) p->stream << ", ";
+      p->Print(alloc_buf->shape[i]);
+    }
+    p->stream << "])\n";
+  }
+  // Print match_buffer_regions
+  for (const auto& match_buf : op->match_buffers) {
+    p->Print(match_buf);
+  }
+  if (!op->annotations.empty()) {
+    p->PrintIndent();
+    p->stream << "annotations(" << op->annotations << ")\n";
+  }
+}
+
+void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
+  // Print init
+  if (op->init.defined()) {
+    p->PrintIndent();
+    p->stream << "with init() {\n";
+    p->indent += 2;
+    p->Print(op->init.value());
+    p->indent -= 2;
+    p->PrintIndent();
+    p->stream << "}\n";
+  }
+  // Print body
+  p->Print(op->body);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<BlockNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const BlockNode*>(node.get());
+      p->PrintIndent();
+      PrintBlockTitle(op, p);
+      p->stream << "{\n";
+      p->indent += 2;
+
+      // Print block elements (e.g. reads/writes, etc)
+      PrintBlockSignature(op, p);
+      // Print block init and body
+      PrintBlockBody(op, p);
+
+      p->indent -= 2;
+      p->PrintIndent();
+      p->stream << "}\n";
+    });
+
+// BlockRealize
+BlockRealize::BlockRealize(Array<PrimExpr> values, PrimExpr predicate, Block block, Span span) {
+  CHECK_EQ(block->iter_vars.size(), values.size())
+      << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values";
+  CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression";
+  ObjectPtr<BlockRealizeNode> node = make_object<BlockRealizeNode>();
+  node->iter_values = std::move(values);
+  node->predicate = std::move(predicate);
+  node->block = std::move(block);
+  node->span = std::move(span);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.BlockRealize")
+    .set_body_typed([](Array<PrimExpr> iter_values, PrimExpr predicate, Block block, Span span) {
+      return BlockRealize(iter_values, predicate, block, span);
+    });
+
+TVM_REGISTER_NODE_TYPE(BlockRealizeNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<BlockRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const BlockRealizeNode*>(node.get());
+      auto* block_op = op->block.get();
+      p->PrintIndent();
+      PrintBlockTitle(block_op, p);
+      p->stream << "{\n";
+      p->indent += 2;
+
+      // Print binding iter_values
+      for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
+        p->PrintIndent();
+        p->stream << "bind(";
+        p->Print(block_op->iter_vars[i]->var);
+        p->stream << ", ";
+        p->Print(op->iter_values[i]);
+        p->stream << ")\n";
+      }
+      // Print predicate
+      if (!is_one(op->predicate)) {
+        p->PrintIndent();
+        p->stream << "where(";
+        p->Print(op->predicate);
+        p->stream << ")\n";
+      }
+      // Print block elements (e.g. reads/writes, etc)
+      PrintBlockSignature(block_op, p);
+      // Print block init and body
+      PrintBlockBody(block_op, p);
+
+      p->indent -= 2;
+      p->PrintIndent();
+      p->stream << "}\n";
+    });
+
 PrimExpr TypeAnnotation(DataType dtype, Span span) {
   static auto op = Op::Get("tir.type_annotation");
   return tir::Call(dtype, op, {}, span);
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index e4cc1b7..f05dc71 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -112,6 +112,35 @@ void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
 
 void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); }
 
+void StmtVisitor::VisitStmt_(const BlockNode* op) {
+  auto fvisit_buffer_region = [this](const BufferRegion& s) {
+    for (const auto& range : s->region) {
+      this->VisitExpr(range->min);
+      this->VisitExpr(range->extent);
+    }
+  };
+  VisitArray(op->iter_vars, [this](const IterVar& iter_var) {
+    this->VisitExpr(iter_var->dom->min);
+    this->VisitExpr(iter_var->dom->extent);
+  });
+  VisitArray(op->reads, fvisit_buffer_region);
+  VisitArray(op->writes, fvisit_buffer_region);
+  VisitArray(op->match_buffers,
+             [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
+               fvisit_buffer_region(match_buffer_region->source);
+             });
+  if (op->init.defined()) {
+    this->VisitStmt(op->init.value());
+  }
+  this->VisitStmt(op->body);
+}
+
+void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) {
+  VisitArray(op->iter_values, [this](const PrimExpr& e) { this->VisitExpr(e); });
+  this->VisitExpr(op->predicate);
+  this->VisitStmt(op->block);
+}
+
 class StmtMutator::Internal {
  public:
   /*!
@@ -150,6 +179,20 @@ class StmtMutator::Internal {
     }
   }
 
+  static Array<IterVar> Mutate(StmtMutator* self, const Array<IterVar>& arr) {
+    auto fmutate = [self](const IterVar& iter_var) {
+      PrimExpr min = self->VisitExpr(iter_var->dom->min);
+      PrimExpr extent = self->VisitExpr(iter_var->dom->extent);
+      if (min.same_as(iter_var->dom->min) && extent.same_as(iter_var->dom->extent)) {
+        return iter_var;
+      } else {
+        return IterVar(Range(min, extent), iter_var->var, iter_var->iter_type,
+                       iter_var->thread_tag);
+      }
+    };
+    return MutateArray(self, arr, fmutate);
+  }
+
   static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
     auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
     return MutateArray(self, arr, fmutate);
@@ -172,6 +215,31 @@ class StmtMutator::Internal {
     };
     return MutateArray(self, arr, fmutate);
   }
+
+  static Array<BufferRegion> Mutate(StmtMutator* self, const Array<BufferRegion>& arr) {
+    auto fmutate = [self](const BufferRegion& buffer_region) {
+      Array<Range> region = Mutate(self, buffer_region->region);
+      if (region.same_as(buffer_region->region)) {
+        return buffer_region;
+      } else {
+        return BufferRegion(buffer_region->buffer, region);
+      }
+    };
+    return MutateArray(self, arr, fmutate);
+  }
+
+  static Array<MatchBufferRegion> Mutate(StmtMutator* self, const Array<MatchBufferRegion>& arr) {
+    auto fmutate = [self](const MatchBufferRegion& match_buffer_region) {
+      Array<Range> region = Mutate(self, match_buffer_region->source->region);
+      if (region.same_as(match_buffer_region->source->region)) {
+        return match_buffer_region;
+      } else {
+        return MatchBufferRegion(match_buffer_region->buffer,
+                                 BufferRegion(match_buffer_region->source->buffer, region));
+      }
+    };
+    return MutateArray(self, arr, fmutate);
+  }
 };
 
 Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
@@ -415,6 +483,47 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
   }
 }
 
+Stmt StmtMutator::VisitStmt_(const BlockNode* op) {
+  Array<IterVar> iter_vars = Internal::Mutate(this, op->iter_vars);
+  Array<BufferRegion> reads = Internal::Mutate(this, op->reads);
+  Array<BufferRegion> writes = Internal::Mutate(this, op->writes);
+  Array<MatchBufferRegion> match_buffers = Internal::Mutate(this, op->match_buffers);
+  Optional<Stmt> init = NullOpt;
+  if (op->init.defined()) {
+    init = VisitStmt(op->init.value());
+  }
+  Stmt body = VisitStmt(op->body);
+  if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) &&
+      body.same_as(op->body) && init.same_as(op->init) &&
+      match_buffers.same_as(op->match_buffers)) {
+    return GetRef<Block>(op);
+  } else {
+    auto n = CopyOnWrite(op);
+    n->iter_vars = std::move(iter_vars);
+    n->reads = std::move(reads);
+    n->writes = std::move(writes);
+    n->body = std::move(body);
+    n->init = std::move(init);
+    n->match_buffers = std::move(match_buffers);
+    return Stmt(n);
+  }
+}
+
+Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) {
+  Array<PrimExpr> v = Internal::Mutate(this, op->iter_values);
+  PrimExpr pred = this->VisitExpr(op->predicate);
+  Stmt block = this->VisitStmt(op->block);
+  if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) {
+    return GetRef<Stmt>(op);
+  } else {
+    auto n = CopyOnWrite(op);
+    n->iter_values = std::move(v);
+    n->predicate = std::move(pred);
+    n->block = Downcast<Block>(block);
+    return Stmt(n);
+  }
+}
+
 // Implementations of IRTransform, PostOrderVisit and Substitute
 class IRApplyVisit : public StmtExprVisitor {
  public:
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index d242b20..237dc46 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -120,6 +120,25 @@ TEST(IRF, StmtVisitor) {
   };
   v(fmaketest());
   ICHECK_EQ(v.count, 3);
+
+  {
+    // tests for block and block_realize
+    Stmt body = fmaketest();
+    DataType dtype = DataType::Float(32);
+    Var buf_var("b", PointerType(PrimType(dtype)));
+    Buffer buffer = decl_buffer({16});
+    BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
+    MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
+
+    // construct block and block_realize
+    Block block =
+        Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region});
+    Stmt block_realize = BlockRealize({}, const_true(), block);
+
+    v.count = 0;
+    v(block_realize);
+    ICHECK_EQ(v.count, 9);
+  }
 }
 
 TEST(IRF, StmtMutator) {
@@ -229,6 +248,28 @@ TEST(IRF, StmtMutator) {
     // the seq get flattened
     ICHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() != extentptr);
   }
+
+  {
+    // tests for block and block_realize
+    Stmt body = fmakealloc();
+    DataType dtype = DataType::Float(32);
+    Var buf_var("b", PointerType(PrimType(dtype)));
+    Buffer buffer = decl_buffer({16});
+    BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
+    MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
+    // construct block and block_realize
+    Block block =
+        Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region});
+    Stmt block_realize = BlockRealize({}, const_true(), block);
+    body = v(std::move(block_realize));
+    // the body should be changed
+    Block new_block = body.as<BlockRealizeNode>()->block;
+    ICHECK(new_block->body.as<AllocateNode>()->extents[1].same_as(x));
+    ICHECK(new_block->init.as<AllocateNode>()->extents[1].same_as(x));
+    ICHECK(new_block->reads[0]->region[0]->min.same_as(x));
+    ICHECK(new_block->writes[0]->region[0]->min.same_as(x));
+    ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x));
+  }
 }
 
 int main(int argc, char** argv) {
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index bff60f7..6e338d6 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -364,6 +364,87 @@ def test_intimm_cond():
     assert x == 1
 
 
+def test_block_blockrealize():
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    vx = tvm.tir.IterVar((16, 16), "vx", 0)
+    vx_var = vx.var
+    vy = tvm.tir.IterVar((16, 16), "vy", 2)
+    vy_var = vy.var
+    A = tvm.tir.decl_buffer((16), "float32")
+    B = tvm.tir.decl_buffer((16, 16), "float32")
+    alloc_buffer = tvm.tir.decl_buffer((16, 16), "float32")
+    match_buffer = tvm.tir.decl_buffer((16, 16), "float32")
+    init_body = tvm.tir.BufferStore(A, 0.0, [vx_var])
+    body = tvm.tir.BufferStore(
+        A,
+        tvm.tir.BufferLoad(A, [vx_var]) + tvm.tir.BufferLoad(B, [vx_var, vy_var]),
+        [vx_var],
+    )
+    reads = [
+        tvm.tir.BufferRegion(
+            B, [tvm.ir.Range.from_min_extent(vx_var, 1), tvm.ir.Range.from_min_extent(vy_var, 1)]
+        )
+    ]
+    writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])]
+    match_buffer_region = tvm.tir.MatchBufferRegion(
+        match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)])
+    )
+
+    block = tvm.tir.Block(
+        [vx, vy],
+        reads,
+        writes,
+        "block",
+        body,
+        init=init_body,
+        alloc_buffers=[alloc_buffer],
+        match_buffers=[match_buffer_region],
+        annotations={"attr_key": "attr_value"},
+    )
+
+    # Checking Block
+    assert isinstance(block, tvm.tir.Block)
+    # Checking iter_vars
+    assert block.iter_vars[0] == vx
+    assert block.iter_vars[1] == vy
+    # Checking reads/writes region
+    assert isinstance(block.reads[0], tvm.tir.BufferRegion)
+    assert block.reads[0].buffer == B
+    assert block.reads[0].region[0].min == vx_var
+    assert block.reads[0].region[1].min == vy_var
+    assert isinstance(block.writes[0], tvm.tir.BufferRegion)
+    assert block.writes[0].buffer == A
+    assert block.writes[0].region[0].min == vx_var
+    assert block.writes[0].region[0].extent == 1
+    # Checking name_hint
+    assert block.name_hint == "block"
+    # Checking body
+    assert block.body == body
+    # Checking init
+    assert block.init == init_body
+    # Checking alloc_buffers
+    assert block.alloc_buffers[0] == alloc_buffer
+    # Checking match_buffers
+    assert block.match_buffers[0].buffer == match_buffer
+    assert isinstance(block.match_buffers[0].source, tvm.tir.BufferRegion)
+    assert block.match_buffers[0].source.buffer == B
+    assert block.match_buffers[0].source.region[0].min == 0
+    assert block.match_buffers[0].source.region[0].extent == 16
+
+    # Checking BlockRealize
+    block_realize = tvm.tir.BlockRealize([x, y], tvm.tir.const(True, "bool"), block)
+    assert isinstance(block_realize, tvm.tir.BlockRealize)
+    assert block_realize.iter_values[0] == x
+    assert block_realize.iter_values[1] == y
+    assert block_realize.predicate == tvm.tir.const(True, "bool")
+    assert block_realize.block == block
+
+    # make sure we can print
+    str(block)
+    str(block_realize)
+
+
 if __name__ == "__main__":
     test_intimm_cond()
     test_buffer_load_store()
@@ -389,3 +470,4 @@ if __name__ == "__main__":
     test_isnan()
     test_equality()
     test_equality_string_imm()
+    test_block_blockrealize()