You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/05 00:48:24 UTC
[tvm] branch main updated: [TIR] Schedule Primitive: Add-Unit-Loop (#11575)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9d2c9a7f64 [TIR] Schedule Primitive: Add-Unit-Loop (#11575)
9d2c9a7f64 is described below
commit 9d2c9a7f6457fb98156a722625c95bf3383dec42
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Jun 4 17:48:19 2022 -0700
[TIR] Schedule Primitive: Add-Unit-Loop (#11575)
In TE, a unit loop could be introduced by fusing an empty list of loops on a stage. This PR adds its counterpart in TIR, while being a bit more explicit with a new schedule primitive which adds a unit loop without impacting any existing functionalities.
---
include/tvm/tir/schedule/schedule.h | 12 ++++
python/tvm/tir/schedule/schedule.py | 64 ++++++++++++++++++--
src/tir/schedule/concrete_schedule.cc | 18 ++++++
src/tir/schedule/concrete_schedule.h | 2 +
src/tir/schedule/primitive.h | 10 ++++
src/tir/schedule/primitive/loop_transformation.cc | 69 ++++++++++++++++++++++
src/tir/schedule/schedule.cc | 12 ++++
src/tir/schedule/traced_schedule.cc | 22 +++++++
src/tir/schedule/traced_schedule.h | 2 +
.../unittest/test_tir_schedule_split_fuse.py | 58 ++++++++++++++++++
10 files changed, 265 insertions(+), 4 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 68900e107d..d3ecd8a113 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -303,6 +303,18 @@ class ScheduleNode : public runtime::Object {
* \param ordered_loop_rvs The loops in the new order
*/
virtual void Reorder(const Array<LoopRV>& ordered_loop_rvs) = 0;
+ /*!
+ * \brief Create a new unit loop on top of the specific block.
+ * \param block_rv The block above which the new loop is created
+ * \return The new loop created
+ */
+ virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0;
+ /*!
+ * \brief Create a new unit loop on top of the specific loop.
+ * \param loop_rv The loop above which the new loop is created
+ * \return The new loop created
+ */
+ virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0;
/******** Schedule: Manipulate ForKind ********/
/*!
* \brief Parallelize the input loop. It requires:
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 4179088aa5..d225280b65 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -15,19 +15,19 @@
# specific language governing permissions and limitations
# under the License.
"""The TensorIR schedule class"""
-from typing import Callable, Dict, List, Optional, Union, Tuple
+from typing import Callable, Dict, List, Optional, Tuple, Union
from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
-from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer
-from ..function import IndexMap
+from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
+from ..function import IndexMap
from . import _ffi_api
+from ._type_checker import type_checked
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
from .trace import Trace
-from ._type_checker import type_checked
@register_error
@@ -685,6 +685,62 @@ class Schedule(Object):
"""
_ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member
+ @type_checked
+ def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV:
+ """Create a new unit loop on top of the specific block or loop.
+
+ Parameters
+ ----------
+ block_or_loop : Union[LoopRV, BlockRV]
+ The block above which the new loop is created
+
+ Returns
+ -------
+ new_loop : LoopRV
+ The new unit loop
+
+ Examples
+ --------
+
+ Before add_unit_loop, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_add_unit_loop(
+ A: T.Buffer[(), "int32"],
+ B: T.Buffer[(), "int32"],
+ C: T.Buffer[(), "int32"],
+ ) -> None:
+ with T.block("C"):
+ vi = T.axis.spatial(1, 0)
+ C[()] = A[()] + B[()]
+
+ Create the schedule and do add-unit-loop:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_add_unit_loop)
+ sch.add_unit_loop(sch.get_block("C"))
+ print(sch.mod["main"].script())
+
+ After applying add-unit-loop, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_add_unit_loop(
+ A: T.Buffer[(), "int32"],
+ B: T.Buffer[(), "int32"],
+ C: T.Buffer[(), "int32"],
+ ) -> None:
+ for u in T.serial(1):
+ with T.block("C"):
+ vi = T.axis.spatial(1, 0)
+ C[()] = A[()] + B[()]
+ """
+ return _ffi_api.ScheduleAddUnitLoop(self, block_or_loop) # type: ignore # pylint: disable=no-member
+
########## Schedule: Manipulate ForKind ##########
@type_checked
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 590a0f0025..051bd42506 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -453,6 +453,24 @@ void ConcreteScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) {
this->state_->DebugVerify();
}
+LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
+ LoopRV result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(block_rv)));
+ TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
+ this->state_->DebugVerify();
+ return result;
+}
+
+LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
+ LoopRV result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(loop_rv)));
+ TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
+ this->state_->DebugVerify();
+ return result;
+}
+
/******** Schedule: Manipulate ForKind ********/
void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) {
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 70c0265611..11d68694a1 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -99,6 +99,8 @@ class ConcreteScheduleNode : public ScheduleNode {
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
+ LoopRV AddUnitLoop(const BlockRV& block_rv) override;
+ LoopRV AddUnitLoop(const LoopRV& loop_rv) override;
/******** Schedule: Manipulate ForKind ********/
void Parallel(const LoopRV& loop_rv) override;
void Vectorize(const LoopRV& loop_rv) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index f4dba69c6b..af0f417e4c 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -186,6 +186,16 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs);
*/
TVM_DLL void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs);
+/*!
+ * \brief Create a new unit loop on top of the specific block or loop.
+ * \param sref The block/loop above which the new thread_binding loop is created
+ * \param extent The extent of the new thread_binding loop
+ * \param thread_axis The thread axis of the new thread_binding loop
+ * \param attrs Extra loop attributes
+ * \return The new thread_binding loop
+ */
+TVM_DLL StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref);
+
/******** Schedule: Manipulate ForKind ********/
/*!
* \brief Parallelize the input loop. It requires:
diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc
index 5315b139f0..66e29518ca 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -698,6 +698,43 @@ void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
self->Replace(GetRef<StmtSRef>(top), new_loop, {});
}
+StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) {
+ if (sref->stmt->IsInstance<ForNode>()) {
+ For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<Stmt>(sref->stmt));
+ self->Replace(sref, new_loop, {});
+ return self->stmt2ref.at(new_loop.get());
+ }
+ class NewLoopCreator : public StmtMutator {
+ public:
+ explicit NewLoopCreator(const StmtNode* src_block) : src_block_(src_block) {}
+
+ Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+ if (realize->block.get() == src_block_) {
+ new_loop_ =
+ For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<BlockRealize>(realize));
+ return new_loop_;
+ }
+ return StmtMutator::VisitStmt_(realize);
+ }
+
+ const StmtNode* src_block_;
+ For new_loop_{nullptr};
+ };
+
+ CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block";
+ StmtSRef parent_sref = GetRef<StmtSRef>(sref->parent);
+ NewLoopCreator creator(sref->stmt);
+ Stmt new_stmt = creator(GetRef<Stmt>(parent_sref->stmt));
+ if (new_stmt->IsInstance<ForNode>()) {
+ self->Replace(parent_sref, std::move(new_stmt), {});
+ } else {
+ Block old_parent_block = GetRef<Block>(parent_sref->StmtAs<BlockNode>());
+ Block new_parent_block = Downcast<Block>(new_stmt);
+ self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}});
+ }
+ return self->stmt2ref.at(creator.new_loop_.get());
+}
+
/******** InstructionKind Registration ********/
struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
@@ -800,9 +837,41 @@ struct ReorderTraits : public UnpackedInstTraits<ReorderTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};
+struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
+ static constexpr const char* kName = "AddUnitLoop";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 0;
+ static constexpr size_t kNumDecisions = 0;
+
+ static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) {
+ if (const auto* block = rv.as<BlockRVNode>()) {
+ return sch->AddUnitLoop(GetRef<BlockRV>(block));
+ } else if (const auto* loop = rv.as<LoopRVNode>()) {
+ return sch->AddUnitLoop(GetRef<LoopRV>(loop));
+ } else {
+ LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block";
+ throw;
+ }
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String rv) {
+ PythonAPICall py("add_unit_loop");
+ py.Input("block_or_loop", rv);
+ py.SingleOutput(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
TVM_REGISTER_INST_KIND_TRAITS(SplitTraits);
TVM_REGISTER_INST_KIND_TRAITS(FuseTraits);
TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits);
+TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 3880d0b19e..372d94a150 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -153,6 +153,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&Sche
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder")
.set_body_method<Schedule>(&ScheduleNode::Reorder);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop")
+ .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV {
+ if (const auto* loop_rv = rv.as<LoopRVNode>()) {
+ return self->AddUnitLoop(GetRef<LoopRV>(loop_rv));
+ } else if (const auto* block_rv = rv.as<BlockRVNode>()) {
+ return self->AddUnitLoop(GetRef<BlockRV>(block_rv));
+ } else {
+ LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
+ << ". Its value is: " << rv;
+ throw;
+ }
+ });
/******** (FFI) Manipulate ForKind ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel")
.set_body_method<Schedule>(&ScheduleNode::Parallel);
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index d2f627edfd..95a10e26ac 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -198,6 +198,28 @@ void TracedScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) {
/*outputs=*/{}));
}
+LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
+ LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv);
+
+ static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{},
+ /*outputs=*/{result}));
+ return result;
+}
+
+LoopRV TracedScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
+ LoopRV result = ConcreteScheduleNode::AddUnitLoop(loop_rv);
+
+ static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+ /*inputs=*/{loop_rv},
+ /*attrs=*/{},
+ /*outputs=*/{result}));
+ return result;
+}
+
/******** Schedule: Manipulate ForKind ********/
void TracedScheduleNode::Parallel(const LoopRV& loop_rv) {
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index ba4a4b99cb..25bf3d4871 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -63,6 +63,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
+ LoopRV AddUnitLoop(const BlockRV& block_rv) final;
+ LoopRV AddUnitLoop(const LoopRV& loop_rv) final;
/******** Schedule: Manipulate ForKind ********/
void Parallel(const LoopRV& loop_rv) final;
void Vectorize(const LoopRV& loop_rv) final;
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py
index 16eef57c47..d70748bc8a 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -524,5 +524,63 @@ def test_fuse_not_affine():
verify_trace_roundtrip(sch=sch, mod=elementwise_not_affine)
+def test_add_unit_loop_above_block():
+ @T.prim_func
+ def zero_dim(
+ A: T.Buffer[(), "int32"],
+ B: T.Buffer[(), "int32"],
+ C: T.Buffer[(), "int32"],
+ ) -> None:
+ with T.block("C"):
+ vi = T.axis.spatial(1, 0)
+ C[()] = A[()] + B[()]
+
+ @T.prim_func
+ def zero_dim_added(
+ A: T.Buffer[(), "int32"],
+ B: T.Buffer[(), "int32"],
+ C: T.Buffer[(), "int32"],
+ ) -> None:
+ for u in range(1):
+ with T.block("C"):
+ vi = T.axis.spatial(1, 0)
+ C[()] = A[()] + B[()]
+
+ sch = tir.Schedule(zero_dim, debug_mask="all")
+ block = sch.get_block("C")
+ sch.add_unit_loop(block)
+ tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"])
+
+
+def test_add_unit_loop_above_loop():
+ @T.prim_func
+ def zero_dim(
+ A: T.Buffer[(), "int32"],
+ B: T.Buffer[(), "int32"],
+ C: T.Buffer[(), "int32"],
+ ) -> None:
+ for u in range(1):
+ with T.block("C"):
+ vi = T.axis.spatial(1, 0)
+ C[()] = A[()] + B[()]
+
+ @T.prim_func
+ def zero_dim_added(
+ A: T.Buffer[(), "int32"],
+ B: T.Buffer[(), "int32"],
+ C: T.Buffer[(), "int32"],
+ ) -> None:
+ for u1, u2 in T.grid(1, 1):
+ with T.block("C"):
+ vi = T.axis.spatial(1, 0)
+ C[()] = A[()] + B[()]
+
+ sch = tir.Schedule(zero_dim, debug_mask="all")
+ block = sch.get_block("C")
+ (loop,) = sch.get_loops(block)
+ sch.add_unit_loop(loop)
+ tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"])
+
+
if __name__ == "__main__":
tvm.testing.main()