You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/06/16 05:42:17 UTC

[tvm] branch main updated: [TIR] Add preserve-unit-iters (#11585)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 89e1a6c3f2 [TIR] Add preserve-unit-iters (#11585)
89e1a6c3f2 is described below

commit 89e1a6c3f2bbdaa3f585459cefbc7612ae46b1ad
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Jun 15 22:42:12 2022 -0700

    [TIR] Add preserve-unit-iters (#11585)
---
 include/tvm/tir/schedule/schedule.h                |   7 +-
 python/tvm/tir/schedule/schedule.py                |  21 ++-
 src/tir/schedule/concrete_schedule.cc              |   9 +-
 src/tir/schedule/concrete_schedule.h               |   5 +-
 src/tir/schedule/primitive.h                       |   7 +-
 src/tir/schedule/primitive/loop_transformation.cc  |  64 +++++----
 src/tir/schedule/traced_schedule.cc                |  13 +-
 src/tir/schedule/traced_schedule.h                 |   5 +-
 .../unittest/test_meta_schedule_integration.py     |   9 +-
 .../test_meta_schedule_post_order_apply.py         |   8 +-
 ...test_meta_schedule_schedule_rule_add_rfactor.py |   4 +-
 .../test_meta_schedule_schedule_rule_auto_bind.py  |  12 +-
 ...chedule_schedule_rule_cross_thread_reduction.py |  38 ++---
 ...ta_schedule_schedule_rule_multi_level_tiling.py | 158 ++++++++++-----------
 tests/python/unittest/test_tir_schedule_trace.py   |   4 +-
 15 files changed, 202 insertions(+), 162 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index d3ecd8a113..d95a9d4e7e 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -277,9 +277,10 @@ class ScheduleNode : public runtime::Object {
    * 3) All loops must start with 0.
    * 4) The domain of a loop to be fused cannot depend on another loop to be fused.
    * \param loop_rvs The loops to be fused
+   * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
    * \return The new loop after fusion
    */
-  virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
+  virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0;
   /*!
    * \brief Split a loop into a list of consecutive loops. It requires:
    * 1) The loop can't have annotation or thread binding.
@@ -287,9 +288,11 @@ class ScheduleNode : public runtime::Object {
    * \param loop_rv The loop to be split
    * \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
    * that factor is inferred.
+   * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
    * \return The new loops after split
    */
-  virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
+  virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
+                              bool preserve_unit_iters = true) = 0;
   /*!
    * \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
    * It requires:
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index d29495c430..7a1e244604 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -495,7 +495,11 @@ class Schedule(Object):
 
     ########## Schedule: Transform loops ##########
     @type_checked
-    def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+    def fuse(
+        self,
+        *loops: List[LoopRV],
+        preserve_unit_iters: bool = True,
+    ) -> LoopRV:
         """Fuse a list of consecutive loops into one. It requires:
         1) The loops can't have annotations or thread bindings.
         2) The (i+1)-th loop must be the only child of the i-th loop.
@@ -553,13 +557,14 @@ class Schedule(Object):
                         B[vi, vj] = A[vi, vj] * 2.0
 
         """
-        return _ffi_api.ScheduleFuse(self, loops)  # type: ignore # pylint: disable=no-member
+        return _ffi_api.ScheduleFuse(self, loops, preserve_unit_iters)  # type: ignore # pylint: disable=no-member
 
     @type_checked
     def split(
         self,
         loop: LoopRV,
         factors: List[Union[int, ExprRV, None]],
+        preserve_unit_iters: bool = True,
     ) -> List[LoopRV]:
         """Split a loop into a list of consecutive loops. It requires:
         1) The loop can't have annotation or thread binding.
@@ -580,6 +585,9 @@ class Schedule(Object):
             - ExprRV
             - Positive constant integers
 
+        preserve_unit_iters : bool
+            Whether or not to preserve unit iterators in block bindings
+
         Returns
         -------
         split_loops : List[LoopRV]
@@ -628,7 +636,14 @@ class Schedule(Object):
         """
         # it will be checked later in C++ implementation
         # that there is at most one None in `factors`
-        return list(_ffi_api.ScheduleSplit(self, loop, factors))  # type: ignore # pylint: disable=no-member
+        return list(
+            _ffi_api.ScheduleSplit(  # type: ignore # pylint: disable=no-member
+                self,
+                loop,
+                factors,
+                preserve_unit_iters,
+            )
+        )
 
     @type_checked
     def reorder(self, *ordered_loops: List[LoopRV]) -> None:
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 051bd42506..b2f48753b5 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -333,19 +333,20 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
 
 /******** Schedule: Transform loops ********/
 
-LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
+LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
   CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
   Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
   StmtSRef result{nullptr};
   TVM_TIR_SCHEDULE_BEGIN();
-  result = tir::Fuse(state_, loop_srefs);
+  result = tir::Fuse(state_, loop_srefs, preserve_unit_iters);
   TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
   this->state_->DebugVerify();
   return CreateRV<LoopRV>(result);
 }
 
 Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
-                                          const Array<Optional<ExprRV>>& factor_rvs) {
+                                          const Array<Optional<ExprRV>>& factor_rvs,
+                                          bool preserve_unit_iters) {
   class NotSingleInferFactorError : public ScheduleError {
    public:
     explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
@@ -440,7 +441,7 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
   } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) {
     throw WrongFactorProductError(state_->mod, GetRef<For>(loop));
   }
-  results = tir::Split(state_, loop_sref, factors);
+  results = tir::Split(state_, loop_sref, factors, preserve_unit_iters);
   TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
   this->state_->DebugVerify();
   return CreateRV<LoopRV>(results);
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 11d68694a1..dfbacb530a 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -96,8 +96,9 @@ class ConcreteScheduleNode : public ScheduleNode {
   Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
   Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
   /******** Schedule: Transform loops ********/
-  LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
-  Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
+  LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
+  Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
+                      bool preserve_unit_iters) override;
   void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
   LoopRV AddUnitLoop(const BlockRV& block_rv) override;
   LoopRV AddUnitLoop(const LoopRV& loop_rv) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index af0f417e4c..212571df10 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -156,10 +156,11 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
  * \param self The state of the schedule
  * \param loop_sref The sref to the loop being split
  * \param factors The splitting factors
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
  * \return An array of srefs to the loops after splitting
  */
 TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
-                              const Array<PrimExpr>& factors);
+                              const Array<PrimExpr>& factors, bool preserve_unit_iters);
 /*!
  * \brief Fuse a list of consecutive loops into one. It requires:
  * 1) The loops can't have annotations or thread bindings.
@@ -168,9 +169,11 @@ TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
  * 4) The domain of a loop to be fused cannot depend on another loop to be fused.
  * \param self The state of the schedule
  * \param loop_srefs An array of srefs to the loops to be fused
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
  * \return The sref to the fused loop
  */
-TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs);
+TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs,
+                      bool preserve_unit_loops);
 /*!
  * \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
  * It requires:
diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc
index bb505bca33..f1b6f46e1b 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -77,18 +77,21 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator {
 /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */
 class IterMapSimplifyBlockBinding : public StmtExprMutator {
  public:
-  explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent)
-      : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {}
-
-  static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs,
-                              MapNode* opaque_blocks) {
+  explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent,
+                                       bool preserve_unit_iters)
+      : opaque_blocks_(opaque_blocks),
+        loop_var2extent_(loop_var2extent),
+        preserve_unit_iters_(preserve_unit_iters) {}
+
+  static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs, MapNode* opaque_blocks,
+                              bool preserve_unit_iters) {
     Map<Var, Range> loop_var2extent;
     for (const StmtSRef& sref : loop_srefs) {
       const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
       loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
     }
-    return Downcast<For>(
-        IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt)));
+    return Downcast<For>(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent),
+                                                     preserve_unit_iters)(std::move(stmt)));
   }
 
  private:
