You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/08/26 11:43:07 UTC

[tvm] branch main updated: [TIR][Schedule] enhance compute_at and reverse_compute_at primitive to choose possible position (#12450)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e02f2f9fdd [TIR][Schedule] enhance compute_at and reverse_compute_at primitive to choose possible position (#12450)
e02f2f9fdd is described below

commit e02f2f9fddd8cd38589e3569c41de9f7af39971c
Author: yin.changsheng <yi...@intellif.com>
AuthorDate: Fri Aug 26 19:42:57 2022 +0800

    [TIR][Schedule] enhance compute_at and reverse_compute_at primitive to choose possible position (#12450)
    
    Current TIR "compute_at" primitive will compute at it's closest consumers. When a block has multiple producers, whoever compute at later who is behind. But for some special hardware, we usually hope keep the a certain order whatever it's compute at early or late.
    eg: block A and block B are producers of block C. block A compute at block C first and block B compute at block C later. We hope the result is block B->block A->block C under some loop var.
---
 include/tvm/tir/schedule/schedule.h                |  14 +-
 python/tvm/tir/schedule/schedule.py                |  16 +++
 src/tir/schedule/concrete_schedule.cc              |   8 +-
 src/tir/schedule/concrete_schedule.h               |   7 +-
 src/tir/schedule/primitive.h                       |  13 +-
 src/tir/schedule/primitive/compute_at.cc           |  67 ++++++---
 src/tir/schedule/traced_schedule.cc                |  19 +--
 src/tir/schedule/traced_schedule.h                 |   7 +-
 ...chedule_schedule_rule_cross_thread_reduction.py |  16 +--
 ...ta_schedule_schedule_rule_multi_level_tiling.py |  86 ++++++------
 ...hedule_schedule_rule_random_compute_location.py |   2 +-
 .../unittest/test_tir_schedule_compute_at.py       | 152 +++++++++++++++++++++
 12 files changed, 308 insertions(+), 99 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 11fec642c7..da399ab976 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -432,9 +432,13 @@ class ScheduleNode : public runtime::Object {
    * \param block_rv The block to be moved
    * \param loop_rv The loop where the block to be moved under
    * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+   * \param index The block index of the loop body subtree blocks:
+   * - `index = -1` means inserted into the last possible insertion point;
+   * - `index = -2` means inserted into the first possible insertion point;
+   * - Otherwise, `index` is a nonnegative number that indicates the insertion point
    */
-  virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
-                         bool preserve_unit_loops) = 0;
+  virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+                         int index = -1) = 0;
   /*!
    * \brief Move a consumer block under the specific loop, and regenerate the
    * loops induced by the block so that the buffer region consumed by the consumer block could
@@ -449,9 +453,13 @@ class ScheduleNode : public runtime::Object {
    * \param block_rv The block to be moved
    * \param loop_rv The loop where the block to be moved under
    * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+   * \param index The block index of the loop body subtree blocks:
+   * - `index = -1` means inserted into the last possible insertion point;
+   * - `index = -2` means inserted into the first possible insertion point;
+   * - Otherwise, `index` is a nonnegative number that indicates the insertion point
    */
   virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
-                                bool preserve_unit_loops) = 0;
+                                bool preserve_unit_loops, int index = -1) = 0;
   /*!
    * \brief Inline a block into its consumer(s). It requires:
    * 1) The block is a complete non-root block, which only produces one buffer
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index e18bee35a5..04cc1bc26a 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1274,6 +1274,7 @@ class Schedule(Object):
         block: Union[BlockRV, str],
         loop: LoopRV,
         preserve_unit_loops: bool = False,
+        index: int = -1,
     ) -> None:
         """Compute-At. Move a producer block under the specific loop, and regenerate the
         loops induced by the block so that the buffer region produced by the producer block could
@@ -1303,6 +1304,12 @@ class Schedule(Object):
         preserve_unit_loops: bool
             Whether to keep the trivial loops whose extents are 1
 
+        index: int
+            The block index of the loop body subtree blocks:
+            - `index = -1` means inserted into the last possible insertion point;
+            - `index = -2` means inserted into the first possible insertion point;
+            - Otherwise, `index` is a nonnegative number that indicates the insertion point
+
         Examples
         --------
 
@@ -1360,6 +1367,7 @@ class Schedule(Object):
             block,
             loop,
             preserve_unit_loops,
+            index,
         )
 
     @type_checked
@@ -1368,6 +1376,7 @@ class Schedule(Object):
         block: Union[BlockRV, str],
         loop: LoopRV,
         preserve_unit_loops: bool = False,
+        index: int = -1,
     ) -> None:
         """Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the
         loops induced by the block so that the buffer region consumed by the consumer block could
@@ -1394,6 +1403,12 @@ class Schedule(Object):
         preserve_unit_loops: bool
             Whether to keep the trivial loops whose extents are 1
 
+        index: int
+            The block index of the loop body subtree blocks:
+            - `index = -1` means inserted into the last possible insertion point;
+            - `index = -2` means inserted into the first possible insertion point;
+            - Otherwise, `index` is a nonnegative number that indicates the insertion point
+
         Examples
         --------
 
@@ -1451,6 +1466,7 @@ class Schedule(Object):
             block,
             loop,
             preserve_unit_loops,
+            index,
         )
 
     @type_checked
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index c16638f748..5f773a02d6 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -574,7 +574,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
 /******** Schedule: Compute location ********/
 
 void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
