You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/09/26 00:39:22 UTC

[tvm] branch unity updated: [Unity][Dlight] Fix inline consumer in matmul tensorize rule (#15781)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e809d64658 [Unity][Dlight] Fix inline consumer in matmul tensorize rule (#15781)
e809d64658 is described below

commit e809d64658481035f79ddba3317171cf1abf1810
Author: Hongyi Jin <ji...@gmail.com>
AuthorDate: Mon Sep 25 17:39:14 2023 -0700

    [Unity][Dlight] Fix inline consumer in matmul tensorize rule (#15781)
    
    * fix dlight
    
    * fix
    
    * format
    
    * pass ci
---
 python/tvm/dlight/gpu/matmul.py                  |  43 +++---
 tests/python/dlight/test_gpu_matmul_tensorize.py | 164 +++++++++++++++++++++++
 2 files changed, 188 insertions(+), 19 deletions(-)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 273fecaf41..1d951854d7 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -86,6 +86,28 @@ def auto_inline_consumers(
             return
 
 
+def auto_inline_consumer_chain(
+    sch: tir.Schedule,
+    block: tir.schedule.BlockRV,
+):
+    auto_inline_consumers(sch, block)
+    remaining_consumers = sch.get_consumers(block)
+
+    if len(remaining_consumers) != 0:
+        # Some blocks have failed to be inlined to the producer cache-write stage.
+        # This could be due to another producer block that has not been scheduled.
+        for c in remaining_consumers:
+            for p in sch.get_producers(c):
+                if sch.get(p) != sch.get(block):
+                    sch.compute_inline(p)
+
+        # Try inlining into the cache-write stage again, this time it should succeed.
+        auto_inline_consumers(sch, block)
+
+    msg = "There are some consumers of the cache-write stage that are not properly inlined."
+    assert len(sch.get_consumers(block)) == 0, msg
+
+
 class IterKind(Enum):
     """Iter kinds for GEMM-liked programs.
     We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K],
@@ -482,8 +504,7 @@ class MatmulTensorization(ScheduleRule):
                 tensorize_success = True
             except:  # pylint: disable=bare-except
                 return None
-
-        auto_inline_consumers(sch, accumulator_shared_to_global)
+        auto_inline_consumer_chain(sch, accumulator_shared_to_global)
 
         fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:])
         _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size])
@@ -669,23 +690,7 @@ class Matmul(ScheduleRule):
         else:
             auto_inline_producers(sch, main_block)
 
-        auto_inline_consumers(sch, l2g)
-
-        remaining_consumers = sch.get_consumers(l2g)
-
-        if len(remaining_consumers) != 0:
-            # Some blocks have failed to be inlined to the producer cache-write stage.
-            # This could be due to another producer block that has not been scheduled.
-            for c in remaining_consumers:
-                for p in sch.get_producers(c):
-                    if sch.get(p) != sch.get(l2g):
-                        sch.compute_inline(p)
-
-            # Try inlining into the cache-write stage again, this time it should succeed.
-            auto_inline_consumers(sch, l2g)
-
-        msg = "There are some consumers of the cache-write stage that are not properly inlined."
-        assert len(sch.get_consumers(l2g)) == 0, msg
+        auto_inline_consumer_chain(sch, l2g)
 
         sch.decompose_reduction(main_block, ko)
         return sch
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py
index 026a6fa624..ebb1f0cd97 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -261,5 +261,169 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
     # fmt: on
 
 
+class TestMatmulTensorizeEpilogue(BaseBeforeAfter):
+    # fmt: off
+
+    @T.prim_func
+    def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Buffer((T.int32(4096), T.int32(64)), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        n = T.int32()
+        lv42 = T.match_buffer(p_lv42, (T.int32(1), n, T.int32(2048)), "float16")
+        lv3 = T.match_buffer(p_lv3, (T.int32(1), n, T.int32(4096)), "float16")
+        p_output0_intermediate = T.match_buffer(p_output0, (T.int32(1), n, T.int32(4096)), "float16")
+        # with T.block("root"):
+        p_output0_intermediate_1 = T.alloc_buffer((T.int32(4096), T.int32(2048)), "float16")
+        var_NT_matmul_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16")
+        var_T_divide_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16")
+        for i, j in T.grid(T.int32(4096), T.int32(2048)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv686[v_i, v_j // T.int32(8)], lv687[v_i, v_j // T.int32(32)])
+                T.writes(p_output0_intermediate_1[v_i, v_j])
+                p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v_i, v_j // T.int32(8)], T.Cast("uint32", v_j % T.int32(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v_i, v_j // T.int32(32)]
+        for i0, i1, i2, k in T.grid(T.int32(1), n, T.int32(4096), T.int32(2048)):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv42[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k])
+                T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+                var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv42[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k]
+        for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)):
+            with T.block("T_divide"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(lv3[v_ax0, v_ax1, v_ax2])
+                T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2])
+                var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * T.float16(0.5)
+        for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+                T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+                p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+    @T.prim_func
+    def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle):
+        T.func_attr({"global_symbol": "fused_fused_decode3_fused_NT_matmul6_divide1_add1", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+        n = T.int32()
+        lv42 = T.match_buffer(p_lv42, (1, n, 2048), "float16")
+        lv3 = T.match_buffer(p_lv3, (1, n, 4096), "float16")
+        p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16")
+        # with T.block("root"):
+        lv42_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="shared.dyn")
+        p_output0_intermediate_1_reindex_shared_dyn = T.alloc_buffer((1, 4096, 2048), "float16", scope="shared.dyn")
+        lv42_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="wmma.matrix_a")
+        p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 2048), "float16", scope="wmma.matrix_b")
+        var_NT_matmul_intermediate_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="shared.dyn")
+        var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="wmma.accumulator")
+        for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+            for ax1_0_0_ax2_0_0_fused in T.thread_binding((n + 127) // 128, thread="blockIdx.x"):
+                for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"):
+                    for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"):
+                        for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+                            with T.block("NT_matmul_o_init"):
+                                v0_o = T.axis.spatial(1, ax0)
+                                v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init)
+                                v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init)
+                                T.reads()
+                                T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                with T.block("NT_matmul_init_o"):
+                                    v1_i_init_o = T.axis.spatial(1, 0)
+                                    v2_i_init_o = T.axis.spatial(1, 0)
+                                    T.reads()
+                                    T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                    C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16)
+                                    T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0))
+                        for ax3_0_0 in range(32):
+                            for ax0_ax1_fused_0 in range(4):
+                                for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(4):
+                                            with T.block("lv42_reindex_pad_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64)
+                                                v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64)
+                                                T.reads(lv42[v0, v1, v2])
+                                                T.writes(lv42_reindex_pad_shared_dyn[v0, v1, v2])
+                                                T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+                                                lv42_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < n, lv42[v0, v1, v2], T.float16(0))
+                            for ax0_ax1_fused_0 in range(4):
+                                for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(4):
+                                            with T.block("p_output0_intermediate_1_reindex_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64)
+                                                v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64)
+                                                T.reads(lv686[v1, v2 // 8], lv687[v1, v2 // 32])
+                                                T.writes(p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2])
+                                                T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+                                                p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v1, v2 // 32]
+                            for ax3_0_1 in range(4):
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with T.block("lv42_reindex_pad_shared.dyn_wmma.matrix_a_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0)
+                                            T.reads(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            T.writes(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            A = T.match_buffer(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16)
+                                            C = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major")
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with T.block("p_output0_intermediate_1_reindex_shared.dyn_wmma.matrix_b_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0)
+                                            T.reads(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            T.writes(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            A = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16)
+                                            C = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major")
+                                for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+                                    with T.block("NT_matmul_o_update"):
+                                        v0_o = T.axis.spatial(1, ax0)
+                                        v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3)
+                                        v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3)
+                                        v3_o = T.axis.reduce(128, ax3_0_0 * 4 + ax3_0_1)
+                                        T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
+                                        T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                        with T.block("NT_matmul_o"):
+                                            v1_i_o = T.axis.spatial(1, 0)
+                                            v2_i_o = T.axis.spatial(1, 0)
+                                            v3_i_o = T.axis.reduce(1, 0)
+                                            T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
+                                            T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            A = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16)
+                                            B = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16)
+                                            C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16)
+                                            T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16)
+                        for ax0_0, ax1_0 in T.grid(2, 2):
+                            with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator_o"):
+                                v0_o = T.axis.spatial(1, 0)
+                                v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+                                v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0)
+                                T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                A = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16)
+                                C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16)
+                                T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major")
+                        for ax0_ax1_fused_0 in range(8):
+                            for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
+                                for ax0_ax1_fused_2 in T.vectorized(4):
+                                    with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
+                                        v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+                                        T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2])
+                                        T.writes(p_output0_intermediate[0, v1, v2])
+                                        T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
+                                        if v1 < n:
+                                            p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
+    # fmt: on
+
+
 if __name__ == "__main__":
     tvm.testing.main()