@@ -112,11 +115,12 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
       }
       return std::move(realize);
     }
-    Array<PrimExpr> v = arith::IterMapSimplify(/*indices=*/op->iter_values,
-                                               /*input_iters=*/loop_var2extent_,
-                                               /*input_pred=*/op->predicate,
-                                               /*check_level=*/arith::IterMapLevel::Surjective,
-                                               /*simplify_trivial_iterators=*/false);
+    Array<PrimExpr> v =
+        arith::IterMapSimplify(/*indices=*/op->iter_values,
+                               /*input_iters=*/loop_var2extent_,
+                               /*input_pred=*/op->predicate,
+                               /*check_level=*/arith::IterMapLevel::Surjective,
+                               /*simplify_trivial_iterators=*/!preserve_unit_iters_);
     if (v.same_as(op->iter_values)) {
       return GetRef<Stmt>(op);
     } else {
@@ -130,6 +134,8 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
   MapNode* opaque_blocks_;
   /*! \brief The range of loops */
   Map<Var, Range> loop_var2extent_;
+  /*! \brief Whether or not to simplify unit iterators */
+  bool preserve_unit_iters_;
 };
 
 class BlockPropertyError : public ScheduleError {
@@ -376,8 +382,8 @@ class DependentLoopError : public ScheduleError {
   PrimitiveKind kind_;
 };
 
-Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
-                      const Array<PrimExpr>& factors) {
+Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array<PrimExpr>& factors,
+                      bool preserve_unit_iters) {
   // Invariance
   // - The total repeat number has not changed for each direct child block with updating predicate.
   // - The execution order has not changed. (The block executes with the same args and the same
@@ -432,7 +438,8 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
     new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt);
   }
   new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref),
-                                                           opaque_block_reuse.CopyOnWrite());
+                                                           opaque_block_reuse.CopyOnWrite(),
+                                                           preserve_unit_iters);
   self->Replace(loop_sref, new_stmt, opaque_block_reuse);
   Array<StmtSRef> result_srefs;
   result_srefs.reserve(n);
@@ -444,7 +451,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
   return result_srefs;
 }
 
-StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
+StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
   // Invariance
   // - The total repeat number has not changed for each direct child block.
   // - The execution order has not changed. (The block executes with the same
@@ -527,7 +534,8 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
   fused_extent = analyzer.Simplify(fused_extent);
   new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt);
   new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(
-      std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite());
+      std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite(),
+      preserve_unit_iters);
   self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
   return self->stmt2ref.at(new_stmt.get());
 }
@@ -755,7 +763,7 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
 
  private:
   static constexpr size_t kNumInputs = 2;
-  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumAttrs = 1;
   static constexpr size_t kNumDecisions = 0;
 
   template <size_t delta>
@@ -770,14 +778,17 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
   }
 
   static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv,
-                                               Array<Optional<ExprRV>> factors) {
-    return sch->Split(loop_rv, factors);
+                                               Array<Optional<ExprRV>> factors,
+                                               Bool preserve_unit_iters) {
+    return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool());
   }
 
-  static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors) {
+  static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors,
+                                 Bool preserve_unit_iters) {
     PythonAPICall py("split");
     py.Input("loop", loop_rv);
     py.Input("factors", factors);
+    py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
     py.OutputList(outputs);
     return py.Str();
   }
@@ -792,7 +803,7 @@ struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
 
  private:
   static constexpr size_t kNumInputs = 1;
-  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumAttrs = 1;
   static constexpr size_t kNumDecisions = 0;
 
   template <size_t delta>
@@ -801,15 +812,18 @@ struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
     setter(delta, inputs);
   }
 
-  static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
-    return sch->Fuse(loop_rvs);
+  static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs,
+                                        Bool preserve_unit_iters) {
+    return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool());
   }
 