-                                     bool preserve_unit_loops) {
+                                     bool preserve_unit_loops, int index) {
   static StmtSRef inline_mark = StmtSRef::InlineMark();
   static StmtSRef root_mark = StmtSRef::RootMark();
   StmtSRef loop_sref = this->GetSRef(loop_rv);
@@ -586,14 +586,14 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop
     TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
   } else {
     TVM_TIR_SCHEDULE_BEGIN();
-    tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
+    tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index);
     TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
   }
   this->state_->DebugVerify();
 }
 
 void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
-                                            bool preserve_unit_loops) {
+                                            bool preserve_unit_loops, int index) {
   static StmtSRef inline_mark = StmtSRef::InlineMark();
   static StmtSRef root_mark = StmtSRef::RootMark();
   StmtSRef loop_sref = this->GetSRef(loop_rv);
@@ -605,7 +605,7 @@ void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopR
     TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
   } else {
     TVM_TIR_SCHEDULE_BEGIN();
-    tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
+    tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index);
     TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
   }
   this->state_->DebugVerify();
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index cdd0a5b7b0..92b9de4088 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -119,9 +119,10 @@ class ConcreteScheduleNode : public ScheduleNode {
   BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
                   BufferIndexType buffer_index_type) override;
   /******** Schedule: Compute location ********/
-  void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
-  void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
-                        bool preserve_unit_loops) override;
+  void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+                 int index = -1) override;
+  void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+                        int index = -1) override;
   void ComputeInline(const BlockRV& block) override;
   void ReverseComputeInline(const BlockRV& block) override;
   /******** Schedule: Reduction ********/
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 14203a0d16..05d9e4cf94 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -299,10 +299,13 @@ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buf
  * \param self The schedule state
  * \param block_sref The block to be moved
  * \param loop_sref The loop where the block to be moved to
- * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+ * \param index The block index of the loop body subtree blocks:
+ * - `index = -1` means inserted into the last possible insertion point;
+ * - `index = -2` means inserted into the first possible insertion point;
+ * - Otherwise, `index` is a nonnegative number that indicates the insertion point
  */
 TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
-                       bool preserve_unit_loops);
+                       bool preserve_unit_loops, int index = -1);
 /*!
  * \brief Move a consumer block under the specific loop, and regenerate the
  * loops induced by the block so that the buffer region consumed by the consumer block could
@@ -318,9 +321,13 @@ TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stm
  * \param block_sref The block to be moved
  * \param loop_sref The loop where the block to be moved to
  * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+ * \param index The block index of the loop body subtree blocks:
+ * - `index = -1` means inserted into the last possible insertion point;
+ * - `index = -2` means inserted into the first possible insertion point;
+ * - Otherwise, `index` is a nonnegative number that indicates the insertion point
  */
 TVM_DLL void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref,
