You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/18 18:51:30 UTC

[tvm] branch main updated: [TVMScript] IRBuilder methods for `Stmt` (#12831)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 052e702827 [TVMScript] IRBuilder methods for `Stmt` (#12831)
052e702827 is described below

commit 052e7028271be2aa2932e8721faf847940d28429
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Sun Sep 18 11:51:23 2022 -0700

    [TVMScript] IRBuilder methods for `Stmt` (#12831)
    
    This PR introduces  IRBuilder methods for
    `allocate`, `Let`, `allocate_const`, `attr`,  `While`, `If/Then/Else`, `decl_buffer`, `buffer_store`, `prefetch`.
    
    Co-authored-by: yongwww <yo...@gmail.com>
---
 include/tvm/script/ir_builder/tir/frame.h          | 307 +++++++++++++++++++++
 include/tvm/script/ir_builder/tir/ir.h             |  97 +++++++
 python/tvm/script/ir_builder/tir/frame.py          |  48 +++-
 python/tvm/script/ir_builder/tir/ir.py             | 271 ++++++++++++++++++
 src/script/ir_builder/tir/frame.cc                 |  78 ++++++
 src/script/ir_builder/tir/ir.cc                    |  86 ++++++
 src/script/ir_builder/tir/utils.h                  |  15 +
 .../unittest/test_tvmscript_ir_builder_tir.py      | 173 +++++++++++-
 8 files changed, 1061 insertions(+), 14 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
index 38fe9009dd..aa2386e7f1 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -435,6 +435,313 @@ class RealizeFrame : public TIRFrame {
  public:
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
 };
+
+/*!
+ * \brief A frame represents the allocate.
+ *
+ * \sa AllocateFrame
+ */
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+  /*! \brief The extents of the allocate. */
+  Array<PrimExpr> extents;
+  /*! \brief The data type of the buffer. */
+  DataType dtype;
+  /*! \brief The storage scope. */
+  String storage_scope;
+  /*! \brief The condition. */
+  PrimExpr condition;
+  /*! \brief Additional annotation hints. */
+  Map<String, ObjectRef> annotations;
+  /*! \brief The buffer. */
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extents", &extents);
+    v->Visit("dtype", &dtype);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+    v->Visit("annotations", &annotations);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to AllocateFrameNode.
+ *
+ * \sa AllocateFrameNode
+ */
+class AllocateFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+/*!
+ * \brief A frame represents the allocate constant.
+ *
+ * \sa AllocateConstFrame
+ */
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+  /*! \brief The data type of the buffer. */
+  DataType dtype;
+  /*! \brief The extents of the allocate. */
+  Array<PrimExpr> extents;
+  /*! \brief The data associated with the constant. */
+  tvm::runtime::NDArray data;
+  /*! \brief The buffer */
+  tvm::tir::Buffer buffer;
+  /*! \brief Additional annotations about the allocation. */
+  Map<String, ObjectRef> annotations;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("dtype", &dtype);
+    v->Visit("extents", &extents);
+    v->Visit("data", &data);
+    v->Visit("buffer", &buffer);
+    v->Visit("annotations", &annotations);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to AllocateConstFrameNode.
+ *
+ * \sa AllocateConstFrameNode
+ */
+class AllocateConstFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+                                                    AllocateConstFrameNode);
+};
+/*!
+ * \brief A frame that represents attribute node.
+ *
+ * \sa AttrFrame
+ */
+class AttrFrameNode : public TIRFrameNode {
+ public:
+  /*! \brief The node to annotate the attribute. */
+  ObjectRef node;
+  /*! \brief Attribute type key. */
+  String attr_key;
+  /*! \brief The value of the attribute. */
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("node", &node);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to AttrFrameNode.
+ *
+ * \sa AttrFrameNode
+ */
+class AttrFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+/*!
+ * \brief A frame that represents while loop.
+ *
+ * \sa WhileFrame
+ */
+class WhileFrameNode : public TIRFrameNode {
+ public:
+  /*! \brief The termination condition of while. */
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to WhileFrameNode.
+ *
+ * \sa WhileFrameNode
+ */
+class WhileFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+/*!
+ * \brief A frame that represents if statement.
+ *
+ * \sa IfFrame
+ */
+class IfFrameNode : public TIRFrameNode {
+ public:
+  /*! \brief The condition of the if statement. */
+  PrimExpr condition;
+  /*! \brief The statements in the true branch. */
+  Optional<Array<tvm::tir::Stmt>> then_stmts;
+  /*! \brief The stetements in the false branch. */
+  Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("then_stmts", &then_stmts);
+    v->Visit("else_stmts", &else_stmts);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to IfFrameNode.
+ *
+ * \sa IfFrameNode
+ */
+class IfFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+/*!
+ * \brief A frame that represents then.
+ *
+ * \sa ThenFrame
+ */
+class ThenFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when entering RAII scope.
+   * \sa tvm::support::With
+   */
+  void EnterWithScope() final;
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to ThenFrameNode.
+ *
+ * \sa ThenFrameNode
+ */
+class ThenFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+/*!
+ * \brief A frame that represents else.
+ *
+ * \sa ElseFrame
+ */
+class ElseFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when entering RAII scope.
+   * \sa tvm::support::With
+   */
+  void EnterWithScope() final;
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to ElseFrameNode.
+ *
+ * \sa ElseFrameNode
+ */
+class ElseFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
 }  // namespace tir
 }  // namespace ir_builder
 }  // namespace script
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index ec1f7f3753..dd289b6915 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -28,6 +28,7 @@ namespace script {
 namespace ir_builder {
 namespace tir {
 
+using tvm::runtime::NDArray;
 using tvm::tir::Buffer;
 using tvm::tir::Var;
 
@@ -317,6 +318,87 @@ LetFrame Let(Var var, PrimExpr value);
  */
 RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
 
+/*!
+ * \brief The allocate node.
+ * \param extents The extents of the allocate.
+ * \param dtype The data type of the buffer.
+ * \param storage_scope The storage scope.
+ * \param condition The condition.
+ * \param annotations Additional annotation hints.
+ * \return The created AllocateFrame.
+ */
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+                       Optional<PrimExpr> condition = NullOpt,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+
+/*!
+ * \brief The allocate constant node.
+ * \param data The data associated with the constant.
+ * \param dtype The data type of the buffer.
+ * \param extents The extents of the allocate.
+ * \param annotations Additional annotation hints.
+ * \return The created AllocateConstFrame.
+ */
+AllocateConstFrame AllocateConst(
+    NDArray data, DataType dtype, Array<PrimExpr> extents,
+    Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+
+/*!
+ * \brief Create an attribute.
+ * \param node The node to annotate the attribute.
+ * \param attr_key Attribute type key.
+ * \param value The value of the attribute.
+ * \return The result AttrFrame.
+ */
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+
+/*!
+ * \brief Create a while loop.
+ * \param condition The termination condition of the loop.
+ * \return The result WhileFrame.
+ */
+WhileFrame While(PrimExpr condition);
+
+/*!
+ * \brief Create an if statement.
+ * \param condition The condition of if statement.
+ * \return The result IfFrame.
+ */
+IfFrame If(PrimExpr condition);
+
+/*!
+ * \brief Create a then.
+ * \return The result ThenFrame.
+ */
+ThenFrame Then();
+
+/*!
+ * \brief Create an else.
+ * \return The result ElseFrame.
+ */
+ElseFrame Else();
+
+/*!
+ * \brief The buffer declaration frame.
+ * \param shape The type of the buffer prior to flattening.
+ * \param dtype The data type in the content of the buffer.
+ * \param buffer_name The name of the buffer.
+ * \param data The pointer to the head of the data.
+ * \param strides The strides of each dimension.
+ * \param elem_offset The offset in terms of number of dtype elements (including lanes).
+ * \param storage_scope The optional storage scope of buffer data pointer.
+ * \param align The alignment requirement of data pointer in bytes.
+ * \param offset_factor The factor of elem_offset field.
+ * \param buffer_type The buffer type.
+ * \param axis_separators The separators between input axes when generating flattened output axes.
+ * \return The declared buffer.
+ */
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+                           Optional<Var> data, Optional<Array<PrimExpr>> strides,
+                           Optional<PrimExpr> elem_offset, String storage_scope, int align,
+                           int offset_factor, String buffer_type,
+                           Optional<Array<IntImm>> axis_separators);
+
 /*!
  * \brief Launch a thread.
  * \param var The iteration variable.
@@ -332,6 +414,21 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
  */
 Var EnvThread(String thread_tag);
 
+/*!
+ * \brief Store data in a buffer.
+ * \param buffer The buffer.
+ * \param value The value to be stored.
+ * \param indices The indices location to be stored.
+ */
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+
+/*!
+ * \brief The prefetch hint for a buffer
+ * \param buffer The buffer to be prefetched.
+ * \param bounds The bounds to be prefetched.
+ */
+void Prefetch(Buffer buffer, Array<Range> bounds);
+
 /*!
  * \brief Evaluate the input expression.
  * \param value The input expression to evaluate.
diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py
index 69bc5bfc96..b9b50dfa98 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -18,7 +18,7 @@
 from typing import List, Union
 
 from tvm._ffi import register_object as _register_object
-from tvm.tir import Var
+from tvm.tir import Buffer, Var
 
 from ..base import IRBuilderFrame
 
@@ -65,6 +65,52 @@ class RealizeFrame(TIRFrame):
     ...
 
 
+@_register_object("script.ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("script.ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("script.ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
 @_register_object("script.ir_builder.tir.LaunchThreadFrame")
 class LaunchThreadFrame(TIRFrame):
     ...
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 6db8f40c32..625e1291ff 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -19,8 +19,10 @@
 
 from numbers import Integral
 from typing import Any, Dict, List, Optional, Union, Tuple
+import numpy as np  # type: ignore
 
 from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
 from tvm.tir import (
     Buffer,
     BufferLoad,
@@ -32,6 +34,7 @@ from tvm.tir import (
     StringImm,
     Var,
 )
+from tvm.tir import Ramp as ramp
 
 from . import _ffi_api, frame
 
@@ -890,6 +893,217 @@ def realize(
     )
 
 
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str = "",
+    condition: PrimExpr = None,
+    annotations=None,
+) -> frame.AllocateFrame:
+    """Allocate node.
+
+    Parameters
+    ----------
+    extents : List[PrimExpr]
+        The extents of the allocate.
+
+    dtype : str
+        The data type of the buffer.
+
+    scope : str
+        The storage scope.
+
+    condition : PrimExpr
+        The condition.
+
+    annotations: Optional[Mapping[str, Object]]
+        Additional annotation hints.
+    """
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.Allocate(  # type: ignore[attr-defined] # pylint: disable=no-member
+        extents, dtype, scope, condition, annotations
+    )
+
+
+def allocate_const(
+    data: List[PrimExpr],
+    dtype: str,
+    extents: List[PrimExpr],
+    annotations=None,
+) -> frame.AllocateConstFrame:
+    """Allocate constant node.
+
+    Parameters
+    ----------
+    data : List[PrimExpr]
+        The data associated with the constant.
+
+    dtype : str
+        The data type of the buffer.
+
+    extents : List[PrimExpr]
+        The extents of the allocate.
+
+    annotations : Optional[Map]
+        Additional annotations about the allocation.
+    """
+
+    return _ffi_api.AllocateConst(  # type: ignore[attr-defined] # pylint: disable=no-member
+        ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+    )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+    """Create an attribute node.
+
+    Parameters
+    ----------
+    node : Any
+        The node to annotate the attribute.
+
+    attr_key : str
+        Attribute type key.
+
+    value : Union[PrimExpr, str]
+        The value of the attribute.
+
+    Returns
+    -------
+    res : frame.AttrFrame
+        The result AttrFrame.
+    """
+    node = convert(node)
+    value = convert(value)
+    return _ffi_api.Attr(node, attr_key, value)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame:  # pylint: disable=invalid-name
+    """Create a while node.
+
+    Parameters
+    ----------
+    condition : PrimExpr
+        The termination condition of the loop.
+
+    Returns
+    -------
+    res : frame.WhileFrame
+        The result WhileFrame.
+    """
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.While(condition)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def If(condition: PrimExpr) -> frame.IfFrame:  # pylint: disable=invalid-name
+    """Create an if node.
+
+    Parameters
+    ----------
+    condition : PrimExpr
+        The condition of if statement, executes the true branch if the condition is true,
+        otherwise jump into the false branch.
+
+    Returns
+    -------
+    res : frame.IfFrame
+        The result IfFrame.
+    """
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.If(condition)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def Then() -> frame.ThenFrame:  # pylint: disable=invalid-name
+    """Create a then.
+
+    Returns
+    -------
+    res : frame.ThenFrame
+        The result ThenFrame.
+    """
+    return _ffi_api.Then()  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def Else() -> frame.ElseFrame:  # pylint: disable=invalid-name
+    """Create an else.
+
+    Returns
+    -------
+    res : frame.ElseFrame
+        The result ElseFrame.
+    """
+    return _ffi_api.Else()  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def decl_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> frame.DeclBufferFrame:
+    """Create a buffer declaration node.
+
+    Parameters
+    ----------
+    shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral]
+        The type of the buffer prior to flattening.
+
+    dtype : str
+        The data type in the content of the buffer.
+
+    data : Var
+        The pointer to the head of the data.
+
+    strides : List[PrimExpr]
+        The strides of each dimension.
+
+    elem_offset : PrimExpr
+        The offset in terms of number of dtype elements (including lanes).
+
+    scope : str
+        The optional storage scope of buffer data pointer.
+
+    align : int
+        The alignment requirement of data pointer in bytes.
+
+    offset_factor : int
+        The factor of elem_offset field.
+
+    buffer_type : str
+        The buffer type.
+
+    axis_separators : List[int]
+        The separators between input axes when generating flattened output axes.
+
+    Returns
+    -------
+    res : frame.DeclBufferFrame
+        The result DeclBufferFrame.
+    """
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.DeclBuffer(  # type: ignore[attr-defined] # pylint: disable=no-member
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
 def launch_thread(
     iter_var: IterVar,  # pylint: disable=redefined-outer-name
     extent: PrimExpr,
@@ -939,6 +1153,53 @@ def env_thread(thread_tag: str) -> IterVar:
     return _ffi_api.EnvThread(thread_tag)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+    """Buffer store node.
+
+    Parameters
+    ----------
+    buffer : Buffer
+        The buffer.
+
+    value : PrimExpr
+        The value to be stored.
+
+    indices : List[Union[PrimExpr, slice]]
+        The indices location to be stored.
+    """
+    from tvm.arith import Analyzer  # pylint: disable=import-outside-toplevel
+
+    expr_indices = []
+    for index in indices:
+        if isinstance(index, slice):
+            step = 1 if index.step is None else index.step
+            lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+            if lanes == 1:
+                expr_indices.append(index.start)
+            else:
+                expr_indices.append(ramp(index.start, step, int(lanes)))
+        else:
+            expr_indices.append(index)
+    if isinstance(value, bool) and buffer.dtype == "bool":
+        value = IntImm("bool", value)
+    return _ffi_api.BufferStore(  # type: ignore[attr-defined] # pylint: disable=no-member
+        buffer, value, expr_indices
+    )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+    """The prefetch hint for a buffer.
+
+    Parameters
+    ----------
+    buffer : Buffer
+        The buffer to be prefetched.
+    indices : List[PrimExpr]
+        The indices of the buffer to extract.
+    """
+    return _ffi_api.Prefetch(buffer, indices)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
 def evaluate(value: PrimExpr) -> None:
     """Evaluate the input expression.
 
@@ -1288,8 +1549,18 @@ __all__ = [
     "Assert",
     "let",
     "realize",
+    "allocate",
+    "allocate_const",
+    "attr",
+    "While",
+    "If",
+    "Then",
+    "Else",
+    "decl_buffer",
     "launch_thread",
     "env_thread",
+    "buffer_store",
+    "prefetch",
     "evaluate",
     "int8",
     "int16",
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index 6c9459e638..aa9efa653f 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -115,6 +115,76 @@ void LaunchThreadFrameNode::ExitWithScope() {
   AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
 }
 
+void AllocateFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
+                                 AsStmt(stmts), annotations));
+}
+
+void AllocateConstFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(
+      tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+}
+void AttrFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts)));
+}
+
+void WhileFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::While(condition, AsStmt(stmts)));
+}
+
+void IfFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  if (!stmts.empty()) {
+    LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame";
+  }
+  if (!then_stmts.defined()) {
+    LOG(FATAL) << "IfThenElse frame should have at least one then branch";
+  }
+  AddToParent(tvm::tir::IfThenElse(
+      condition, AsStmt(then_stmts.value()),
+      else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr)));
+}
+
+void ThenFrameNode::EnterWithScope() {
+  IfFrame frame = FindIfFrame("T.then_");
+  if (frame->then_stmts.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is "
+               << frame->then_stmts.value();
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void ThenFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  FindIfFrame("T.then_")->then_stmts = stmts;
+}
+
+void ElseFrameNode::EnterWithScope() {
+  IfFrame frame = FindIfFrame("T.else_");
+  if (!frame->then_stmts.defined()) {
+    LOG(FATAL) << "The else branch should follow then branch";
+  }
+  if (frame->else_stmts.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is "
+               << frame->else_stmts.value();
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void ElseFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  FindIfFrame("T.else_")->else_stmts = stmts;
+}
+
+void DeclBufferFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+}
+
 TVM_REGISTER_NODE_TYPE(TIRFrameNode);
 TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
 TVM_REGISTER_NODE_TYPE(BlockFrameNode);
@@ -124,6 +194,14 @@ TVM_REGISTER_NODE_TYPE(AssertFrameNode);
 TVM_REGISTER_NODE_TYPE(LetFrameNode);
 TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
 TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode);
+TVM_REGISTER_NODE_TYPE(AttrFrameNode);
+TVM_REGISTER_NODE_TYPE(WhileFrameNode);
+TVM_REGISTER_NODE_TYPE(IfFrameNode);
+TVM_REGISTER_NODE_TYPE(ThenFrameNode);
+TVM_REGISTER_NODE_TYPE(ElseFrameNode);
+TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode);
 
 }  // namespace tir
 }  // namespace ir_builder
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 5951af298f..28c3d69861 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -444,6 +444,63 @@ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
   return RealizeFrame(n);
 }
 
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope,
+                       Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) {
+  ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>();
+  n->extents = extents;
+  n->dtype = dtype;
+  n->storage_scope = storage_scope;
+  n->condition = condition.value_or(tvm::Bool(true));
+  n->annotations = annotations.value_or(Map<String, ObjectRef>());
+  n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
+                         "default", NullOpt);
+  return AllocateFrame(n);
+}
+
+AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
+                                 Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+  ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
+  n->dtype = dtype;
+  n->extents = extents;
+  n->data = data;
+  n->annotations = annotations;
+  n->buffer =
+      BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+  return AllocateConstFrame(n);
+}
+
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) {
+  ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
+  n->node = node;
+  n->attr_key = attr_key;
+  n->value = value;
+  return AttrFrame(n);
+}
+
+WhileFrame While(PrimExpr condition) {
+  ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>();
+  n->condition = condition;
+  return WhileFrame(n);
+}
+
+IfFrame If(PrimExpr condition) {
+  ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>();
+  n->condition = condition;
+  n->then_stmts = NullOpt;
+  n->else_stmts = NullOpt;
+  return IfFrame(n);
+}
+
+ThenFrame Then() {
+  ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>();
+  return ThenFrame(n);
+}
+
+ElseFrame Else() {
+  ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>();
+  return ElseFrame(n);
+}
+
 Var EnvThread(String thread_tag) {
   IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
                    thread_tag);
@@ -456,6 +513,25 @@ Var EnvThread(String thread_tag) {
   return var;
 }
 
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+  AddToParent(tvm::tir::BufferStore(buffer, value, indices));
+}
+
+void Prefetch(Buffer buffer, Array<Range> bounds) {
+  AddToParent(tvm::tir::Prefetch(buffer, bounds));
+}
+
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+                           Optional<Var> data, Optional<Array<PrimExpr>> strides,
+                           Optional<PrimExpr> elem_offset, String storage_scope, int align,
+                           int offset_factor, String buffer_type,
+                           Optional<Array<IntImm>> axis_separators) {
+  ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
+  n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
+                         align, offset_factor, buffer_type, axis_separators);
+  return DeclBufferFrame(n);
+}
+
 void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
 
 using tvm::script::ir_builder::details::Namer;
@@ -540,10 +616,20 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);
 
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8);
diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h
index c29fae1c65..733c975fad 100644
--- a/src/script/ir_builder/tir/utils.h
+++ b/src/script/ir_builder/tir/utils.h
@@ -88,6 +88,21 @@ inline BlockFrame FindBlockFrame(const String& method) {
   throw;
 }
 
+/*!
+ * \brief Check whether the top frame in IRBuilder frame stack is IfFrame.
+ * \param method The method name to be printed when throwing exception.
+ * \return The top frame of IfFrame.
+ */
+inline IfFrame FindIfFrame(const String& method) {
+  if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) {
+    return frame.value();
+  } else {
+    LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method
+               << "' is called under T.if_()";
+  }
+  throw;
+}
+
 /*!
  * \brief Convert BufferLoad to BufferRegion.
  * \param buffer_load The BufferLoad.
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 7f2e6e1a47..40e13a2fbe 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -17,9 +17,11 @@
 # pylint: disable=invalid-name, missing-docstring
 """Unittests for tvm.script.ir_builder.tir"""
 import pytest
-import tvm.testing
+import numpy as np
 import tvm
+import tvm.testing
 from tvm import tir
+from tvm.runtime import ndarray
 from tvm.script.ir_builder import tir as T
 from tvm.script.ir_builder import IRBuilder
 from tvm.ir.base import assert_structural_equal
@@ -29,6 +31,7 @@ def test_ir_builder_tir_primfunc_base():
     with IRBuilder() as ib:
         with T.prim_func():
             T.evaluate(0)
+
     # the prim_func generated by IRBuilder
     prim_func_actual = ib.get()
 
@@ -41,6 +44,7 @@ def test_ir_builder_tir_primfunc_base():
         preflattened_buffer_map=None,
         attrs=None,
     )
+
     # Check if the generated ir is expected
     assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True)
 
@@ -58,6 +62,7 @@ def test_ir_builder_tir_primfunc_complete():
             buffer_d = T.match_buffer(d, (64, 64), "int64")
             T.preflattened_buffer(e, (32, 32), "int8", data=e.data)
             T.evaluate(0)
+
     # the prim_func generated by IRBuilder
     prim_func_actual = ib.get()
 
@@ -83,6 +88,7 @@ def test_ir_builder_tir_primfunc_complete():
         },
         attrs=tvm.ir.make_node("DictAttrs", key="value"),
     )
+
     # Check if the generated ir is expected
     assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True)
 
@@ -91,6 +97,7 @@ def test_ir_builder_tir_block_base():
     with IRBuilder() as ib:
         with T.block("block"):
             T.evaluate(0)
+
     # the block generated by IRBuilder
     block_realize_actual = ib.get()
 
@@ -110,6 +117,7 @@ def test_ir_builder_tir_block_base():
         predicate=True,
         block=block_expected,
     )
+
     # Check if the generated ir is expected
     assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
 
@@ -131,6 +139,7 @@ def test_ir_builder_tir_block_complete():
             T.match_buffer(e[0:32, 0:32], (32, 32), "float32")
             T.axis.spatial(128, f)
             T.evaluate(0)
+
     # the block generated by IRBuilder
     block_realize_actual = ib.get()
 
@@ -158,6 +167,7 @@ def test_ir_builder_tir_block_complete():
         predicate=var_a > 1,
         block=block_expected,
     )
+
     # Check if the generated ir is expected
     assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
 
@@ -201,6 +211,7 @@ def test_ir_builder_tir_axis():
         predicate=True,
         block=block_expected,
     )
+
     # Check if the generated ir is expected
     assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
 
@@ -256,6 +267,7 @@ def test_ir_builder_tir_for():
         kind=tir.ForKind.SERIAL,
         body=parallel_expected,
     )
+
     # Check if the generated ir is expected
     assert_structural_equal(for_actual, for_expected, map_free_vars=True)
 
@@ -271,20 +283,9 @@ def test_ir_builder_tir_assert():
     assert_expected = tir.AssertStmt(
         T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0)
     )
-    # Check if the generated ir is expected
-    assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
 
-
-def test_ir_builder_tir_evaluate():
-    with IRBuilder() as ib:
-        T.evaluate(0)
-    # the evaluate generated by IRBuilder
-    eval_actual = ib.get()
-
-    # the expected evaluate
-    eval_expected = tir.Evaluate(0)
     # Check if the generated ir is expected
-    assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)
+    assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
 
 
 def test_ir_builder_tir_let():
@@ -296,6 +297,8 @@ def test_ir_builder_tir_let():
 
     # the expected Let statement
     let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0))
+
+    # Check if the generated ir is expected
     assert_structural_equal(let_actual, let_expected, map_free_vars=True)
 
 
@@ -304,6 +307,8 @@ def test_ir_builder_tir_realize():
     with IRBuilder() as ib:
         with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True):
             T.evaluate(0)
+
+    # the buffer realization generated by IRBuilder
     realize_actual = ib.get()
 
     # the expected buffer realization
@@ -313,6 +318,8 @@ def test_ir_builder_tir_realize():
     expected_realize = tir.AttrStmt(
         buffer_a, "realize_scope", tir.StringImm("test_storage_scope"), buffer_realize
     )
+
+    # Check if the generated ir is expected
     assert_structural_equal(realize_actual, expected_realize, map_free_vars=True)
 
 
@@ -322,12 +329,152 @@ def test_ir_builder_tir_thread():
             brow = T.env_thread("blockIdx.y")
             with T.launch_thread(brow, 1):
                 T.evaluate(0)
+
+    # the prim_func generated by IRBuilder
     ir_actual = ib.get()
+
+    # the expected prim_func
     iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y")
     attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0))
     func = tir.PrimFunc([], attr_stmt)
+
+    # Check if the generated ir is expected
     assert_structural_equal(ir_actual, func, map_free_vars=True)
 
 
+def test_ir_builder_tir_allocate():
+    with IRBuilder() as ib:
+        with T.allocate([10], "float32", scope="local"):
+            T.evaluate(1)
+
+    # the allocate generated by IRBuilder
+    ir_actual = ib.get()
+
+    # the expected allocate
+    buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), "local"))
+    ir_expected = tir.Allocate(
+        buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
+    )
+
+    # Check if the generated ir is expected
+    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_allocate_const():
+    data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+    with IRBuilder() as ib:
+        with T.allocate_const(data, "int32", [10]):
+            T.evaluate(1)
+
+    # the allocate const generated by IRBuilder
+    ir_actual = ib.get()
+
+    # the expected allocate const
+    buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32")))
+    ir_expected = tir.AllocateConst(
+        buffer_var, "int32", [10], ndarray.array(np.asarray(data, "int32")), tir.Evaluate(1)
+    )
+
+    # Check if the generated ir is expected
+    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_while():
+    with IRBuilder() as ib:
+        with T.While(T.var("int32", "x") > 0):
+            T.evaluate(0)
+
+    # the while generated by IRBuilder
+    ir_actual = ib.get()
+
+    # the expected while
+    ir_expected = tir.While(tir.Var("x", "int32") > 0, tir.Evaluate(0))
+
+    # Check if the generated ir is expected
+    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_if_then_else():
+    with IRBuilder() as ib:
+        with T.If(T.var("int32", "c") < 12):
+            with T.Then():
+                T.evaluate(T.int32(0))
+            with T.Else():
+                T.evaluate(T.int32(1))
+
+    # the if_then_else generated by IRBuilder
+    ir_actual = ib.get()
+
+    # the expected if_then_else
+    ir_expected = tir.IfThenElse(
+        tir.Var("c", "int32") < 12,
+        tir.Evaluate(tir.IntImm("int32", 0)),
+        tir.Evaluate(tir.IntImm("int32", 1)),
+    )
+
+    # Check if the generated ir is expected
+    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_buffer_store():
+    buffer_a = T.buffer_decl((10, 10), "float32")
+    i = T.var("int32", "x")
+    with IRBuilder() as ib:
+        T.buffer_store(buffer_a, 0.1, [0, i])
+
+    # the buffer store generated by IRBuilder
+    ir_actual = ib.get()
+
+    # the expected buffer store
+    ir_expected = tir.BufferStore(buffer_a, 0.1, [0, i])
+
+    # Check if the generated ir is expected
+    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_prefetch():
+    with IRBuilder() as ib:
+        buffer_a = T.buffer_decl((128, 128), "float32")
+        T.prefetch(buffer_a, [])
+
+    # the prefetch generated by IRBuilder
+    ir_actual = ib.get()
+
+    # the expected prefetch
+    ir_expected = tir.Prefetch(buffer_a, [])
+
+    # Check if the generated ir is expected
+    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_evaluate():
+    with IRBuilder() as ib:
+        T.evaluate(0)
+    # the evaluate generated by IRBuilder
+    eval_actual = ib.get()
+
+    # the expected evaluate
+    eval_expected = tir.Evaluate(0)
+
+    # Check if the generated ir is expected
+    assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_decl_buffer():
+    with IRBuilder() as ib:
+        with T.decl_buffer([128, 128], "float32"):
+            T.evaluate(0)
+
+    # the decl_buffer generated by IRBuilder
+    ir_actual = ib.get()
+
+    # the expected decl_buffer
+    buffer = T.buffer_decl((128, 128), "float32")
+    ir_expected = tir.DeclBuffer(buffer, tir.Evaluate(0))
+
+    # Check if the generated ir is expected
+    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()