You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/18 02:18:07 UTC

[tvm] branch main updated: [TVMScript] IRBuilder methods for `Stmt` (#12830)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b2c5addbb4 [TVMScript] IRBuilder methods for `Stmt` (#12830)
b2c5addbb4 is described below

commit b2c5addbb4e92aa770f0cd0847eabb43400ac9d2
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Sat Sep 17 19:18:01 2022 -0700

    [TVMScript] IRBuilder methods for `Stmt` (#12830)
    
    This PR introduces  IRBuilder methods for `Assert`, `Let`, `Realize`, `Evaluate`, `LaunchThread`, `EnvThread`.
    
    Co-authored-by: yongwww <yo...@gmail.com>
---
 include/tvm/script/ir_builder/tir/frame.h          | 132 +++++++++++++++++++++
 include/tvm/script/ir_builder/tir/ir.h             |  40 +++++++
 python/tvm/script/ir_builder/tir/frame.py          |  20 ++++
 python/tvm/script/ir_builder/tir/ir.py             | 131 ++++++++++++++++++++
 src/script/ir_builder/tir/frame.cc                 |  27 +++++
 src/script/ir_builder/tir/ir.cc                    |  67 +++++++++++
 .../unittest/test_tvmscript_ir_builder_tir.py      |  69 +++++++++++
 7 files changed, 486 insertions(+)

diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
index c76b400d96..38fe9009dd 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -303,6 +303,138 @@ class AssertFrameNode : public TIRFrameNode {
   void ExitWithScope() final;
 };
 
+/*!
+ * \brief Managed reference to AssertFrameNode.
+ *
+ * \sa AssertFrameNode
+ */
+class AssertFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+/*!
+ * \brief A frame represents the let binding expression, which binds a var.
+ *
+ * \sa LetFrameNode
+ */
+class LetFrameNode : public TIRFrameNode {
+ public:
+  /*! \brief The variable we bind to */
+  tvm::tir::Var var;
+  /*! \brief The value we bind var to */
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to LetFrameNode.
+ *
+ * \sa LetFrameNode
+ */
+class LetFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+/*!
+ * \brief The LaunchThreadFrameNode.
+ * \note It is used only inside a PrimFunc.
+ */
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+  /*! \brief The extent of environment thread. */
+  PrimExpr extent;
+  /*! \brief The attribute key, could be either virtual_thread or thread_extent. */
+  String attr_key;
+  /*! \brief The iteration variable. */
+  tvm::tir::IterVar iter_var;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extent", &extent);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("iter_var", &iter_var);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to LaunchThreadFrameNode.
+ *
+ * \sa LaunchThreadFrameNode
+ */
+class LaunchThreadFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+                                                    LaunchThreadFrameNode);
+};
+
+/*!
+ * \brief A frame that represents realization.
+ *
+ * \sa RealizeFrame
+ */
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+  /*! \brief The region of buffer access. */
+  tvm::tir::BufferRegion buffer_slice;
+  /*! \brief The storage scope associated with this realization. */
+  String storage_scope;
+  /*! \brief The condition expression. */
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer_slice", &buffer_slice);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to RealizeFrameNode.
+ *
+ * \sa RealizeFrameNode
+ */
+class RealizeFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
 }  // namespace tir
 }  // namespace ir_builder
 }  // namespace script
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index 191887648d..ec1f7f3753 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -292,6 +292,46 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
  */
 ForFrame Grid(Array<PrimExpr> extents);
 