-                              const StmtSRef& loop_sref, bool preserve_unit_loops);
+                              const StmtSRef& loop_sref, bool preserve_unit_loops, int index = -1);
 /*!
  * \brief Inline a block into its consumer(s). It requires:
  * 1) The block is a complete non-root block, which only produces one buffer
diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc
index 98a6b2400e..8baedfd70d 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -129,15 +129,19 @@ class NotInSameScopeError : public ScheduleError {
  * \param producer_srefs The producer blocks
  * \param consumer_srefs The consumer blocks
  * \param block2realize A cache that maps a block to its realize
- * \return The last position the new block can be inserted onto, and the
+ * \param index The block index of the loop body subtree blocks:
+ * - `index = -1` means inserted into the last possible insertion point;
+ * - `index = -2` means inserted into the first possible insertion point;
+ * - Otherwise, `index` is a nonnegative number that indicates the insertion point
+ * \return The possible position the new block can be inserted into, and the
  * producer-consumer-relationship is still satisfied.
  * \throws ScheduleError if there is no such insertion point found
  */
 template <bool require_all_producers_visited, bool require_all_consumers_visited>
-int FindInsertionPoint(
-    const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
-    const Array<StmtSRef>& consumer_srefs,
-    std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
+int FindInsertionPoint(const ScheduleState& self, const Array<Stmt>& subtrees,
+                       const Array<StmtSRef>& producer_srefs, const Array<StmtSRef>& consumer_srefs,
+                       std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize,
+                       int index) {
   ProducerConsumerSplit split =
       ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
   // Step 1. Check if all the producers are visited in the subtrees, if required to
@@ -159,8 +163,22 @@ int FindInsertionPoint(
   // Step 3. Check if there is at least one index of the position can be inserted into
   // The valid indices are: (last_producer_position, first_consumer_position]
   ICHECK(split.last_producer_position < split.first_consumer_position);
-  // Step 4. Return the last valid insertion point
-  return split.first_consumer_position;
+  // Step 4. Return the possible insertion point according to index
+  int insert_position;
+  if (index == -1) {
+    insert_position = split.first_consumer_position;
+  } else if (index == -2) {
+    insert_position = split.last_producer_position + 1;
+  } else if (index >= 0 && index >= split.last_producer_position + 1 &&
+             index <= split.first_consumer_position) {
+    insert_position = index;
+  } else {
+    LOG(FATAL) << "Valid index:(-1, -2, [" << split.last_producer_position + 1 << ", "
+               << split.first_consumer_position << "]), "
+               << "current index=" << index;
+    throw;
+  }
+  return insert_position;
 }
 
 /*!
@@ -556,7 +574,8 @@ void CalculateProvidedRequiredRegions(
 template <bool is_compute_at>
 void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
                                      const StmtSRef& loop_sref, bool preserve_unit_loops,
-                                     arith::Analyzer* analyzer, bool check_only = false) {
+                                     arith::Analyzer* analyzer, bool check_only = false,
+                                     int index = -1) {
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
   // Step 1. Bunch of checks
@@ -588,7 +607,8 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
       /*self=*/self,
       /*subtrees=*/AsArray(loop->body),
       /*producer_srefs=*/producer_srefs,
-      /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize);
+      /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize,
+      /*index=*/index);
   // Step 4. Calculate the region provided by a single execution instance of `block`,
   // as well as the region required by dependent blocks under `loop`.
   // Here is the definition of `provide` and `require`:
@@ -626,17 +646,17 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
 }
 
 void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
-               bool preserve_unit_loops) {
+               bool preserve_unit_loops, int index) {
   arith::Analyzer analyzer;
-  ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
-                                        &analyzer);
+  ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, &analyzer,
+                                        false, index);
 }
 
 void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
-                      bool preserve_unit_loops) {
+                      bool preserve_unit_loops, int index) {
   arith::Analyzer analyzer;
   ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
-                                         &analyzer);
+                                         &analyzer, false, index);
 }
 
 bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
@@ -671,20 +691,21 @@ struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {
 
  private:
   static constexpr size_t kNumInputs = 2;
-  static constexpr size_t kNumAttrs = 1;
+  static constexpr size_t kNumAttrs = 2;
   static constexpr size_t kNumDecisions = 0;
 
   static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
-                                      Bool preserve_unit_loops) {
-    return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
+                                      Bool preserve_unit_loops, IntImm index) {
+    return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value);
   }
 
   static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
-                                 Bool preserve_unit_loops) {
+                                 Bool preserve_unit_loops, IntImm index) {
     PythonAPICall py("compute_at");
     py.Input("block", block_rv);
     py.Input("loop", loop_rv);
     py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
+    py.Input("index", index);
     return py.Str();
   }
 
@@ -698,20 +719,22 @@ struct ReverseComputeAtTraits : public UnpackedInstTraits<ReverseComputeAtTraits
 
  private:
   static constexpr size_t kNumInputs = 2;
