You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/06 20:29:36 UTC
[tvm] branch unity updated: [Unity][Dlight] Matmul rule on int32 workloads (#15486)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f0869fecc4 [Unity][Dlight] Matmul rule on int32 workloads (#15486)
f0869fecc4 is described below
commit f0869fecc426bff496c95e4d25ca5b987d1d0b49
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Wed Sep 6 16:29:28 2023 -0400
[Unity][Dlight] Matmul rule on int32 workloads (#15486)
Prior to this PR, the dlight matmul rule uses hardcoded `int64`
as index dtype. When the input workload is `int32` dtyped and has
symbolic vars like `n: T.int32`, the PrimFunc generated by the
matmul rule may contain PrimExpr like
```python
T.Cast("int32", T.Cast("int64", n))
```
which is not wrong but not 100% correct.
This PR changes the hardcoded `int64` to the dtype of the input block
iters in the rule, and resolves the issue above.
---
python/tvm/dlight/gpu/matmul.py | 4 +-
tests/python/dlight/test_gpu_matmul.py | 92 ++++++++++++++++++++++++
tests/python/dlight/test_gpu_matmul_tensorize.py | 26 +++----
3 files changed, 107 insertions(+), 15 deletions(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 552adfa141..273fecaf41 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -122,7 +122,7 @@ def make_iter_fusion_index_map(
fused_iters: Dict[IterKind, PrimExpr] = {}
input_iters: List[tir.Var] = []
for i, trait in enumerate(traits):
- v_i = tir.Var(f"i{i}", "int64")
+ v_i = tir.Var(f"i{i}", trait.extent.dtype)
input_iters.append(v_i)
if trait.kind == IterKind.kIter_T:
continue
@@ -134,7 +134,7 @@ def make_iter_fusion_index_map(
fused_iters[trait.kind] = v_i
final_indices: List[tir.PrimExpr] = [
- fused_iters.get(kind, tir.IntImm("int64", 0)) for kind in kind_order
+ fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order
]
return tir.IndexMap(input_iters, final_indices, None)
diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py
index 4b1b61a5e1..550e30e6e7 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -122,6 +122,98 @@ class TestMatmul(BaseBeforeAfter):
# fmt: on
+def test_matmul_int32():
+ # fmt: off
+ @T.prim_func(private=True)
+ def func(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul: T.handle):
+ m = T.int32()
+ inp0 = T.match_buffer(var_inp0, (1, m, 4096))
+ matmul = T.match_buffer(var_matmul, (1, m, 4096))
+ for i0, i1, i2, k in T.grid(1, m, 4096, 4096):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ with T.init():
+ matmul[v_i0, v_i1, v_i2] = T.float32(0)
+ matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2]
+
+ @T.prim_func(private=True)
+ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul: T.handle):
+ T.func_attr({"tir.is_scheduled": 1})
+ m = T.int32()
+ inp0 = T.match_buffer(var_inp0, (1, m, 4096))
+ matmul = T.match_buffer(var_matmul, (1, m, 4096))
+ # with T.block("root"):
+ matmul_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 4096), scope="local")
+ inp0_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 4096), scope="shared")
+ inp1_reindex_shared = T.alloc_buffer((1, 4096, 4096), scope="shared")
+ for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
+ for ax1_0 in T.thread_binding((m + 31) // 32, thread="blockIdx.x"):
+ for ax2_1 in T.thread_binding(1, thread="vthread.y"):
+ for ax1_1 in T.thread_binding(1, thread="vthread.x"):
+ for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_0_init in T.grid(4, 2):
+ for ax1_3_1_init in T.vectorized(2):
+ with T.block("matmul_init"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init)
+ v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
+ T.reads()
+ T.writes(matmul_reindex_pad_local[0, v1, v2])
+ matmul_reindex_pad_local[0, v1, v2] = T.float32(0)
+ for ax3_0 in range(256):
+ for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in range(2):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(2):
+ with T.block("inp0_reindex_pad_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+ v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+ T.reads(inp0[v0, v1, v2])
+ T.writes(inp0_reindex_pad_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+ inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
+ for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in range(4):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(2):
+ with T.block("inp1_reindex_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+ v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+ T.reads(inp1[v2, v1])
+ T.writes(inp1_reindex_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+ inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
+ for ax3_1, ax2_3, ax1_3_0 in T.grid(16, 4, 2):
+ for ax1_3_1 in T.vectorized(2):
+ with T.block("matmul_update"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1)
+ v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
+ v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
+ T.reads(matmul_reindex_pad_local[0, v1, v2], inp0_reindex_pad_shared[0, v1, v3], inp1_reindex_shared[0, v2, v3])
+ T.writes(matmul_reindex_pad_local[0, v1, v2])
+ matmul_reindex_pad_local[0, v1, v2] = matmul_reindex_pad_local[0, v1, v2] + inp0_reindex_pad_shared[0, v1, v3] * inp1_reindex_shared[0, v2, v3]
+ for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
+ for ax2_1_1 in T.vectorized(2):
+ with T.block("matmul_reindex_pad_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
+ v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
+ T.reads(matmul_reindex_pad_local[v0, v1, v2])
+ T.writes(matmul[0, v1, v2])
+ if v1 < m:
+ matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
+ # fmt: on
+
+ mod = tvm.IRModule({"main": func})
+ with Target("nvidia/geforce-gtx-1080-ti"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
class TestFusedMatmul(BaseBeforeAfter):
# fmt: off
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py
index 09a0ccf662..026a6fa624 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -60,13 +60,13 @@ class TestMatmulTensorize(BaseBeforeAfter):
W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.matrix_b")
compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn")
compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.accumulator")
- for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+ for ax0 in T.thread_binding(1, thread="blockIdx.z"):
for ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.x"):
for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"):
for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"):
for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
with T.block("compute_o_init"):
- v0_o = T.axis.spatial(T.int64(1), ax0)
+ v0_o = T.axis.spatial(1, ax0)
v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init)
v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init)
T.reads()
@@ -128,7 +128,7 @@ class TestMatmulTensorize(BaseBeforeAfter):
T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major")
for ax1_0_3, ax2_0_3 in T.grid(2, 2):
with T.block("compute_o_update"):
- v0_o = T.axis.spatial(T.int64(1), ax0)
+ v0_o = T.axis.spatial(1, ax0)
v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3)
v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3)
v3_o = T.axis.reduce(16, ax3_0_0 * 4 + ax3_0_1)
@@ -195,11 +195,11 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
X = T.match_buffer(var_X, (m, 256), "float16")
compute = T.match_buffer(var_compute, (m, 15))
# with T.block("root"):
- compute_reindex_pad_local = T.alloc_buffer((1, (T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, 64), scope="local")
- X_reindex_pad_shared = T.alloc_buffer((1, (T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, 256), "float16", scope="shared")
+ compute_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 64), scope="local")
+ X_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 256), "float16", scope="shared")
W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16", scope="shared")
- for ax0_ax2_0_fused in T.thread_binding(T.int64(1), thread="blockIdx.y"):
- for ax1_0 in T.thread_binding((T.Cast("int32", T.Cast("int64", m)) + 31) // 32, thread="blockIdx.x"):
+ for ax0_ax2_0_fused in T.thread_binding(1, thread="blockIdx.y"):
+ for ax1_0 in T.thread_binding((m + 31) // 32, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
@@ -207,8 +207,8 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
for ax2_3_init, ax1_3_0_init in T.grid(4, 2):
for ax1_3_1_init in T.vectorized(2):
with T.block("compute_init"):
- v0 = T.axis.spatial(T.int64(1), T.int64(0))
- v1 = T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init)
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init)
v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
T.reads()
T.writes(compute_reindex_pad_local[0, v1, v2])
@@ -220,7 +220,7 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("X_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+ v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(X[v1, v2])
T.writes(X_reindex_pad_shared[v0, v1, v2])
@@ -241,8 +241,8 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
for ax3_1, ax2_3, ax1_3_0 in T.grid(16, 4, 2):
for ax1_3_1 in T.vectorized(2):
with T.block("compute_update"):
- v0 = T.axis.spatial(T.int64(1), T.int64(0))
- v1 = T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1)
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1)
v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3)
v3 = T.axis.reduce(256, ax3_0 * 16 + ax3_1)
T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3], W_reindex_pad_shared[0, v2, v3])
@@ -252,7 +252,7 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
for ax2_1_1 in T.vectorized(2):
with T.block("compute_reindex_pad_local"):
v0 = T.axis.spatial(1, ax0)
- v1 = T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
+ v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.reads(compute_reindex_pad_local[v0, v1, v2])
T.writes(compute[v1, v2])