You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/07/31 04:24:26 UTC

[tvm] branch main updated: [MetaSchedule][Test] Add unittests for GMM (#12243)

This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 42dd6afa97 [MetaSchedule][Test] Add unittests for GMM (#12243)
42dd6afa97 is described below

commit 42dd6afa970e8948584d5474691673c32e2c3457
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Jul 30 21:24:19 2022 -0700

    [MetaSchedule][Test] Add unittests for GMM (#12243)
---
 .../unittest/test_meta_schedule_space_cpu.py       | 123 +++++++++++++++++++++
 .../unittest/test_meta_schedule_space_cuda.py      |  82 ++++++++++++++
 2 files changed, 205 insertions(+)

diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py
index 12aa150f57..7d601a7b0b 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -1079,6 +1079,128 @@ def test_cpu_dil():
     )
 
 
+def test_cpu_gmm():
+    # fmt: off
+    @T.prim_func
+    def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
+            Z_global = T.alloc_buffer([1, 128, 128], dtype="float32")
+            for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1 in T.grid(1, 4, 2, 1, 1, 8):
+                for i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(128, 1, 16, 1, 1, 1, 2, 8):
+                    with T.block("Z"):
+                        b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+                        i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
+                        j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
+                        k = T.axis.reduce(128, i3_1 + i3_0)
+                        T.reads(X[b, i, k], Y[b, k, j])
+                        T.writes(Z_global[b, i, j])
+                        T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                        with T.init():
+                            Z_global[b, i, j] = T.float32(0)
+                        Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j]
+                for ax0, ax1, ax2 in T.grid(1, 32, 8):
+                    with T.block("Z_global"):
+                        v0 = T.axis.spatial(1, ax0)
+                        v1 = T.axis.spatial(128, i1_0 * 32 + ax1)
+                        v2 = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + ax2)
+                        T.reads(Z_global[v0, v1, v2])
+                        T.writes(Z[v0, v1, v2])
+                        Z[v0, v1, v2] = Z_global[v0, v1, v2]
+    @T.prim_func
+    def gmm_1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
+            Z_global = T.alloc_buffer([1, 128, 128], dtype="float32")
+            for i0_0, i1_0, i2_0 in T.grid(1, 4, 2):
+                for i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8):
+                    with T.block("Z"):
+                        b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+                        i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
+                        j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
+                        k = T.axis.reduce(128, i3_1 + i3_0)
+                        T.reads(X[b, i, k], Y[b, k, j])
+                        T.writes(Z_global[b, i, j])
+                        T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                        with T.init():
+                            Z_global[b, i, j] = T.float32(0)
+                        Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j]
+                for ax0, ax1, ax2 in T.grid(1, 32, 64):
+                    with T.block("Z_global"):
+                        v0 = T.axis.spatial(1, ax0)
+                        v1 = T.axis.spatial(128, i1_0 * 32 + ax1)
+                        v2 = T.axis.spatial(128, i2_0 * 64 + ax2)
+                        T.reads(Z_global[v0, v1, v2])
+                        T.writes(Z[v0, v1, v2])
+                        Z[v0, v1, v2] = Z_global[v0, v1, v2]
+    @T.prim_func
+    def gmm_2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
+            for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 4, 2, 1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8):
+                with T.block("Z"):
+                    b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+                    i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
+                    j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
+                    k = T.axis.reduce(128, i3_1 + i3_0)
+                    T.reads(X[b, i, k], Y[b, k, j])
+                    T.writes(Z[b, i, j])
+                    T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                    with T.init():
+                        Z[b, i, j] = T.float32(0)
+                    Z[b, i, j] = Z[b, i, j] + X[b, i, k] * Y[b, k, j]
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [4, 1, 16, 2]),
+        ("SamplePerfectTile", [2, 8, 1, 8]),
+        ("SamplePerfectTile", [128, 1]),
+        ("SampleCategorical", 1),
+    ]
+    decision_1 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [4, 1, 16, 2]),
+        ("SamplePerfectTile", [2, 8, 1, 8]),
+        ("SamplePerfectTile", [128, 1]),
+        ("SampleCategorical", 1),
+    ]
+    decision_2 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [4, 1, 16, 2]),
+        ("SamplePerfectTile", [2, 8, 1, 8]),
+        ("SamplePerfectTile", [128, 1]),
+        ("SampleCategorical", 1),
+    ]
+    mod = create_te_workload("GMM", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[gmm_0, gmm_1, gmm_2],
+        expected_decisions=[decision_0, decision_1, decision_2],
+    )
+
+
 if __name__ == "__main__":
     test_cpu_c1d()
     test_cpu_c2d()