-  static constexpr size_t kNumAttrs = 1;
+  static constexpr size_t kNumAttrs = 2;
   static constexpr size_t kNumDecisions = 0;
 
   static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
-                                      Bool preserve_unit_loops) {
-    return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
+                                      Bool preserve_unit_loops, IntImm index) {
+    return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
+                                 index->value);
   }
 
   static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
-                                 Bool preserve_unit_loops) {
+                                 Bool preserve_unit_loops, IntImm index) {
     PythonAPICall py("reverse_compute_at");
     py.Input("block", block_rv);
     py.Input("loop", loop_rv);
     py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
+    py.Input("index", index);
     return py.Str();
   }
 
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 07d4da54d7..04ddc0507d 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -322,24 +322,25 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
 /******** Schedule: Compute location ********/
 
 void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
-                                   bool preserve_unit_loops) {
-  ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops);
+                                   bool preserve_unit_loops, int index) {
+  ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, index);
 
   static const InstructionKind& kind = InstructionKind::Get("ComputeAt");
-  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
-                                      /*inputs=*/{block_rv, loop_rv},
-                                      /*attrs=*/{Integer(preserve_unit_loops)},
-                                      /*outputs=*/{}));
+  trace_->Append(
+      /*inst=*/Instruction(/*kind=*/kind,
+                           /*inputs=*/{block_rv, loop_rv},
+                           /*attrs=*/{Integer(preserve_unit_loops), Integer(index)},
+                           /*outputs=*/{}));
 }
 
 void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
-                                          bool preserve_unit_loops) {
-  ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops);
+                                          bool preserve_unit_loops, int index) {
+  ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops, index);
 
   static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt");
   trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
                                       /*inputs=*/{block_rv, loop_rv},
-                                      /*attrs=*/{Integer(preserve_unit_loops)},
+                                      /*attrs=*/{Integer(preserve_unit_loops), Integer(index)},
                                       /*outputs=*/{}));
 }
 
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index 865a216879..d98e4ba4bb 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -79,9 +79,10 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
                   BufferIndexType buffer_index_type) final;
   /******** Schedule: Compute location ********/
-  void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final;
-  void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
-                        bool preserve_unit_loops) final;
+  void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+                 int index = -1) final;
+  void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
+                        int index = -1) final;
   void ComputeInline(const BlockRV& block_rv) final;
   void ReverseComputeInline(const BlockRV& block_rv) final;
   /******** Schedule: Reduction ********/
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index 5f76e77592..592d32d624 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -80,7 +80,7 @@ def test_gpu_softmax_mn():
             "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
             "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l6, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
+            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l7, l8, l9 = sch.get_loops(block=b0)",
             "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
@@ -93,7 +93,7 @@ def test_gpu_softmax_mn():
             "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
             "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l6, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
+            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l7, l8, l9 = sch.get_loops(block=b0)",
             "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
@@ -107,7 +107,7 @@ def test_gpu_softmax_mn():
             "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
             "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
             'sch.bind(loop=l7, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
+            "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)",
             'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
             "l8, l9, l10 = sch.get_loops(block=b1)",
             "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
@@ -117,7 +117,7 @@ def test_gpu_softmax_mn():
             "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
             "l17, l18 = sch.split(loop=l15, factors=[None, v16], preserve_unit_iters=True)",
             'sch.bind(loop=l18, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)",
+            "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True, index=-1)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l19, l20, l21 = sch.get_loops(block=b0)",
             "l22, l23 = sch.split(loop=l21, factors=[None, v16], preserve_unit_iters=True)",
@@ -157,7 +157,7 @@ def test_gpu_softmax_mn_after_inline():
             "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
             "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l6, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
+            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l7, l8, l9 = sch.get_loops(block=b0)",
             "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
@@ -171,14 +171,14 @@ def test_gpu_softmax_mn_after_inline():
             "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
             "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
             'sch.bind(loop=l7, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
+            "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)",
             'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
             "l8, l9, l10 = sch.get_loops(block=b1)",
             "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
             'sch.bind(loop=l12, thread_axis="threadIdx.x")',
             "b13, b14 = sch.get_consumers(block=b0)",
             "l15, l16, l17, l18 = sch.get_loops(block=b13)",
-            "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)",
+            "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True, index=-1)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l19, l20, l21 = sch.get_loops(block=b0)",
             "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)",
@@ -206,7 +206,7 @@ def test_gpu_batch_norm_bmn():
             "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
             "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)",
             'sch.bind(loop=l5, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)",
