You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/06 20:29:36 UTC

[tvm] branch unity updated: [Unity][Dlight] Matmul rule on int32 workloads (#15486)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f0869fecc4 [Unity][Dlight] Matmul rule on int32 workloads (#15486)
f0869fecc4 is described below

commit f0869fecc426bff496c95e4d25ca5b987d1d0b49
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Wed Sep 6 16:29:28 2023 -0400

    [Unity][Dlight] Matmul rule on int32 workloads (#15486)
    
    Prior to this PR, the dlight matmul rule uses hardcoded `int64`
    as index dtype. When the input workload is `int32` dtyped and has
    symbolic vars like `n: T.int32`, the PrimFunc generated by the
    matmul rule may contain PrimExpr like
    ```python
    T.Cast("int32", T.Cast("int64", n))
    ```
    which is not wrong but not 100% correct.
    
    This PR changes the hardcoded `int64` to the dtype of the input block
    iters in the rule, and resolves the issue above.
---
 python/tvm/dlight/gpu/matmul.py                  |  4 +-
 tests/python/dlight/test_gpu_matmul.py           | 92 ++++++++++++++++++++++++
 tests/python/dlight/test_gpu_matmul_tensorize.py | 26 +++----
 3 files changed, 107 insertions(+), 15 deletions(-)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 552adfa141..273fecaf41 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -122,7 +122,7 @@ def make_iter_fusion_index_map(
     fused_iters: Dict[IterKind, PrimExpr] = {}
     input_iters: List[tir.Var] = []
     for i, trait in enumerate(traits):
-        v_i = tir.Var(f"i{i}", "int64")
+        v_i = tir.Var(f"i{i}", trait.extent.dtype)
         input_iters.append(v_i)
         if trait.kind == IterKind.kIter_T:
             continue
@@ -134,7 +134,7 @@ def make_iter_fusion_index_map(
             fused_iters[trait.kind] = v_i
 
     final_indices: List[tir.PrimExpr] = [
-        fused_iters.get(kind, tir.IntImm("int64", 0)) for kind in kind_order
+        fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order
     ]
 
     return tir.IndexMap(input_iters, final_indices, None)
diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py
index 4b1b61a5e1..550e30e6e7 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -122,6 +122,98 @@ class TestMatmul(BaseBeforeAfter):
     # fmt: on
 
 
+def test_matmul_int32():
+    # fmt: off
+    @T.prim_func(private=True)
+    def func(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul: T.handle):
+        m = T.int32()
+        inp0 = T.match_buffer(var_inp0, (1, m, 4096))
+        matmul = T.match_buffer(var_matmul, (1, m, 4096))
+        for i0, i1, i2, k in T.grid(1, m, 4096, 4096):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                with T.init():
+                    matmul[v_i0, v_i1, v_i2] = T.float32(0)
+                matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2]
+
+    @T.prim_func(private=True)
+    def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul: T.handle):
+        T.func_attr({"tir.is_scheduled": 1})
+        m = T.int32()
+        inp0 = T.match_buffer(var_inp0, (1, m, 4096))
+        matmul = T.match_buffer(var_matmul, (1, m, 4096))
+        # with T.block("root"):
+        matmul_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 4096), scope="local")
+        inp0_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 4096), scope="shared")
+        inp1_reindex_shared = T.alloc_buffer((1, 4096, 4096), scope="shared")
+        for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
+            for ax1_0 in T.thread_binding((m + 31) // 32, thread="blockIdx.x"):
+                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
+                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
+                        for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
+                            for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                                for ax2_3_init, ax1_3_0_init in T.grid(4, 2):
+                                    for ax1_3_1_init in T.vectorized(2):
+                                        with T.block("matmul_init"):
+                                            v0 = T.axis.spatial(1, 0)
+                                            v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init)
+                                            v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
+                                            T.reads()
+                                            T.writes(matmul_reindex_pad_local[0, v1, v2])
+                                            matmul_reindex_pad_local[0, v1, v2] = T.float32(0)
+                                for ax3_0 in range(256):
+                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in range(2):
+                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
+                                                    with T.block("inp0_reindex_pad_shared"):
+                                                        v0 = T.axis.spatial(1, 0)
+                                                        v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+                                                        v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+                                                        T.reads(inp0[v0, v1, v2])
+                                                        T.writes(inp0_reindex_pad_shared[v0, v1, v2])
+                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
+                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in range(4):
+                                                for ax0_ax1_ax2_fused_3 in T.vectorized(2):
+                                                    with T.block("inp1_reindex_shared"):
+                                                        v0 = T.axis.spatial(1, 0)
+                                                        v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+                                                        v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+                                                        T.reads(inp1[v2, v1])
+                                                        T.writes(inp1_reindex_shared[v0, v1, v2])
+                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
+                                    for ax3_1, ax2_3, ax1_3_0 in T.grid(16, 4, 2):
+                                        for ax1_3_1 in T.vectorized(2):
+                                            with T.block("matmul_update"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1)
+                                                v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
+                                                v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
+                                                T.reads(matmul_reindex_pad_local[0, v1, v2], inp0_reindex_pad_shared[0, v1, v3], inp1_reindex_shared[0, v2, v3])
+                                                T.writes(matmul_reindex_pad_local[0, v1, v2])
+                                                matmul_reindex_pad_local[0, v1, v2] = matmul_reindex_pad_local[0, v1, v2] + inp0_reindex_pad_shared[0, v1, v3] * inp1_reindex_shared[0, v2, v3]
+                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
+                                    for ax2_1_1 in T.vectorized(2):
+                                        with T.block("matmul_reindex_pad_local"):
+                                            v0 = T.axis.spatial(1, ax0)
+                                            v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
+                                            v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
+                                            T.reads(matmul_reindex_pad_local[v0, v1, v2])
+                                            T.writes(matmul[0, v1, v2])
+                                            if v1 < m:
+                                                matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
+    # fmt: on
+
+    mod = tvm.IRModule({"main": func})
+    with Target("nvidia/geforce-gtx-1080-ti"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
 class TestFusedMatmul(BaseBeforeAfter):
     # fmt: off
 
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py
index 09a0ccf662..026a6fa624 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -60,13 +60,13 @@ class TestMatmulTensorize(BaseBeforeAfter):
         W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.matrix_b")
         compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn")
         compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.accumulator")
-        for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+        for ax0 in T.thread_binding(1, thread="blockIdx.z"):
             for ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.x"):
                 for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"):
                     for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"):
                         for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
                             with T.block("compute_o_init"):
-                                v0_o = T.axis.spatial(T.int64(1), ax0)
+                                v0_o = T.axis.spatial(1, ax0)
                                 v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init)
                                 v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init)
                                 T.reads()
