You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/09/16 01:49:30 UTC

[tvm] branch main updated: [TIR] Add extra simpliciation in region cover analysis (#12800)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e6525a30e6 [TIR] Add extra simpliciation in region cover analysis (#12800)
e6525a30e6 is described below

commit e6525a30e6de3bc3f95564beeead8e9e8b1f9efc
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Sep 15 18:49:22 2022 -0700

    [TIR] Add extra simpliciation in region cover analysis (#12800)
    
    Added extra simplify step to eliminate false negative cases.
---
 src/tir/schedule/state.cc                          |  5 ++
 .../test_tir_schedule_state_cached_flags.py        | 86 +++++++++++++++++++++-
 2 files changed, 90 insertions(+), 1 deletion(-)

diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 15d0e08ddc..6d4a42236f 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -108,6 +108,11 @@ bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape,
     produced = arith::Intersect({produced, buffer_size});
     consumed = arith::Intersect({consumed, buffer_size});
 
+    produced = arith::IntSet::Interval(analyzer->Simplify(produced.min()),
+                                       analyzer->Simplify(produced.max()));
+    consumed = arith::IntSet::Interval(analyzer->Simplify(consumed.min()),
+                                       analyzer->Simplify(consumed.max()));
+
     if (!analyzer->CanProve((analyzer->canonical_simplify(produced.min() - consumed.min()) <= 0) &&
                             (analyzer->canonical_simplify(consumed.max() - produced.max()) <= 0))) {
       return false;
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index bbeb8d8760..9878217140 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -26,7 +26,7 @@ from tvm.tir.schedule.state import CachedFlags
 from tvm.tir.stmt_functor import post_order_visit
 
 # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
-
+# fmt: off
 
 @T.prim_func
 def elementwise(a: T.handle, c: T.handle) -> None:
@@ -366,7 +366,80 @@ def uncovered_producer_region(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,)
             B[vi] = A[vi]
 
 
+@T.prim_func
+def matmul_relu_padding(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 127), "float16"], compute: T.Buffer[(127, 127), "float32"]) -> None:
+    # function attr dict
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    # body
+    # with T.block("root")
+    C = T.alloc_buffer([127, 127], dtype="float32")
+    A_reindex = T.alloc_buffer([128, 128], dtype="float16")
+    B_reindex = T.alloc_buffer([128, 128], dtype="float16")
+    C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+    C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator")
+    for ax0, ax1, ax2 in T.grid(128, 1, 128):
+        with T.block("A_reindex"):
+            v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(A[v0, v2])
+            T.writes(A_reindex[v0, v2])
+            A_reindex[v0, v2] = T.if_then_else(v0 < 127 and v2 < 127, A[v0, v2], T.float16(0), dtype="float16")
+    for ax0, ax1, ax2 in T.grid(1, 128, 128):
+        with T.block("B_reindex"):
+            v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(B[v2, v1])
+            T.writes(B_reindex[v2, v1])
+            B_reindex[v2, v1] = T.if_then_else(v2 < 127 and v1 < 127, B[v2, v1], T.float16(0), dtype="float16")
+    for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"):
+        for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
+            for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"):
+                for ax2_0_0, ax2_0_1, ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 2, 1, 2, 2, 1, 1):
+                    with T.block("C_o"):
+                        v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4)
+                        v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3)
+                        v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax2_0_2)
+                        T.reads(A_reindex[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                        T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                        T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1})
+                        with T.init():
+                            for ax0_1, ax1_1 in T.grid(16, 16):
+                                with T.block("C_init"):
+                                    v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1])
+                                    T.reads()
+                                    T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init])
+                                    C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0)
+                        for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16):
+                            with T.block("C"):
+                                v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
+                                T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
+                                T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32")
+                for ax0, ax1 in T.grid(16, 32):
+                    with T.block("C_reindex_shared_wmma.accumulator"):
+                        v0 = T.axis.spatial(128, ax0_0_2_ax1_0_2_fused // 2 * 16 + ax0)
+                        v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_2_ax1_0_2_fused % 2 * 32 + ax1)
+                        T.reads(C_reindex_shared_wmma_accumulator[v0, v1])
+                        T.writes(C_reindex_shared[v0, v1])
+                        C_reindex_shared[v0, v1] = C_reindex_shared_wmma_accumulator[v0, v1]
+            for ax0, ax1 in T.grid(128, 64):
+                with T.block("C_reindex_shared"):
+                    v0 = T.axis.spatial(128, ax0)
+                    v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax1)
+                    T.where(ax0 < 127 and ax0_0_0_ax1_0_0_fused * 64 + ax1 < 127)
+                    T.reads(C_reindex_shared[v0, v1])
+                    T.writes(C[v0, v1])
+                    T.block_attr({"meta_schedule.cooperative_fetch":3})
+                    C[v0, v1] = C_reindex_shared[v0, v1]
+    for i0, i1 in T.grid(127, 127):
+        with T.block("compute"):
+            i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
+            T.reads(C[i0_1, i1_1])
+            T.writes(compute[i0_1, i1_1])
+            compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+# fmt: on
 
 
 def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef:
@@ -781,5 +854,16 @@ def test_uncovered_producer_region():
     # pylint: enable=protected-access
 
 
+def test_matmul_relu_padding():
+    s = tir.ScheduleState(matmul_relu_padding, debug_mask="all")
+    # pylint: disable=protected-access
+    assert s._get_cached_flags(_get_block(s, "C_reindex_shared")) == CachedFlags(
+        affine_binding=True,
+        region_cover=True,
+        stage_pipeline=True,
+    )
+    # pylint: enable=protected-access
+
+
 if __name__ == "__main__":
     tvm.testing.main()