+            "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True, index=-1)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l6, l7, l8, l9 = sch.get_loops(block=b0)",
             "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)",
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 87159fcb31..fe1220c509 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -62,7 +62,7 @@ def test_cpu_matmul():
             "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
             'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
-            "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
+            "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)",
         ],
         [
             'b0 = sch.get_block(name="C", func_name="main")',
@@ -76,7 +76,7 @@ def test_cpu_matmul():
             "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
             'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
-            "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
+            "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)",
         ],
         [
             'b0 = sch.get_block(name="C", func_name="main")',
@@ -123,7 +123,7 @@ def test_cpu_matmul_relu():
             "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
             "b24, = sch.get_consumers(block=b0)",
-            "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
+            "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)",
         ],
         [
             'b0 = sch.get_block(name="C", func_name="main")',
@@ -137,7 +137,7 @@ def test_cpu_matmul_relu():
             "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
             "b24, = sch.get_consumers(block=b0)",
-            "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
+            "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)",
         ],
         [
             'b0 = sch.get_block(name="C", func_name="main")',
@@ -193,15 +193,15 @@ def test_cuda_matmul():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)',
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)',
             'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
-            "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
+            "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)",
             'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
-            "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
+            "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)",
             "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
             "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
             "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
             'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
             'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
-            "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
+            "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)",
             "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
             "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
             "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
@@ -247,15 +247,15 @@ def test_cuda_matmul_relu():
             "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)",
             'sch.bind(loop=l32, thread_axis="threadIdx.x")',
             'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
-            "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
+            "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)",
             'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
-            "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
+            "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)",
             "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
             "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
             "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
             'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
             'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
-            "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
+            "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)",
             "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
             "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
             "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
@@ -402,7 +402,7 @@ v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
 l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
 sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
 b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
-sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split(
+sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True, index=-1)""".split(
             "\n"
         ),
         """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
@@ -437,7 +437,7 @@ v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
 l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
 sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
 b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
-sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split(
+sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True, index=-1)""".split(
             "\n"
         ),
         """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
@@ -546,15 +546,15 @@ sch.bind(loop=l37, thread_axis="vthread.x")
 l38 = sch.fuse(l17, l27, preserve_unit_iters=True)
 sch.bind(loop=l38, thread_axis="threadIdx.x")
 b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local")
-sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True, index=-1)
 b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True)
+sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True, index=-1)
 l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40)
 l47 = sch.fuse(l45, l46, preserve_unit_iters=True)
 v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48)
 b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True)
+sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True, index=-1)
 l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49)
 l56 = sch.fuse(l54, l55, preserve_unit_iters=True)
 v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
@@ -632,9 +632,9 @@ sch.bind(loop=l51, thread_axis="blockIdx.x")
 l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
 sch.bind(loop=l52, thread_axis="threadIdx.y")
 b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1)
 b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1)
 v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
 sch.reverse_compute_inline(block=b2)
@@ -646,19 +646,19 @@ sch.reorder(l70, l64, l62)
 b72 = sch.blockize(loop=l64)
 sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
 b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1)
 l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
 l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
 v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
 b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1)
 l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
 l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
 v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
 b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
+sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1)
 l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
 l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
 l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
@@ -667,7 +667,7 @@ sch.reorder(l110, l102, l100)
 b112 = sch.blockize(loop=l102)
 sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
 b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True)
+sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1)
 l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
 l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
 l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
@@ -772,9 +772,9 @@ sch.bind(loop=l51, thread_axis="blockIdx.x")
 l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
 sch.bind(loop=l52, thread_axis="threadIdx.y")
 b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1)
 b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1)
 v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
 sch.reverse_compute_inline(block=b2)
@@ -786,19 +786,19 @@ sch.reorder(l70, l64, l62)
 b72 = sch.blockize(loop=l64)
 sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
 b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1)
 l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
 l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
 v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
 b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1)
 l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
 l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
 v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
 b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
+sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1)
 l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
 l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
 l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
@@ -807,7 +807,7 @@ sch.reorder(l110, l102, l100)
 b112 = sch.blockize(loop=l102)
 sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
 b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True)
+sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1)
 l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
 l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
 l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
@@ -895,7 +895,7 @@ sch.bind(loop=l50, thread_axis="blockIdx.x")
 l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
 sch.bind(loop=l51, thread_axis="threadIdx.y")
 b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1)
 sch.reverse_compute_inline(block=b1)
 l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
 l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
