You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by bo...@apache.org on 2022/04/02 17:32:56 UTC
[tvm] branch main updated: [BugFix][MetaSchedule] Fuse only serial loops in rewrite-unbound-block (#10883)
This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 264ee08 [BugFix][MetaSchedule] Fuse only serial loops in rewrite-unbound-block (#10883)
264ee08 is described below
commit 264ee08aef985eddebc633bae356f3931060b6d9
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Apr 2 10:32:19 2022 -0700
[BugFix][MetaSchedule] Fuse only serial loops in rewrite-unbound-block (#10883)
---
.../postproc/rewrite_unbound_block.cc | 5 ++
...meta_schedule_postproc_rewrite_unbound_block.py | 97 ++++++++++++++++++++++
2 files changed, 102 insertions(+)
diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc
index f06df72..73dc89d 100644
--- a/src/meta_schedule/postproc/rewrite_unbound_block.cc
+++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc
@@ -61,6 +61,11 @@ BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) {
i_thread_idx = i;
}
}
+ if (loop->kind != tir::ForKind::kSerial) {
+ if (i_multi_child == -1) {
+ i_multi_child = i;
+ }
+ }
if (!IsSingleStmt(loop->body)) {
if (i_multi_child == -1) {
i_multi_child = i + 1;
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py
index 70ae070..61bd0e34 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py
@@ -269,6 +269,92 @@ class Bert_fused_reshape_transpose_reshape_after_rub_large:
]
+@T.prim_func
+def before_unrolled_loop(
+ placeholder: T.Buffer[(1, 56, 56, 64), "float32"],
+) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32")
+ inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32")
+ for i2_0, i3_0, i2_1, i3_1 in T.grid(98, 4, 2, 16):
+ for i0 in T.unroll(4):
+ for i1 in T.unroll(4):
+ for i4 in T.unroll(6):
+ for i5 in T.unroll(6):
+ with T.block("inverse"):
+ vh, vw = T.axis.remap("SS", [i0, i1])
+ p = T.axis.spatial(196, i2_0 * 2 + i2_1)
+ co = T.axis.spatial(64, i3_0 * 16 + i3_1)
+ r_a, r_b = T.axis.remap("RR", [i4, i5])
+ T.reads(bgemm[r_a, r_b, p, co])
+ T.writes(inverse[vh, vw, p, co])
+ with T.init():
+ inverse[vh, vw, p, co] = T.float32(0)
+ inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co]
+
+
+@T.prim_func
+def after_unrolled_loop(
+ placeholder: T.Buffer[(1, 56, 56, 64), "float32"],
+) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32")
+ inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32")
+ for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(13, thread="blockIdx.x"):
+ for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
+ for i0 in T.unroll(4):
+ for i1 in T.unroll(4):
+ for i4 in T.unroll(6):
+ for i5 in T.unroll(6):
+ with T.block("inverse"):
+ vh, vw = T.axis.remap("SS", [i0, i1])
+ p = T.axis.spatial(
+ 196,
+ (
+ i2_0_i3_0_i2_1_i3_1_fused_0 * 1024
+ + i2_0_i3_0_i2_1_i3_1_fused_1
+ )
+ // 128
+ * 2
+ + (
+ i2_0_i3_0_i2_1_i3_1_fused_0 * 1024
+ + i2_0_i3_0_i2_1_i3_1_fused_1
+ )
+ % 32
+ // 16,
+ )
+ co = T.axis.spatial(
+ 64,
+ (
+ i2_0_i3_0_i2_1_i3_1_fused_0 * 1024
+ + i2_0_i3_0_i2_1_i3_1_fused_1
+ )
+ % 128
+ // 32
+ * 16
+ + (
+ i2_0_i3_0_i2_1_i3_1_fused_0 * 1024
+ + i2_0_i3_0_i2_1_i3_1_fused_1
+ )
+ % 16,
+ )
+ r_a, r_b = T.axis.remap("RR", [i4, i5])
+ T.where(
+ i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1
+ < 12544
+ )
+ T.reads(bgemm[r_a, r_b, p, co])
+ T.writes(inverse[vh, vw, p, co])
+ with T.init():
+ inverse[vh, vw, p, co] = T.float32(0)
+ inverse[vh, vw, p, co] = (
+ inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co]
+ )
+
+
# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on
@@ -313,8 +399,19 @@ def test_rewrite_cuda_loop_split_no_reduction_large():
tvm.ir.assert_structural_equal(sch.mod, Bert_fused_reshape_transpose_reshape_after_rub_large)
+def test_rewrite_cuda_loop_split_for_kind():
+ mod = before_unrolled_loop
+ target = Target("nvidia/nvidia-v100", host="llvm")
+ ctx = _create_context(mod, target)
+ sch = tir.Schedule(mod, debug_mask="all")
+ sch.enter_postproc()
+ assert ctx.postprocs[0].apply(sch)
+ tvm.ir.assert_structural_equal(sch.mod["main"], after_unrolled_loop)
+
+
if __name__ == "__main__":
test_rewrite_cooperative_fetch()
test_rewrite_norm_bmn()
test_rewrite_cuda_loop_split_no_reduction()
test_rewrite_cuda_loop_split_no_reduction_large()
+ test_rewrite_cuda_loop_split_for_kind()