You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/01/04 02:23:30 UTC

[tvm] branch main updated: [Schedule][Bugfix] Fix decompose padding wrt the single child subtree (#13646)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 49ed54478b [Schedule][Bugfix] Fix decompose padding wrt the single child subtree (#13646)
49ed54478b is described below

commit 49ed54478b255f1c39a9548eee284987efcd008c
Author: wrongtest <wr...@gmail.com>
AuthorDate: Wed Jan 4 10:23:24 2023 +0800

    [Schedule][Bugfix] Fix decompose padding wrt the single child subtree (#13646)
    
    Fix bug when decompose padding wrt the single child subtree
---
 src/tir/schedule/primitive/decompose_padding.cc    | 17 +++---
 .../test_tir_schedule_decompose_padding.py         | 63 ++++++++++++++++++++++
 2 files changed, 74 insertions(+), 6 deletions(-)

diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc
index c417608767..e657b4f466 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -114,6 +114,10 @@ class PaddingInfoAnalyzer {
 
     // Step 3. Analyze in-bound write region.
     PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate);
+    if (analyzer_->CanProveEqual(in_bound_predicate, 1)) {
+      SetError("The in-bound predicate is trivial");
+      return false;
+    }
     Array<Range> in_bound_region = this->EstimateInBoundRegion(
         /*iter_values=*/realize->iter_values, /*dom_map=*/dom_map,
         /*in_bound_predicate=*/in_bound_predicate);
@@ -439,13 +443,14 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
     analyzer.Bind(cur_loop->loop_var, range);
     loops.push_back(cur_loop);
 
-    if (!found_const_filling_pos) {
-      if (cur_loop.same_as(const_filling_pos)) {
-        found_const_filling_pos = true;
+    if (cur_loop.same_as(const_filling_pos)) {
+      ICHECK(!found_const_filling_pos);
+      found_const_filling_pos = true;
+      if (!found_in_bound_filling_pos) {
+        found_in_bound_filling_pos = true;
+        in_bound_filling_pos = cur_loop;
       }
-    }
-
-    if (!found_in_bound_filling_pos) {
+    } else if (!found_in_bound_filling_pos) {
       if (!cur_loop->body->IsInstance<ForNode>() &&
           !cur_loop->body->IsInstance<BlockRealizeNode>()) {
         found_in_bound_filling_pos = true;
diff --git a/tests/python/unittest/test_tir_schedule_decompose_padding.py b/tests/python/unittest/test_tir_schedule_decompose_padding.py
index a3fc4326a3..ead8b0b332 100644
--- a/tests/python/unittest/test_tir_schedule_decompose_padding.py
+++ b/tests/python/unittest/test_tir_schedule_decompose_padding.py
@@ -309,5 +309,68 @@ def test_decompose_hw_padding_non_perfect_tiled():
     check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True)
 
 
+def test_decompose_wrt_single_child_subtree():
+    """Test the case when the decompose position is under the single child subtree"""
+
+    @T.prim_func
+    def pad_op(
+        x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 231, 231], dtype="int8")
+    ):
+        for i0, i1, i2, i3 in T.grid(1, 16, 231, 231):
+            with T.block("pad_temp"):
+                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                y[ax0, ax1, ax2, ax3] = T.if_then_else(
+                    3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228,
+                    x[ax0, ax1, ax2 - 3, ax3 - 3],
+                    T.int8(0),
+                    dtype="int8",
+                )
+
+    @T.prim_func
+    def pad_op_after(
+        x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer[(1, 16, 231, 231), "int8"]
+    ):
+        for i0, i1 in T.grid(1, 16):
+            for i2, i3 in T.grid(231, 231):
+                with T.block("pad_temp_pad_const"):
+                    ax0 = T.axis.spatial(1, 0)
+                    ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
+                    y[ax0, ax1, ax2, ax3] = T.int8(0)
+            for i2, i3 in T.grid(225, 225):
+                with T.block("pad_temp"):
+                    ax0 = T.axis.spatial(1, 0)
+                    ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
+                    y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3]
+
+    sch = tir.Schedule(pad_op, debug_mask="all")
+    pad = sch.get_block("pad_temp")
+    _, _, h, _ = sch.get_loops(pad)
+    sch.decompose_padding(pad, h)
+    check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True)
+
+
+def test_not_to_decompose_trivial_predicate():
+    """Test the case when the padding condition is trivial"""
+
+    @T.prim_func
+    def trivial_pad(
+        x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 225, 225], dtype="int8")
+    ):
+        for i0, i1, i2, i3 in T.grid(1, 16, 225, 225):
+            with T.block("pad_temp"):
+                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                y[ax0, ax1, ax2, ax3] = T.if_then_else(
+                    0 <= ax2 and ax2 < 225 and 0 <= ax3 and ax3 < 225,
+                    x[ax0, ax1, ax2, ax3],
+                    T.int8(0),
+                    dtype="int8",
+                )
+
+    sch = tir.Schedule(trivial_pad, debug_mask="all")
+    pad = sch.get_block("pad_temp")
+    _, _, h, _ = sch.get_loops(pad)
+    assert not sch.can_decompose_padding(pad, h)
+
+
 if __name__ == "__main__":
     tvm.testing.main()