You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by bo...@apache.org on 2022/04/07 19:01:40 UTC

[tvm] branch main updated: relax reorder primitive's affineness check (#10887)

This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 00c830ece0 relax reorder primitive's  affineness check (#10887)
00c830ece0 is described below

commit 00c830ece0feb9455fe0045b85ed01d6d363a495
Author: wrongtest <wr...@gmail.com>
AuthorDate: Fri Apr 8 03:01:33 2022 +0800

    relax reorder primitive's  affineness check (#10887)
---
 src/tir/schedule/analysis.h                        | 11 +++
 src/tir/schedule/analysis/analysis.cc              | 50 ++++++++++--
 src/tir/schedule/primitive/loop_transformation.cc  | 21 +++--
 tests/python/unittest/test_tir_schedule_reorder.py | 89 ++++++++++++++++++++++
 4 files changed, 156 insertions(+), 15 deletions(-)

diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index e74b9ea264..b76d41326f 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -231,6 +231,17 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
  */
 void CheckAffineBinding(const ScheduleState& self, Block block);
 
+/*!
+ * \brief Check whether a block has an affine binding under the high exclusive sref node,
+ * throw an exception if the block does not have an affine binding.
+ * \param self The schedule state
+ * \param block The block to be checked
+ * \param high_exclusive The highest sref node
+ * \throw ScheduleError If the input block does not have an affine binding
+ */
+void CheckPartialAffineBinding(const ScheduleState& self, Block block,
+                               const Optional<StmtSRef>& high_exclusive);
+
 /*!
  * \brief Extracts the ranges of loop variables in a path of the sref tree
  * \param low_inclusive The lowest node in the path
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 435870471f..4a7ac401dd 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -544,26 +544,62 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
   return true;
 }
 
-void CheckAffineBinding(const ScheduleState& self, Block block) {
+void CheckPartialAffineBinding(const ScheduleState& self, Block block,
+                               const Optional<StmtSRef>& high_exclusive) {
   class NotAffineBindingError : public ScheduleError {
    public:
-    explicit NotAffineBindingError(IRModule mod, Block block)
-        : mod_(std::move(mod)), block_(std::move(block)) {}
+    explicit NotAffineBindingError(IRModule mod, Block block, Optional<StmtSRef> high_exclusive)
+        : mod_(std::move(mod)), block_(std::move(block)) {
+      if (high_exclusive.defined()) {
+        high_exclusive_loop_ = high_exclusive.value()->StmtAs<ForNode>();
+      }
+    }
     String FastErrorString() const final {
-      return "ScheduleError: The block is required to have an affine binding";
+      std::ostringstream ss;
+      if (high_exclusive_loop_) {
+        ss << "ScheduleError: The block is required to have an partial affine binding under "
+           << high_exclusive_loop_->loop_var;
+      } else {
+        ss << "ScheduleError: The block is required to have an affine binding";
+      }
+      return ss.str();
     }
     String DetailRenderTemplate() const final {
-      return "The block {0} is required to have an affine binding";
+      std::ostringstream ss;
+      if (high_exclusive_loop_) {
+        ss << "The block {0} is required to have an partial affine binding under "
+           << high_exclusive_loop_->loop_var;
+      } else {
+        ss << "The block {0} is required to have an affine binding";
+      }
+      return ss.str();
     }
     IRModule mod() const final { return mod_; }
     Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
     IRModule mod_;
     Block block_;
+    const ForNode* high_exclusive_loop_{nullptr};
   };
 
-  if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) {
-    throw NotAffineBindingError(self->mod, std::move(block));
+  StmtSRef block_sref = self->stmt2ref.at(block.get());
+  if (self->IsAffineBlockBinding(block_sref)) {
+    // check block cached state for global affineness
+    return;
+  }
+  if (block_sref->parent && high_exclusive.defined()) {
+    // if it is not of global affine binding, check affineness under high_exclusive,
+    arith::Analyzer analyzer;
+    Map<Var, Range> dom_map =
+        LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent), high_exclusive);
+    if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) {
+      return;
+    }
   }
+  throw NotAffineBindingError(self->mod, std::move(block), high_exclusive);
+}
+
+void CheckAffineBinding(const ScheduleState& self, Block block) {
+  CheckPartialAffineBinding(self, std::move(block), NullOpt);
 }
 
 Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc
index fa2a4469b8..d64a72ed34 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -134,16 +134,18 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
 class BlockPropertyError : public ScheduleError {
  public:
   /*!
-   * \brief Check that all the blocks under the specific stmt have affine bindings and only have
-   *     data-parallel or reduction block iters
+   * \brief Check that all the blocks under the specific stmt have affine bindings
+   *     wrt top loop sref and only have data-parallel or reduction block iters
    * \param self The state of the schedule
    * \param sref The sref to the specific stmt
    */
-  static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self,
+  static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* top,
                                                  const StmtSRefNode* sref) {
     class BlockIterTypeAndAffineBindingChecker : public StmtVisitor {
      public:
-      explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {}
+      explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state,
+                                                    const StmtSRefNode* top)
+          : state_(state), top_(top) {}
 
      private:
       void VisitStmt_(const BlockNode* op) final {
@@ -151,13 +153,16 @@ class BlockPropertyError : public ScheduleError {
           if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) {
             throw BlockPropertyError(state_->mod, GetRef<Block>(op));
           }
-          CheckAffineBinding(state_, GetRef<Block>(op));
+          Optional<StmtSRef> high_exclusive =
+              top_->parent ? GetRef<StmtSRef>(top_->parent) : Optional<StmtSRef>(NullOpt);
+          CheckPartialAffineBinding(state_, GetRef<Block>(op), high_exclusive);
         }
       }
       const ScheduleState& state_;