-  static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
+  static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs,
+                                 Bool preserve_unit_iters) {
     PythonAPICall py("fuse");
     for (const String& loop_rv : loop_rvs) {
       py.Input("", loop_rv);
     }
+    py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
     py.SingleOutput(outputs);
     return py.Str();
   }
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 95a10e26ac..733b5d872f 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -158,20 +158,21 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
 
 /******** Schedule: Transform loops ********/
 
-LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
-  LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs);
+LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
+  LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);
 
   static const InstructionKind& kind = InstructionKind::Get("Fuse");
   trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
                                       /*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
-                                      /*attrs=*/{},
+                                      /*attrs=*/{Integer(preserve_unit_loops)},
                                       /*outputs=*/{result}));
   return result;
 }
 
 Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
-                                        const Array<Optional<ExprRV>>& factor_rvs) {
-  Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs);
+                                        const Array<Optional<ExprRV>>& factor_rvs,
+                                        bool preserve_unit_iters) {
+  Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters);
 
   std::vector<ObjectRef> inputs;
   inputs.reserve(1 + factor_rvs.size());
@@ -183,7 +184,7 @@ Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
   static const InstructionKind& kind = InstructionKind::Get("Split");
   trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
                                       /*inputs=*/inputs,
-                                      /*attrs=*/{},
+                                      /*attrs=*/{Integer(preserve_unit_iters)},
                                       /*outputs=*/{results.begin(), results.end()}));
   return results;
 }
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index 25bf3d4871..178026d9ea 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -60,8 +60,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
   Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
   /******** Schedule: Transform loops ********/
-  LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
-  Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
+  LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
+  Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
+                      bool preserve_unit_iters) final;
   void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
   LoopRV AddUnitLoop(const BlockRV& block_rv) final;
   LoopRV AddUnitLoop(const LoopRV& loop_rv) final;
diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py
index f2802b41eb..6d5016cd81 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -193,6 +193,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
 @requires_torch
 def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
     def filter_func(args) -> bool:
+        from tvm.te import create_prim_func  # pylint: disable=import-outside-toplevel
 
         has_complex_op = False
         visited = set()
@@ -205,16 +206,16 @@ def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
             if isinstance(t.op, te.PlaceholderOp):
                 pass
             elif isinstance(t.op, te.ComputeOp):
-                has_complex_op = has_complex_op or any(
-                    [isinstance(e, tir.Reduce) for e in t.op.body]
-                )
+                has_complex_op = has_complex_op or any(isinstance(e, tir.Reduce) for e in t.op.body)
                 for x in t.op.input_tensors:
                     traverse(x)
             visited.add(t.handle.value)
 
         for t in args:
             traverse(t)
-        return has_complex_op
+        if not has_complex_op:
+            return None
+        return create_prim_func(args)
 
     mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
     extracted_tasks = ms.extract_task_from_relay(
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index 2609d2be9d..21d29ac74d 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -326,12 +326,12 @@ def test_meta_schedule_post_order_apply_remove_block():
                 'b2 = sch.get_block(name="C", func_name="main")',
                 "sch.compute_inline(block=b1)",
                 "l3, l4 = sch.get_loops(block=b2)",
-                "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ")",
-                "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ")",
+                "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)",
+                "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)",
                 "sch.reorder(l5, l7, l6, l8)",
                 "l9, l10 = sch.get_loops(block=b0)",
-                "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ")",
-                "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ")",
+                "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)",
+                "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ", preserve_unit_iters=True)",
                 "sch.reorder(l11, l13, l12, l14)",
             ]
         )
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
index 09daea0945..a39c8aea5f 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
@@ -43,7 +43,7 @@ def test_cpu_matmul():
             'b0 = sch.get_block(name="C", func_name="main")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l6, l7 = sch.split(loop=l3, factors=[v4, v5])",
+            "l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)",
             "b8 = sch.rfactor(loop=l7, factor_axis=2)",
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)',
         ],
@@ -51,7 +51,7 @@ def test_cpu_matmul():
             'b0 = sch.get_block(name="C", func_name="main")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l6, l7 = sch.split(loop=l3, factors=[v4, v5])",
