You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/01/04 02:23:30 UTC
[tvm] branch main updated: [Schedule][Bugfix] Fix decompose padding wrt the single child subtree (#13646)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 49ed54478b [Schedule][Bugfix] Fix decompose padding wrt the single child subtree (#13646)
49ed54478b is described below
commit 49ed54478b255f1c39a9548eee284987efcd008c
Author: wrongtest <wr...@gmail.com>
AuthorDate: Wed Jan 4 10:23:24 2023 +0800
[Schedule][Bugfix] Fix decompose padding wrt the single child subtree (#13646)
Fix bug when decompose padding wrt the single child subtree
---
src/tir/schedule/primitive/decompose_padding.cc | 17 +++---
.../test_tir_schedule_decompose_padding.py | 63 ++++++++++++++++++++++
2 files changed, 74 insertions(+), 6 deletions(-)
diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc
index c417608767..e657b4f466 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -114,6 +114,10 @@ class PaddingInfoAnalyzer {
// Step 3. Analyze in-bound write region.
PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate);
+ if (analyzer_->CanProveEqual(in_bound_predicate, 1)) {
+ SetError("The in-bound predicate is trivial");
+ return false;
+ }
Array<Range> in_bound_region = this->EstimateInBoundRegion(
/*iter_values=*/realize->iter_values, /*dom_map=*/dom_map,
/*in_bound_predicate=*/in_bound_predicate);
@@ -439,13 +443,14 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
analyzer.Bind(cur_loop->loop_var, range);
loops.push_back(cur_loop);
- if (!found_const_filling_pos) {
- if (cur_loop.same_as(const_filling_pos)) {
- found_const_filling_pos = true;
+ if (cur_loop.same_as(const_filling_pos)) {
+ ICHECK(!found_const_filling_pos);
+ found_const_filling_pos = true;
+ if (!found_in_bound_filling_pos) {
+ found_in_bound_filling_pos = true;
+ in_bound_filling_pos = cur_loop;
}
- }
-
- if (!found_in_bound_filling_pos) {
+ } else if (!found_in_bound_filling_pos) {
if (!cur_loop->body->IsInstance<ForNode>() &&
!cur_loop->body->IsInstance<BlockRealizeNode>()) {
found_in_bound_filling_pos = true;
diff --git a/tests/python/unittest/test_tir_schedule_decompose_padding.py b/tests/python/unittest/test_tir_schedule_decompose_padding.py
index a3fc4326a3..ead8b0b332 100644
--- a/tests/python/unittest/test_tir_schedule_decompose_padding.py
+++ b/tests/python/unittest/test_tir_schedule_decompose_padding.py
@@ -309,5 +309,68 @@ def test_decompose_hw_padding_non_perfect_tiled():
check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True)
+def test_decompose_wrt_single_child_subtree():
+ """Test the case when the decompose position is under the single child subtree"""
+
+ @T.prim_func
+ def pad_op(
+ x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 231, 231], dtype="int8")
+ ):
+ for i0, i1, i2, i3 in T.grid(1, 16, 231, 231):
+ with T.block("pad_temp"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ y[ax0, ax1, ax2, ax3] = T.if_then_else(
+ 3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228,
+ x[ax0, ax1, ax2 - 3, ax3 - 3],
+ T.int8(0),
+ dtype="int8",
+ )
+
+ @T.prim_func
+ def pad_op_after(
+ x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer[(1, 16, 231, 231), "int8"]
+ ):
+ for i0, i1 in T.grid(1, 16):
+ for i2, i3 in T.grid(231, 231):
+ with T.block("pad_temp_pad_const"):
+ ax0 = T.axis.spatial(1, 0)
+ ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
+ y[ax0, ax1, ax2, ax3] = T.int8(0)
+ for i2, i3 in T.grid(225, 225):
+ with T.block("pad_temp"):
+ ax0 = T.axis.spatial(1, 0)
+ ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
+ y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3]
+
+ sch = tir.Schedule(pad_op, debug_mask="all")
+ pad = sch.get_block("pad_temp")
+ _, _, h, _ = sch.get_loops(pad)
+ sch.decompose_padding(pad, h)
+ check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True)
+
+
+def test_not_to_decompose_trivial_predicate():
+ """Test the case when the padding condition is trivial"""
+
+ @T.prim_func
+ def trivial_pad(
+ x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 225, 225], dtype="int8")
+ ):
+ for i0, i1, i2, i3 in T.grid(1, 16, 225, 225):
+ with T.block("pad_temp"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ y[ax0, ax1, ax2, ax3] = T.if_then_else(
+ 0 <= ax2 and ax2 < 225 and 0 <= ax3 and ax3 < 225,
+ x[ax0, ax1, ax2, ax3],
+ T.int8(0),
+ dtype="int8",
+ )
+
+ sch = tir.Schedule(trivial_pad, debug_mask="all")
+ pad = sch.get_block("pad_temp")
+ _, _, h, _ = sch.get_loops(pad)
+ assert not sch.can_decompose_padding(pad, h)
+
+
if __name__ == "__main__":
tvm.testing.main()