You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/07/18 21:58:40 UTC

[tvm] 04/07: [Unity][Dlight] Rule matmul avoiding blockIdx.z (#15333)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 6294aada465143d7927c2d91a677567071b5d9bc
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Jul 17 02:59:46 2023 -0700

    [Unity][Dlight] Rule matmul avoiding blockIdx.z (#15333)
    
    Prior to this PR, the matmul rule of dlight binds loops to `blockIdx.z`.
    However, not every device supports this blockIdx dimension (for example,
    WebGPU does not support `blockIdx.z`), which makes dlight fails to
    apply and build.
    
    Therefore, this PR fuses the `blockIdx.z` loop with other `blockIdx`
    loop.
---
 python/tvm/dlight/gpu/matmul.py        |   4 +-
 tests/python/dlight/test_gpu_matmul.py | 225 ++++++++++++++++-----------------
 2 files changed, 114 insertions(+), 115 deletions(-)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index be5e4b02d7..b9977d08b9 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -327,8 +327,8 @@ class Matmul(ScheduleRule):
         bx, vx, tx, xi = sch.split(x, [None, vthread_x, block_size_x, micro_size_x])
         by, vy, ty, yi = sch.split(y, [None, vthread_y, block_size_y, micro_size_y])
         ko, ki = sch.split(k, factors=[None, micro_size_k])
-        sch.reorder(bx, by, vy, vx, ty, tx, ko, ki, yi, xi)
-        sch.bind(batch, "blockIdx.z")
+        sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi)
+        by = sch.fuse(batch, by)
         sch.bind(bx, "blockIdx.x")
         sch.bind(by, "blockIdx.y")
         sch.bind(vy, "vthread.y")
diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py
index f3d9a7089d..318a3e833c 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -57,65 +57,64 @@ class TestMatmul(BaseBeforeAfter):
         matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
         inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
         inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
-        for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+        for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
             for ax1_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
-                for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
-                    for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
-                        for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
-                            for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                                    for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
-                                        with T.block("matmul_init"):
-                                            v0 = T.axis.spatial(T.int64(1), ax0)
-                                            v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
-                                            v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
-                                            T.reads()
+                for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+                    for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
+                        for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                            for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                                for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
+                                    with T.block("matmul_init"):
+                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                        v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
+                                        v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+                                        T.reads()
+                                        T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+                                        matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
+                                for ax3_0 in range(T.int64(256)):
+                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
+                                                for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+                                                    with T.block("inp0_reindex_pad_shared"):
+                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(inp0[v0, v1, v2])
+                                                        T.writes(inp0_reindex_pad_shared[v0, v1, v2])
+                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
+                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
+                                                for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+                                                    with T.block("inp1_reindex_shared"):
+                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(inp1[v2, v1])
+                                                        T.writes(inp1_reindex_shared[v0, v1, v2])
+                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
+                                    for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
+                                        with T.block("matmul_update"):
+                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                            v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+                                            v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+                                            v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
+                                            T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3])
                                             T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
-                                            matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
-                                    for ax3_0 in range(T.int64(256)):
-                                        for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
-                                                for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
-                                                    for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
-                                                        with T.block("inp0_reindex_pad_shared"):
-                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
-                                                            v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
-                                                            v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
-                                                            T.reads(inp0[v0, v1, v2])
-                                                            T.writes(inp0_reindex_pad_shared[v0, v1, v2])
-                                                            T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
-                                                            inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
-                                        for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
-                                                for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
-                                                    for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
-                                                        with T.block("inp1_reindex_shared"):
-                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
-                                                            v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
-                                                            v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
-                                                            T.reads(inp1[v2, v1])
-                                                            T.writes(inp1_reindex_shared[v0, v1, v2])
-                                                            T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
-                                                            inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
-                                        for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
-                                            with T.block("matmul_update"):
-                                                v0 = T.axis.spatial(T.int64(1), ax0)
-                                                v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
-                                                v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
-                                                v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
-                                                T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3])
-                                                T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
-                                                matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3]
-                                    for ax0_1, ax1, ax2_0_1 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
-                                        for ax2_1_1 in T.vectorized(T.int64(2)):
-                                            with T.block("matmul_reindex_pad_local"):
-                                                v0 = T.axis.spatial(T.int64(1), ax0_1)
-                                                v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
-                                                v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2_0_1 * T.int64(2) + ax2_1_1)
-                                                T.reads(matmul_reindex_pad_local[v0, v1, v2])
-                                                T.writes(matmul[T.int64(0), v1, v2])
-                                                if v1 < m:
-                                                    matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
+                                            matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3]
+                                for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
+                                    for ax2_1_1 in T.vectorized(T.int64(2)):
+                                        with T.block("matmul_reindex_pad_local"):
+                                            v0 = T.axis.spatial(T.int64(1), ax0)
+                                            v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
+                                            v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
+                                            T.reads(matmul_reindex_pad_local[v0, v1, v2])
+                                            T.writes(matmul[T.int64(0), v1, v2])
+                                            if v1 < m:
+                                                matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
     # fmt: on
 
 