+            "l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)",
             "b8 = sch.rfactor(loop=l6, factor_axis=2)",
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)',
         ],
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
index 9c43c23a3e..a89cca72e1 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
@@ -76,9 +76,9 @@ def test_cuda_element_wise():
         [
             'b0 = sch.get_block(name="C", func_name="main")',
             "l1, l2 = sch.get_loops(block=b0)",
-            "l3 = sch.fuse(l1, l2)",
+            "l3 = sch.fuse(l1, l2, preserve_unit_iters=True)",
             "v4 = sch.sample_categorical(candidates=[32, 64, 128, 256, 512, 1024], probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666])",
-            "l5, l6 = sch.split(loop=l3, factors=[None, v4])",
+            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l5, thread_axis="blockIdx.x")',
             'sch.bind(loop=l6, thread_axis="threadIdx.x")',
         ]
@@ -100,8 +100,8 @@ def test_cuda_reduction_loop_only():
             'b0 = sch.get_block(name="C", func_name="main")',
             "l1, = sch.get_loops(block=b0)",
             "l2 = sch.add_unit_loop(block_or_loop=l1)",
-            "l3 = sch.fuse(l2)",
-            "l4, l5 = sch.split(loop=l3, factors=[None, 1])",
+            "l3 = sch.fuse(l2, preserve_unit_iters=True)",
+            "l4, l5 = sch.split(loop=l3, factors=[None, 1], preserve_unit_iters=True)",
             'sch.bind(loop=l4, thread_axis="blockIdx.x")',
             'sch.bind(loop=l5, thread_axis="threadIdx.x")',
         ]
@@ -122,8 +122,8 @@ def test_cuda_zero_dim_add():
         [
             'b0 = sch.get_block(name="C", func_name="main")',
             "l1 = sch.add_unit_loop(block_or_loop=b0)",
-            "l2 = sch.fuse(l1)",
-            "l3, l4 = sch.split(loop=l2, factors=[None, 1])",
+            "l2 = sch.fuse(l1, preserve_unit_iters=True)",
+            "l3, l4 = sch.split(loop=l2, factors=[None, 1], preserve_unit_iters=True)",
             'sch.bind(loop=l3, thread_axis="blockIdx.x")',
             'sch.bind(loop=l4, thread_axis="threadIdx.x")',
         ]
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index 8b21d11a37..5f76e77592 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -78,12 +78,12 @@ def test_gpu_softmax_mn():
             "b1, = sch.get_consumers(block=b0)",
             "l2, l3 = sch.get_loops(block=b1)",
             "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l5, l6 = sch.split(loop=l3, factors=[None, v4])",
+            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l6, thread_axis="threadIdx.x")',
             "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l7, l8, l9 = sch.get_loops(block=b0)",
-            "l10, l11 = sch.split(loop=l9, factors=[None, v4])",
+            "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l11, thread_axis="threadIdx.x")',
         ],
         [
@@ -91,12 +91,12 @@ def test_gpu_softmax_mn():
             "b1, = sch.get_consumers(block=b0)",
             "l2, l3 = sch.get_loops(block=b1)",
             "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l5, l6 = sch.split(loop=l3, factors=[None, v4])",
+            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l6, thread_axis="threadIdx.x")',
             "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l7, l8, l9 = sch.get_loops(block=b0)",