@@ -128,7 +128,7 @@ class TestMatmulTensorize(BaseBeforeAfter):
                                             T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major")
                                 for ax1_0_3, ax2_0_3 in T.grid(2, 2):
                                     with T.block("compute_o_update"):
-                                        v0_o = T.axis.spatial(T.int64(1), ax0)
+                                        v0_o = T.axis.spatial(1, ax0)
                                         v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3)
                                         v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3)
                                         v3_o = T.axis.reduce(16, ax3_0_0 * 4 + ax3_0_1)
@@ -195,11 +195,11 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
         X = T.match_buffer(var_X, (m, 256), "float16")
         compute = T.match_buffer(var_compute, (m, 15))
         # with T.block("root"):
-        compute_reindex_pad_local = T.alloc_buffer((1, (T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, 64), scope="local")
-        X_reindex_pad_shared = T.alloc_buffer((1, (T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, 256), "float16", scope="shared")
+        compute_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 64), scope="local")
+        X_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 256), "float16", scope="shared")
         W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16", scope="shared")
-        for ax0_ax2_0_fused in T.thread_binding(T.int64(1), thread="blockIdx.y"):
-            for ax1_0 in T.thread_binding((T.Cast("int32", T.Cast("int64", m)) + 31) // 32, thread="blockIdx.x"):
+        for ax0_ax2_0_fused in T.thread_binding(1, thread="blockIdx.y"):
+            for ax1_0 in T.thread_binding((m + 31) // 32, thread="blockIdx.x"):
                 for ax2_1 in T.thread_binding(1, thread="vthread.y"):
                     for ax1_1 in T.thread_binding(1, thread="vthread.x"):
                         for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
@@ -207,8 +207,8 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
                                 for ax2_3_init, ax1_3_0_init in T.grid(4, 2):
                                     for ax1_3_1_init in T.vectorized(2):
                                         with T.block("compute_init"):
-                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
-                                            v1 = T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init)
+                                            v0 = T.axis.spatial(1, 0)
+                                            v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init)
                                             v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
                                             T.reads()
                                             T.writes(compute_reindex_pad_local[0, v1, v2])
@@ -220,7 +220,7 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
                                                 for ax0_ax1_ax2_fused_3 in T.vectorized(2):
                                                     with T.block("X_reindex_pad_shared"):
                                                         v0 = T.axis.spatial(1, 0)
-                                                        v1 = T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+                                                        v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
                                                         v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
                                                         T.reads(X[v1, v2])
                                                         T.writes(X_reindex_pad_shared[v0, v1, v2])
@@ -241,8 +241,8 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
                                     for ax3_1, ax2_3, ax1_3_0 in T.grid(16, 4, 2):
                                         for ax1_3_1 in T.vectorized(2):
                                             with T.block("compute_update"):
-                                                v0 = T.axis.spatial(T.int64(1), T.int64(0))
-                                                v1 = T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1)
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1)
                                                 v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3)
                                                 v3 = T.axis.reduce(256, ax3_0 * 16 + ax3_1)
                                                 T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3], W_reindex_pad_shared[0, v2, v3])
@@ -252,7 +252,7 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
                                     for ax2_1_1 in T.vectorized(2):
                                         with T.block("compute_reindex_pad_local"):
                                             v0 = T.axis.spatial(1, ax0)
-                                            v1 = T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
+                                            v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                             v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
                                             T.reads(compute_reindex_pad_local[v0, v1, v2])
                                             T.writes(compute[v1, v2])