You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/05 00:48:24 UTC

[tvm] branch main updated: [TIR] Schedule Primitive: Add-Unit-Loop (#11575)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d2c9a7f64 [TIR] Schedule Primitive: Add-Unit-Loop (#11575)
9d2c9a7f64 is described below

commit 9d2c9a7f6457fb98156a722625c95bf3383dec42
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Jun 4 17:48:19 2022 -0700

    [TIR] Schedule Primitive: Add-Unit-Loop (#11575)
    
    In TE, a unit loop could be introduced by fusing an empty list of loops on a stage. This PR adds its counterpart in TIR, while being a bit more explicit with a new schedule primitive which adds a unit loop without impacting any existing functionalities.
---
 include/tvm/tir/schedule/schedule.h                | 12 ++++
 python/tvm/tir/schedule/schedule.py                | 64 ++++++++++++++++++--
 src/tir/schedule/concrete_schedule.cc              | 18 ++++++
 src/tir/schedule/concrete_schedule.h               |  2 +
 src/tir/schedule/primitive.h                       | 10 ++++
 src/tir/schedule/primitive/loop_transformation.cc  | 69 ++++++++++++++++++++++
 src/tir/schedule/schedule.cc                       | 12 ++++
 src/tir/schedule/traced_schedule.cc                | 22 +++++++
 src/tir/schedule/traced_schedule.h                 |  2 +
 .../unittest/test_tir_schedule_split_fuse.py       | 58 ++++++++++++++++++
 10 files changed, 265 insertions(+), 4 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 68900e107d..d3ecd8a113 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -303,6 +303,18 @@ class ScheduleNode : public runtime::Object {
    * \param ordered_loop_rvs The loops in the new order
    */
   virtual void Reorder(const Array<LoopRV>& ordered_loop_rvs) = 0;
+  /*!
+   * \brief Create a new unit loop on top of the specific block.
+   * \param block_rv The block above which the new loop is created
+   * \return The new loop created
+   */
+  virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0;
+  /*!
+   * \brief Create a new unit loop on top of the specific loop.
+   * \param loop_rv The loop above which the new loop is created
+   * \return The new loop created
+   */
+  virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0;
   /******** Schedule: Manipulate ForKind ********/
   /*!
    * \brief Parallelize the input loop. It requires:
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 4179088aa5..d225280b65 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -15,19 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 """The TensorIR schedule class"""
-from typing import Callable, Dict, List, Optional, Union, Tuple
+from typing import Callable, Dict, List, Optional, Tuple, Union
 
 from tvm._ffi import register_object as _register_object
 from tvm.error import TVMError, register_error
 from tvm.ir import IRModule, PrimExpr
 from tvm.runtime import Object, String
-from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer
-from ..function import IndexMap
+from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
 
+from ..function import IndexMap
 from . import _ffi_api
+from ._type_checker import type_checked
 from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
 from .trace import Trace
-from ._type_checker import type_checked
 
 
 @register_error
@@ -685,6 +685,62 @@ class Schedule(Object):
         """
         _ffi_api.ScheduleReorder(self, ordered_loops)  # type: ignore # pylint: disable=no-member
 
+    @type_checked
+    def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV:
+        """Create a new unit loop on top of the specific block or loop.
+
+        Parameters
+        ----------
+        block_or_loop : Union[LoopRV, BlockRV]
+            The block above which the new loop is created
+
+        Returns
+        -------
+        new_loop : LoopRV
+            The new unit loop
+
+        Examples
+        --------
+
+        Before add_unit_loop, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_add_unit_loop(
+                A: T.Buffer[(), "int32"],
+                B: T.Buffer[(), "int32"],
+                C: T.Buffer[(), "int32"],
+            ) -> None:
+                with T.block("C"):
+                    vi = T.axis.spatial(1, 0)
+                    C[()] = A[()] + B[()]
+
+        Create the schedule and do add-unit-loop:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_add_unit_loop)
+            sch.add_unit_loop(sch.get_block("C"))
+            print(sch.mod["main"].script())
+
+        After applying add-unit-loop, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_add_unit_loop(
+                A: T.Buffer[(), "int32"],
+                B: T.Buffer[(), "int32"],
+                C: T.Buffer[(), "int32"],
+            ) -> None:
+                for u in T.serial(1):
+                    with T.block("C"):
+                        vi = T.axis.spatial(1, 0)
+                        C[()] = A[()] + B[()]
+        """
+        return _ffi_api.ScheduleAddUnitLoop(self, block_or_loop)  # type: ignore # pylint: disable=no-member
+
     ########## Schedule: Manipulate ForKind ##########
 
     @type_checked
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 590a0f0025..051bd42506 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -453,6 +453,24 @@ void ConcreteScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) {
   this->state_->DebugVerify();
 }
 
+LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
+  LoopRV result{nullptr};
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(block_rv)));
+  TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
+  this->state_->DebugVerify();
+  return result;
+}
+
+LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
+  LoopRV result{nullptr};
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(loop_rv)));
+  TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
+  this->state_->DebugVerify();
+  return result;
+}
+
 /******** Schedule: Manipulate ForKind ********/
 
 void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) {
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 70c0265611..11d68694a1 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -99,6 +99,8 @@ class ConcreteScheduleNode : public ScheduleNode {
   LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
   Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
   void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
+  LoopRV AddUnitLoop(const BlockRV& block_rv) override;
+  LoopRV AddUnitLoop(const LoopRV& loop_rv) override;
   /******** Schedule: Manipulate ForKind ********/
   void Parallel(const LoopRV& loop_rv) override;
   void Vectorize(const LoopRV& loop_rv) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index f4dba69c6b..af0f417e4c 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -186,6 +186,16 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs);
  */
 TVM_DLL void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs);
 
+/*!
+ * \brief Create a new unit loop on top of the specific block or loop.
+ * \param sref The block/loop above which the new thread_binding loop is created
+ * \param extent The extent of the new thread_binding loop
+ * \param thread_axis The thread axis of the new thread_binding loop
+ * \param attrs Extra loop attributes
+ * \return The new thread_binding loop
+ */
+TVM_DLL StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref);
+
 /******** Schedule: Manipulate ForKind ********/
 /*!
  * \brief Parallelize the input loop. It requires:
diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc
index 5315b139f0..66e29518ca 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -698,6 +698,43 @@ void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
   self->Replace(GetRef<StmtSRef>(top), new_loop, {});
 }
 
+StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) {
+  if (sref->stmt->IsInstance<ForNode>()) {
+    For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<Stmt>(sref->stmt));
+    self->Replace(sref, new_loop, {});
+    return self->stmt2ref.at(new_loop.get());
+  }
+  class NewLoopCreator : public StmtMutator {
+   public:
+    explicit NewLoopCreator(const StmtNode* src_block) : src_block_(src_block) {}
+
+    Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+      if (realize->block.get() == src_block_) {
+        new_loop_ =
+            For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<BlockRealize>(realize));
+        return new_loop_;
+      }
+      return StmtMutator::VisitStmt_(realize);
+    }
+
+    const StmtNode* src_block_;
+    For new_loop_{nullptr};
+  };
+
+  CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block";
+  StmtSRef parent_sref = GetRef<StmtSRef>(sref->parent);
+  NewLoopCreator creator(sref->stmt);
+  Stmt new_stmt = creator(GetRef<Stmt>(parent_sref->stmt));
+  if (new_stmt->IsInstance<ForNode>()) {
+    self->Replace(parent_sref, std::move(new_stmt), {});
+  } else {
+    Block old_parent_block = GetRef<Block>(parent_sref->StmtAs<BlockNode>());
+    Block new_parent_block = Downcast<Block>(new_stmt);
+    self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}});
+  }
+  return self->stmt2ref.at(creator.new_loop_.get());
+}
+
 /******** InstructionKind Registration ********/
 
 struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
@@ -800,9 +837,41 @@ struct ReorderTraits : public UnpackedInstTraits<ReorderTraits> {
   friend struct ::tvm::tir::UnpackedInstTraits;
 };
 
+struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
+  static constexpr const char* kName = "AddUnitLoop";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumDecisions = 0;
+
+  static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) {
+    if (const auto* block = rv.as<BlockRVNode>()) {
+      return sch->AddUnitLoop(GetRef<BlockRV>(block));
+    } else if (const auto* loop = rv.as<LoopRVNode>()) {
+      return sch->AddUnitLoop(GetRef<LoopRV>(loop));
+    } else {
+      LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block";
+      throw;
+    }
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String rv) {
+    PythonAPICall py("add_unit_loop");
+    py.Input("block_or_loop", rv);
+    py.SingleOutput(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
 TVM_REGISTER_INST_KIND_TRAITS(SplitTraits);
 TVM_REGISTER_INST_KIND_TRAITS(FuseTraits);
 TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits);
+TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits);
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 3880d0b19e..372d94a150 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -153,6 +153,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&Sche
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder")
     .set_body_method<Schedule>(&ScheduleNode::Reorder);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop")