+      const StmtSRefNode* top_;
     };
 
-    BlockIterTypeAndAffineBindingChecker checker(self);
+    BlockIterTypeAndAffineBindingChecker checker(self, top);
     checker(GetRef<Stmt>(sref->stmt));
   }
 
@@ -708,8 +713,8 @@ void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
   // Step 3. Collect all loops in the chain and check the loops are single-branch
   std::vector<const StmtSRefNode*> chain = GetLoopsInReorderRange(self, top, bottom);
   // Step 4. Check the block below has all its block_var to be data-parallel or reduction,
-  // and the block has an affine binding.
-  BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom);
+  // and the block has an affine binding wrt top of the loop range.
+  BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, top, bottom);
   // Step 5. Replace the original loops with the reordered loops and check that outer loop is
   // not dependent on inner loop
   For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs);
diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py
index f62a316f80..462099e6fe 100644
--- a/tests/python/unittest/test_tir_schedule_reorder.py
+++ b/tests/python/unittest/test_tir_schedule_reorder.py
@@ -213,6 +213,95 @@ def test_reorder_with_opaque_access():
     verify_trace_roundtrip(sch=sch, mod=opaque_access)
 
 
+def test_reorder_with_partial_affineness():
+    @T.prim_func
+    def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+        # example to write first axis multiple times
+        for v0, v1, v2 in T.grid(6, 4, 4):
+            with T.block("block"):
+                i = T.axis.spatial(14, v0 * 2 + v1)
+                j = T.axis.spatial(4, v2)
+                B[i, j] = A[i, j] + 1.0
+
+    @T.prim_func
+    def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+        # example to write first axis multiple times
+        for v0, v2, v1 in T.grid(6, 4, 4):
+            with T.block("block"):
+                i = T.axis.spatial(14, v0 * 2 + v1)
+                j = T.axis.spatial(4, v2)
+                B[i, j] = A[i, j] + 1.0
+
+    sch = tir.Schedule(non_affine_func, debug_mask="all")
+    v0, v1, v2 = sch.get_loops(sch.get_block("block"))
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.reorder(v0, v2, v1)
+
+    sch.reorder(v2, v1)
+    tvm.ir.assert_structural_equal(non_affine_func_reorder, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=non_affine_func)
+
+
+def test_reorder_with_cascade_tiled_ops():
+    @T.prim_func
+    def cascade_pool_ops(
+        x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
+    ) -> None:
+        y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
+        for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3):
+            with T.block("pool_0"):
+                ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
+                with T.init():
+                    y1[ax0, ax1, ax2, ax3] = 0.0
+                y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
+        for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3):
+            with T.block("pool_1"):
+                ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
+                with T.init():
+                    y2[ax0, ax1, ax2, ax3] = 0.0
+                y2[ax0, ax1, ax2, ax3] = y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
+
+    @T.prim_func
+    def cascade_pool_ops_tile_reordered(
+        x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
+    ) -> None:
+        y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
+        for n, c, h_o in T.grid(1, 16, 27):
+            for w, h_i, kh, kw in T.grid(110, 6, 3, 3):
+                with T.block("pool_0"):
+                    ax0 = T.axis.spatial(1, 0)
+                    ax1 = T.axis.spatial(16, c)
+                    ax2 = T.axis.spatial(110, h_o * 4 + h_i)
+                    ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
+                    with T.init():
+                        y1[ax0, ax1, ax2, ax3] = 0.0
+                    y1[ax0, ax1, ax2, ax3] = (
+                        y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
+                    )
+            for h_i, w, kh, kw in T.grid(4, 108, 3, 3):
+                with T.block("pool_1"):
+                    ax0 = T.axis.spatial(1, 0)
+                    ax1 = T.axis.spatial(16, c)
+                    ax2 = T.axis.spatial(108, h_o * 4 + h_i)
+                    ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
+                    with T.init():
+                        y2[ax0, ax1, ax2, ax3] = 0.0
+                    y2[ax0, ax1, ax2, ax3] = (
+                        y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
+                    )
+
+    sch = tvm.tir.schedule.Schedule(cascade_pool_ops)
+    pool_0 = sch.get_block("pool_0")
+    pool_1 = sch.get_block("pool_1")
+    _, _, h, w, _, _ = sch.get_loops(pool_1)
+    ho, _ = sch.split(h, factors=[None, 4])
+    sch.compute_at(pool_0, ho)
+    _, _, _, h_i, w, _, _ = sch.get_loops(pool_0)
+    sch.reorder(w, h_i)
+    tvm.ir.assert_structural_equal(cascade_pool_ops_tile_reordered, sch.mod["main"], True)
+    verify_trace_roundtrip(sch=sch, mod=cascade_pool_ops)
+
+
 def test_reorder_with_predicate():
     sch = tir.Schedule(elementwise_predicate, debug_mask="all")
     block_b = sch.get_block("B")