You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/18 18:51:30 UTC
[tvm] branch main updated: [TVMScript] IRBuilder methods for `Stmt` (#12831)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 052e702827 [TVMScript] IRBuilder methods for `Stmt` (#12831)
052e702827 is described below
commit 052e7028271be2aa2932e8721faf847940d28429
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Sun Sep 18 11:51:23 2022 -0700
[TVMScript] IRBuilder methods for `Stmt` (#12831)
This PR introduces IRBuilder methods for
`allocate`, `Let`, `allocate_const`, `attr`, `While`, `If/Then/Else`, `decl_buffer`, `buffer_store`, `prefetch`.
Co-authored-by: yongwww <yo...@gmail.com>
---
include/tvm/script/ir_builder/tir/frame.h | 307 +++++++++++++++++++++
include/tvm/script/ir_builder/tir/ir.h | 97 +++++++
python/tvm/script/ir_builder/tir/frame.py | 48 +++-
python/tvm/script/ir_builder/tir/ir.py | 271 ++++++++++++++++++
src/script/ir_builder/tir/frame.cc | 78 ++++++
src/script/ir_builder/tir/ir.cc | 86 ++++++
src/script/ir_builder/tir/utils.h | 15 +
.../unittest/test_tvmscript_ir_builder_tir.py | 173 +++++++++++-
8 files changed, 1061 insertions(+), 14 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
index 38fe9009dd..aa2386e7f1 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -435,6 +435,313 @@ class RealizeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
};
+
+/*!
+ * \brief A frame represents the allocate.
+ *
+ * \sa AllocateFrame
+ */
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+ /*! \brief The extents of the allocate. */
+ Array<PrimExpr> extents;
+ /*! \brief The data type of the buffer. */
+ DataType dtype;
+ /*! \brief The storage scope. */
+ String storage_scope;
+ /*! \brief The condition. */
+ PrimExpr condition;
+ /*! \brief Additional annotation hints. */
+ Map<String, ObjectRef> annotations;
+ /*! \brief The buffer. */
+ tvm::tir::Buffer buffer;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("extents", &extents);
+ v->Visit("dtype", &dtype);
+ v->Visit("storage_scope", &storage_scope);
+ v->Visit("condition", &condition);
+ v->Visit("annotations", &annotations);
+ v->Visit("buffer", &buffer);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to AllocateFrameNode.
+ *
+ * \sa AllocateFrameNode
+ */
+class AllocateFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+/*!
+ * \brief A frame represents the allocate constant.
+ *
+ * \sa AllocateConstFrame
+ */
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+ /*! \brief The data type of the buffer. */
+ DataType dtype;
+ /*! \brief The extents of the allocate. */
+ Array<PrimExpr> extents;
+ /*! \brief The data associated with the constant. */
+ tvm::runtime::NDArray data;
+ /*! \brief The buffer */
+ tvm::tir::Buffer buffer;
+ /*! \brief Additional annotations about the allocation. */
+ Map<String, ObjectRef> annotations;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("dtype", &dtype);
+ v->Visit("extents", &extents);
+ v->Visit("data", &data);
+ v->Visit("buffer", &buffer);
+ v->Visit("annotations", &annotations);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to AllocateConstFrameNode.
+ *
+ * \sa AllocateConstFrameNode
+ */
+class AllocateConstFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+ AllocateConstFrameNode);
+};
+/*!
+ * \brief A frame that represents attribute node.
+ *
+ * \sa AttrFrame
+ */
+class AttrFrameNode : public TIRFrameNode {
+ public:
+ /*! \brief The node to annotate the attribute. */
+ ObjectRef node;
+ /*! \brief Attribute type key. */
+ String attr_key;
+ /*! \brief The value of the attribute. */
+ PrimExpr value;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("node", &node);
+ v->Visit("attr_key", &attr_key);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to AttrFrameNode.
+ *
+ * \sa AttrFrameNode
+ */
+class AttrFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+/*!
+ * \brief A frame that represents while loop.
+ *
+ * \sa WhileFrame
+ */
+class WhileFrameNode : public TIRFrameNode {
+ public:
+ /*! \brief The termination condition of while. */
+ PrimExpr condition;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("condition", &condition);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to WhileFrameNode.
+ *
+ * \sa WhileFrameNode
+ */
+class WhileFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+/*!
+ * \brief A frame that represents if statement.
+ *
+ * \sa IfFrame
+ */
+class IfFrameNode : public TIRFrameNode {
+ public:
+ /*! \brief The condition of the if statement. */
+ PrimExpr condition;
+ /*! \brief The statements in the true branch. */
+ Optional<Array<tvm::tir::Stmt>> then_stmts;
+ /*! \brief The stetements in the false branch. */
+ Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("condition", &condition);
+ v->Visit("then_stmts", &then_stmts);
+ v->Visit("else_stmts", &else_stmts);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to IfFrameNode.
+ *
+ * \sa IfFrameNode
+ */
+class IfFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+/*!
+ * \brief A frame that represents then.
+ *
+ * \sa ThenFrame
+ */
+class ThenFrameNode : public TIRFrameNode {
+ public:
+ static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when entering RAII scope.
+ * \sa tvm::support::With
+ */
+ void EnterWithScope() final;
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to ThenFrameNode.
+ *
+ * \sa ThenFrameNode
+ */
+class ThenFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+/*!
+ * \brief A frame that represents else.
+ *
+ * \sa ElseFrame
+ */
+class ElseFrameNode : public TIRFrameNode {
+ public:
+ static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when entering RAII scope.
+ * \sa tvm::support::With
+ */
+ void EnterWithScope() final;
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to ElseFrameNode.
+ *
+ * \sa ElseFrameNode
+ */
+class ElseFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+ tvm::tir::Buffer buffer;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("buffer", &buffer);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
} // namespace tir
} // namespace ir_builder
} // namespace script
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index ec1f7f3753..dd289b6915 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -28,6 +28,7 @@ namespace script {
namespace ir_builder {
namespace tir {
+using tvm::runtime::NDArray;
using tvm::tir::Buffer;
using tvm::tir::Var;
@@ -317,6 +318,87 @@ LetFrame Let(Var var, PrimExpr value);
*/
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+/*!
+ * \brief The allocate node.
+ * \param extents The extents of the allocate.
+ * \param dtype The data type of the buffer.
+ * \param storage_scope The storage scope.
+ * \param condition The condition.
+ * \param annotations Additional annotation hints.
+ * \return The created AllocateFrame.
+ */
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+ Optional<PrimExpr> condition = NullOpt,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+
+/*!
+ * \brief The allocate constant node.
+ * \param data The data associated with the constant.
+ * \param dtype The data type of the buffer.
+ * \param extents The extents of the allocate.
+ * \param annotations Additional annotation hints.
+ * \return The created AllocateConstFrame.
+ */
+AllocateConstFrame AllocateConst(
+ NDArray data, DataType dtype, Array<PrimExpr> extents,
+ Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+
+/*!
+ * \brief Create an attribute.
+ * \param node The node to annotate the attribute.
+ * \param attr_key Attribute type key.
+ * \param value The value of the attribute.
+ * \return The result AttrFrame.
+ */
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+
+/*!
+ * \brief Create a while loop.
+ * \param condition The termination condition of the loop.
+ * \return The result WhileFrame.
+ */
+WhileFrame While(PrimExpr condition);
+
+/*!
+ * \brief Create an if statement.
+ * \param condition The condition of if statement.
+ * \return The result IfFrame.
+ */
+IfFrame If(PrimExpr condition);
+
+/*!
+ * \brief Create a then.
+ * \return The result ThenFrame.
+ */
+ThenFrame Then();
+
+/*!
+ * \brief Create an else.
+ * \return The result ElseFrame.
+ */
+ElseFrame Else();
+
+/*!
+ * \brief The buffer declaration frame.
+ * \param shape The type of the buffer prior to flattening.
+ * \param dtype The data type in the content of the buffer.
+ * \param buffer_name The name of the buffer.
+ * \param data The pointer to the head of the data.
+ * \param strides The strides of each dimension.
+ * \param elem_offset The offset in terms of number of dtype elements (including lanes).
+ * \param storage_scope The optional storage scope of buffer data pointer.
+ * \param align The alignment requirement of data pointer in bytes.
+ * \param offset_factor The factor of elem_offset field.
+ * \param buffer_type The buffer type.
+ * \param axis_separators The separators between input axes when generating flattened output axes.
+ * \return The declared buffer.
+ */
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+ Optional<Var> data, Optional<Array<PrimExpr>> strides,
+ Optional<PrimExpr> elem_offset, String storage_scope, int align,
+ int offset_factor, String buffer_type,
+ Optional<Array<IntImm>> axis_separators);
+
/*!
* \brief Launch a thread.
* \param var The iteration variable.
@@ -332,6 +414,21 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
*/
Var EnvThread(String thread_tag);
+/*!
+ * \brief Store data in a buffer.
+ * \param buffer The buffer.
+ * \param value The value to be stored.
+ * \param indices The indices location to be stored.
+ */
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+
+/*!
+ * \brief The prefetch hint for a buffer
+ * \param buffer The buffer to be prefetched.
+ * \param bounds The bounds to be prefetched.
+ */
+void Prefetch(Buffer buffer, Array<Range> bounds);
+
/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py
index 69bc5bfc96..b9b50dfa98 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -18,7 +18,7 @@
from typing import List, Union
from tvm._ffi import register_object as _register_object
-from tvm.tir import Var
+from tvm.tir import Buffer, Var
from ..base import IRBuilderFrame
@@ -65,6 +65,52 @@ class RealizeFrame(TIRFrame):
...
+@_register_object("script.ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
+
+
+@_register_object("script.ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
+
+
+@_register_object("script.ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+ ...
+
+
+@_register_object("script.ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+ ...
+
+
+@_register_object("script.ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+ ...
+
+
+@_register_object("script.ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+ ...
+
+
+@_register_object("script.ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+ ...
+
+
+@_register_object("script.ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
+
+
@_register_object("script.ir_builder.tir.LaunchThreadFrame")
class LaunchThreadFrame(TIRFrame):
...
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 6db8f40c32..625e1291ff 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -19,8 +19,10 @@
from numbers import Integral
from typing import Any, Dict, List, Optional, Union, Tuple
+import numpy as np # type: ignore
from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
from tvm.tir import (
Buffer,
BufferLoad,
@@ -32,6 +34,7 @@ from tvm.tir import (
StringImm,
Var,
)
+from tvm.tir import Ramp as ramp
from . import _ffi_api, frame
@@ -890,6 +893,217 @@ def realize(
)
+def allocate(
+ extents: List[PrimExpr],
+ dtype: str,
+ scope: str = "",
+ condition: PrimExpr = None,
+ annotations=None,
+) -> frame.AllocateFrame:
+ """Allocate node.
+
+ Parameters
+ ----------
+ extents : List[PrimExpr]
+ The extents of the allocate.
+
+ dtype : str
+ The data type of the buffer.
+
+ scope : str
+ The storage scope.
+
+ condition : PrimExpr
+ The condition.
+
+ annotations: Optional[Mapping[str, Object]]
+ Additional annotation hints.
+ """
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.Allocate( # type: ignore[attr-defined] # pylint: disable=no-member
+ extents, dtype, scope, condition, annotations
+ )
+
+
+def allocate_const(
+ data: List[PrimExpr],
+ dtype: str,
+ extents: List[PrimExpr],
+ annotations=None,
+) -> frame.AllocateConstFrame:
+ """Allocate constant node.
+
+ Parameters
+ ----------
+ data : List[PrimExpr]
+ The data associated with the constant.
+
+ dtype : str
+ The data type of the buffer.
+
+ extents : List[PrimExpr]
+ The extents of the allocate.
+
+ annotations : Optional[Map]
+ Additional annotations about the allocation.
+ """
+
+ return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member
+ ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+ )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+ """Create an attribute node.
+
+ Parameters
+ ----------
+ node : Any
+ The node to annotate the attribute.
+
+ attr_key : str
+ Attribute type key.
+
+ value : Union[PrimExpr, str]
+ The value of the attribute.
+
+ Returns
+ -------
+ res : frame.AttrFrame
+ The result AttrFrame.
+ """
+ node = convert(node)
+ value = convert(value)
+ return _ffi_api.Attr(node, attr_key, value) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame: # pylint: disable=invalid-name
+ """Create a while node.
+
+ Parameters
+ ----------
+ condition : PrimExpr
+ The termination condition of the loop.
+
+ Returns
+ -------
+ res : frame.WhileFrame
+ The result WhileFrame.
+ """
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.While(condition) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def If(condition: PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name
+ """Create an if node.
+
+ Parameters
+ ----------
+ condition : PrimExpr
+ The condition of if statement, executes the true branch if the condition is true,
+ otherwise jump into the false branch.
+
+ Returns
+ -------
+ res : frame.IfFrame
+ The result IfFrame.
+ """
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def Then() -> frame.ThenFrame: # pylint: disable=invalid-name
+ """Create a then.
+
+ Returns
+ -------
+ res : frame.ThenFrame
+ The result ThenFrame.
+ """
+ return _ffi_api.Then() # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def Else() -> frame.ElseFrame: # pylint: disable=invalid-name
+ """Create an else.
+
+ Returns
+ -------
+ res : frame.ElseFrame
+ The result ElseFrame.
+ """
+ return _ffi_api.Else() # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def decl_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+) -> frame.DeclBufferFrame:
+ """Create a buffer declaration node.
+
+ Parameters
+ ----------
+ shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral]
+ The type of the buffer prior to flattening.
+
+ dtype : str
+ The data type in the content of the buffer.
+
+ data : Var
+ The pointer to the head of the data.
+
+ strides : List[PrimExpr]
+ The strides of each dimension.
+
+ elem_offset : PrimExpr
+ The offset in terms of number of dtype elements (including lanes).
+
+ scope : str
+ The optional storage scope of buffer data pointer.
+
+ align : int
+ The alignment requirement of data pointer in bytes.
+
+ offset_factor : int
+ The factor of elem_offset field.
+
+ buffer_type : str
+ The buffer type.
+
+ axis_separators : List[int]
+ The separators between input axes when generating flattened output axes.
+
+ Returns
+ -------
+ res : frame.DeclBufferFrame
+ The result DeclBufferFrame.
+ """
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint: disable=no-member
+ shape,
+ dtype,
+ "",
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
def launch_thread(
iter_var: IterVar, # pylint: disable=redefined-outer-name
extent: PrimExpr,
@@ -939,6 +1153,53 @@ def env_thread(thread_tag: str) -> IterVar:
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+ """Buffer store node.
+
+ Parameters
+ ----------
+ buffer : Buffer
+ The buffer.
+
+ value : PrimExpr
+ The value to be stored.
+
+ indices : List[Union[PrimExpr, slice]]
+ The indices location to be stored.
+ """
+ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel
+
+ expr_indices = []
+ for index in indices:
+ if isinstance(index, slice):
+ step = 1 if index.step is None else index.step
+ lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+ if lanes == 1:
+ expr_indices.append(index.start)
+ else:
+ expr_indices.append(ramp(index.start, step, int(lanes)))
+ else:
+ expr_indices.append(index)
+ if isinstance(value, bool) and buffer.dtype == "bool":
+ value = IntImm("bool", value)
+ return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member
+ buffer, value, expr_indices
+ )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+ """The prefetch hint for a buffer.
+
+ Parameters
+ ----------
+ buffer : Buffer
+ The buffer to be prefetched.
+ indices : List[PrimExpr]
+ The indices of the buffer to extract.
+ """
+ return _ffi_api.Prefetch(buffer, indices) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.
@@ -1288,8 +1549,18 @@ __all__ = [
"Assert",
"let",
"realize",
+ "allocate",
+ "allocate_const",
+ "attr",
+ "While",
+ "If",
+ "Then",
+ "Else",
+ "decl_buffer",
"launch_thread",
"env_thread",
+ "buffer_store",
+ "prefetch",
"evaluate",
"int8",
"int16",
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index 6c9459e638..aa9efa653f 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -115,6 +115,76 @@ void LaunchThreadFrameNode::ExitWithScope() {
AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
}
+void AllocateFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
+ AsStmt(stmts), annotations));
+}
+
+void AllocateConstFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(
+ tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+}
+void AttrFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts)));
+}
+
+void WhileFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::While(condition, AsStmt(stmts)));
+}
+
+void IfFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ if (!stmts.empty()) {
+ LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame";
+ }
+ if (!then_stmts.defined()) {
+ LOG(FATAL) << "IfThenElse frame should have at least one then branch";
+ }
+ AddToParent(tvm::tir::IfThenElse(
+ condition, AsStmt(then_stmts.value()),
+ else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr)));
+}
+
+void ThenFrameNode::EnterWithScope() {
+ IfFrame frame = FindIfFrame("T.then_");
+ if (frame->then_stmts.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is "
+ << frame->then_stmts.value();
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void ThenFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ FindIfFrame("T.then_")->then_stmts = stmts;
+}
+
+void ElseFrameNode::EnterWithScope() {
+ IfFrame frame = FindIfFrame("T.else_");
+ if (!frame->then_stmts.defined()) {
+ LOG(FATAL) << "The else branch should follow then branch";
+ }
+ if (frame->else_stmts.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is "
+ << frame->else_stmts.value();
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void ElseFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ FindIfFrame("T.else_")->else_stmts = stmts;
+}
+
+void DeclBufferFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+}
+
TVM_REGISTER_NODE_TYPE(TIRFrameNode);
TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
TVM_REGISTER_NODE_TYPE(BlockFrameNode);
@@ -124,6 +194,14 @@ TVM_REGISTER_NODE_TYPE(AssertFrameNode);
TVM_REGISTER_NODE_TYPE(LetFrameNode);
TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode);
+TVM_REGISTER_NODE_TYPE(AttrFrameNode);
+TVM_REGISTER_NODE_TYPE(WhileFrameNode);
+TVM_REGISTER_NODE_TYPE(IfFrameNode);
+TVM_REGISTER_NODE_TYPE(ThenFrameNode);
+TVM_REGISTER_NODE_TYPE(ElseFrameNode);
+TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode);
} // namespace tir
} // namespace ir_builder
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 5951af298f..28c3d69861 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -444,6 +444,63 @@ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
return RealizeFrame(n);
}
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope,
+ Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) {
+ ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>();
+ n->extents = extents;
+ n->dtype = dtype;
+ n->storage_scope = storage_scope;
+ n->condition = condition.value_or(tvm::Bool(true));
+ n->annotations = annotations.value_or(Map<String, ObjectRef>());
+ n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
+ "default", NullOpt);
+ return AllocateFrame(n);
+}
+
+AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
+ Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+ ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
+ n->dtype = dtype;
+ n->extents = extents;
+ n->data = data;
+ n->annotations = annotations;
+ n->buffer =
+ BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+ return AllocateConstFrame(n);
+}
+
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) {
+ ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
+ n->node = node;
+ n->attr_key = attr_key;
+ n->value = value;
+ return AttrFrame(n);
+}
+
+WhileFrame While(PrimExpr condition) {
+ ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>();
+ n->condition = condition;
+ return WhileFrame(n);
+}
+
+IfFrame If(PrimExpr condition) {
+ ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>();
+ n->condition = condition;
+ n->then_stmts = NullOpt;
+ n->else_stmts = NullOpt;
+ return IfFrame(n);
+}
+
+ThenFrame Then() {
+ ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>();
+ return ThenFrame(n);
+}
+
+ElseFrame Else() {
+ ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>();
+ return ElseFrame(n);
+}
+
Var EnvThread(String thread_tag) {
IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
thread_tag);
@@ -456,6 +513,25 @@ Var EnvThread(String thread_tag) {
return var;
}
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+ AddToParent(tvm::tir::BufferStore(buffer, value, indices));
+}
+
+void Prefetch(Buffer buffer, Array<Range> bounds) {
+ AddToParent(tvm::tir::Prefetch(buffer, bounds));
+}
+
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+ Optional<Var> data, Optional<Array<PrimExpr>> strides,
+ Optional<PrimExpr> elem_offset, String storage_scope, int align,
+ int offset_factor, String buffer_type,
+ Optional<Array<IntImm>> axis_separators) {
+ ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
+ n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
+ align, offset_factor, buffer_type, axis_separators);
+ return DeclBufferFrame(n);
+}
+
void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
using tvm::script::ir_builder::details::Namer;
@@ -540,10 +616,20 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8);
diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h
index c29fae1c65..733c975fad 100644
--- a/src/script/ir_builder/tir/utils.h
+++ b/src/script/ir_builder/tir/utils.h
@@ -88,6 +88,21 @@ inline BlockFrame FindBlockFrame(const String& method) {
throw;
}
+/*!
+ * \brief Check whether the top frame in IRBuilder frame stack is IfFrame.
+ * \param method The method name to be printed when throwing exception.
+ * \return The top frame of IfFrame.
+ */
+inline IfFrame FindIfFrame(const String& method) {
+ if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) {
+ return frame.value();
+ } else {
+ LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method
+ << "' is called under T.if_()";
+ }
+ throw;
+}
+
/*!
* \brief Convert BufferLoad to BufferRegion.
* \param buffer_load The BufferLoad.
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 7f2e6e1a47..40e13a2fbe 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -17,9 +17,11 @@
# pylint: disable=invalid-name, missing-docstring
"""Unittests for tvm.script.ir_builder.tir"""
import pytest
-import tvm.testing
+import numpy as np
import tvm
+import tvm.testing
from tvm import tir
+from tvm.runtime import ndarray
from tvm.script.ir_builder import tir as T
from tvm.script.ir_builder import IRBuilder
from tvm.ir.base import assert_structural_equal
@@ -29,6 +31,7 @@ def test_ir_builder_tir_primfunc_base():
with IRBuilder() as ib:
with T.prim_func():
T.evaluate(0)
+
# the prim_func generated by IRBuilder
prim_func_actual = ib.get()
@@ -41,6 +44,7 @@ def test_ir_builder_tir_primfunc_base():
preflattened_buffer_map=None,
attrs=None,
)
+
# Check if the generated ir is expected
assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True)
@@ -58,6 +62,7 @@ def test_ir_builder_tir_primfunc_complete():
buffer_d = T.match_buffer(d, (64, 64), "int64")
T.preflattened_buffer(e, (32, 32), "int8", data=e.data)
T.evaluate(0)
+
# the prim_func generated by IRBuilder
prim_func_actual = ib.get()
@@ -83,6 +88,7 @@ def test_ir_builder_tir_primfunc_complete():
},
attrs=tvm.ir.make_node("DictAttrs", key="value"),
)
+
# Check if the generated ir is expected
assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True)
@@ -91,6 +97,7 @@ def test_ir_builder_tir_block_base():
with IRBuilder() as ib:
with T.block("block"):
T.evaluate(0)
+
# the block generated by IRBuilder
block_realize_actual = ib.get()
@@ -110,6 +117,7 @@ def test_ir_builder_tir_block_base():
predicate=True,
block=block_expected,
)
+
# Check if the generated ir is expected
assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
@@ -131,6 +139,7 @@ def test_ir_builder_tir_block_complete():
T.match_buffer(e[0:32, 0:32], (32, 32), "float32")
T.axis.spatial(128, f)
T.evaluate(0)
+
# the block generated by IRBuilder
block_realize_actual = ib.get()
@@ -158,6 +167,7 @@ def test_ir_builder_tir_block_complete():
predicate=var_a > 1,
block=block_expected,
)
+
# Check if the generated ir is expected
assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
@@ -201,6 +211,7 @@ def test_ir_builder_tir_axis():
predicate=True,
block=block_expected,
)
+
# Check if the generated ir is expected
assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
@@ -256,6 +267,7 @@ def test_ir_builder_tir_for():
kind=tir.ForKind.SERIAL,
body=parallel_expected,
)
+
# Check if the generated ir is expected
assert_structural_equal(for_actual, for_expected, map_free_vars=True)
@@ -271,20 +283,9 @@ def test_ir_builder_tir_assert():
assert_expected = tir.AssertStmt(
T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0)
)
- # Check if the generated ir is expected
- assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
-
-def test_ir_builder_tir_evaluate():
- with IRBuilder() as ib:
- T.evaluate(0)
- # the evaluate generated by IRBuilder
- eval_actual = ib.get()
-
- # the expected evaluate
- eval_expected = tir.Evaluate(0)
# Check if the generated ir is expected
- assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)
+ assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
def test_ir_builder_tir_let():
@@ -296,6 +297,8 @@ def test_ir_builder_tir_let():
# the expected Let statement
let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0))
+
+ # Check if the generated ir is expected
assert_structural_equal(let_actual, let_expected, map_free_vars=True)
@@ -304,6 +307,8 @@ def test_ir_builder_tir_realize():
with IRBuilder() as ib:
with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True):
T.evaluate(0)
+
+ # the buffer realization generated by IRBuilder
realize_actual = ib.get()
# the expected buffer realization
@@ -313,6 +318,8 @@ def test_ir_builder_tir_realize():
expected_realize = tir.AttrStmt(
buffer_a, "realize_scope", tir.StringImm("test_storage_scope"), buffer_realize
)
+
+ # Check if the generated ir is expected
assert_structural_equal(realize_actual, expected_realize, map_free_vars=True)
@@ -322,12 +329,152 @@ def test_ir_builder_tir_thread():
brow = T.env_thread("blockIdx.y")
with T.launch_thread(brow, 1):
T.evaluate(0)
+
+ # the prim_func generated by IRBuilder
ir_actual = ib.get()
+
+ # the expected prim_func
iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y")
attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0))
func = tir.PrimFunc([], attr_stmt)
+
+ # Check if the generated ir is expected
assert_structural_equal(ir_actual, func, map_free_vars=True)
+def test_ir_builder_tir_allocate():
+ with IRBuilder() as ib:
+ with T.allocate([10], "float32", scope="local"):
+ T.evaluate(1)
+
+ # the allocate generated by IRBuilder
+ ir_actual = ib.get()
+
+ # the expected allocate
+ buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), "local"))
+ ir_expected = tir.Allocate(
+ buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
+ )
+
+ # Check if the generated ir is expected
+ assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_allocate_const():
+ data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ with IRBuilder() as ib:
+ with T.allocate_const(data, "int32", [10]):
+ T.evaluate(1)
+
+ # the allocate const generated by IRBuilder
+ ir_actual = ib.get()
+
+ # the expected allocate const
+ buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32")))
+ ir_expected = tir.AllocateConst(
+ buffer_var, "int32", [10], ndarray.array(np.asarray(data, "int32")), tir.Evaluate(1)
+ )
+
+ # Check if the generated ir is expected
+ assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_while():
+ with IRBuilder() as ib:
+ with T.While(T.var("int32", "x") > 0):
+ T.evaluate(0)
+
+ # the while generated by IRBuilder
+ ir_actual = ib.get()
+
+ # the expected while
+ ir_expected = tir.While(tir.Var("x", "int32") > 0, tir.Evaluate(0))
+
+ # Check if the generated ir is expected
+ assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_if_then_else():
+ with IRBuilder() as ib:
+ with T.If(T.var("int32", "c") < 12):
+ with T.Then():
+ T.evaluate(T.int32(0))
+ with T.Else():
+ T.evaluate(T.int32(1))
+
+ # the if_then_else generated by IRBuilder
+ ir_actual = ib.get()
+
+ # the expected if_then_else
+ ir_expected = tir.IfThenElse(
+ tir.Var("c", "int32") < 12,
+ tir.Evaluate(tir.IntImm("int32", 0)),
+ tir.Evaluate(tir.IntImm("int32", 1)),
+ )
+
+ # Check if the generated ir is expected
+ assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_buffer_store():
+ buffer_a = T.buffer_decl((10, 10), "float32")
+ i = T.var("int32", "x")
+ with IRBuilder() as ib:
+ T.buffer_store(buffer_a, 0.1, [0, i])
+
+ # the buffer store generated by IRBuilder
+ ir_actual = ib.get()
+
+ # the expected buffer store
+ ir_expected = tir.BufferStore(buffer_a, 0.1, [0, i])
+
+ # Check if the generated ir is expected
+ assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_prefetch():
+ with IRBuilder() as ib:
+ buffer_a = T.buffer_decl((128, 128), "float32")
+ T.prefetch(buffer_a, [])
+
+ # the prefetch generated by IRBuilder
+ ir_actual = ib.get()
+
+ # the expected prefetch
+ ir_expected = tir.Prefetch(buffer_a, [])
+
+ # Check if the generated ir is expected
+ assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_evaluate():
+ with IRBuilder() as ib:
+ T.evaluate(0)
+ # the evaluate generated by IRBuilder
+ eval_actual = ib.get()
+
+ # the expected evaluate
+ eval_expected = tir.Evaluate(0)
+
+ # Check if the generated ir is expected
+ assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_decl_buffer():
+ with IRBuilder() as ib:
+ with T.decl_buffer([128, 128], "float32"):
+ T.evaluate(0)
+
+ # the decl_buffer generated by IRBuilder
+ ir_actual = ib.get()
+
+ # the expected decl_buffer
+ buffer = T.buffer_decl((128, 128), "float32")
+ ir_expected = tir.DeclBuffer(buffer, tir.Evaluate(0))
+
+ # Check if the generated ir is expected
+ assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
if __name__ == "__main__":
tvm.testing.main()