You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/07/18 21:58:40 UTC
[tvm] 04/07: [Unity][Dlight] Rule matmul avoiding blockIdx.z (#15333)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 6294aada465143d7927c2d91a677567071b5d9bc
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Jul 17 02:59:46 2023 -0700
[Unity][Dlight] Rule matmul avoiding blockIdx.z (#15333)
Prior to this PR, the matmul rule of dlight binds loops to `blockIdx.z`.
However, not every device supports this blockIdx dimension (for example,
WebGPU does not support `blockIdx.z`), which makes dlight fails to
apply and build.
Therefore, this PR fuses the `blockIdx.z` loop with other `blockIdx`
loop.
---
python/tvm/dlight/gpu/matmul.py | 4 +-
tests/python/dlight/test_gpu_matmul.py | 225 ++++++++++++++++-----------------
2 files changed, 114 insertions(+), 115 deletions(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index be5e4b02d7..b9977d08b9 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -327,8 +327,8 @@ class Matmul(ScheduleRule):
bx, vx, tx, xi = sch.split(x, [None, vthread_x, block_size_x, micro_size_x])
by, vy, ty, yi = sch.split(y, [None, vthread_y, block_size_y, micro_size_y])
ko, ki = sch.split(k, factors=[None, micro_size_k])
- sch.reorder(bx, by, vy, vx, ty, tx, ko, ki, yi, xi)
- sch.bind(batch, "blockIdx.z")
+ sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi)
+ by = sch.fuse(batch, by)
sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")
sch.bind(vy, "vthread.y")
diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py
index f3d9a7089d..318a3e833c 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -57,65 +57,64 @@ class TestMatmul(BaseBeforeAfter):
matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
- for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+ for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
for ax1_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
- for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
- for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
- for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
- for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
- with T.block("matmul_init"):
- v0 = T.axis.spatial(T.int64(1), ax0)
- v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
- v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
- T.reads()
+ for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+ for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
+ for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
+ with T.block("matmul_init"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
+ v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+ T.reads()
+ T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+ matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
+ for ax3_0 in range(T.int64(256)):
+ for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+ with T.block("inp0_reindex_pad_shared"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(inp0[v0, v1, v2])
+ T.writes(inp0_reindex_pad_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+ inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
+ for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+ with T.block("inp1_reindex_shared"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(inp1[v2, v1])
+ T.writes(inp1_reindex_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+ inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
+ for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
+ with T.block("matmul_update"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+ v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+ v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
+ T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3])
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
- matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
- for ax3_0 in range(T.int64(256)):
- for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
- for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
- with T.block("inp0_reindex_pad_shared"):
- v0 = T.axis.spatial(T.int64(1), T.int64(0))
- v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
- T.reads(inp0[v0, v1, v2])
- T.writes(inp0_reindex_pad_shared[v0, v1, v2])
- T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
- inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
- for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
- for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
- with T.block("inp1_reindex_shared"):
- v0 = T.axis.spatial(T.int64(1), T.int64(0))
- v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
- T.reads(inp1[v2, v1])
- T.writes(inp1_reindex_shared[v0, v1, v2])
- T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
- inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
- for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
- with T.block("matmul_update"):
- v0 = T.axis.spatial(T.int64(1), ax0)
- v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
- v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
- v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
- T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3])
- T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
- matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3]
- for ax0_1, ax1, ax2_0_1 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
- for ax2_1_1 in T.vectorized(T.int64(2)):
- with T.block("matmul_reindex_pad_local"):
- v0 = T.axis.spatial(T.int64(1), ax0_1)
- v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
- v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2_0_1 * T.int64(2) + ax2_1_1)
- T.reads(matmul_reindex_pad_local[v0, v1, v2])
- T.writes(matmul[T.int64(0), v1, v2])
- if v1 < m:
- matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
+ matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3]
+ for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
+ for ax2_1_1 in T.vectorized(T.int64(2)):
+ with T.block("matmul_reindex_pad_local"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
+ v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
+ T.reads(matmul_reindex_pad_local[v0, v1, v2])
+ T.writes(matmul[T.int64(0), v1, v2])
+ if v1 < m:
+ matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
@@ -146,6 +145,7 @@ class TestFusedMatmul(BaseBeforeAfter):
T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
T.writes(Out[v_ax0, v_ax1, v_ax2])
Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
@T.prim_func
def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")):
T.func_attr({"tir.is_scheduled": 1})
@@ -153,64 +153,63 @@ class TestFusedMatmul(BaseBeforeAfter):
var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local")
A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared")
var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
- for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+ for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
for ax1_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
- for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
- for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
- for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
- for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
- with T.block("matmul_init"):
- v0 = T.axis.spatial(T.int64(1), ax0)
- v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
- v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
- T.reads()
+ for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+ for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
+ for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
+ with T.block("matmul_init"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
+ v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+ T.reads()
+ T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
+ var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
+ for ax3_0 in range(T.int64(256)):
+ for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+ with T.block("A_reindex_shared"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(A[v0, v1, v2])
+ T.writes(A_reindex_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+ A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
+ for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+ with T.block("var_decode_intermediate_reindex_shared"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(W[v2 // T.int64(8), v1], S[v2 // T.int64(32), v1])
+ T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+ var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+ for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
+ with T.block("matmul_update"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+ v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+ v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
+ T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], A_reindex_shared[T.int64(0), v1, v3], var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
- var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
- for ax3_0 in range(T.int64(256)):
- for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
- for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
- with T.block("A_reindex_shared"):
- v0 = T.axis.spatial(T.int64(1), T.int64(0))
- v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
- T.reads(A[v0, v1, v2])
- T.writes(A_reindex_shared[v0, v1, v2])
- T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
- A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
- for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
- for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
- with T.block("var_decode_intermediate_reindex_shared"):
- v0 = T.axis.spatial(T.int64(1), T.int64(0))
- v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
- T.reads(W[v2 // T.int64(8), v1], S[v2 // T.int64(32), v1])
- T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
- T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
- var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.ui [...]
- for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
- with T.block("matmul_update"):
- v0 = T.axis.spatial(T.int64(1), ax0)
- v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
- v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
- v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
- T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], A_reindex_shared[T.int64(0), v1, v3], var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
- T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
- var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + A_reindex_shared[T.int64(0), v1, v3] * var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
- for ax0_1, ax1, ax2_0_1 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
- for ax2_1_1 in T.vectorized(T.int64(2)):
- with T.block("var_matmul_intermediate_reindex_local"):
- v0 = T.axis.spatial(T.int64(1), ax0_1)
- v1 = T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1)
- v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2_0_1 * T.int64(2) + ax2_1_1)
- T.reads(C[T.int64(0), v1, v2], var_matmul_intermediate_reindex_local[v0, v1, v2])
- T.writes(Out[T.int64(0), v1, v2])
- Out[T.int64(0), v1, v2] = C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
+ var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + A_reindex_shared[T.int64(0), v1, v3] * var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
+ for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
+ for ax2_1_1 in T.vectorized(T.int64(2)):
+ with T.block("var_matmul_intermediate_reindex_local"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ v1 = T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1)
+ v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
+ T.reads(C[T.int64(0), v1, v2], var_matmul_intermediate_reindex_local[v0, v1, v2])
+ T.writes(Out[T.int64(0), v1, v2])
+ Out[T.int64(0), v1, v2] = C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
# fmt: on