+/*!
+ * \brief The assertion statement.
+ * \param condition The assertion condition.
+ * \param message The error message when the assertion fails.
+ * \return The AssertFrame.
+ */
+AssertFrame Assert(PrimExpr condition, String message);
+
+/*!
+ * \brief The let binding.
+ * \param var The variable to bind.
+ * \param value The value to be bound.
+ * \return The created LetFrame.
+ */
+LetFrame Let(Var var, PrimExpr value);
+
+/*!
+ * \brief The realization.
+ * \param buffer_slice The region of buffer access.
+ * \param storage_scope The storage scope associated with this realization.
+ * \param condition The condition expression.
+ * \return The result RealizeFrame.
+ */
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+
+/*!
+ * \brief Launch a thread.
+ * \param var The iteration variable.
+ * \param extent The extent of environment thread.
+ * \return The result LaunchThreadFrame.
+ */
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
+
+/*!
+ * \brief Bind a var to thread env.
+ * \param thread_tag The thread type tag.
+ * \return The result variable which gets bound to the thread env.
+ */
+Var EnvThread(String thread_tag);
+
 /*!
  * \brief Evaluate the input expression.
  * \param value The input expression to evaluate.
diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py
index 2ad08f3516..69bc5bfc96 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -48,3 +48,23 @@ class ForFrame(TIRFrame):
     def __enter__(self) -> Union[Var, List[Var]]:  # type: ignore[override]
         super().__enter__()
         return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("script.ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+    ...
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index d1dc1c8960..6db8f40c32 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -26,6 +26,8 @@ from tvm.tir import (
     BufferLoad,
     BufferRegion,
     IntImm,
+    IterVar,
+    Let,
     PrimExpr,
     StringImm,
     Var,
@@ -813,6 +815,130 @@ def grid(*extents: PrimExpr) -> frame.ForFrame:
     return _ffi_api.Grid(extents)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame:  # pylint: disable=invalid-name
+    """Create an assertion statement.
+
+    Parameters
+    ----------
+    condition : PrimExpr
+        The PrimExpr to test.
+
+    message : str
+        The output error message when the assertion fails.
+
+    Returns
+    -------
+    res : frame.AssertFrame
+        The result AssertFrame.
+    """
+    return _ffi_api.Assert(condition, message)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def let(
+    v: Var,
+    value: PrimExpr,
+    body: PrimExpr = None,
+) -> frame.LetFrame:
+    """Create a new let binding.
+
+    Parameters
+    ----------
+    v : Var
+        The variable to bind.
+
+    value : PrimExpr
+        The value to be bound.
+
+    body : PrimExpr
+        The body expression, None will be used if it was not specified.
+
+    Returns
+    -------
+    res : frame.LetFrame
+        The result LetFrame.
+    """
+    if body is None:
+        return _ffi_api.Let(v, value)  # type: ignore[attr-defined] # pylint: disable=no-member
+    return Let(v, value, body)
+
+
+def realize(
+    buffer_slice: BufferRegion,
+    storage_scope: str,
+    condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+    """Create a realization.
+
+    Parameters
+    ----------
+    buffer_slice : BufferRegion
+        The region of buffer access.
+
+    storage_scope : str
+        The storage scope associated with this realization.
+
+    condition: PrimExpr
+        The condition expression, the default is True.
+
+    Returns
+    -------
+    res : frame.RealizeFrame
+        The result RealizeFrame.
+    """
+    return _ffi_api.Realize(  # type: ignore[attr-defined] # pylint: disable=no-member
+        buffer_slice, storage_scope, condition
+    )
+
+
+def launch_thread(
+    iter_var: IterVar,  # pylint: disable=redefined-outer-name
+    extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+    """Launch a thread.
+
+    Parameters
+    ----------
+    iter_var : IterVar
+        The iteration variable.
+
+    extent : PrimExpr
+        The extent of environment thread.
+
+    Returns
+    -------
+    res : frame.LaunchThreadFrame
+        The result LaunchThreadFrame.
+
+    Examples
+    --------
+
+    .. code-block:: python
+
+    from tvm.script.ir_builder import tir as T
+    brow = T.env_thread("blockIdx.y")
+    T.launch_thread(brow, 1)
+
+    """
+    return _ffi_api.LaunchThread(iter_var, extent)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def env_thread(thread_tag: str) -> IterVar:
+    """Bind a var to thread env"
+
+    Parameters
+    ----------
+    thread_tag : str
+        The thread type tag.
+
+    Returns
+    -------
+    res : IterVar
+        The result iteration variable gets bound to the thread env.
+
+    """
+    return _ffi_api.EnvThread(thread_tag)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
 def evaluate(value: PrimExpr) -> None:
     """Evaluate the input expression.
 
@@ -1159,6 +1285,11 @@ __all__ = [
     "unroll",
     "thread_binding",
     "grid",
+    "Assert",
+    "let",
+    "realize",
+    "launch_thread",
+    "env_thread",
     "evaluate",
     "int8",
     "int16",
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index 8b8b2a4d80..6c9459e638 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -92,11 +92,38 @@ void ForFrameNode::ExitWithScope() {
   AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
 }
 
+void AssertFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+}
+
+void LetFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
+}
+
+void RealizeFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
+                                 tvm::tir::StringImm(storage_scope),
+                                 tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region,
+                                                         condition, AsStmt(stmts))));
+}
+
+void LaunchThreadFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
+}
+
 TVM_REGISTER_NODE_TYPE(TIRFrameNode);
 TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
 TVM_REGISTER_NODE_TYPE(BlockFrameNode);
 TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
 TVM_REGISTER_NODE_TYPE(ForFrameNode);
+TVM_REGISTER_NODE_TYPE(AssertFrameNode);
+TVM_REGISTER_NODE_TYPE(LetFrameNode);
+TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
+TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
 
 }  // namespace tir
 }  // namespace ir_builder
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 75e7592626..5951af298f 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -395,6 +395,67 @@ ForFrame Grid(Array<PrimExpr> extents) {
   return ForFrame(n);
 }
 
+AssertFrame Assert(PrimExpr condition, String message) {
+  ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>();
+  n->condition = condition;
+  n->message = tvm::tir::StringImm(message);
+  return AssertFrame(n);
+}
+
+LetFrame Let(Var var, PrimExpr value) {
+  ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
+  n->var = var;
+  n->value = value;
+  return LetFrame(n);
+}
+
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
+  IterVar iter_var{nullptr};
+
+  if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
+    if (Optional<IterVar> opt_iter_var = opt_frame.value()->env_threads.Get(var)) {
+      iter_var = opt_iter_var.value();
+    } else {
+      LOG(FATAL) << "ValueError: " << var->name_hint
+                 << " is not an env_thread created using T.env_thread.";
+    }
+  } else {
+    LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc";
+  }
+  ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
+  if (!iter_var->dom.defined()) {
+    const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+  } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
+    LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
+               << iter_var->dom->extent << " vs " << extent;
+  }
+  n->iter_var = iter_var;
+  n->extent = extent;
+  n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent";
+  return LaunchThreadFrame(n);
+}
+
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
+                     PrimExpr condition) {
+  ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
+  n->buffer_slice = buffer_slice;
+  n->storage_scope = storage_scope;
+  n->condition = condition;
+  return RealizeFrame(n);
+}
+
+Var EnvThread(String thread_tag) {
+  IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
+                   thread_tag);
+  Var var = iter_var->var;
+  if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
+    opt_frame.value()->env_threads.Set(var, iter_var);
+  } else {
+    LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
+  }
+  return var;
+}
+
 void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
 
 using tvm::script::ir_builder::details::Namer;
@@ -477,6 +538,12 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid);
 
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);
+
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8);
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index a5d8c10680..7f2e6e1a47 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -260,5 +260,74 @@ def test_ir_builder_tir_for():
     assert_structural_equal(for_actual, for_expected, map_free_vars=True)
 
 
+def test_ir_builder_tir_assert():
+    with IRBuilder() as ib:
+        with T.Assert(T.var("int32", name="a") == 0, message="a is 0"):
+            T.evaluate(0)
+    # the assert generated by IRBuilder
+    assert_actual = ib.get()
+
+    # the expected assert statement
+    assert_expected = tir.AssertStmt(
+        T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0)
+    )
+    # Check if the generated ir is expected
+    assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_evaluate():
+    with IRBuilder() as ib:
+        T.evaluate(0)
+    # the evaluate generated by IRBuilder
+    eval_actual = ib.get()
+
+    # the expected evaluate
+    eval_expected = tir.Evaluate(0)
+    # Check if the generated ir is expected
+    assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_let():
+    with IRBuilder() as ib:
+        with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)):
+            T.evaluate(0)
+    # the let binding generated by IRBuilder
+    let_actual = ib.get()
+
+    # the expected Let statement
+    let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0))
+    assert_structural_equal(let_actual, let_expected, map_free_vars=True)
+
+
+def test_ir_builder_tir_realize():
+    buffer_a = T.buffer_decl((128, 128), "float32")
+    with IRBuilder() as ib:
+        with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True):
+            T.evaluate(0)
+    realize_actual = ib.get()
+
+    # the expected buffer realization
+    buffer_realize = tir.BufferRealize(
+        buffer_a, [tvm.ir.Range(0, 128), tvm.ir.Range(0, 128)], True, tir.Evaluate(0)
+    )
+    expected_realize = tir.AttrStmt(
+        buffer_a, "realize_scope", tir.StringImm("test_storage_scope"), buffer_realize
+    )
+    assert_structural_equal(realize_actual, expected_realize, map_free_vars=True)
+
+
+def test_ir_builder_tir_thread():
+    with IRBuilder() as ib:
+        with T.prim_func():
+            brow = T.env_thread("blockIdx.y")
+            with T.launch_thread(brow, 1):
+                T.evaluate(0)
+    ir_actual = ib.get()
+    iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y")
+    attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0))
+    func = tir.PrimFunc([], attr_stmt)
+    assert_structural_equal(ir_actual, func, map_free_vars=True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()