You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by bo...@apache.org on 2022/04/07 19:01:40 UTC
[tvm] branch main updated: relax reorder primitive's affineness check (#10887)
This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 00c830ece0 relax reorder primitive's affineness check (#10887)
00c830ece0 is described below
commit 00c830ece0feb9455fe0045b85ed01d6d363a495
Author: wrongtest <wr...@gmail.com>
AuthorDate: Fri Apr 8 03:01:33 2022 +0800
relax reorder primitive's affineness check (#10887)
---
src/tir/schedule/analysis.h | 11 +++
src/tir/schedule/analysis/analysis.cc | 50 ++++++++++--
src/tir/schedule/primitive/loop_transformation.cc | 21 +++--
tests/python/unittest/test_tir_schedule_reorder.py | 89 ++++++++++++++++++++++
4 files changed, 156 insertions(+), 15 deletions(-)
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index e74b9ea264..b76d41326f 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -231,6 +231,17 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
*/
void CheckAffineBinding(const ScheduleState& self, Block block);
+/*!
+ * \brief Check whether a block has an affine binding under the high exclusive sref node,
+ * throw an exception if the block does not have an affine binding.
+ * \param self The schedule state
+ * \param block The block to be checked
+ * \param high_exclusive The highest sref node
+ * \throw ScheduleError If the input block does not have an affine binding
+ */
+void CheckPartialAffineBinding(const ScheduleState& self, Block block,
+ const Optional<StmtSRef>& high_exclusive);
+
/*!
* \brief Extracts the ranges of loop variables in a path of the sref tree
* \param low_inclusive The lowest node in the path
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 435870471f..4a7ac401dd 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -544,26 +544,62 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
return true;
}
-void CheckAffineBinding(const ScheduleState& self, Block block) {
+void CheckPartialAffineBinding(const ScheduleState& self, Block block,
+ const Optional<StmtSRef>& high_exclusive) {
class NotAffineBindingError : public ScheduleError {
public:
- explicit NotAffineBindingError(IRModule mod, Block block)
- : mod_(std::move(mod)), block_(std::move(block)) {}
+ explicit NotAffineBindingError(IRModule mod, Block block, Optional<StmtSRef> high_exclusive)
+ : mod_(std::move(mod)), block_(std::move(block)) {
+ if (high_exclusive.defined()) {
+ high_exclusive_loop_ = high_exclusive.value()->StmtAs<ForNode>();
+ }
+ }
String FastErrorString() const final {
- return "ScheduleError: The block is required to have an affine binding";
+ std::ostringstream ss;
+ if (high_exclusive_loop_) {
+ ss << "ScheduleError: The block is required to have an partial affine binding under "
+ << high_exclusive_loop_->loop_var;
+ } else {
+ ss << "ScheduleError: The block is required to have an affine binding";
+ }
+ return ss.str();
}
String DetailRenderTemplate() const final {
- return "The block {0} is required to have an affine binding";
+ std::ostringstream ss;
+ if (high_exclusive_loop_) {
+ ss << "The block {0} is required to have an partial affine binding under "
+ << high_exclusive_loop_->loop_var;
+ } else {
+ ss << "The block {0} is required to have an affine binding";
+ }
+ return ss.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
IRModule mod_;
Block block_;
+ const ForNode* high_exclusive_loop_{nullptr};
};
- if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) {
- throw NotAffineBindingError(self->mod, std::move(block));
+ StmtSRef block_sref = self->stmt2ref.at(block.get());
+ if (self->IsAffineBlockBinding(block_sref)) {
+ // check block cached state for global affineness
+ return;
+ }
+ if (block_sref->parent && high_exclusive.defined()) {
+ // if it is not of global affine binding, check affineness under high_exclusive,
+ arith::Analyzer analyzer;
+ Map<Var, Range> dom_map =
+ LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent), high_exclusive);
+ if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) {
+ return;
+ }
}
+ throw NotAffineBindingError(self->mod, std::move(block), high_exclusive);
+}
+
+void CheckAffineBinding(const ScheduleState& self, Block block) {
+ CheckPartialAffineBinding(self, std::move(block), NullOpt);
}
Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc
index fa2a4469b8..d64a72ed34 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -134,16 +134,18 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
class BlockPropertyError : public ScheduleError {
public:
/*!
- * \brief Check that all the blocks under the specific stmt have affine bindings and only have
- * data-parallel or reduction block iters
+ * \brief Check that all the blocks under the specific stmt have affine bindings
+ * wrt top loop sref and only have data-parallel or reduction block iters
* \param self The state of the schedule
* \param sref The sref to the specific stmt
*/
- static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self,
+ static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* top,
const StmtSRefNode* sref) {
class BlockIterTypeAndAffineBindingChecker : public StmtVisitor {
public:
- explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {}
+ explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state,
+ const StmtSRefNode* top)
+ : state_(state), top_(top) {}
private:
void VisitStmt_(const BlockNode* op) final {
@@ -151,13 +153,16 @@ class BlockPropertyError : public ScheduleError {
if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) {
throw BlockPropertyError(state_->mod, GetRef<Block>(op));
}
- CheckAffineBinding(state_, GetRef<Block>(op));
+ Optional<StmtSRef> high_exclusive =
+ top_->parent ? GetRef<StmtSRef>(top_->parent) : Optional<StmtSRef>(NullOpt);
+ CheckPartialAffineBinding(state_, GetRef<Block>(op), high_exclusive);
}
}
const ScheduleState& state_;
+ const StmtSRefNode* top_;
};
- BlockIterTypeAndAffineBindingChecker checker(self);
+ BlockIterTypeAndAffineBindingChecker checker(self, top);
checker(GetRef<Stmt>(sref->stmt));
}
@@ -708,8 +713,8 @@ void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
// Step 3. Collect all loops in the chain and check the loops are single-branch
std::vector<const StmtSRefNode*> chain = GetLoopsInReorderRange(self, top, bottom);
// Step 4. Check the block below has all its block_var to be data-parallel or reduction,
- // and the block has an affine binding.
- BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom);
+ // and the block has an affine binding wrt top of the loop range.
+ BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, top, bottom);
// Step 5. Replace the original loops with the reordered loops and check that outer loop is
// not dependent on inner loop
For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs);
diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py
index f62a316f80..462099e6fe 100644
--- a/tests/python/unittest/test_tir_schedule_reorder.py
+++ b/tests/python/unittest/test_tir_schedule_reorder.py
@@ -213,6 +213,95 @@ def test_reorder_with_opaque_access():
verify_trace_roundtrip(sch=sch, mod=opaque_access)
+def test_reorder_with_partial_affineness():
+ @T.prim_func
+ def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+ # example to write first axis multiple times
+ for v0, v1, v2 in T.grid(6, 4, 4):
+ with T.block("block"):
+ i = T.axis.spatial(14, v0 * 2 + v1)
+ j = T.axis.spatial(4, v2)
+ B[i, j] = A[i, j] + 1.0
+
+ @T.prim_func
+ def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+ # example to write first axis multiple times
+ for v0, v2, v1 in T.grid(6, 4, 4):
+ with T.block("block"):
+ i = T.axis.spatial(14, v0 * 2 + v1)
+ j = T.axis.spatial(4, v2)
+ B[i, j] = A[i, j] + 1.0
+
+ sch = tir.Schedule(non_affine_func, debug_mask="all")
+ v0, v1, v2 = sch.get_loops(sch.get_block("block"))
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.reorder(v0, v2, v1)
+
+ sch.reorder(v2, v1)
+ tvm.ir.assert_structural_equal(non_affine_func_reorder, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=non_affine_func)
+
+
+def test_reorder_with_cascade_tiled_ops():
+ @T.prim_func
+ def cascade_pool_ops(
+ x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
+ ) -> None:
+ y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
+ for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3):
+ with T.block("pool_0"):
+ ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
+ with T.init():
+ y1[ax0, ax1, ax2, ax3] = 0.0
+ y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
+ for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3):
+ with T.block("pool_1"):
+ ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
+ with T.init():
+ y2[ax0, ax1, ax2, ax3] = 0.0
+ y2[ax0, ax1, ax2, ax3] = y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
+
+ @T.prim_func
+ def cascade_pool_ops_tile_reordered(
+ x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
+ ) -> None:
+ y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
+ for n, c, h_o in T.grid(1, 16, 27):
+ for w, h_i, kh, kw in T.grid(110, 6, 3, 3):
+ with T.block("pool_0"):
+ ax0 = T.axis.spatial(1, 0)
+ ax1 = T.axis.spatial(16, c)
+ ax2 = T.axis.spatial(110, h_o * 4 + h_i)
+ ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
+ with T.init():
+ y1[ax0, ax1, ax2, ax3] = 0.0
+ y1[ax0, ax1, ax2, ax3] = (
+ y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
+ )
+ for h_i, w, kh, kw in T.grid(4, 108, 3, 3):
+ with T.block("pool_1"):
+ ax0 = T.axis.spatial(1, 0)
+ ax1 = T.axis.spatial(16, c)
+ ax2 = T.axis.spatial(108, h_o * 4 + h_i)
+ ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
+ with T.init():
+ y2[ax0, ax1, ax2, ax3] = 0.0
+ y2[ax0, ax1, ax2, ax3] = (
+ y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
+ )
+
+ sch = tvm.tir.schedule.Schedule(cascade_pool_ops)
+ pool_0 = sch.get_block("pool_0")
+ pool_1 = sch.get_block("pool_1")
+ _, _, h, w, _, _ = sch.get_loops(pool_1)
+ ho, _ = sch.split(h, factors=[None, 4])
+ sch.compute_at(pool_0, ho)
+ _, _, _, h_i, w, _, _ = sch.get_loops(pool_0)
+ sch.reorder(w, h_i)
+ tvm.ir.assert_structural_equal(cascade_pool_ops_tile_reordered, sch.mod["main"], True)
+ verify_trace_roundtrip(sch=sch, mod=cascade_pool_ops)
+
+
def test_reorder_with_predicate():
sch = tir.Schedule(elementwise_predicate, debug_mask="all")
block_b = sch.get_block("B")