-            "l10, l11 = sch.split(loop=l9, factors=[None, v4])",
+            "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l11, thread_axis="threadIdx.x")',
         ],
         [
@@ -105,22 +105,22 @@ def test_gpu_softmax_mn():
             "b2, = sch.get_consumers(block=b1)",
             "l3, l4 = sch.get_loops(block=b2)",
             "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l6, l7 = sch.split(loop=l4, factors=[None, v5])",
+            "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
             'sch.bind(loop=l7, thread_axis="threadIdx.x")',
             "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
             'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
             "l8, l9, l10 = sch.get_loops(block=b1)",
-            "l11, l12 = sch.split(loop=l10, factors=[None, v5])",
+            "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
             'sch.bind(loop=l12, thread_axis="threadIdx.x")',
             "b13, = sch.get_consumers(block=b0)",
             "l14, l15 = sch.get_loops(block=b13)",
             "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l17, l18 = sch.split(loop=l15, factors=[None, v16])",
+            "l17, l18 = sch.split(loop=l15, factors=[None, v16], preserve_unit_iters=True)",
             'sch.bind(loop=l18, thread_axis="threadIdx.x")',
             "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l19, l20, l21 = sch.get_loops(block=b0)",
-            "l22, l23 = sch.split(loop=l21, factors=[None, v16])",
+            "l22, l23 = sch.split(loop=l21, factors=[None, v16], preserve_unit_iters=True)",
             'sch.bind(loop=l23, thread_axis="threadIdx.x")',
         ],
     ]
@@ -147,7 +147,7 @@ def test_gpu_softmax_mn_after_inline():
             'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
             "v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
             "l2, l3 = sch.get_loops(block=b0)",
-            "l4, l5 = sch.split(loop=l3, factors=[None, v1])",
+            "l4, l5 = sch.split(loop=l3, factors=[None, v1], preserve_unit_iters=True)",
             'sch.bind(loop=l5, thread_axis="threadIdx.x")',
         ],
         [
@@ -155,12 +155,12 @@ def test_gpu_softmax_mn_after_inline():
             "b1, = sch.get_consumers(block=b0)",
             "l2, l3 = sch.get_loops(block=b1)",
             "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l5, l6 = sch.split(loop=l3, factors=[None, v4])",
+            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l6, thread_axis="threadIdx.x")',
             "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l7, l8, l9 = sch.get_loops(block=b0)",
-            "l10, l11 = sch.split(loop=l9, factors=[None, v4])",
+            "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
             'sch.bind(loop=l11, thread_axis="threadIdx.x")',
         ],
         [
@@ -169,19 +169,19 @@ def test_gpu_softmax_mn_after_inline():
             "b2, = sch.get_consumers(block=b1)",
             "l3, l4 = sch.get_loops(block=b2)",
             "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l6, l7 = sch.split(loop=l4, factors=[None, v5])",
+            "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
             'sch.bind(loop=l7, thread_axis="threadIdx.x")',
             "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
             'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
             "l8, l9, l10 = sch.get_loops(block=b1)",
-            "l11, l12 = sch.split(loop=l10, factors=[None, v5])",
+            "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
             'sch.bind(loop=l12, thread_axis="threadIdx.x")',
             "b13, b14 = sch.get_consumers(block=b0)",
             "l15, l16, l17, l18 = sch.get_loops(block=b13)",
             "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l19, l20, l21 = sch.get_loops(block=b0)",
-            "l22, l23 = sch.split(loop=l21, factors=[None, v5])",
+            "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)",
             'sch.bind(loop=l23, thread_axis="threadIdx.x")',
         ],
     ]
@@ -204,13 +204,13 @@ def test_gpu_batch_norm_bmn():
             "b1, = sch.get_consumers(block=b0)",
             "l2, = sch.get_loops(block=b1)",
             "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l4, l5 = sch.split(loop=l2, factors=[None, v3])",
+            "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)",
             'sch.bind(loop=l5, thread_axis="threadIdx.x")',
             "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)",
             'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
             "l6, l7, l8, l9 = sch.get_loops(block=b0)",
-            "l10 = sch.fuse(l8, l9)",
-            "l11, l12 = sch.split(loop=l10, factors=[None, v3])",
+            "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)",
+            "l11, l12 = sch.split(loop=l10, factors=[None, v3], preserve_unit_iters=True)",
             'sch.bind(loop=l12, thread_axis="threadIdx.x")',
         ],
     ]
@@ -232,6 +232,6 @@ def test_gpu_batch_norm_bmn():
 
 
 if __name__ == "__main__":
