You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by bo...@apache.org on 2022/04/02 17:32:56 UTC

[tvm] branch main updated: [BugFix][MetaSchedule] Fuse only serial loops in rewrite-unbound-block (#10883)

This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 264ee08  [BugFix][MetaSchedule] Fuse only serial loops in rewrite-unbound-block (#10883)
264ee08 is described below

commit 264ee08aef985eddebc633bae356f3931060b6d9
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Apr 2 10:32:19 2022 -0700

    [BugFix][MetaSchedule] Fuse only serial loops in rewrite-unbound-block (#10883)
---
 .../postproc/rewrite_unbound_block.cc              |  5 ++
 ...meta_schedule_postproc_rewrite_unbound_block.py | 97 ++++++++++++++++++++++
 2 files changed, 102 insertions(+)

diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc
index f06df72..73dc89d 100644
--- a/src/meta_schedule/postproc/rewrite_unbound_block.cc
+++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc
@@ -61,6 +61,11 @@ BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) {
         i_thread_idx = i;
       }
     }
+    if (loop->kind != tir::ForKind::kSerial) {
+      if (i_multi_child == -1) {
+        i_multi_child = i;
+      }
+    }
     if (!IsSingleStmt(loop->body)) {
       if (i_multi_child == -1) {
         i_multi_child = i + 1;
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py
index 70ae070..61bd0e34 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py
@@ -269,6 +269,92 @@ class Bert_fused_reshape_transpose_reshape_after_rub_large:
                         ]
 
 
+@T.prim_func
+def before_unrolled_loop(
+    placeholder: T.Buffer[(1, 56, 56, 64), "float32"],
+) -> None:
+    # function attr dict
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32")
+    inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32")
+    for i2_0, i3_0, i2_1, i3_1 in T.grid(98, 4, 2, 16):
+        for i0 in T.unroll(4):
+            for i1 in T.unroll(4):
+                for i4 in T.unroll(6):
+                    for i5 in T.unroll(6):
+                        with T.block("inverse"):
+                            vh, vw = T.axis.remap("SS", [i0, i1])
+                            p = T.axis.spatial(196, i2_0 * 2 + i2_1)
+                            co = T.axis.spatial(64, i3_0 * 16 + i3_1)
+                            r_a, r_b = T.axis.remap("RR", [i4, i5])
+                            T.reads(bgemm[r_a, r_b, p, co])
+                            T.writes(inverse[vh, vw, p, co])
+                            with T.init():
+                                inverse[vh, vw, p, co] = T.float32(0)
+                            inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co]
+
+
+@T.prim_func
+def after_unrolled_loop(
+    placeholder: T.Buffer[(1, 56, 56, 64), "float32"],
+) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    # body
+    # with T.block("root")
+    bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32")
+    inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32")
+    for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(13, thread="blockIdx.x"):
+        for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
+            for i0 in T.unroll(4):
+                for i1 in T.unroll(4):
+                    for i4 in T.unroll(6):
+                        for i5 in T.unroll(6):
+                            with T.block("inverse"):
+                                vh, vw = T.axis.remap("SS", [i0, i1])
+                                p = T.axis.spatial(
+                                    196,
+                                    (
+                                        i2_0_i3_0_i2_1_i3_1_fused_0 * 1024
+                                        + i2_0_i3_0_i2_1_i3_1_fused_1
+                                    )
+                                    // 128
+                                    * 2
+                                    + (
+                                        i2_0_i3_0_i2_1_i3_1_fused_0 * 1024
+                                        + i2_0_i3_0_i2_1_i3_1_fused_1
+                                    )
+                                    % 32
+                                    // 16,
+                                )
+                                co = T.axis.spatial(
+                                    64,
+                                    (
+                                        i2_0_i3_0_i2_1_i3_1_fused_0 * 1024
+                                        + i2_0_i3_0_i2_1_i3_1_fused_1
+                                    )
+                                    % 128
+                                    // 32
+                                    * 16
+                                    + (
+                                        i2_0_i3_0_i2_1_i3_1_fused_0 * 1024
+                                        + i2_0_i3_0_i2_1_i3_1_fused_1
+                                    )
+                                    % 16,
+                                )
+                                r_a, r_b = T.axis.remap("RR", [i4, i5])
+                                T.where(
+                                    i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1
+                                    < 12544
+                                )
+                                T.reads(bgemm[r_a, r_b, p, co])
+                                T.writes(inverse[vh, vw, p, co])
+                                with T.init():
+                                    inverse[vh, vw, p, co] = T.float32(0)
+                                inverse[vh, vw, p, co] = (
+                                    inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co]
+                                )
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
 # fmt: on
 
@@ -313,8 +399,19 @@ def test_rewrite_cuda_loop_split_no_reduction_large():
     tvm.ir.assert_structural_equal(sch.mod, Bert_fused_reshape_transpose_reshape_after_rub_large)
 
 
+def test_rewrite_cuda_loop_split_for_kind():
+    mod = before_unrolled_loop
+    target = Target("nvidia/nvidia-v100", host="llvm")
+    ctx = _create_context(mod, target)
+    sch = tir.Schedule(mod, debug_mask="all")
+    sch.enter_postproc()
+    assert ctx.postprocs[0].apply(sch)
+    tvm.ir.assert_structural_equal(sch.mod["main"], after_unrolled_loop)
+
+
 if __name__ == "__main__":
     test_rewrite_cooperative_fetch()
     test_rewrite_norm_bmn()
     test_rewrite_cuda_loop_split_no_reduction()
     test_rewrite_cuda_loop_split_no_reduction_large()
+    test_rewrite_cuda_loop_split_for_kind()