@@ -1086,3 +1208,4 @@ if __name__ == "__main__":
     test_cpu_cap()
     test_cpu_dep()
     test_cpu_dil()
+    test_cpu_gmm()
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 7323bc441f..3bf2666cdc 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -572,6 +572,87 @@ def test_cuda_dil():
     )
 
 
+def test_cuda_gmm():
+    # fmt: off
+    @T.prim_func
+    def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.unroll_explicit":1024})
+            Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
+            X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
+            Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
+            for i0_0_i1_0_i2_0_fused in T.thread_binding(1, thread="blockIdx.x"):
+                for i0_1_i1_1_i2_1_fused in T.thread_binding(32, thread="vthread.x"):
+                    for i0_2_i1_2_i2_2_fused in T.thread_binding(2, thread="threadIdx.x"):
+                        for i3_0 in T.serial(1):
+                            for ax0_ax1_ax2_fused in T.serial(16384):
+                                with T.block("X_shared"):
+                                    v0 = T.axis.spatial(1, 0)
+                                    v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128)
+                                    v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128)
+                                    T.reads(X[v0, v1, v2])
+                                    T.writes(X_shared[v0, v1, v2])
+                                    T.block_attr({"meta_schedule.cooperative_fetch":2})
+                                    X_shared[v0, v1, v2] = X[v0, v1, v2]
+                            for ax0_ax1_ax2_fused in T.serial(16384):
+                                with T.block("Y_shared"):
+                                    v0 = T.axis.spatial(1, 0)
+                                    v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128)
+                                    v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128)
+                                    T.reads(Y[v0, v1, v2])
+                                    T.writes(Y_shared[v0, v1, v2])
+                                    T.block_attr({"meta_schedule.cooperative_fetch":1})
+                                    Y_shared[v0, v1, v2] = Y[v0, v1, v2]
+                            for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(32, 1, 2, 64, 4, 1, 2, 1):
+                                with T.block("Z"):
+                                    b = T.axis.spatial(1, i0_4 + i0_3)
+                                    i = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + i1_3 * 2 + i1_4)
+                                    j = T.axis.spatial(128, i2_4 + i0_2_i1_2_i2_2_fused * 64 + i2_3)
+                                    k = T.axis.reduce(128, i3_0 * 128 + i3_1 * 4 + i3_2)
+                                    T.reads(X_shared[b, i, k], Y_shared[b, k, j])
+                                    T.writes(Z_local[b, i, j])
+                                    T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
+                                    with T.init():
+                                        Z_local[b, i, j] = T.float32(0)
+                                    Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
+                        for ax0, ax1, ax2 in T.grid(1, 4, 64):
+                            with T.block("Z_local"):
+                                v0 = T.axis.spatial(1, ax0)
+                                v1 = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + ax1)
+                                v2 = T.axis.spatial(128, i0_2_i1_2_i2_2_fused * 64 + ax2)
+                                T.reads(Z_local[v0, v1, v2])
+                                T.writes(Z[v0, v1, v2])
+                                Z[v0, v1, v2] = Z_local[v0, v1, v2]
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 32, 1, 2, 2]),
+        ("SamplePerfectTile", [1, 1, 2, 64, 1]),
+        ("SamplePerfectTile", [1, 32, 4]),
+        ("SampleCategorical", 1),
+        ("SampleCategorical", 0),
+        ("SampleCategorical", 4),
+    ]
+    mod = create_te_workload("GMM", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[gmm_0],
+        expected_decisions=[decision_0],
+    )
+
+
 if __name__ == "__main__":
     test_cuda_c1d()
     test_cuda_c2d()
@@ -579,3 +660,4 @@ if __name__ == "__main__":
     test_cuda_cap()
     test_cuda_dep()
     test_cuda_dil()
+    test_cuda_gmm()