-    test_gpu_softmax_mn()
-    test_gpu_softmax_mn_after_inline()
+    # test_gpu_softmax_mn()
+    # test_gpu_softmax_mn_after_inline()
     test_gpu_batch_norm_bmn()
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 51f62f8bd8..30511d6690 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -48,11 +48,11 @@ def test_cpu_matmul():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
             "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
             "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
             'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
             "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
@@ -62,11 +62,11 @@ def test_cpu_matmul():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
             "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
             "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
             'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
             "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
@@ -76,11 +76,11 @@ def test_cpu_matmul():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
             "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
             "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
         ],
     ]
@@ -109,11 +109,11 @@ def test_cpu_matmul_relu():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
             "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
             "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
             "b24, = sch.get_consumers(block=b0)",
             "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
@@ -123,11 +123,11 @@ def test_cpu_matmul_relu():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
             "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
             "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
             "b24, = sch.get_consumers(block=b0)",
             "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
@@ -137,11 +137,11 @@ def test_cpu_matmul_relu():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
-            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
             "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
-            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
             "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
-            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+            "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
             "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
         ],
     ]
@@ -171,17 +171,17 @@ def test_cuda_matmul():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)",
-            "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])",
+            "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8], preserve_unit_iters=True)",
             "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)",
-            "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])",
+            "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True)",
             "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)",
-            "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])",
+            "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26], preserve_unit_iters=True)",
             "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)",
-            "l30 = sch.fuse(l9, l19)",
+            "l30 = sch.fuse(l9, l19, preserve_unit_iters=True)",
             'sch.bind(loop=l30, thread_axis="blockIdx.x")',
-            "l31 = sch.fuse(l10, l20)",
+            "l31 = sch.fuse(l10, l20, preserve_unit_iters=True)",
             'sch.bind(loop=l31, thread_axis="vthread.x")',
-            "l32 = sch.fuse(l11, l21)",
+            "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)",
             'sch.bind(loop=l32, thread_axis="threadIdx.x")',
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)',
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)',
@@ -190,13 +190,13 @@ def test_cuda_matmul():
             'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
             "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
             "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
-            "l41 = sch.fuse(l39, l40)",
+            "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
             "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
             'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
             'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
             "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
             "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
-            "l50 = sch.fuse(l48, l49)",
+            "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
             "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
             'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
         ]
@@ -227,30 +227,30 @@ def test_cuda_matmul_relu():
             'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")',
             "l1, l2, l3 = sch.get_loops(block=b0)",
             "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)",
-            "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])",
+            "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8], preserve_unit_iters=True)",
             "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)",
-            "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])",
+            "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True)",
             "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)",
-            "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])",
+            "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26], preserve_unit_iters=True)",
             "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)",
-            "l30 = sch.fuse(l9, l19)",
+            "l30 = sch.fuse(l9, l19, preserve_unit_iters=True)",
             'sch.bind(loop=l30, thread_axis="blockIdx.x")',
-            "l31 = sch.fuse(l10, l20)",
+            "l31 = sch.fuse(l10, l20, preserve_unit_iters=True)",
             'sch.bind(loop=l31, thread_axis="vthread.x")',
-            "l32 = sch.fuse(l11, l21)",
+            "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)",
             'sch.bind(loop=l32, thread_axis="threadIdx.x")',
             'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
             "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
             'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
             "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
             "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
-            "l41 = sch.fuse(l39, l40)",
+            "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
             "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
             'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
             'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
             "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
             "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
-            "l50 = sch.fuse(l48, l49)",
+            "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
             "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
             'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
         ]
@@ -366,33 +366,33 @@ def test_multi_level_tiling_conv2d_nchwc_vnni():
         """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
 sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4])
-l13, l14 = sch.split(loop=l5, factors=[1, 16])
+l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
 l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
 sch.reorder(l21, l22, l23, l24, l25, l14, l12)
 b27 = sch.blockize(loop=l14)
 sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
 l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
 v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
+l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
 v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
+l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
 v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
+l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
 v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
+l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
 v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
+l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
 v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79])
+l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
 v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83])
+l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
 v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87])
+l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
 v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91])
+l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
 v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95])
+l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
 sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
 b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
 sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split(
@@ -401,33 +401,33 @@ sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split(
         """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
 sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4])