@@ -146,6 +145,7 @@ class TestFusedMatmul(BaseBeforeAfter):
                 T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
                 T.writes(Out[v_ax0, v_ax1, v_ax2])
                 Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
     @T.prim_func
     def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")):
         T.func_attr({"tir.is_scheduled": 1})
@@ -153,64 +153,63 @@ class TestFusedMatmul(BaseBeforeAfter):
         var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local")
         A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared")
         var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
-        for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+        for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
             for ax1_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
-                for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
-                    for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
-                        for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
-                            for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                                    for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
-                                        with T.block("matmul_init"):
-                                            v0 = T.axis.spatial(T.int64(1), ax0)
-                                            v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
-                                            v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
-                                            T.reads()
+                for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+                    for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
+                        for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                            for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                                for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
+                                    with T.block("matmul_init"):
+                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                        v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
+                                        v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+                                        T.reads()
+                                        T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
+                                        var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
+                                for ax3_0 in range(T.int64(256)):
+                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
+                                                for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+                                                    with T.block("A_reindex_shared"):
+                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(A[v0, v1, v2])
+                                                        T.writes(A_reindex_shared[v0, v1, v2])
+                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
+                                    for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
+                                                for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+                                                    with T.block("var_decode_intermediate_reindex_shared"):
+                                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(W[v2 // T.int64(8), v1], S[v2 // T.int64(32), v1])
+                                                        T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
+                                                        T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+                                    for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
+                                        with T.block("matmul_update"):
+                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                            v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+                                            v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+                                            v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
+                                            T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], A_reindex_shared[T.int64(0), v1, v3], var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
                                             T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
-                                            var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
-                                    for ax3_0 in range(T.int64(256)):
-                                        for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
-                                                for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
-                                                    for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
-                                                        with T.block("A_reindex_shared"):
-                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
-                                                            v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
-                                                            v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
-                                                            T.reads(A[v0, v1, v2])
-                                                            T.writes(A_reindex_shared[v0, v1, v2])
-                                                            T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
-                                                            A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
-                                        for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
-                                                for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
-                                                    for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
-                                                        with T.block("var_decode_intermediate_reindex_shared"):
-                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
-                                                            v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
-                                                            v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
-                                                            T.reads(W[v2 // T.int64(8), v1], S[v2 // T.int64(32), v1])
-                                                            T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
-                                                            T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
-                                                            var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.ui [...]
-                                        for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
-                                            with T.block("matmul_update"):
-                                                v0 = T.axis.spatial(T.int64(1), ax0)
-                                                v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
-                                                v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
-                                                v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
-                                                T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], A_reindex_shared[T.int64(0), v1, v3], var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
-                                                T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
-                                                var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + A_reindex_shared[T.int64(0), v1, v3] * var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
-                                    for ax0_1, ax1, ax2_0_1 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
-                                        for ax2_1_1 in T.vectorized(T.int64(2)):
-                                            with T.block("var_matmul_intermediate_reindex_local"):
-                                                v0 = T.axis.spatial(T.int64(1), ax0_1)
-                                                v1 = T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1)
-                                                v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2_0_1 * T.int64(2) + ax2_1_1)
-                                                T.reads(C[T.int64(0), v1, v2], var_matmul_intermediate_reindex_local[v0, v1, v2])
-                                                T.writes(Out[T.int64(0), v1, v2])
-                                                Out[T.int64(0), v1, v2] = C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
+                                            var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + A_reindex_shared[T.int64(0), v1, v3] * var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
+                                for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
+                                    for ax2_1_1 in T.vectorized(T.int64(2)):
+                                        with T.block("var_matmul_intermediate_reindex_local"):
+                                            v0 = T.axis.spatial(T.int64(1), ax0)
+                                            v1 = T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1)
+                                            v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
+                                            T.reads(C[T.int64(0), v1, v2], var_matmul_intermediate_reindex_local[v0, v1, v2])
+                                            T.writes(Out[T.int64(0), v1, v2])
+                                            Out[T.int64(0), v1, v2] = C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
     # fmt: on