You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/06/16 05:42:17 UTC
[tvm] branch main updated: [TIR] Add preserve-unit-iters (#11585)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 89e1a6c3f2 [TIR] Add preserve-unit-iters (#11585)
89e1a6c3f2 is described below
commit 89e1a6c3f2bbdaa3f585459cefbc7612ae46b1ad
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Jun 15 22:42:12 2022 -0700
[TIR] Add preserve-unit-iters (#11585)
---
include/tvm/tir/schedule/schedule.h | 7 +-
python/tvm/tir/schedule/schedule.py | 21 ++-
src/tir/schedule/concrete_schedule.cc | 9 +-
src/tir/schedule/concrete_schedule.h | 5 +-
src/tir/schedule/primitive.h | 7 +-
src/tir/schedule/primitive/loop_transformation.cc | 64 +++++----
src/tir/schedule/traced_schedule.cc | 13 +-
src/tir/schedule/traced_schedule.h | 5 +-
.../unittest/test_meta_schedule_integration.py | 9 +-
.../test_meta_schedule_post_order_apply.py | 8 +-
...test_meta_schedule_schedule_rule_add_rfactor.py | 4 +-
.../test_meta_schedule_schedule_rule_auto_bind.py | 12 +-
...chedule_schedule_rule_cross_thread_reduction.py | 38 ++---
...ta_schedule_schedule_rule_multi_level_tiling.py | 158 ++++++++++-----------
tests/python/unittest/test_tir_schedule_trace.py | 4 +-
15 files changed, 202 insertions(+), 162 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index d3ecd8a113..d95a9d4e7e 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -277,9 +277,10 @@ class ScheduleNode : public runtime::Object {
* 3) All loops must start with 0.
* 4) The domain of a loop to be fused cannot depend on another loop to be fused.
* \param loop_rvs The loops to be fused
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new loop after fusion
*/
- virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
+ virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0;
/*!
* \brief Split a loop into a list of consecutive loops. It requires:
* 1) The loop can't have annotation or thread binding.
@@ -287,9 +288,11 @@ class ScheduleNode : public runtime::Object {
* \param loop_rv The loop to be split
* \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
* that factor is inferred.
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new loops after split
*/
- virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
+ virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
+ bool preserve_unit_iters = true) = 0;
/*!
* \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
* It requires:
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index d29495c430..7a1e244604 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -495,7 +495,11 @@ class Schedule(Object):
########## Schedule: Transform loops ##########
@type_checked
- def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+ def fuse(
+ self,
+ *loops: List[LoopRV],
+ preserve_unit_iters: bool = True,
+ ) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
@@ -553,13 +557,14 @@ class Schedule(Object):
B[vi, vj] = A[vi, vj] * 2.0
"""
- return _ffi_api.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member
+ return _ffi_api.ScheduleFuse(self, loops, preserve_unit_iters) # type: ignore # pylint: disable=no-member
@type_checked
def split(
self,
loop: LoopRV,
factors: List[Union[int, ExprRV, None]],
+ preserve_unit_iters: bool = True,
) -> List[LoopRV]:
"""Split a loop into a list of consecutive loops. It requires:
1) The loop can't have annotation or thread binding.
@@ -580,6 +585,9 @@ class Schedule(Object):
- ExprRV
- Positive constant integers
+ preserve_unit_iters : bool
+ Whether or not to preserve unit iterators in block bindings
+
Returns
-------
split_loops : List[LoopRV]
@@ -628,7 +636,14 @@ class Schedule(Object):
"""
# it will be checked later in C++ implementation
# that there is at most one None in `factors`
- return list(_ffi_api.ScheduleSplit(self, loop, factors)) # type: ignore # pylint: disable=no-member
+ return list(
+ _ffi_api.ScheduleSplit( # type: ignore # pylint: disable=no-member
+ self,
+ loop,
+ factors,
+ preserve_unit_iters,
+ )
+ )
@type_checked
def reorder(self, *ordered_loops: List[LoopRV]) -> None:
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 051bd42506..b2f48753b5 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -333,19 +333,20 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
/******** Schedule: Transform loops ********/
-LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
+LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
- result = tir::Fuse(state_, loop_srefs);
+ result = tir::Fuse(state_, loop_srefs, preserve_unit_iters);
TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(result);
}
Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
- const Array<Optional<ExprRV>>& factor_rvs) {
+ const Array<Optional<ExprRV>>& factor_rvs,
+ bool preserve_unit_iters) {
class NotSingleInferFactorError : public ScheduleError {
public:
explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
@@ -440,7 +441,7 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
} else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) {
throw WrongFactorProductError(state_->mod, GetRef<For>(loop));
}
- results = tir::Split(state_, loop_sref, factors);
+ results = tir::Split(state_, loop_sref, factors, preserve_unit_iters);
TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(results);
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 11d68694a1..dfbacb530a 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -96,8 +96,9 @@ class ConcreteScheduleNode : public ScheduleNode {
Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
/******** Schedule: Transform loops ********/
- LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
- Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
+ LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
+ Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
+ bool preserve_unit_iters) override;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
LoopRV AddUnitLoop(const BlockRV& block_rv) override;
LoopRV AddUnitLoop(const LoopRV& loop_rv) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index af0f417e4c..212571df10 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -156,10 +156,11 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
* \param self The state of the schedule
* \param loop_sref The sref to the loop being split
* \param factors The splitting factors
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return An array of srefs to the loops after splitting
*/
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
- const Array<PrimExpr>& factors);
+ const Array<PrimExpr>& factors, bool preserve_unit_iters);
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
@@ -168,9 +169,11 @@ TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
* 4) The domain of a loop to be fused cannot depend on another loop to be fused.
* \param self The state of the schedule
* \param loop_srefs An array of srefs to the loops to be fused
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The sref to the fused loop
*/
-TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs);
+TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs,
+ bool preserve_unit_loops);
/*!
* \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
* It requires:
diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc
index bb505bca33..f1b6f46e1b 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -77,18 +77,21 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator {
/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */
class IterMapSimplifyBlockBinding : public StmtExprMutator {
public:
- explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent)
- : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {}
-
- static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs,
- MapNode* opaque_blocks) {
+ explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent,
+ bool preserve_unit_iters)
+ : opaque_blocks_(opaque_blocks),
+ loop_var2extent_(loop_var2extent),
+ preserve_unit_iters_(preserve_unit_iters) {}
+
+ static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs, MapNode* opaque_blocks,
+ bool preserve_unit_iters) {
Map<Var, Range> loop_var2extent;
for (const StmtSRef& sref : loop_srefs) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
}
- return Downcast<For>(
- IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt)));
+ return Downcast<For>(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent),
+ preserve_unit_iters)(std::move(stmt)));
}
private:
@@ -112,11 +115,12 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
}
return std::move(realize);
}
- Array<PrimExpr> v = arith::IterMapSimplify(/*indices=*/op->iter_values,
- /*input_iters=*/loop_var2extent_,
- /*input_pred=*/op->predicate,
- /*check_level=*/arith::IterMapLevel::Surjective,
- /*simplify_trivial_iterators=*/false);
+ Array<PrimExpr> v =
+ arith::IterMapSimplify(/*indices=*/op->iter_values,
+ /*input_iters=*/loop_var2extent_,
+ /*input_pred=*/op->predicate,
+ /*check_level=*/arith::IterMapLevel::Surjective,
+ /*simplify_trivial_iterators=*/!preserve_unit_iters_);
if (v.same_as(op->iter_values)) {
return GetRef<Stmt>(op);
} else {
@@ -130,6 +134,8 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
MapNode* opaque_blocks_;
/*! \brief The range of loops */
Map<Var, Range> loop_var2extent_;
+ /*! \brief Whether or not to simplify unit iterators */
+ bool preserve_unit_iters_;
};
class BlockPropertyError : public ScheduleError {
@@ -376,8 +382,8 @@ class DependentLoopError : public ScheduleError {
PrimitiveKind kind_;
};
-Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
- const Array<PrimExpr>& factors) {
+Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array<PrimExpr>& factors,
+ bool preserve_unit_iters) {
// Invariance
// - The total repeat number has not changed for each direct child block with updating predicate.
// - The execution order has not changed. (The block executes with the same args and the same
@@ -432,7 +438,8 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt);
}
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref),
- opaque_block_reuse.CopyOnWrite());
+ opaque_block_reuse.CopyOnWrite(),
+ preserve_unit_iters);
self->Replace(loop_sref, new_stmt, opaque_block_reuse);
Array<StmtSRef> result_srefs;
result_srefs.reserve(n);
@@ -444,7 +451,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
return result_srefs;
}
-StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
+StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
// Invariance
// - The total repeat number has not changed for each direct child block.
// - The execution order has not changed. (The block executes with the same
@@ -527,7 +534,8 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
fused_extent = analyzer.Simplify(fused_extent);
new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt);
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(
- std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite());
+ std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite(),
+ preserve_unit_iters);
self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
return self->stmt2ref.at(new_stmt.get());
}
@@ -755,7 +763,7 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
private:
static constexpr size_t kNumInputs = 2;
- static constexpr size_t kNumAttrs = 0;
+ static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;
template <size_t delta>
@@ -770,14 +778,17 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
}
static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv,
- Array<Optional<ExprRV>> factors) {
- return sch->Split(loop_rv, factors);
+ Array<Optional<ExprRV>> factors,
+ Bool preserve_unit_iters) {
+ return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool());
}
- static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors) {
+ static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors,
+ Bool preserve_unit_iters) {
PythonAPICall py("split");
py.Input("loop", loop_rv);
py.Input("factors", factors);
+ py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
py.OutputList(outputs);
return py.Str();
}
@@ -792,7 +803,7 @@ struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
private:
static constexpr size_t kNumInputs = 1;
- static constexpr size_t kNumAttrs = 0;
+ static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;
template <size_t delta>
@@ -801,15 +812,18 @@ struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
setter(delta, inputs);
}
- static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
- return sch->Fuse(loop_rvs);
+ static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs,
+ Bool preserve_unit_iters) {
+ return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool());
}
- static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
+ static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs,
+ Bool preserve_unit_iters) {
PythonAPICall py("fuse");
for (const String& loop_rv : loop_rvs) {
py.Input("", loop_rv);
}
+ py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
py.SingleOutput(outputs);
return py.Str();
}
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 95a10e26ac..733b5d872f 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -158,20 +158,21 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
/******** Schedule: Transform loops ********/
-LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
- LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs);
+LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
+ LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);
static const InstructionKind& kind = InstructionKind::Get("Fuse");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
- /*attrs=*/{},
+ /*attrs=*/{Integer(preserve_unit_loops)},
/*outputs=*/{result}));
return result;
}
Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
- const Array<Optional<ExprRV>>& factor_rvs) {
- Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs);
+ const Array<Optional<ExprRV>>& factor_rvs,
+ bool preserve_unit_iters) {
+ Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters);
std::vector<ObjectRef> inputs;
inputs.reserve(1 + factor_rvs.size());
@@ -183,7 +184,7 @@ Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
static const InstructionKind& kind = InstructionKind::Get("Split");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/inputs,
- /*attrs=*/{},
+ /*attrs=*/{Integer(preserve_unit_iters)},
/*outputs=*/{results.begin(), results.end()}));
return results;
}
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index 25bf3d4871..178026d9ea 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -60,8 +60,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
/******** Schedule: Transform loops ********/
- LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
- Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
+ LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
+ Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
+ bool preserve_unit_iters) final;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
LoopRV AddUnitLoop(const BlockRV& block_rv) final;
LoopRV AddUnitLoop(const LoopRV& loop_rv) final;
diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py
index f2802b41eb..6d5016cd81 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -193,6 +193,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
@requires_torch
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
def filter_func(args) -> bool:
+ from tvm.te import create_prim_func # pylint: disable=import-outside-toplevel
has_complex_op = False
visited = set()
@@ -205,16 +206,16 @@ def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
if isinstance(t.op, te.PlaceholderOp):
pass
elif isinstance(t.op, te.ComputeOp):
- has_complex_op = has_complex_op or any(
- [isinstance(e, tir.Reduce) for e in t.op.body]
- )
+ has_complex_op = has_complex_op or any(isinstance(e, tir.Reduce) for e in t.op.body)
for x in t.op.input_tensors:
traverse(x)
visited.add(t.handle.value)
for t in args:
traverse(t)
- return has_complex_op
+ if not has_complex_op:
+ return None
+ return create_prim_func(args)
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.extract_task_from_relay(
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index 2609d2be9d..21d29ac74d 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -326,12 +326,12 @@ def test_meta_schedule_post_order_apply_remove_block():
'b2 = sch.get_block(name="C", func_name="main")',
"sch.compute_inline(block=b1)",
"l3, l4 = sch.get_loops(block=b2)",
- "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ")",
- "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ")",
+ "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)",
+ "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)",
"sch.reorder(l5, l7, l6, l8)",
"l9, l10 = sch.get_loops(block=b0)",
- "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ")",
- "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ")",
+ "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)",
+ "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ", preserve_unit_iters=True)",
"sch.reorder(l11, l13, l12, l14)",
]
)
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
index 09daea0945..a39c8aea5f 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
@@ -43,7 +43,7 @@ def test_cpu_matmul():
'b0 = sch.get_block(name="C", func_name="main")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
- "l6, l7 = sch.split(loop=l3, factors=[v4, v5])",
+ "l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)",
"b8 = sch.rfactor(loop=l7, factor_axis=2)",
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)',
],
@@ -51,7 +51,7 @@ def test_cpu_matmul():
'b0 = sch.get_block(name="C", func_name="main")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
- "l6, l7 = sch.split(loop=l3, factors=[v4, v5])",
+ "l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)",
"b8 = sch.rfactor(loop=l6, factor_axis=2)",
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)',
],
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
index 9c43c23a3e..a89cca72e1 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
@@ -76,9 +76,9 @@ def test_cuda_element_wise():
[
'b0 = sch.get_block(name="C", func_name="main")',
"l1, l2 = sch.get_loops(block=b0)",
- "l3 = sch.fuse(l1, l2)",
+ "l3 = sch.fuse(l1, l2, preserve_unit_iters=True)",
"v4 = sch.sample_categorical(candidates=[32, 64, 128, 256, 512, 1024], probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666])",
- "l5, l6 = sch.split(loop=l3, factors=[None, v4])",
+ "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l5, thread_axis="blockIdx.x")',
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
]
@@ -100,8 +100,8 @@ def test_cuda_reduction_loop_only():
'b0 = sch.get_block(name="C", func_name="main")',
"l1, = sch.get_loops(block=b0)",
"l2 = sch.add_unit_loop(block_or_loop=l1)",
- "l3 = sch.fuse(l2)",
- "l4, l5 = sch.split(loop=l3, factors=[None, 1])",
+ "l3 = sch.fuse(l2, preserve_unit_iters=True)",
+ "l4, l5 = sch.split(loop=l3, factors=[None, 1], preserve_unit_iters=True)",
'sch.bind(loop=l4, thread_axis="blockIdx.x")',
'sch.bind(loop=l5, thread_axis="threadIdx.x")',
]
@@ -122,8 +122,8 @@ def test_cuda_zero_dim_add():
[
'b0 = sch.get_block(name="C", func_name="main")',
"l1 = sch.add_unit_loop(block_or_loop=b0)",
- "l2 = sch.fuse(l1)",
- "l3, l4 = sch.split(loop=l2, factors=[None, 1])",
+ "l2 = sch.fuse(l1, preserve_unit_iters=True)",
+ "l3, l4 = sch.split(loop=l2, factors=[None, 1], preserve_unit_iters=True)",
'sch.bind(loop=l3, thread_axis="blockIdx.x")',
'sch.bind(loop=l4, thread_axis="threadIdx.x")',
]
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index 8b21d11a37..5f76e77592 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -78,12 +78,12 @@ def test_gpu_softmax_mn():
"b1, = sch.get_consumers(block=b0)",
"l2, l3 = sch.get_loops(block=b1)",
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l5, l6 = sch.split(loop=l3, factors=[None, v4])",
+ "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
- "l10, l11 = sch.split(loop=l9, factors=[None, v4])",
+ "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l11, thread_axis="threadIdx.x")',
],
[
@@ -91,12 +91,12 @@ def test_gpu_softmax_mn():
"b1, = sch.get_consumers(block=b0)",
"l2, l3 = sch.get_loops(block=b1)",
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l5, l6 = sch.split(loop=l3, factors=[None, v4])",
+ "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
- "l10, l11 = sch.split(loop=l9, factors=[None, v4])",
+ "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l11, thread_axis="threadIdx.x")',
],
[
@@ -105,22 +105,22 @@ def test_gpu_softmax_mn():
"b2, = sch.get_consumers(block=b1)",
"l3, l4 = sch.get_loops(block=b2)",
"v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l6, l7 = sch.split(loop=l4, factors=[None, v5])",
+ "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
"l8, l9, l10 = sch.get_loops(block=b1)",
- "l11, l12 = sch.split(loop=l10, factors=[None, v5])",
+ "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
'sch.bind(loop=l12, thread_axis="threadIdx.x")',
"b13, = sch.get_consumers(block=b0)",
"l14, l15 = sch.get_loops(block=b13)",
"v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l17, l18 = sch.split(loop=l15, factors=[None, v16])",
+ "l17, l18 = sch.split(loop=l15, factors=[None, v16], preserve_unit_iters=True)",
'sch.bind(loop=l18, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l19, l20, l21 = sch.get_loops(block=b0)",
- "l22, l23 = sch.split(loop=l21, factors=[None, v16])",
+ "l22, l23 = sch.split(loop=l21, factors=[None, v16], preserve_unit_iters=True)",
'sch.bind(loop=l23, thread_axis="threadIdx.x")',
],
]
@@ -147,7 +147,7 @@ def test_gpu_softmax_mn_after_inline():
'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
"v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l2, l3 = sch.get_loops(block=b0)",
- "l4, l5 = sch.split(loop=l3, factors=[None, v1])",
+ "l4, l5 = sch.split(loop=l3, factors=[None, v1], preserve_unit_iters=True)",
'sch.bind(loop=l5, thread_axis="threadIdx.x")',
],
[
@@ -155,12 +155,12 @@ def test_gpu_softmax_mn_after_inline():
"b1, = sch.get_consumers(block=b0)",
"l2, l3 = sch.get_loops(block=b1)",
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l5, l6 = sch.split(loop=l3, factors=[None, v4])",
+ "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
- "l10, l11 = sch.split(loop=l9, factors=[None, v4])",
+ "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
'sch.bind(loop=l11, thread_axis="threadIdx.x")',
],
[
@@ -169,19 +169,19 @@ def test_gpu_softmax_mn_after_inline():
"b2, = sch.get_consumers(block=b1)",
"l3, l4 = sch.get_loops(block=b2)",
"v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l6, l7 = sch.split(loop=l4, factors=[None, v5])",
+ "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
"l8, l9, l10 = sch.get_loops(block=b1)",
- "l11, l12 = sch.split(loop=l10, factors=[None, v5])",
+ "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
'sch.bind(loop=l12, thread_axis="threadIdx.x")',
"b13, b14 = sch.get_consumers(block=b0)",
"l15, l16, l17, l18 = sch.get_loops(block=b13)",
"sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l19, l20, l21 = sch.get_loops(block=b0)",
- "l22, l23 = sch.split(loop=l21, factors=[None, v5])",
+ "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)",
'sch.bind(loop=l23, thread_axis="threadIdx.x")',
],
]
@@ -204,13 +204,13 @@ def test_gpu_batch_norm_bmn():
"b1, = sch.get_consumers(block=b0)",
"l2, = sch.get_loops(block=b1)",
"v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l4, l5 = sch.split(loop=l2, factors=[None, v3])",
+ "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)",
'sch.bind(loop=l5, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l6, l7, l8, l9 = sch.get_loops(block=b0)",
- "l10 = sch.fuse(l8, l9)",
- "l11, l12 = sch.split(loop=l10, factors=[None, v3])",
+ "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)",
+ "l11, l12 = sch.split(loop=l10, factors=[None, v3], preserve_unit_iters=True)",
'sch.bind(loop=l12, thread_axis="threadIdx.x")',
],
]
@@ -232,6 +232,6 @@ def test_gpu_batch_norm_bmn():
if __name__ == "__main__":
- test_gpu_softmax_mn()
- test_gpu_softmax_mn_after_inline()
+ # test_gpu_softmax_mn()
+ # test_gpu_softmax_mn_after_inline()
test_gpu_batch_norm_bmn()
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 51f62f8bd8..30511d6690 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -48,11 +48,11 @@ def test_cpu_matmul():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
- "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+ "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
"v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
- "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+ "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
"v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
- "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+ "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
@@ -62,11 +62,11 @@ def test_cpu_matmul():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
- "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+ "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
"v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
- "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+ "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
"v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
- "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+ "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
@@ -76,11 +76,11 @@ def test_cpu_matmul():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
- "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+ "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
"v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
- "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+ "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
"v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
- "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+ "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
],
]
@@ -109,11 +109,11 @@ def test_cpu_matmul_relu():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
- "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+ "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
"v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
- "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+ "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
"v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
- "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+ "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
"b24, = sch.get_consumers(block=b0)",
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
@@ -123,11 +123,11 @@ def test_cpu_matmul_relu():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
- "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+ "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
"v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
- "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+ "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
"v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
- "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+ "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
"b24, = sch.get_consumers(block=b0)",
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
@@ -137,11 +137,11 @@ def test_cpu_matmul_relu():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
- "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
+ "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)",
"v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
- "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
+ "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)",
"v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
- "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
+ "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
],
]
@@ -171,17 +171,17 @@ def test_cuda_matmul():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)",
- "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])",
+ "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8], preserve_unit_iters=True)",
"v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)",
- "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])",
+ "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True)",
"v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)",
- "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])",
+ "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26], preserve_unit_iters=True)",
"sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)",
- "l30 = sch.fuse(l9, l19)",
+ "l30 = sch.fuse(l9, l19, preserve_unit_iters=True)",
'sch.bind(loop=l30, thread_axis="blockIdx.x")',
- "l31 = sch.fuse(l10, l20)",
+ "l31 = sch.fuse(l10, l20, preserve_unit_iters=True)",
'sch.bind(loop=l31, thread_axis="vthread.x")',
- "l32 = sch.fuse(l11, l21)",
+ "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)",
'sch.bind(loop=l32, thread_axis="threadIdx.x")',
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)',
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)',
@@ -190,13 +190,13 @@ def test_cuda_matmul():
'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
"l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
- "l41 = sch.fuse(l39, l40)",
+ "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
"v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
"l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
- "l50 = sch.fuse(l48, l49)",
+ "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
"v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
]
@@ -227,30 +227,30 @@ def test_cuda_matmul_relu():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)",
- "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])",
+ "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8], preserve_unit_iters=True)",
"v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)",
- "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])",
+ "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True)",
"v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)",
- "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])",
+ "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26], preserve_unit_iters=True)",
"sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)",
- "l30 = sch.fuse(l9, l19)",
+ "l30 = sch.fuse(l9, l19, preserve_unit_iters=True)",
'sch.bind(loop=l30, thread_axis="blockIdx.x")',
- "l31 = sch.fuse(l10, l20)",
+ "l31 = sch.fuse(l10, l20, preserve_unit_iters=True)",
'sch.bind(loop=l31, thread_axis="vthread.x")',
- "l32 = sch.fuse(l11, l21)",
+ "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)",
'sch.bind(loop=l32, thread_axis="threadIdx.x")',
'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")',
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
"l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
- "l41 = sch.fuse(l39, l40)",
+ "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)",
"v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
"l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
- "l50 = sch.fuse(l48, l49)",
+ "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)",
"v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
]
@@ -366,33 +366,33 @@ def test_multi_level_tiling_conv2d_nchwc_vnni():
"""b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4])
-l13, l14 = sch.split(loop=l5, factors=[1, 16])
+l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
sch.reorder(l21, l22, l23, l24, l25, l14, l12)
b27 = sch.blockize(loop=l14)
sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
+l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
+l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
+l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
+l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
+l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79])
+l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83])
+l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87])
+l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91])
+l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95])
+l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split(
@@ -401,33 +401,33 @@ sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split(
"""b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4])
-l13, l14 = sch.split(loop=l5, factors=[1, 16])
+l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
sch.reorder(l21, l22, l23, l24, l25, l14, l12)
b27 = sch.blockize(loop=l14)
sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
+l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
+l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
+l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
+l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
+l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79])
+l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83])
+l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87])
+l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91])
+l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95])
+l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split(
@@ -436,33 +436,33 @@ sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split(
"""b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
-l11, l12 = sch.split(loop=l10, factors=[1, 4])
-l13, l14 = sch.split(loop=l5, factors=[1, 16])
+l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True)
+l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True)
l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
sch.reorder(l21, l22, l23, l24, l25, l14, l12)
b27 = sch.blockize(loop=l14)
sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
-l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
+l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True)
v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
-l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
+l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True)
v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
-l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
+l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True)
v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
-l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
+l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True)
v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
-l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
+l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True)
v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
-l80, l81 = sch.split(loop=l33, factors=[v78, v79])
+l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True)
v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
-l84, l85 = sch.split(loop=l34, factors=[v82, v83])
+l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True)
v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
-l88, l89 = sch.split(loop=l35, factors=[v86, v87])
+l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True)
v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
-l92, l93 = sch.split(loop=l36, factors=[v90, v91])
+l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True)
v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
-l96, l97 = sch.split(loop=l37, factors=[v94, v95])
+l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True)
sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)""".split(
"\n"
),
@@ -517,36 +517,36 @@ def test_multi_level_tiling_dense_dpa4():
"""b0 = sch.get_block(name="compute", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
l1, l2, l3 = sch.get_loops(block=b0)
-l4, l5 = sch.split(loop=l3, factors=[32, 4])
+l4, l5 = sch.split(loop=l3, factors=[32, 4], preserve_unit_iters=True)
sch.reorder(l5)
b6 = sch.blockize(loop=l5)
sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a")
l7, l8, l9 = sch.get_loops(block=b6)
v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)
-l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14])
+l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True)
v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)
-l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24])
+l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True)
v30, v31, v32 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)
-l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32])
+l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32], preserve_unit_iters=True)
sch.reorder(l15, l25, l16, l26, l17, l27, l33, l34, l18, l28, l35, l19, l29)
-l36 = sch.fuse(l15, l25)
+l36 = sch.fuse(l15, l25, preserve_unit_iters=True)
sch.bind(loop=l36, thread_axis="blockIdx.x")
-l37 = sch.fuse(l16, l26)
+l37 = sch.fuse(l16, l26, preserve_unit_iters=True)
sch.bind(loop=l37, thread_axis="vthread.x")
-l38 = sch.fuse(l17, l27)
+l38 = sch.fuse(l17, l27, preserve_unit_iters=True)
sch.bind(loop=l38, thread_axis="threadIdx.x")
b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local")
sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True)
b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True)
l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40)
-l47 = sch.fuse(l45, l46)
+l47 = sch.fuse(l45, l46, preserve_unit_iters=True)
v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48)
b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True)
l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49)
-l56 = sch.fuse(l54, l55)
+l56 = sch.fuse(l54, l55, preserve_unit_iters=True)
v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v57)""".split(
"\n"
diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py
index 6fc573b1a8..d1d87b60b7 100644
--- a/tests/python/unittest/test_tir_schedule_trace.py
+++ b/tests/python/unittest/test_tir_schedule_trace.py
@@ -87,7 +87,7 @@ def _make_split(inputs, outputs): # pylint: disable=redefined-builtin
return Instruction(
kind=InstructionKind.get("Split"),
inputs=inputs,
- attrs=[],
+ attrs=[True],
outputs=outputs,
)
@@ -262,7 +262,7 @@ def test_trace_simplified_3():
(
'b0 = sch.get_block(name="B", func_name="main")',
"l1, = sch.get_loops(block=b0)",
- "l2, l3 = sch.split(loop=l1, factors=[None, 32])",
+ "l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)",
)
)