You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/18 02:18:07 UTC
[tvm] branch main updated: [TVMScript] IRBuilder methods for `Stmt` (#12830)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b2c5addbb4 [TVMScript] IRBuilder methods for `Stmt` (#12830)
b2c5addbb4 is described below
commit b2c5addbb4e92aa770f0cd0847eabb43400ac9d2
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Sat Sep 17 19:18:01 2022 -0700
[TVMScript] IRBuilder methods for `Stmt` (#12830)
This PR introduces IRBuilder methods for `Assert`, `Let`, `Realize`, `Evaluate`, `LaunchThread`, `EnvThread`.
Co-authored-by: yongwww <yo...@gmail.com>
---
include/tvm/script/ir_builder/tir/frame.h | 132 +++++++++++++++++++++
include/tvm/script/ir_builder/tir/ir.h | 40 +++++++
python/tvm/script/ir_builder/tir/frame.py | 20 ++++
python/tvm/script/ir_builder/tir/ir.py | 131 ++++++++++++++++++++
src/script/ir_builder/tir/frame.cc | 27 +++++
src/script/ir_builder/tir/ir.cc | 67 +++++++++++
.../unittest/test_tvmscript_ir_builder_tir.py | 69 +++++++++++
7 files changed, 486 insertions(+)
diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
index c76b400d96..38fe9009dd 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -303,6 +303,138 @@ class AssertFrameNode : public TIRFrameNode {
void ExitWithScope() final;
};
+/*!
+ * \brief Managed reference to AssertFrameNode.
+ *
+ * \sa AssertFrameNode
+ */
+class AssertFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+/*!
+ * \brief A frame represents the let binding expression, which binds a var.
+ *
+ * \sa LetFrameNode
+ */
+class LetFrameNode : public TIRFrameNode {
+ public:
+ /*! \brief The variable we bind to */
+ tvm::tir::Var var;
+ /*! \brief The value we bind var to */
+ PrimExpr value;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("var", &var);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to LetFrameNode.
+ *
+ * \sa LetFrameNode
+ */
+class LetFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+/*!
+ * \brief The LaunchThreadFrameNode.
+ * \note It is used only inside a PrimFunc.
+ */
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+ /*! \brief The extent of environment thread. */
+ PrimExpr extent;
+ /*! \brief The attribute key, could be either virtual_thread or thread_extent. */
+ String attr_key;
+ /*! \brief The iteration variable. */
+ tvm::tir::IterVar iter_var;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("extent", &extent);
+ v->Visit("attr_key", &attr_key);
+ v->Visit("iter_var", &iter_var);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to LaunchThreadFrameNode.
+ *
+ * \sa LaunchThreadFrameNode
+ */
+class LaunchThreadFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+ LaunchThreadFrameNode);
+};
+
+/*!
+ * \brief A frame that represents realization.
+ *
+ * \sa RealizeFrame
+ */
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+ /*! \brief The region of buffer access. */
+ tvm::tir::BufferRegion buffer_slice;
+ /*! \brief The storage scope associated with this realization. */
+ String storage_scope;
+ /*! \brief The condition expression. */
+ PrimExpr condition;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("buffer_slice", &buffer_slice);
+ v->Visit("storage_scope", &storage_scope);
+ v->Visit("condition", &condition);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to RealizeFrameNode.
+ *
+ * \sa RealizeFrameNode
+ */
+class RealizeFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
} // namespace tir
} // namespace ir_builder
} // namespace script
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index 191887648d..ec1f7f3753 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -292,6 +292,46 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
*/
ForFrame Grid(Array<PrimExpr> extents);
+/*!
+ * \brief The assertion statement.
+ * \param condition The assertion condition.
+ * \param message The error message when the assertion fails.
+ * \return The AssertFrame.
+ */
+AssertFrame Assert(PrimExpr condition, String message);
+
+/*!
+ * \brief The let binding.
+ * \param var The variable to bind.
+ * \param value The value to be bound.
+ * \return The created LetFrame.
+ */
+LetFrame Let(Var var, PrimExpr value);
+
+/*!
+ * \brief The realization.
+ * \param buffer_slice The region of buffer access.
+ * \param storage_scope The storage scope associated with this realization.
+ * \param condition The condition expression.
+ * \return The result RealizeFrame.
+ */
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+
+/*!
+ * \brief Launch a thread.
+ * \param var The iteration variable.
+ * \param extent The extent of environment thread.
+ * \return The result LaunchThreadFrame.
+ */
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
+
+/*!
+ * \brief Bind a var to thread env.
+ * \param thread_tag The thread type tag.
+ * \return The result variable which gets bound to the thread env.
+ */
+Var EnvThread(String thread_tag);
+
/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py
index 2ad08f3516..69bc5bfc96 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -48,3 +48,23 @@ class ForFrame(TIRFrame):
def __enter__(self) -> Union[Var, List[Var]]: # type: ignore[override]
super().__enter__()
return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("script.ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+ ...
+
+
+@_register_object("script.ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+ ...
+
+
+@_register_object("script.ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+ ...
+
+
+@_register_object("script.ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+ ...
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index d1dc1c8960..6db8f40c32 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -26,6 +26,8 @@ from tvm.tir import (
BufferLoad,
BufferRegion,
IntImm,
+ IterVar,
+ Let,
PrimExpr,
StringImm,
Var,
@@ -813,6 +815,130 @@ def grid(*extents: PrimExpr) -> frame.ForFrame:
return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name
+ """Create an assertion statement.
+
+ Parameters
+ ----------
+ condition : PrimExpr
+ The PrimExpr to test.
+
+ message : str
+ The output error message when the assertion fails.
+
+ Returns
+ -------
+ res : frame.AssertFrame
+ The result AssertFrame.
+ """
+ return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def let(
+ v: Var,
+ value: PrimExpr,
+ body: PrimExpr = None,
+) -> frame.LetFrame:
+ """Create a new let binding.
+
+ Parameters
+ ----------
+ v : Var
+ The variable to bind.
+
+ value : PrimExpr
+ The value to be bound.
+
+ body : PrimExpr
+ The body expression, None will be used if it was not specified.
+
+ Returns
+ -------
+ res : frame.LetFrame
+ The result LetFrame.
+ """
+ if body is None:
+ return _ffi_api.Let(v, value) # type: ignore[attr-defined] # pylint: disable=no-member
+ return Let(v, value, body)
+
+
+def realize(
+ buffer_slice: BufferRegion,
+ storage_scope: str,
+ condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+ """Create a realization.
+
+ Parameters
+ ----------
+ buffer_slice : BufferRegion
+ The region of buffer access.
+
+ storage_scope : str
+ The storage scope associated with this realization.
+
+ condition: PrimExpr
+ The condition expression, the default is True.
+
+ Returns
+ -------
+ res : frame.RealizeFrame
+ The result RealizeFrame.
+ """
+ return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member
+ buffer_slice, storage_scope, condition
+ )
+
+
+def launch_thread(
+ iter_var: IterVar, # pylint: disable=redefined-outer-name
+ extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+ """Launch a thread.
+
+ Parameters
+ ----------
+ iter_var : IterVar
+ The iteration variable.
+
+ extent : PrimExpr
+ The extent of environment thread.
+
+ Returns
+ -------
+ res : frame.LaunchThreadFrame
+ The result LaunchThreadFrame.
+
+ Examples
+ --------
+
+ .. code-block:: python
+
+ from tvm.script.ir_builder import tir as T
+ brow = T.env_thread("blockIdx.y")
+ T.launch_thread(brow, 1)
+
+ """
+ return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def env_thread(thread_tag: str) -> IterVar:
+ """Bind a var to thread env"
+
+ Parameters
+ ----------
+ thread_tag : str
+ The thread type tag.
+
+ Returns
+ -------
+ res : IterVar
+ The result iteration variable gets bound to the thread env.
+
+ """
+ return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.
@@ -1159,6 +1285,11 @@ __all__ = [
"unroll",
"thread_binding",
"grid",
+ "Assert",
+ "let",
+ "realize",
+ "launch_thread",
+ "env_thread",
"evaluate",
"int8",
"int16",
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index 8b8b2a4d80..6c9459e638 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -92,11 +92,38 @@ void ForFrameNode::ExitWithScope() {
AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
}
+void AssertFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+}
+
+void LetFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
+}
+
+void RealizeFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
+ tvm::tir::StringImm(storage_scope),
+ tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region,
+ condition, AsStmt(stmts))));
+}
+
+void LaunchThreadFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
+}
+
TVM_REGISTER_NODE_TYPE(TIRFrameNode);
TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
TVM_REGISTER_NODE_TYPE(BlockFrameNode);
TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
TVM_REGISTER_NODE_TYPE(ForFrameNode);
+TVM_REGISTER_NODE_TYPE(AssertFrameNode);
+TVM_REGISTER_NODE_TYPE(LetFrameNode);
+TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
+TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
} // namespace tir
} // namespace ir_builder
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 75e7592626..5951af298f 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -395,6 +395,67 @@ ForFrame Grid(Array<PrimExpr> extents) {
return ForFrame(n);
}
+AssertFrame Assert(PrimExpr condition, String message) {
+ ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>();
+ n->condition = condition;
+ n->message = tvm::tir::StringImm(message);
+ return AssertFrame(n);
+}
+
+LetFrame Let(Var var, PrimExpr value) {
+ ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
+ n->var = var;
+ n->value = value;
+ return LetFrame(n);
+}
+
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
+ IterVar iter_var{nullptr};
+
+ if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
+ if (Optional<IterVar> opt_iter_var = opt_frame.value()->env_threads.Get(var)) {
+ iter_var = opt_iter_var.value();
+ } else {
+ LOG(FATAL) << "ValueError: " << var->name_hint
+ << " is not an env_thread created using T.env_thread.";
+ }
+ } else {
+ LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc";
+ }
+ ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
+ if (!iter_var->dom.defined()) {
+ const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+ } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
+ LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
+ << iter_var->dom->extent << " vs " << extent;
+ }
+ n->iter_var = iter_var;
+ n->extent = extent;
+ n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent";
+ return LaunchThreadFrame(n);
+}
+
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
+ PrimExpr condition) {
+ ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
+ n->buffer_slice = buffer_slice;
+ n->storage_scope = storage_scope;
+ n->condition = condition;
+ return RealizeFrame(n);
+}
+
+Var EnvThread(String thread_tag) {
+ IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
+ thread_tag);
+ Var var = iter_var->var;
+ if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
+ opt_frame.value()->env_threads.Set(var, iter_var);
+ } else {
+ LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
+ }
+ return var;
+}
+
void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
using tvm::script::ir_builder::details::Namer;
@@ -477,6 +538,12 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);
+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8);
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index a5d8c10680..7f2e6e1a47 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -260,5 +260,74 @@ def test_ir_builder_tir_for():
assert_structural_equal(for_actual, for_expected, map_free_vars=True)
+def test_ir_builder_tir_assert():
+ with IRBuilder() as ib:
+ with T.Assert(T.var("int32", name="a") == 0, message="a is 0"):
+ T.evaluate(0)
+ # the assert generated by IRBuilder
+ assert_actual = ib.get()
+
+ # the expected assert statement
+ assert_expected = tir.AssertStmt(
+ T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0)
+ )
+ # Check if the generated ir is expected
+ assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_evaluate():
+ with IRBuilder() as ib:
+ T.evaluate(0)
+ # the evaluate generated by IRBuilder
+ eval_actual = ib.get()
+
+ # the expected evaluate
+ eval_expected = tir.Evaluate(0)
+ # Check if the generated ir is expected
+ assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_let():
+ with IRBuilder() as ib:
+ with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)):
+ T.evaluate(0)
+ # the let binding generated by IRBuilder
+ let_actual = ib.get()
+
+ # the expected Let statement
+ let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0))
+ assert_structural_equal(let_actual, let_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_realize():
+ buffer_a = T.buffer_decl((128, 128), "float32")
+ with IRBuilder() as ib:
+ with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True):
+ T.evaluate(0)
+ realize_actual = ib.get()
+
+ # the expected buffer realization
+ buffer_realize = tir.BufferRealize(
+ buffer_a, [tvm.ir.Range(0, 128), tvm.ir.Range(0, 128)], True, tir.Evaluate(0)
+ )
+ expected_realize = tir.AttrStmt(
+ buffer_a, "realize_scope", tir.StringImm("test_storage_scope"), buffer_realize
+ )
+ assert_structural_equal(realize_actual, expected_realize, map_free_vars=True)
+
+
+def test_ir_builder_tir_thread():
+ with IRBuilder() as ib:
+ with T.prim_func():
+ brow = T.env_thread("blockIdx.y")
+ with T.launch_thread(brow, 1):
+ T.evaluate(0)
+ ir_actual = ib.get()
+ iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y")
+ attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0))
+ func = tir.PrimFunc([], attr_stmt)
+ assert_structural_equal(ir_actual, func, map_free_vars=True)
+
+
if __name__ == "__main__":
tvm.testing.main()