@@ -905,19 +905,19 @@ sch.reorder(l67, l61, l59)
 b69 = sch.blockize(loop=l61)
 sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
 b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
+sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1)
 l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
 l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
 v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
 b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
+sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1)
 l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
 l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
 v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
 b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1)
 l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
 l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
 l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
@@ -926,7 +926,7 @@ sch.reorder(l107, l99, l97)
 b109 = sch.blockize(loop=l99)
 sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
 b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1)
 l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
 l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
 l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
@@ -995,7 +995,7 @@ sch.bind(loop=l50, thread_axis="blockIdx.x")
 l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
 sch.bind(loop=l51, thread_axis="threadIdx.y")
 b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1)
 sch.reverse_compute_inline(block=b1)
 l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
 l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
@@ -1005,19 +1005,19 @@ sch.reorder(l67, l61, l59)
 b69 = sch.blockize(loop=l61)
 sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
 b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
+sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1)
 l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
 l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
 v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
 b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
+sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1)
 l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
 l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
 v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
 b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1)
 l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
 l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
 l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
@@ -1026,7 +1026,7 @@ sch.reorder(l107, l99, l97)
 b109 = sch.blockize(loop=l99)
 sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
 b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True)
+sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1)
 l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
 l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
 l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
@@ -1133,9 +1133,9 @@ sch.bind(loop=l63, thread_axis="blockIdx.x")
 l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True)
 sch.bind(loop=l64, thread_axis="threadIdx.y")
 b65 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="shared")
-sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True, index=-1)
 b66 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="wmma.accumulator")
-sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True)
+sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True, index=-1)
 v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch", ann_val=v67)
 sch.reverse_compute_inline(block=b1)
@@ -1147,19 +1147,19 @@ sch.reorder(l82, l76, l74)
 b84 = sch.blockize(loop=l76)
 sch.annotate(block_or_loop=b84, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
 b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared")
-sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True)
+sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True, index=-1)
 l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85)
 l92 = sch.fuse(l90, l91, preserve_unit_iters=True)
 v93 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93)
 b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared")
-sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True)
+sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True, index=-1)
 l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94)
 l101 = sch.fuse(l99, l100, preserve_unit_iters=True)
 v102 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102)
 b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a")
-sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True)
+sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True, index=-1)
 l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b103)
 l111, l112 = sch.split(loop=l110, factors=[None, 16], preserve_unit_iters=True)
 l113, l114 = sch.split(loop=l109, factors=[None, 16], preserve_unit_iters=True)
@@ -1168,7 +1168,7 @@ sch.reorder(l122, l114, l112)
 b124 = sch.blockize(loop=l114)
 sch.annotate(block_or_loop=b124, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
 b125 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="wmma.matrix_b")
-sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True)
+sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True, index=-1)
 l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b125)
 l133, l134 = sch.split(loop=l132, factors=[None, 16], preserve_unit_iters=True)
 l135, l136 = sch.split(loop=l131, factors=[None, 16], preserve_unit_iters=True)
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
index b2df408e9d..c951a5adf3 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
@@ -71,7 +71,7 @@ def test_random_compute_location():
         [
             'b0 = sch.get_block(name="move", func_name="main")',
             "l1 = sch.sample_compute_location(block=b0)",
-            "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)",
+            "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True, index=-1)",
         ]
     ]
     mod = Add
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py
index 0c20a4783c..72cba1a8fd 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1353,5 +1353,157 @@ def test_compute_at_int64_loop(use_block_name):
     verify_trace_roundtrip(sch=sch, mod=mod)
 
 
