You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/09/26 00:39:22 UTC
[tvm] branch unity updated: [Unity][Dlight] Fix inline consumer in matmul tensorize rule (#15781)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e809d64658 [Unity][Dlight] Fix inline consumer in matmul tensorize rule (#15781)
e809d64658 is described below
commit e809d64658481035f79ddba3317171cf1abf1810
Author: Hongyi Jin <ji...@gmail.com>
AuthorDate: Mon Sep 25 17:39:14 2023 -0700
[Unity][Dlight] Fix inline consumer in matmul tensorize rule (#15781)
* fix dlight
* fix
* format
* pass ci
---
python/tvm/dlight/gpu/matmul.py | 43 +++---
tests/python/dlight/test_gpu_matmul_tensorize.py | 164 +++++++++++++++++++++++
2 files changed, 188 insertions(+), 19 deletions(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 273fecaf41..1d951854d7 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -86,6 +86,28 @@ def auto_inline_consumers(
return
+def auto_inline_consumer_chain(
+ sch: tir.Schedule,
+ block: tir.schedule.BlockRV,
+):
+ auto_inline_consumers(sch, block)
+ remaining_consumers = sch.get_consumers(block)
+
+ if len(remaining_consumers) != 0:
+ # Some blocks have failed to be inlined to the producer cache-write stage.
+ # This could be due to another producer block that has not been scheduled.
+ for c in remaining_consumers:
+ for p in sch.get_producers(c):
+ if sch.get(p) != sch.get(block):
+ sch.compute_inline(p)
+
+ # Try inlining into the cache-write stage again, this time it should succeed.
+ auto_inline_consumers(sch, block)
+
+ msg = "There are some consumers of the cache-write stage that are not properly inlined."
+ assert len(sch.get_consumers(block)) == 0, msg
+
+
class IterKind(Enum):
"""Iter kinds for GEMM-liked programs.
We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K],
@@ -482,8 +504,7 @@ class MatmulTensorization(ScheduleRule):
tensorize_success = True
except: # pylint: disable=bare-except
return None
-
- auto_inline_consumers(sch, accumulator_shared_to_global)
+ auto_inline_consumer_chain(sch, accumulator_shared_to_global)
fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:])
_, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size])
@@ -669,23 +690,7 @@ class Matmul(ScheduleRule):
else:
auto_inline_producers(sch, main_block)
- auto_inline_consumers(sch, l2g)
-
- remaining_consumers = sch.get_consumers(l2g)
-
- if len(remaining_consumers) != 0:
- # Some blocks have failed to be inlined to the producer cache-write stage.
- # This could be due to another producer block that has not been scheduled.
- for c in remaining_consumers:
- for p in sch.get_producers(c):
- if sch.get(p) != sch.get(l2g):
- sch.compute_inline(p)
-
- # Try inlining into the cache-write stage again, this time it should succeed.
- auto_inline_consumers(sch, l2g)
-
- msg = "There are some consumers of the cache-write stage that are not properly inlined."
- assert len(sch.get_consumers(l2g)) == 0, msg
+ auto_inline_consumer_chain(sch, l2g)
sch.decompose_reduction(main_block, ko)
return sch
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py
index 026a6fa624..ebb1f0cd97 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -261,5 +261,169 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
# fmt: on
+class TestMatmulTensorizeEpilogue(BaseBeforeAfter):
+ # fmt: off
+
+ @T.prim_func
+ def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Buffer((T.int32(4096), T.int32(64)), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int32()
+ lv42 = T.match_buffer(p_lv42, (T.int32(1), n, T.int32(2048)), "float16")
+ lv3 = T.match_buffer(p_lv3, (T.int32(1), n, T.int32(4096)), "float16")
+ p_output0_intermediate = T.match_buffer(p_output0, (T.int32(1), n, T.int32(4096)), "float16")
+ # with T.block("root"):
+ p_output0_intermediate_1 = T.alloc_buffer((T.int32(4096), T.int32(2048)), "float16")
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16")
+ var_T_divide_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16")
+ for i, j in T.grid(T.int32(4096), T.int32(2048)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(lv686[v_i, v_j // T.int32(8)], lv687[v_i, v_j // T.int32(32)])
+ T.writes(p_output0_intermediate_1[v_i, v_j])
+ p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v_i, v_j // T.int32(8)], T.Cast("uint32", v_j % T.int32(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v_i, v_j // T.int32(32)]
+ for i0, i1, i2, k in T.grid(T.int32(1), n, T.int32(4096), T.int32(2048)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(lv42[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k])
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
+ with T.init():
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv42[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k]
+ for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)):
+ with T.block("T_divide"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(lv3[v_ax0, v_ax1, v_ax2])
+ T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2])
+ var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * T.float16(0.5)
+ for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+ @T.prim_func
+ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle):
+ T.func_attr({"global_symbol": "fused_fused_decode3_fused_NT_matmul6_divide1_add1", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ n = T.int32()
+ lv42 = T.match_buffer(p_lv42, (1, n, 2048), "float16")
+ lv3 = T.match_buffer(p_lv3, (1, n, 4096), "float16")
+ p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16")
+ # with T.block("root"):
+ lv42_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="shared.dyn")
+ p_output0_intermediate_1_reindex_shared_dyn = T.alloc_buffer((1, 4096, 2048), "float16", scope="shared.dyn")
+ lv42_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="wmma.matrix_a")
+ p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 2048), "float16", scope="wmma.matrix_b")
+ var_NT_matmul_intermediate_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="shared.dyn")
+ var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="wmma.accumulator")
+ for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+ for ax1_0_0_ax2_0_0_fused in T.thread_binding((n + 127) // 128, thread="blockIdx.x"):
+ for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"):
+ for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"):
+ for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+ with T.block("NT_matmul_o_init"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init)
+ v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init)
+ T.reads()
+ T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ with T.block("NT_matmul_init_o"):
+ v1_i_init_o = T.axis.spatial(1, 0)
+ v2_i_init_o = T.axis.spatial(1, 0)
+ T.reads()
+ T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16)
+ T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0))
+ for ax3_0_0 in range(32):
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.vectorized(4):
+ with T.block("lv42_reindex_pad_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64)
+ v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64)
+ T.reads(lv42[v0, v1, v2])
+ T.writes(lv42_reindex_pad_shared_dyn[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+ lv42_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < n, lv42[v0, v1, v2], T.float16(0))
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.vectorized(4):
+ with T.block("p_output0_intermediate_1_reindex_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64)
+ v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64)
+ T.reads(lv686[v1, v2 // 8], lv687[v1, v2 // 32])
+ T.writes(p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+ p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v1, v2 // 32]
+ for ax3_0_1 in range(4):
+ for ax0_0 in T.unroll(2):
+ for ax1_0 in T.unroll(1):
+ with T.block("lv42_reindex_pad_shared.dyn_wmma.matrix_a_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0)
+ T.reads(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ T.writes(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ A = T.match_buffer(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16)
+ C = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16)
+ T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major")
+ for ax0_0 in T.unroll(2):
+ for ax1_0 in T.unroll(1):
+ with T.block("p_output0_intermediate_1_reindex_shared.dyn_wmma.matrix_b_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0)
+ T.reads(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ T.writes(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ A = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16)
+ C = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16)
+ T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major")
+ for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+ with T.block("NT_matmul_o_update"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3)
+ v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3)
+ v3_o = T.axis.reduce(128, ax3_0_0 * 4 + ax3_0_1)
+ T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
+ T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ with T.block("NT_matmul_o"):
+ v1_i_o = T.axis.spatial(1, 0)
+ v2_i_o = T.axis.spatial(1, 0)
+ v3_i_o = T.axis.reduce(1, 0)
+ T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
+ T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ A = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16)
+ B = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16)
+ C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16)
+ T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16)
+ for ax0_0, ax1_0 in T.grid(2, 2):
+ with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0)
+ T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ A = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16)
+ C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16)
+ T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major")
+ for ax0_ax1_fused_0 in range(8):
+ for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_2 in T.vectorized(4):
+ with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
+ v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+ T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2])
+ T.writes(p_output0_intermediate[0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
+ if v1 < n:
+ p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
+ # fmt: on
+
+
if __name__ == "__main__":
tvm.testing.main()