You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/08/26 11:43:07 UTC
[tvm] branch main updated: [TIR][Schedule] enhance compute_at and reverse_compute_at primitive to choose possible position (#12450)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e02f2f9fdd [TIR][Schedule] enhance compute_at and reverse_compute_at primitive to choose possible position (#12450)
e02f2f9fdd is described below
commit e02f2f9fddd8cd38589e3569c41de9f7af39971c
Author: yin.changsheng <yi...@intellif.com>
AuthorDate: Fri Aug 26 19:42:57 2022 +0800
[TIR][Schedule] enhance compute_at and reverse_compute_at primitive to choose possible position (#12450)
Current TIR "compute_at" primitive will compute at it's closest consumers. When a block has multiple producers, whoever compute at later who is behind. But for some special hardware, we usually hope keep the a certain order whatever it's compute at early or late.
eg: block A and block B are producers of block C. block A compute at block C first and block B compute at block C later. We hope the result is block B->block A->block C under some loop var.
---
include/tvm/tir/schedule/schedule.h | 14 +-
python/tvm/tir/schedule/schedule.py | 16 +++
src/tir/schedule/concrete_schedule.cc | 8 +-
src/tir/schedule/concrete_schedule.h | 7 +-
src/tir/schedule/primitive.h | 13 +-
src/tir/schedule/primitive/compute_at.cc | 67 ++++++---
src/tir/schedule/traced_schedule.cc | 19 +--
src/tir/schedule/traced_schedule.h | 7 +-
...chedule_schedule_rule_cross_thread_reduction.py | 16 +--
...ta_schedule_schedule_rule_multi_level_tiling.py | 86 ++++++------
...hedule_schedule_rule_random_compute_location.py | 2 +-
.../unittest/test_tir_schedule_compute_at.py | 152 +++++++++++++++++++++
12 files changed, 308 insertions(+), 99 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 11fec642c7..da399ab976 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -432,9 +432,13 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The block to be moved
* \param loop_rv The loop where the block to be moved under
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+ * \param index The block index of the loop body subtree blocks:
+ * - `index = -1` means inserted into the last possible insertion point;
+ * - `index = -2` means inserted into the first possible insertion point;
+ * - Otherwise, `index` is a nonnegative number that indicates the insertion point
*/
- virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
- bool preserve_unit_loops) = 0;
+ virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+ int index = -1) = 0;
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
@@ -449,9 +453,13 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The block to be moved
* \param loop_rv The loop where the block to be moved under
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+ * \param index The block index of the loop body subtree blocks:
+ * - `index = -1` means inserted into the last possible insertion point;
+ * - `index = -2` means inserted into the first possible insertion point;
+ * - Otherwise, `index` is a nonnegative number that indicates the insertion point
*/
virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
- bool preserve_unit_loops) = 0;
+ bool preserve_unit_loops, int index = -1) = 0;
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index e18bee35a5..04cc1bc26a 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1274,6 +1274,7 @@ class Schedule(Object):
block: Union[BlockRV, str],
loop: LoopRV,
preserve_unit_loops: bool = False,
+ index: int = -1,
) -> None:
"""Compute-At. Move a producer block under the specific loop, and regenerate the
loops induced by the block so that the buffer region produced by the producer block could
@@ -1303,6 +1304,12 @@ class Schedule(Object):
preserve_unit_loops: bool
Whether to keep the trivial loops whose extents are 1
+ index: int
+ The block index of the loop body subtree blocks:
+ - `index = -1` means inserted into the last possible insertion point;
+ - `index = -2` means inserted into the first possible insertion point;
+ - Otherwise, `index` is a nonnegative number that indicates the insertion point
+
Examples
--------
@@ -1360,6 +1367,7 @@ class Schedule(Object):
block,
loop,
preserve_unit_loops,
+ index,
)
@type_checked
@@ -1368,6 +1376,7 @@ class Schedule(Object):
block: Union[BlockRV, str],
loop: LoopRV,
preserve_unit_loops: bool = False,
+ index: int = -1,
) -> None:
"""Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the
loops induced by the block so that the buffer region consumed by the consumer block could
@@ -1394,6 +1403,12 @@ class Schedule(Object):
preserve_unit_loops: bool
Whether to keep the trivial loops whose extents are 1
+ index: int
+ The block index of the loop body subtree blocks:
+ - `index = -1` means inserted into the last possible insertion point;
+ - `index = -2` means inserted into the first possible insertion point;
+ - Otherwise, `index` is a nonnegative number that indicates the insertion point
+
Examples
--------
@@ -1451,6 +1466,7 @@ class Schedule(Object):
block,
loop,
preserve_unit_loops,
+ index,
)
@type_checked
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index c16638f748..5f773a02d6 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -574,7 +574,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
/******** Schedule: Compute location ********/
void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
- bool preserve_unit_loops) {
+ bool preserve_unit_loops, int index) {
static StmtSRef inline_mark = StmtSRef::InlineMark();
static StmtSRef root_mark = StmtSRef::RootMark();
StmtSRef loop_sref = this->GetSRef(loop_rv);
@@ -586,14 +586,14 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
} else {
TVM_TIR_SCHEDULE_BEGIN();
- tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
+ tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index);
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
}
this->state_->DebugVerify();
}
void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
- bool preserve_unit_loops) {
+ bool preserve_unit_loops, int index) {
static StmtSRef inline_mark = StmtSRef::InlineMark();
static StmtSRef root_mark = StmtSRef::RootMark();
StmtSRef loop_sref = this->GetSRef(loop_rv);
@@ -605,7 +605,7 @@ void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopR
TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
} else {
TVM_TIR_SCHEDULE_BEGIN();
- tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
+ tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index);
TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
}
this->state_->DebugVerify();
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index cdd0a5b7b0..92b9de4088 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -119,9 +119,10 @@ class ConcreteScheduleNode : public ScheduleNode {
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
- void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
- void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
- bool preserve_unit_loops) override;
+ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+ int index = -1) override;
+ void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+ int index = -1) override;
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 14203a0d16..05d9e4cf94 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -299,10 +299,13 @@ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buf
* \param self The schedule state
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
- * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+ * \param index The block index of the loop body subtree blocks:
+ * - `index = -1` means inserted into the last possible insertion point;
+ * - `index = -2` means inserted into the first possible insertion point;
+ * - Otherwise, `index` is a nonnegative number that indicates the insertion point
*/
TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
- bool preserve_unit_loops);
+ bool preserve_unit_loops, int index = -1);
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
@@ -318,9 +321,13 @@ TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stm
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+ * \param index The block index of the loop body subtree blocks:
+ * - `index = -1` means inserted into the last possible insertion point;
+ * - `index = -2` means inserted into the first possible insertion point;
+ * - Otherwise, `index` is a nonnegative number that indicates the insertion point
*/
TVM_DLL void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref,
- const StmtSRef& loop_sref, bool preserve_unit_loops);
+ const StmtSRef& loop_sref, bool preserve_unit_loops, int index = -1);
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc
index 98a6b2400e..8baedfd70d 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -129,15 +129,19 @@ class NotInSameScopeError : public ScheduleError {
* \param producer_srefs The producer blocks
* \param consumer_srefs The consumer blocks
* \param block2realize A cache that maps a block to its realize
- * \return The last position the new block can be inserted onto, and the
+ * \param index The block index of the loop body subtree blocks:
+ * - `index = -1` means inserted into the last possible insertion point;
+ * - `index = -2` means inserted into the first possible insertion point;
+ * - Otherwise, `index` is a nonnegative number that indicates the insertion point
+ * \return The possible position the new block can be inserted into, and the
* producer-consumer-relationship is still satisfied.
* \throws ScheduleError if there is no such insertion point found
*/
template <bool require_all_producers_visited, bool require_all_consumers_visited>
-int FindInsertionPoint(
- const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
- const Array<StmtSRef>& consumer_srefs,
- std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
+int FindInsertionPoint(const ScheduleState& self, const Array<Stmt>& subtrees,
+ const Array<StmtSRef>& producer_srefs, const Array<StmtSRef>& consumer_srefs,
+ std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize,
+ int index) {
ProducerConsumerSplit split =
ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
// Step 1. Check if all the producers are visited in the subtrees, if required to
@@ -159,8 +163,22 @@ int FindInsertionPoint(
// Step 3. Check if there is at least one index of the position can be inserted into
// The valid indices are: (last_producer_position, first_consumer_position]
ICHECK(split.last_producer_position < split.first_consumer_position);
- // Step 4. Return the last valid insertion point
- return split.first_consumer_position;
+ // Step 4. Return the possible insertion point according to index
+ int insert_position;
+ if (index == -1) {
+ insert_position = split.first_consumer_position;
+ } else if (index == -2) {
+ insert_position = split.last_producer_position + 1;
+ } else if (index >= 0 && index >= split.last_producer_position + 1 &&
+ index <= split.first_consumer_position) {
+ insert_position = index;
+ } else {
+ LOG(FATAL) << "Valid index:(-1, -2, [" << split.last_producer_position + 1 << ", "
+ << split.first_consumer_position << "]), "
+ << "current index=" << index;
+ throw;
+ }
+ return insert_position;
}
/*!
@@ -556,7 +574,8 @@ void CalculateProvidedRequiredRegions(
template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops,
- arith::Analyzer* analyzer, bool check_only = false) {
+ arith::Analyzer* analyzer, bool check_only = false,
+ int index = -1) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
@@ -588,7 +607,8 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
/*self=*/self,
/*subtrees=*/AsArray(loop->body),
/*producer_srefs=*/producer_srefs,
- /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize);
+ /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize,
+ /*index=*/index);
// Step 4. Calculate the region provided by a single execution instance of `block`,
// as well as the region required by dependent blocks under `loop`.
// Here is the definition of `provide` and `require`:
@@ -626,17 +646,17 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
}
void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
- bool preserve_unit_loops) {
+ bool preserve_unit_loops, int index) {
arith::Analyzer analyzer;
- ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
- &analyzer);
+ ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, &analyzer,
+ false, index);
}
void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
- bool preserve_unit_loops) {
+ bool preserve_unit_loops, int index) {
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
- &analyzer);
+ &analyzer, false, index);
}
bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
@@ -671,20 +691,21 @@ struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {
private:
static constexpr size_t kNumInputs = 2;
- static constexpr size_t kNumAttrs = 1;
+ static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
- Bool preserve_unit_loops) {
- return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
+ Bool preserve_unit_loops, IntImm index) {
+ return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value);
}
static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
- Bool preserve_unit_loops) {
+ Bool preserve_unit_loops, IntImm index) {
PythonAPICall py("compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
+ py.Input("index", index);
return py.Str();
}
@@ -698,20 +719,22 @@ struct ReverseComputeAtTraits : public UnpackedInstTraits<ReverseComputeAtTraits
private:
static constexpr size_t kNumInputs = 2;
- static constexpr size_t kNumAttrs = 1;
+ static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
- Bool preserve_unit_loops) {
- return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
+ Bool preserve_unit_loops, IntImm index) {
+ return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
+ index->value);
}
static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
- Bool preserve_unit_loops) {
+ Bool preserve_unit_loops, IntImm index) {
PythonAPICall py("reverse_compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
+ py.Input("index", index);
return py.Str();
}
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 07d4da54d7..04ddc0507d 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -322,24 +322,25 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
/******** Schedule: Compute location ********/
void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
- bool preserve_unit_loops) {
- ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops);
+ bool preserve_unit_loops, int index) {
+ ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, index);
static const InstructionKind& kind = InstructionKind::Get("ComputeAt");
- trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
- /*inputs=*/{block_rv, loop_rv},
- /*attrs=*/{Integer(preserve_unit_loops)},
- /*outputs=*/{}));
+ trace_->Append(
+ /*inst=*/Instruction(/*kind=*/kind,
+ /*inputs=*/{block_rv, loop_rv},
+ /*attrs=*/{Integer(preserve_unit_loops), Integer(index)},
+ /*outputs=*/{}));
}
void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
- bool preserve_unit_loops) {
- ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops);
+ bool preserve_unit_loops, int index) {
+ ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops, index);
static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
- /*attrs=*/{Integer(preserve_unit_loops)},
+ /*attrs=*/{Integer(preserve_unit_loops), Integer(index)},
/*outputs=*/{}));
}
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index 865a216879..d98e4ba4bb 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -79,9 +79,10 @@ class TracedScheduleNode : public ConcreteScheduleNode {
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) final;
/******** Schedule: Compute location ********/
- void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final;
- void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
- bool preserve_unit_loops) final;
+ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+ int index = -1) final;
+ void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+ int index = -1) final;
void ComputeInline(const BlockRV& block_rv) final;
void ReverseComputeInline(const BlockRV& block_rv) final;
/******** Schedule: Reduction ********/
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index 5f76e77592..592d32d624 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -80,7 +80,7 @@ def test_gpu_softmax_mn():
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
+ "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
"l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
@@ -93,7 +93,7 @@ def test_gpu_softmax_mn():
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
+ "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
"l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
@@ -107,7 +107,7 @@ def test_gpu_softmax_mn():
"v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
+ "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)",
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
"l8, l9, l10 = sch.get_loops(block=b1)",
"l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
@@ -117,7 +117,7 @@ def test_gpu_softmax_mn():
"v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l17, l18 = sch.split(loop=l15, factors=[None, v16], preserve_unit_iters=True)",
'sch.bind(loop=l18, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)",
+ "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True, index=-1)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l19, l20, l21 = sch.get_loops(block=b0)",
"l22, l23 = sch.split(loop=l21, factors=[None, v16], preserve_unit_iters=True)",
@@ -157,7 +157,7 @@ def test_gpu_softmax_mn_after_inline():
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
+ "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
"l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
@@ -171,14 +171,14 @@ def test_gpu_softmax_mn_after_inline():
"v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
+ "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)",
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
"l8, l9, l10 = sch.get_loops(block=b1)",
"l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
'sch.bind(loop=l12, thread_axis="threadIdx.x")',
"b13, b14 = sch.get_consumers(block=b0)",
"l15, l16, l17, l18 = sch.get_loops(block=b13)",
- "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)",
+ "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True, index=-1)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l19, l20, l21 = sch.get_loops(block=b0)",
"l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)",
@@ -206,7 +206,7 @@ def test_gpu_batch_norm_bmn():
"v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)",
'sch.bind(loop=l5, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)",
+ "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True, index=-1)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l6, l7, l8, l9 = sch.get_loops(block=b0)",
"l10 = sch.fuse(l8, l9, preserve_unit_iters=True)",
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 87159fcb31..fe1220c509 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -62,7 +62,7 @@ def test_cpu_matmul():
"l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
- "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
+ "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)",
],
[
'b0 = sch.get_block(name="C", func_name="main")',
@@ -76,7 +76,7 @@ def test_cpu_matmul():
"l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
- "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
+ "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)",
],
[
'b0 = sch.get_block(name="C", func_name="main")',
@@ -123,7 +123,7 @@ def test_cpu_matmul_relu():
"l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
"b24, = sch.get_consumers(block=b0)",
- "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
+ "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)",
],
[
'b0 = sch.get_block(name="C", func_name="main")',
@@ -137,7 +137,7 @@ def test_cpu_matmul_relu():
"l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
"b24, = sch.get_consumers(block=b0)",
- "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
+ "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)",
],
[
'b0 = sch.get_block(name="C", func_name="main")',
@@ -193,15 +193,15 @@ def test_cuda_matmul():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)',
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)',
'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
- "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
+ "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)",
'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
- "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
+ "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)",
"l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
"l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
"v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
- "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
+ "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)",
"l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
"l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
"v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
@@ -247,15 +247,15 @@ def test_cuda_matmul_relu():
"l32 = sch.fuse(l11, l21, preserve_unit_iters=True)",
'sch.bind(loop=l32, thread_axis="threadIdx.x")',
'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
- "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
+ "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)",
'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
- "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
+ "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)",
"l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
"l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
"v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
- "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
+ "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)",
"l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
"l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
"v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
@@ -402,7 +402,7 @@ v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
-sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split(
+sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True, index=-1)""".split(
"\n"
),
"""b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
@@ -437,7 +437,7 @@ v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
-sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split(
+sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True, index=-1)""".split(
"\n"
),
"""b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
@@ -546,15 +546,15 @@ sch.bind(loop=l37, thread_axis="vthread.x")
l38 = sch.fuse(l17, l27, preserve_unit_iters=True)
sch.bind(loop=l38, thread_axis="threadIdx.x")
b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local")
-sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True, index=-1)
b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True)
+sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True, index=-1)
l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40)
l47 = sch.fuse(l45, l46, preserve_unit_iters=True)
v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48)
b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True)
+sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True, index=-1)
l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49)
l56 = sch.fuse(l54, l55, preserve_unit_iters=True)
v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
@@ -632,9 +632,9 @@ sch.bind(loop=l51, thread_axis="blockIdx.x")
l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
sch.bind(loop=l52, thread_axis="threadIdx.y")
b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1)
b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1)
v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
sch.reverse_compute_inline(block=b2)
@@ -646,19 +646,19 @@ sch.reorder(l70, l64, l62)
b72 = sch.blockize(loop=l64)
sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1)
l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1)
l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
+sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1)
l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
@@ -667,7 +667,7 @@ sch.reorder(l110, l102, l100)
b112 = sch.blockize(loop=l102)
sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True)
+sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1)
l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
@@ -772,9 +772,9 @@ sch.bind(loop=l51, thread_axis="blockIdx.x")
l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
sch.bind(loop=l52, thread_axis="threadIdx.y")
b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1)
b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1)
v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
sch.reverse_compute_inline(block=b2)
@@ -786,19 +786,19 @@ sch.reorder(l70, l64, l62)
b72 = sch.blockize(loop=l64)
sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1)
l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1)
l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
+sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1)
l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
@@ -807,7 +807,7 @@ sch.reorder(l110, l102, l100)
b112 = sch.blockize(loop=l102)
sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True)
+sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1)
l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
@@ -895,7 +895,7 @@ sch.bind(loop=l50, thread_axis="blockIdx.x")
l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
sch.bind(loop=l51, thread_axis="threadIdx.y")
b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1)
sch.reverse_compute_inline(block=b1)
l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
@@ -905,19 +905,19 @@ sch.reorder(l67, l61, l59)
b69 = sch.blockize(loop=l61)
sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
+sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1)
l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
+sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1)
l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1)
l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
@@ -926,7 +926,7 @@ sch.reorder(l107, l99, l97)
b109 = sch.blockize(loop=l99)
sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1)
l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
@@ -995,7 +995,7 @@ sch.bind(loop=l50, thread_axis="blockIdx.x")
l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
sch.bind(loop=l51, thread_axis="threadIdx.y")
b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1)
sch.reverse_compute_inline(block=b1)
l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
@@ -1005,19 +1005,19 @@ sch.reorder(l67, l61, l59)
b69 = sch.blockize(loop=l61)
sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
+sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1)
l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
+sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1)
l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1)
l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
@@ -1026,7 +1026,7 @@ sch.reorder(l107, l99, l97)
b109 = sch.blockize(loop=l99)
sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1)
l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
@@ -1133,9 +1133,9 @@ sch.bind(loop=l63, thread_axis="blockIdx.x")
l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True)
sch.bind(loop=l64, thread_axis="threadIdx.y")
b65 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True, index=-1)
b66 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True, index=-1)
v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch", ann_val=v67)
sch.reverse_compute_inline(block=b1)
@@ -1147,19 +1147,19 @@ sch.reorder(l82, l76, l74)
b84 = sch.blockize(loop=l76)
sch.annotate(block_or_loop=b84, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True)
+sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True, index=-1)
l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85)
l92 = sch.fuse(l90, l91, preserve_unit_iters=True)
v93 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93)
b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True)
+sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True, index=-1)
l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94)
l101 = sch.fuse(l99, l100, preserve_unit_iters=True)
v102 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102)
b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True)
+sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True, index=-1)
l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b103)
l111, l112 = sch.split(loop=l110, factors=[None, 16], preserve_unit_iters=True)
l113, l114 = sch.split(loop=l109, factors=[None, 16], preserve_unit_iters=True)
@@ -1168,7 +1168,7 @@ sch.reorder(l122, l114, l112)
b124 = sch.blockize(loop=l114)
sch.annotate(block_or_loop=b124, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
b125 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True)
+sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True, index=-1)
l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b125)
l133, l134 = sch.split(loop=l132, factors=[None, 16], preserve_unit_iters=True)
l135, l136 = sch.split(loop=l131, factors=[None, 16], preserve_unit_iters=True)
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
index b2df408e9d..c951a5adf3 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
@@ -71,7 +71,7 @@ def test_random_compute_location():
[
'b0 = sch.get_block(name="move", func_name="main")',
"l1 = sch.sample_compute_location(block=b0)",
- "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)",
+ "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True, index=-1)",
]
]
mod = Add
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py
index 0c20a4783c..72cba1a8fd 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1353,5 +1353,157 @@ def test_compute_at_int64_loop(use_block_name):
verify_trace_roundtrip(sch=sch, mod=mod)
+def test_compute_at_to_index():
+ @T.prim_func
+ def multi_producers_conv(
+ data: T.Buffer[(1, 3, 224, 224), "int8"],
+ w: T.Buffer[(16, 3, 7, 7), "int8"],
+ conv: T.Buffer[(1, 16, 112, 112), "int32"],
+ ) -> None:
+ pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8")
+ wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8")
+ for i0, i1, i2, i3 in T.grid(1, 3, 230, 230):
+ with T.block("pad"):
+ i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3])
+ T.writes(pad[i0_1, i1_1, i2_1, i3_1])
+ pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+ 3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227,
+ data[i0_1, i1_1, i2_1 - 3, i3_1 - 3],
+ T.int8(0),
+ dtype="int8",
+ )
+ for i0 in T.serial(1):
+ for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7):
+ with T.block("wbuf"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w[v0, v1, v2, v3])
+ T.writes(wbuf[v0, v1, v2, v3])
+ wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3]
+ for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7):
+ with T.block("conv"):
+ nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
+ "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]
+ )
+ T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx])
+ T.writes(conv[nn, ff, yy, xx])
+ with T.init():
+ conv[nn, ff, yy, xx] = 0
+ conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast(
+ pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32"
+ ) * T.cast(wbuf[ff, rc, ry, rx], "int32")
+
+ @T.prim_func
+ def multi_producers_after_compute_at(
+ data: T.Buffer[(1, 3, 224, 224), "int8"],
+ w: T.Buffer[(16, 3, 7, 7), "int8"],
+ conv: T.Buffer[(1, 16, 112, 112), "int32"],
+ ) -> None:
+ pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8")
+ wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8")
+ for i0 in T.serial(1):
+ for ax0, ax1, ax2 in T.grid(3, 229, 229):
+ with T.block("pad"):
+ i0_1 = T.axis.spatial(1, 0)
+ i1_1 = T.axis.spatial(3, ax0)
+ i2_1 = T.axis.spatial(230, ax1)
+ i3_1 = T.axis.spatial(230, ax2)
+ T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3])
+ T.writes(pad[i0_1, i1_1, i2_1, i3_1])
+ pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+ 3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227,
+ data[i0_1, i1_1, i2_1 - 3, i3_1 - 3],
+ T.int8(0),
+ dtype="int8",
+ )
+ for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7):
+ with T.block("wbuf"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w[v0, v1, v2, v3])
+ T.writes(wbuf[v0, v1, v2, v3])
+ wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3]
+ for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7):
+ with T.block("conv"):
+ nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
+ "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]
+ )
+ T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx])
+ T.writes(conv[nn, ff, yy, xx])
+ with T.init():
+ conv[nn, ff, yy, xx] = 0
+ conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast(
+ pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32"
+ ) * T.cast(wbuf[ff, rc, ry, rx], "int32")
+
+ sch = tir.Schedule(multi_producers_conv, debug_mask="all")
+ block_c = sch.get_block("pad")
+ axis = sch.get_loops("conv")[0]
+ sch.compute_at(block_c, axis, index=-2)
+ tvm.ir.assert_structural_equal(multi_producers_after_compute_at, sch.mod["main"])
+
+
+def test_reverse_compute_at_to_index():
+ @T.prim_func
+ def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(128, 128), "float32"]) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ C = T.alloc_buffer([128, 128], dtype="float32")
+ for i_0, j_0, i_1 in T.grid(8, 8, 16):
+ for j_1 in T.serial(16):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for ax0 in T.serial(16):
+ with T.block("C"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + ax0)
+ T.reads(B[vi, vj])
+ T.writes(C[vi, vj])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+ for i, j in T.grid(128, 128):
+ with T.block("D"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(B[vi, vj])
+ T.writes(D[vi, vj])
+ D[vi, vj] = B[vi, vj] + T.float32(1)
+
+ @T.prim_func
+ def main_reverse_compute_at(
+ A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(128, 128), "float32"]
+ ) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ C = T.alloc_buffer([128, 128], dtype="float32")
+ for i_0, j_0, i_1 in T.grid(8, 8, 16):
+ for j_1 in T.serial(16):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for ax0 in T.serial(16):
+ with T.block("D"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + ax0)
+ T.reads(B[vi, vj])
+ T.writes(D[vi, vj])
+ D[vi, vj] = B[vi, vj] + T.float32(1)
+ for ax0 in T.serial(16):
+ with T.block("C"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + ax0)
+ T.reads(B[vi, vj])
+ T.writes(C[vi, vj])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+
+ sch = tir.Schedule(main, debug_mask="all")
+ block_c = sch.get_block("D")
+ axis = sch.get_loops("B")[2]
+ sch.reverse_compute_at(block_c, axis, index=1)
+ tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])
+
+
if __name__ == "__main__":
tvm.testing.main()