-l13, l14 = sch.split(loop=l5, factors=[1, 16])
+l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
 l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
 sch.reorder(l21, l22, l23, l24, l25, l14, l12)
 b27 = sch.blockize(loop=l14)
 sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
 l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
 v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
+l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
 v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
+l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
 v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
+l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
 v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
+l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
 v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
+l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
 v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79])
+l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
 v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83])
+l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
 v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87])
+l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
 v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91])
+l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
 v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95])
+l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
 sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
 b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
 sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split(
@@ -436,33 +436,33 @@ sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split(
         """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
 sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4])
-l13, l14 = sch.split(loop=l5, factors=[1, 16])
+l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
 l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
 sch.reorder(l21, l22, l23, l24, l25, l14, l12)
 b27 = sch.blockize(loop=l14)
 sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
 l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
 v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
+l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
 v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
+l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
 v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
+l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
 v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
+l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
 v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
+l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
 v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79])
+l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
 v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83])
+l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
 v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87])
+l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
 v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91])
+l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
 v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95])
+l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
 sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)""".split(
             "\n"
         ),
@@ -517,36 +517,36 @@ def test_multi_level_tiling_dense_dpa4():
         """b0 = sch.get_block(name="compute", func_name="main")
 sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
 l1, l2, l3 = sch.get_loops(block=b0)
-l4, l5 = sch.split(loop=l3, factors=[32, 4])
+l4, l5 = sch.split(loop=l3, factors=[32, 4], preserve_unit_iters=True)
 sch.reorder(l5)
 b6 = sch.blockize(loop=l5)
 sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a")
 l7, l8, l9 = sch.get_loops(block=b6)
 v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)
-l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14])
+l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True)
 v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)
-l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24])
+l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True)
 v30, v31, v32 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)
-l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32])
+l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32], preserve_unit_iters=True)
 sch.reorder(l15, l25, l16, l26, l17, l27, l33, l34, l18, l28, l35, l19, l29)
-l36 = sch.fuse(l15, l25)
+l36 = sch.fuse(l15, l25, preserve_unit_iters=True)
 sch.bind(loop=l36, thread_axis="blockIdx.x")
-l37 = sch.fuse(l16, l26)
+l37 = sch.fuse(l16, l26, preserve_unit_iters=True)
 sch.bind(loop=l37, thread_axis="vthread.x")
-l38 = sch.fuse(l17, l27)
+l38 = sch.fuse(l17, l27, preserve_unit_iters=True)
 sch.bind(loop=l38, thread_axis="threadIdx.x")
 b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local")
 sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True)
 b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared")
 sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True)
 l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40)
-l47 = sch.fuse(l45, l46)
+l47 = sch.fuse(l45, l46, preserve_unit_iters=True)
 v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48)
 b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared")
 sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True)
 l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49)
-l56 = sch.fuse(l54, l55)
+l56 = sch.fuse(l54, l55, preserve_unit_iters=True)
 v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v57)""".split(
             "\n"
diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py
index 6fc573b1a8..d1d87b60b7 100644
--- a/tests/python/unittest/test_tir_schedule_trace.py
+++ b/tests/python/unittest/test_tir_schedule_trace.py
@@ -87,7 +87,7 @@ def _make_split(inputs, outputs):  # pylint: disable=redefined-builtin
     return Instruction(
         kind=InstructionKind.get("Split"),
         inputs=inputs,
-        attrs=[],
+        attrs=[True],
         outputs=outputs,
     )
 
@@ -262,7 +262,7 @@ def test_trace_simplified_3():
         (
             'b0 = sch.get_block(name="B", func_name="main")',
             "l1, = sch.get_loops(block=b0)",
-            "l2, l3 = sch.split(loop=l1, factors=[None, 32])",
+            "l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)",
         )
     )