+    .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV {
+      if (const auto* loop_rv = rv.as<LoopRVNode>()) {
+        return self->AddUnitLoop(GetRef<LoopRV>(loop_rv));
+      } else if (const auto* block_rv = rv.as<BlockRVNode>()) {
+        return self->AddUnitLoop(GetRef<BlockRV>(block_rv));
+      } else {
+        LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
+                   << ". Its value is: " << rv;
+        throw;
+      }
+    });
 /******** (FFI) Manipulate ForKind ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel")
     .set_body_method<Schedule>(&ScheduleNode::Parallel);
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index d2f627edfd..95a10e26ac 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -198,6 +198,28 @@ void TracedScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) {
                                       /*outputs=*/{}));
 }
 
+LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
+  LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv);
+
+  static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+                                      /*inputs=*/{block_rv},
+                                      /*attrs=*/{},
+                                      /*outputs=*/{result}));
+  return result;
+}
+
+LoopRV TracedScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
+  LoopRV result = ConcreteScheduleNode::AddUnitLoop(loop_rv);
+
+  static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+                                      /*inputs=*/{loop_rv},
+                                      /*attrs=*/{},
+                                      /*outputs=*/{result}));
+  return result;
+}
+
 /******** Schedule: Manipulate ForKind ********/
 
 void TracedScheduleNode::Parallel(const LoopRV& loop_rv) {
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index ba4a4b99cb..25bf3d4871 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -63,6 +63,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
   Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
   void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
+  LoopRV AddUnitLoop(const BlockRV& block_rv) final;
+  LoopRV AddUnitLoop(const LoopRV& loop_rv) final;
   /******** Schedule: Manipulate ForKind ********/
   void Parallel(const LoopRV& loop_rv) final;
   void Vectorize(const LoopRV& loop_rv) final;
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py
index 16eef57c47..d70748bc8a 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -524,5 +524,63 @@ def test_fuse_not_affine():
     verify_trace_roundtrip(sch=sch, mod=elementwise_not_affine)
 
 
+def test_add_unit_loop_above_block():
+    @T.prim_func
+    def zero_dim(
+        A: T.Buffer[(), "int32"],
+        B: T.Buffer[(), "int32"],
+        C: T.Buffer[(), "int32"],
+    ) -> None:
+        with T.block("C"):
+            vi = T.axis.spatial(1, 0)
+            C[()] = A[()] + B[()]
+
+    @T.prim_func
+    def zero_dim_added(
+        A: T.Buffer[(), "int32"],
+        B: T.Buffer[(), "int32"],
+        C: T.Buffer[(), "int32"],
+    ) -> None:
+        for u in range(1):
+            with T.block("C"):
+                vi = T.axis.spatial(1, 0)
+                C[()] = A[()] + B[()]
+
+    sch = tir.Schedule(zero_dim, debug_mask="all")
+    block = sch.get_block("C")
+    sch.add_unit_loop(block)
+    tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"])
+
+
+def test_add_unit_loop_above_loop():
+    @T.prim_func
+    def zero_dim(
+        A: T.Buffer[(), "int32"],
+        B: T.Buffer[(), "int32"],
+        C: T.Buffer[(), "int32"],
+    ) -> None:
+        for u in range(1):
+            with T.block("C"):
+                vi = T.axis.spatial(1, 0)
+                C[()] = A[()] + B[()]
+
+    @T.prim_func
+    def zero_dim_added(
+        A: T.Buffer[(), "int32"],
+        B: T.Buffer[(), "int32"],
+        C: T.Buffer[(), "int32"],
+    ) -> None:
+        for u1, u2 in T.grid(1, 1):
+            with T.block("C"):
+                vi = T.axis.spatial(1, 0)
+                C[()] = A[()] + B[()]
+
+    sch = tir.Schedule(zero_dim, debug_mask="all")
+    block = sch.get_block("C")
+    (loop,) = sch.get_loops(block)
+    sch.add_unit_loop(loop)
+    tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()