+def test_compute_at_to_index():
+    @T.prim_func
+    def multi_producers_conv(
+        data: T.Buffer[(1, 3, 224, 224), "int8"],
+        w: T.Buffer[(16, 3, 7, 7), "int8"],
+        conv: T.Buffer[(1, 16, 112, 112), "int32"],
+    ) -> None:
+        pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8")
+        wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8")
+        for i0, i1, i2, i3 in T.grid(1, 3, 230, 230):
+            with T.block("pad"):
+                i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3])
+                T.writes(pad[i0_1, i1_1, i2_1, i3_1])
+                pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+                    3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227,
+                    data[i0_1, i1_1, i2_1 - 3, i3_1 - 3],
+                    T.int8(0),
+                    dtype="int8",
+                )
+        for i0 in T.serial(1):
+            for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7):
+                with T.block("wbuf"):
+                    v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(w[v0, v1, v2, v3])
+                    T.writes(wbuf[v0, v1, v2, v3])
+                    wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3]
+            for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7):
+                with T.block("conv"):
+                    nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
+                        "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]
+                    )
+                    T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx])
+                    T.writes(conv[nn, ff, yy, xx])
+                    with T.init():
+                        conv[nn, ff, yy, xx] = 0
+                    conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast(
+                        pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32"
+                    ) * T.cast(wbuf[ff, rc, ry, rx], "int32")
+
+    @T.prim_func
+    def multi_producers_after_compute_at(
+        data: T.Buffer[(1, 3, 224, 224), "int8"],
+        w: T.Buffer[(16, 3, 7, 7), "int8"],
+        conv: T.Buffer[(1, 16, 112, 112), "int32"],
+    ) -> None:
+        pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8")
+        wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8")
+        for i0 in T.serial(1):
+            for ax0, ax1, ax2 in T.grid(3, 229, 229):
+                with T.block("pad"):
+                    i0_1 = T.axis.spatial(1, 0)
+                    i1_1 = T.axis.spatial(3, ax0)
+                    i2_1 = T.axis.spatial(230, ax1)
+                    i3_1 = T.axis.spatial(230, ax2)
+                    T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3])
+                    T.writes(pad[i0_1, i1_1, i2_1, i3_1])
+                    pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+                        3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227,
+                        data[i0_1, i1_1, i2_1 - 3, i3_1 - 3],
+                        T.int8(0),
+                        dtype="int8",
+                    )
+            for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7):
+                with T.block("wbuf"):
+                    v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(w[v0, v1, v2, v3])
+                    T.writes(wbuf[v0, v1, v2, v3])
+                    wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3]
+            for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7):
+                with T.block("conv"):
+                    nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
+                        "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]
+                    )
+                    T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx])
+                    T.writes(conv[nn, ff, yy, xx])
+                    with T.init():
+                        conv[nn, ff, yy, xx] = 0
+                    conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast(
+                        pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32"
+                    ) * T.cast(wbuf[ff, rc, ry, rx], "int32")
+
+    sch = tir.Schedule(multi_producers_conv, debug_mask="all")
+    block_c = sch.get_block("pad")
+    axis = sch.get_loops("conv")[0]
+    sch.compute_at(block_c, axis, index=-2)
+    tvm.ir.assert_structural_equal(multi_producers_after_compute_at, sch.mod["main"])
+
+
+def test_reverse_compute_at_to_index():
+    @T.prim_func
+    def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(128, 128), "float32"]) -> None:
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        C = T.alloc_buffer([128, 128], dtype="float32")
+        for i_0, j_0, i_1 in T.grid(8, 8, 16):
+            for j_1 in T.serial(16):
+                with T.block("B"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+            for ax0 in T.serial(16):
+                with T.block("C"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + ax0)
+                    T.reads(B[vi, vj])
+                    T.writes(C[vi, vj])
+                    C[vi, vj] = B[vi, vj] + T.float32(1)
+        for i, j in T.grid(128, 128):
+            with T.block("D"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(B[vi, vj])
+                T.writes(D[vi, vj])
+                D[vi, vj] = B[vi, vj] + T.float32(1)
+
+    @T.prim_func
+    def main_reverse_compute_at(
+        A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(128, 128), "float32"]
+    ) -> None:
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        C = T.alloc_buffer([128, 128], dtype="float32")
+        for i_0, j_0, i_1 in T.grid(8, 8, 16):
+            for j_1 in T.serial(16):
+                with T.block("B"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+            for ax0 in T.serial(16):
+                with T.block("D"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + ax0)
+                    T.reads(B[vi, vj])
+                    T.writes(D[vi, vj])
+                    D[vi, vj] = B[vi, vj] + T.float32(1)
+            for ax0 in T.serial(16):
+                with T.block("C"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + ax0)
+                    T.reads(B[vi, vj])
+                    T.writes(C[vi, vj])
+                    C[vi, vj] = B[vi, vj] + T.float32(1)
+
+    sch = tir.Schedule(main, debug_mask="all")
+    block_c = sch.get_block("D")
+    axis = sch.get_loops("B")[2]
+    sch.reverse_compute_at(block_c, axis, index=1)
+    tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()