You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/07/02 17:54:08 UTC

[tvm] branch unity updated: [Unity][Dlight] Matmul Rules (#15191)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 918fc4ecf7 [Unity][Dlight] Matmul Rules (#15191)
918fc4ecf7 is described below

commit 918fc4ecf7803bdf89466f82e0b634951a47f64c
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Mon Jul 3 01:54:02 2023 +0800

    [Unity][Dlight] Matmul Rules (#15191)
    
    This PR introduces default schedule rules for matmul kernels. Note that
    we skip GEMV-liked kernels as it would be a separate rule.
---
 python/tvm/dlight/base/analysis.py                 |  40 +++
 python/tvm/dlight/base/transform.py                |   2 +-
 python/tvm/dlight/gpu/__init__.py                  |   1 +
 python/tvm/dlight/gpu/fallback.py                  |  14 +-
 python/tvm/dlight/gpu/matmul.py                    | 366 +++++++++++++++++++++
 .../schedule/primitive/layout_transformation.cc    |  11 +-
 tests/python/dlight/test_gpu_matmul.py             | 252 ++++++++++++++
 7 files changed, 670 insertions(+), 16 deletions(-)

diff --git a/python/tvm/dlight/base/analysis.py b/python/tvm/dlight/base/analysis.py
index a5d70c6c0c..d11e29a8ad 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -21,6 +21,9 @@ from typing_extensions import Literal
 
 from tvm import tir
 from tvm._ffi import get_global_func
+from tvm.target.target import Target
+from tvm.tir import Schedule
+from tvm.tir.schedule import BlockRV
 
 
 class IterInfo:
@@ -146,3 +149,40 @@ def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
             )
         )
     return blocks
+
+
+def _assert_gpu_target(target: Target):
+    if "gpu" not in target.keys:
+        raise ValueError(f"Expect a GPU target, but got {target}")
+
+
+def get_max_threads_per_block(target: Target) -> int:
+    _assert_gpu_target(target)
+    max_threads_per_block = None
+    for name in ["max_threads_per_block", "max_num_threads"]:
+        if max_threads_per_block is None:
+            max_threads_per_block = target.attrs.get(name, None)
+    if max_threads_per_block is None:
+        max_threads_per_block = 64
+    return int(max_threads_per_block)
+
+
+def get_max_shared_memory_per_block(target: Target) -> int:
+    _assert_gpu_target(target)
+    max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None)
+    if max_shared_memory_per_block is None:
+        raise ValueError(
+            f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually"
+        )
+    return int(max_shared_memory_per_block)
+
+
+def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
+    try:
+        block = sch.mod[func_name].body.block
+    except:
+        raise ValueError(
+            f"The function body is expected to be the root block, but got:\n"
+            f"{sch.mod[func_name].body}"
+        )
+    return sch.get_block(block.name_hint)
diff --git a/python/tvm/dlight/base/transform.py b/python/tvm/dlight/base/transform.py
index c11a4ae060..dc02c3dc0f 100644
--- a/python/tvm/dlight/base/transform.py
+++ b/python/tvm/dlight/base/transform.py
@@ -60,7 +60,7 @@ class ApplyDefaultSchedule:  # pylint: disable=too-few-public-methods
         target = Target.current(allow_none=False)
         updated_functions = {}
         for g_var, func in mod.functions.items():
-            if not _is_scheduled(func):
+            if isinstance(func, tir.PrimFunc) and not _is_scheduled(func):
                 sch = _apply_rules(func, target, self.rules, tunable=False)
                 if sch is not None:
                     assert len(sch) == 1
diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py
index b689bef381..79090d400b 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -21,3 +21,4 @@ For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/meta
 from .fallback import Fallback
 from .decode_gemv import DecodeGEMV
 from .reduction import Reduction
+from .matmul import Matmul
diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py
index 63033aa7c7..6b120b1648 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -21,17 +21,7 @@ from typing import List
 from tvm import tir
 from tvm.target import Target
 
-from ..base import ScheduleRule, normalize_prim_func, try_inline
-
-
-def _max_threads_per_block(target: Target) -> int:
-    max_threads_per_block = None
-    for name in ["max_threads_per_block", "max_num_threads"]:
-        if max_threads_per_block is None:
-            max_threads_per_block = target.attrs.get(name, None)
-    if max_threads_per_block is None:
-        max_threads_per_block = 64
-    return int(max_threads_per_block)
+from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline
 
 
 class Fallback(ScheduleRule):
@@ -46,7 +36,7 @@ class Fallback(ScheduleRule):
         target: Target,
         _: bool,
     ) -> tir.Schedule:
-        max_threads_per_block = _max_threads_per_block(target)
+        max_threads_per_block = analysis.get_max_threads_per_block(target)
 
         sch = tir.Schedule(func)
         block_infos = try_inline(sch, normalize_prim_func(sch))
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
new file mode 100644
index 0000000000..e66eaa3222
--- /dev/null
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -0,0 +1,366 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name
+"""A GEMM schedule rule for GPU operators."""
+from enum import Enum
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Set, Tuple
+
+from tvm import tir
+from tvm.ir import Range
+from tvm.target import Target
+from tvm.tir import PrimExpr, Var, IterVar
+from tvm.tir.analysis import undefined_vars
+from tvm.tir.schedule.schedule import BlockRV
+
+from ..base import ScheduleRule, analysis
+
+
+def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
+    result = []
+    for producer in sch.get_producers(block):
+        result.append(producer)
+        result.extend(_collect_producers(sch, producer))
+    return result
+
+
+def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
+    result = []
+    for consumer in sch.get_consumers(block):
+        result.append(consumer)
+        result.extend(_collect_consumers(sch, consumer))
+    return result
+
+
+def auto_inline_producers(
+    sch: tir.Schedule,
+    block: tir.schedule.BlockRV,
+):
+    while True:
+        inlined_cnt = 0
+        producers = _collect_producers(sch, block)
+        for producer in producers:
+            try:
+                sch.compute_inline(producer)
+                inlined_cnt += 1
+            except:  # pylint: disable=bare-except
+                continue
+        if inlined_cnt == 0:
+            return
+
+
+def auto_inline_consumers(
+    sch: tir.Schedule,
+    block: tir.schedule.BlockRV,
+):
+    while True:
+        inlined_cnt = 0
+        consumers = _collect_consumers(sch, block)
+        for consumer in consumers:
+            try:
+                sch.compute_inline(consumer)
+                inlined_cnt += 1
+            except:  # pylint: disable=bare-except
+                continue
+        for consumer in consumers:
+            try:
+                sch.reverse_compute_inline(consumer)
+                inlined_cnt += 1
+            except:  # pylint: disable=bare-except
+                continue
+        if inlined_cnt == 0:
+            return
+
+
+class IterKind(Enum):
+    """Iter kinds for GEMM-liked programs.
+    We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K],
+    where `I, J, K` are fundamental axes for gemm and `S` represents all
+    other spatial axes (e.g. batches)
+    kIter_S: spatial axes
+    kIter_I: I axes
+    kIter_J: J axes
+    kIter_K: K axes
+    kIter_T: trivial axes (i.e. with extent 1)
+    """
+
+    kIter_S = 0
+    kIter_I = 1
+    kIter_J = 2
+    kIter_K = 3
+    kIter_T = 4
+
+
+@dataclass
+class IterTrait:
+    kind: IterKind
+    extent: PrimExpr
+
+
+def _is_one(x: PrimExpr) -> bool:
+    return isinstance(x, tir.IntImm) and x.value == 1
+
+
+def make_iter_fusion_index_map(
+    traits: List[IterTrait],
+    kind_order: List[IterKind],
+) -> tir.IndexMap:
+    fused_iters: Dict[IterKind, PrimExpr] = {}
+    input_iters: List[tir.Var] = []
+    for i, trait in enumerate(traits):
+        v_i = tir.Var(f"i{i}", "int64")
+        input_iters.append(v_i)
+        if trait.kind == IterKind.kIter_T:
+            continue
+        if trait.kind not in kind_order:
+            raise ValueError(f"Unknown iter kind {trait.kind}")
+        if trait.kind in fused_iters:
+            fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i
+        else:
+            fused_iters[trait.kind] = v_i
+
+    final_indices: List[tir.PrimExpr] = [
+        fused_iters.get(kind, tir.IntImm("int64", 0)) for kind in kind_order
+    ]
+
+    return tir.IndexMap(input_iters, final_indices, None)
+
+
+def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
+    """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
+
+    Parameters
+    ----------
+    block : tir.Block
+        The block to be analyzed
+
+    Returns
+    -------
+    traits : Optional[Tuple[List[IterTrait]]]
+        The detected iter traits for axes in A, B and C. None if the block
+        does not match the pattern.
+
+    """
+
+    if len(block.reads) != 2 or len(block.writes) != 1:
+        return None
+
+    def get_access_axes(region: List[Range]) -> Set[Var]:
+        axes: Set[Var] = set()
+        for r in region:
+            if not _is_one(r.extent):
+                raise ValueError("Expect elemwise block access")
+            axes = axes.union(set(undefined_vars(r.min)))
+        return axes
+
+    try:
+        A_axes = get_access_axes(block.reads[0].region)
+        B_axes = get_access_axes(block.reads[1].region)
+        C_axes = get_access_axes(block.writes[0].region)
+    except ValueError:
+        return None
+
+    traits: Dict[Var, IterTrait] = {}
+    for iter_var in block.iter_vars:
+        var = iter_var.var
+        kind: IterKind
+        if _is_one(iter_var.dom.extent):
+            kind = IterKind.kIter_T
+        elif iter_var.iter_type == iter_var.DataPar:
+            if var in A_axes and var in B_axes and var in C_axes:
+                kind = IterKind.kIter_S
+            elif var in A_axes and var in C_axes:
+                kind = IterKind.kIter_I
+            elif var in B_axes and var in C_axes:
+                kind = IterKind.kIter_J
+            else:
+                return None
+        elif iter_var.iter_type == tir.IterVar.CommReduce:
+            if var in A_axes and var in B_axes and var not in C_axes:
+                kind = IterKind.kIter_K
+            else:
+                return None
+        else:
+            return None
+        traits[var] = IterTrait(kind, iter_var.dom.extent)
+
+    # A Gemm-kernel requires have I, J and K axes
+    gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K}
+    if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits:
+        return None
+
+    A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes]
+    B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes]
+    C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes]
+    block_traits = [traits[i.var] for i in block.iter_vars]
+    return A_traits, B_traits, C_traits, block_traits
+
+
+def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]:
+    """Get index maps for the block
+
+    Parameters
+    ----------
+    block : tir.Block
+        The block to be analyzed
+
+    Returns
+    -------
+    index_maps : Optional[Tuple[tir.IndexMap]]
+        The index maps for the block, or None if the block is not a gemm-liked kernel
+    """
+    traits = detect_iter_traits(block)
+    if traits is None:
+        return None
+    A_traits, B_traits, C_traits, block_traits = traits
+
+    A_index_map = make_iter_fusion_index_map(
+        A_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_K]
+    )
+    B_index_map = make_iter_fusion_index_map(
+        B_traits, [IterKind.kIter_S, IterKind.kIter_J, IterKind.kIter_K]
+    )
+    C_index_map = make_iter_fusion_index_map(
+        C_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J]
+    )
+    matmul_index_map = make_iter_fusion_index_map(
+        block_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K]
+    )
+
+    return (
+        matmul_index_map,
+        A_index_map,
+        B_index_map,
+        C_index_map,
+    )
+
+
+class Matmul(ScheduleRule):
+    """The schedule rule for matmul-like computation"""
+
+    def apply(  # pylint: disable=too-many-locals,missing-docstring
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> Optional[tir.Schedule]:
+        sch = tir.Schedule(func)
+        root_block = analysis.get_root_block(sch)
+        blocks = sch.get_child_blocks(root_block)
+
+        # Get the main computation block
+        def is_reduction(block: BlockRV) -> bool:
+            block_stmt = sch.get(block)
+            iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+            return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+
+        def is_spatial(block: BlockRV) -> bool:
+            block_stmt = sch.get(block)
+            iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+            return iter_types == {IterVar.DataPar}
+
+        # NOTE: We assume there is only one reduction block in the function
+        # all blocks are required to be spatial or reduction
+        if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
+            return None
+
+        # There is only one reduction block
+        reduction_blocks = [block for block in blocks if is_reduction(block)]
+        if len(reduction_blocks) != 1:
+            return None
+
+        main_block = reduction_blocks[0]
+        block_stmt = sch.get(main_block)
+        index_maps = get_index_map(block_stmt)
+        if index_maps is None:
+            return None
+        matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+        # Start Schedule
+        # Step 0. Get schedule config.
+        # NOTE: we can analyze the config by the hardware spec in the future
+        block_size_x = 8
+        block_size_y = 16
+        vthread_x = 1
+        vthread_y = 1
+        micro_size_x = 2
+        micro_size_y = 4
+        micro_size_k = 16
+        vector_size = 2
+
+        # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
+        block = sch.reindex(main_block, ("read", 0))
+        sch.transform_layout(block, ("write", 0), a_index_map)
+        block = sch.reindex(main_block, ("read", 1))
+        sch.transform_layout(block, ("write", 0), b_index_map)
+        block = sch.reindex(main_block, ("write", 0))
+        sch.transform_layout(block, ("read", 0), c_index_map)
+        sch.transform_block_layout(main_block, matmul_index_map)
+
+        # Step 2. Padding for dynamic shape kernels
+        sch.pad_einsum(
+            main_block,
+            [
+                1,
+                vthread_x * block_size_x * micro_size_x,
+                vthread_y * block_size_y * micro_size_y,
+                micro_size_k,
+            ],
+        )
+
+        # Step 3. Schedule matmul
+        batch, x, y, k = sch.get_loops(main_block)
+        bx, vx, tx, xi = sch.split(x, [None, vthread_x, block_size_x, micro_size_x])
+        by, vy, ty, yi = sch.split(y, [None, vthread_y, block_size_y, micro_size_y])
+        ko, ki = sch.split(k, factors=[None, micro_size_k])
+        sch.reorder(bx, by, vy, vx, ty, tx, ko, ki, yi, xi)
+        sch.bind(batch, "blockIdx.z")
+        sch.bind(bx, "blockIdx.x")
+        sch.bind(by, "blockIdx.y")
+        sch.bind(vy, "vthread.y")
+        sch.bind(vx, "vthread.x")
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+        sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=256)
+        sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1)
+
+        l2g = sch.cache_write(main_block, 0, "local")
+        sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True)
+
+        def _cooperative_fetch(index, vec_len):
+            block = sch.cache_read(main_block, index, "shared")
+            num_loops = len(sch.get_loops(block))
+            sch.compute_at(block, ko, preserve_unit_loops=True)
+            loops = sch.get_loops(block)[-num_loops:]
+            _, ty, tx, vec = sch.split(
+                sch.fuse(*loops),
+                factors=[None, block_size_y, block_size_x, vec_len],
+            )
+            sch.vectorize(vec)
+            sch.bind(ty, "threadIdx.y")
+            sch.bind(tx, "threadIdx.x")
+            sch.storage_align(block, 0, axis=1, factor=32, offset=vec_len)
+            return block
+
+        a_g2s = _cooperative_fetch(0, vec_len=vector_size)
+        b_g2s = _cooperative_fetch(1, vec_len=vector_size)
+
+        auto_inline_producers(sch, a_g2s)
+        auto_inline_producers(sch, b_g2s)
+        auto_inline_consumers(sch, l2g)
+        sch.decompose_reduction(main_block, ko)
+        return sch
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index bb2abc559d..c6b9ea6a3e 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -1215,7 +1215,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
 }
 
 /*!
- * \brief Detect the block iter type assoicated with the expression
+ * \brief Detect the block iter type associated with the expression
  *
  * This function collects block iters in the expression and check if the block iters have the same
  * iter type. The detected iter type is the iter type of the block iters in the expression
@@ -1387,13 +1387,18 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
   for (size_t i = 0; i < transformed_block_iters.size(); ++i) {
     Var new_block_var{"v" + std::to_string(i), transformed_block_iters[i]->dtype};
     new_block_vars.push_back(new_block_var);
-    IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type);
+    IterVarType iter_type;
+    if (is_one(new_block_iter_range[i])) {
+      iter_type = kDataPar;
+    } else {
+      iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type);
+    }
     if (iter_type == kOpaque) {
       throw OpaqueNewIterTypeError(self->mod, GetRef<Block>(block_ptr), transformed_block_iters[i]);
     }
     auto dtype = new_block_var.dtype();
     new_block_iters.push_back(IterVar(
-        /*dom=*/Range::FromMinExtent(make_zero(dtype), new_block_iter_range[i]),
+        /*dom=*/Range::FromMinExtent(make_zero(dtype), cast(dtype, new_block_iter_range[i])),
         /*var=*/std::move(new_block_var), /*iter_type=*/iter_type));
   }
 
diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py
new file mode 100644
index 0000000000..38ecbfee94
--- /dev/null
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import pytest
+
+import tvm.testing
+from tvm import dlight as dl
+from tvm.ir import assert_structural_equal
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+    @pytest.fixture
+    def transform(self):
+        def transform(mod):
+            with Target("nvidia/geforce-rtx-3090-ti"):
+                return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+
+        return transform
+
+
+class TestMatmul(BaseBeforeAfter):
+    # fmt: off
+    @T.prim_func
+    def before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
+        m = T.int64()
+        inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
+        matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
+        for i0, i1, i2, k in T.grid(T.int64(1), m, T.int64(4096), T.int64(4096)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                with T.init():
+                    matmul[v_i0, v_i1, v_i2] = T.float32(0)
+                matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2]
+
+    @T.prim_func
+    def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
+        T.func_attr({"tir.is_scheduled": 1})
+        m = T.int64()
+        inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
+        matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
+        # with T.block("root"):
+        matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local")
+        inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="shared")
+        inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
+        for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+            for ax1_0 in T.thread_binding((m + T.int64(15)) // T.int64(16), thread="blockIdx.x"):
+                for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
+                    for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+                        for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
+                            for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                                    for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(2)):
+                                        with T.block("matmul_init"):
+                                            v0 = T.axis.spatial(T.int64(1), ax0)
+                                            v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init)
+                                            v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+                                            T.reads()
+                                            T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+                                            matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
+                                    for ax3_0 in range(T.int64(256)):
+                                        for ax0_ax1_ax2_fused_0 in range(T.int64(1)):
+                                            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                                for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                                    for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+                                                        with T.block("inp0_reindex_pad_shared"):
+                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                                            v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                            v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                            T.reads(inp0[v0, v1, v2])
+                                                            T.writes(inp0_reindex_pad_shared[v0, v1, v2])
+                                                            T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+                                                            inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
+                                        for ax0_ax1_ax2_fused_0 in range(T.int64(4)):
+                                            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                                for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                                    for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+                                                        with T.block("inp1_reindex_shared"):
+                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                                            v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                            v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                            T.reads(inp1[v2, v1])
+                                                            T.writes(inp1_reindex_shared[v0, v1, v2])
+                                                            T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+                                                            inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
+                                        for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(2)):
+                                            with T.block("matmul_update"):
+                                                v0 = T.axis.spatial(T.int64(1), ax0)
+                                                v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3)
+                                                v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+                                                v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
+                                                T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3])
+                                                T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+                                                matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3]
+                                    for ax0_1, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(4)):
+                                        with T.block("matmul_reindex_pad_local"):
+                                            v0 = T.axis.spatial(T.int64(1), ax0_1)
+                                            v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_2 * T.int64(2) + ax1)
+                                            v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2)
+                                            T.reads(matmul_reindex_pad_local[v0, v1, v2])
+                                            T.writes(matmul[T.int64(0), v1, v2])
+                                            if v1 < m:
+                                                matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
+    # fmt: on
+
+
+class TestFusedMatmul(BaseBeforeAfter):
+    # fmt: off
+
+    @T.prim_func
+    def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")):
+        var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)))
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)))
+        for i, j in T.grid(T.int64(4096), T.int64(4096)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(W[v_i // T.int64(8), v_j], S[v_i // T.int64(32), v_j])
+                T.writes(var_decode_intermediate[v_i, v_j])
+                var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(4096), T.int64(4096)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(A[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(4096)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+                T.writes(Out[v_ax0, v_ax1, v_ax2])
+                Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+    @T.prim_func
+    def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")):
+        T.func_attr({"tir.is_scheduled": 1})
+        var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local")
+        A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared")
+        var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
+        for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+            for ax1_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"):
+                for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
+                    for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+                        for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
+                            for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                                    for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(2)):
+                                        with T.block("matmul_init"):
+                                            v0 = T.axis.spatial(T.int64(1), ax0)
+                                            v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init)
+                                            v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+                                            T.reads()
+                                            T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
+                                            var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
+                                    for ax3_0 in range(T.int64(256)):
+                                        for ax0_ax1_ax2_fused_0 in range(T.int64(1)):
+                                            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                                for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                                    for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+                                                        with T.block("A_reindex_shared"):
+                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                                            v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                            v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                            T.reads(A[v0, v1, v2])
+                                                            T.writes(A_reindex_shared[v0, v1, v2])
+                                                            T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+                                                            A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
+                                        for ax0_ax1_ax2_fused_0 in range(T.int64(4)):
+                                            for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                                for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                                    for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+                                                        with T.block("var_decode_intermediate_reindex_shared"):
+                                                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                                            v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                            v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                            T.reads(W[v2 // T.int64(8), v1], S[v2 // T.int64(32), v1])
+                                                            T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
+                                                            T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+                                                            var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.ui [...]
+                                        for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(2)):
+                                            with T.block("matmul_update"):
+                                                v0 = T.axis.spatial(T.int64(1), ax0)
+                                                v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3)
+                                                v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+                                                v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
+                                                T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], A_reindex_shared[T.int64(0), v1, v3], var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
+                                                T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
+                                                var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + A_reindex_shared[T.int64(0), v1, v3] * var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
+                                    for ax0_1, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(4)):
+                                        with T.block("var_matmul_intermediate_reindex_local"):
+                                            v0 = T.axis.spatial(T.int64(1), ax0_1)
+                                            v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + ax1_2 * T.int64(2) + ax1)
+                                            v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2)
+                                            T.reads(C[T.int64(0), v1, v2], var_matmul_intermediate_reindex_local[v0, v1, v2])
+                                            T.writes(Out[T.int64(0), v1, v2])
+                                            Out[T.int64(0), v1, v2] = C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
+    # fmt: on
+
+
+class TestSkipGEMV(BaseBeforeAfter):
+    # fmt: off
+
+    @T.prim_func
+    def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)))
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)))
+        for i, j in T.grid(T.int64(4096), T.int64(4096)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(W[v_i // T.int64(8), v_j], S[v_i // T.int64(32), v_j])
+                T.writes(var_decode_intermediate[v_i, v_j])
+                var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(A[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+                T.writes(Out[v_ax0, v_ax1, v_ax2])
+                Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+    # fmt: on
+
+    expected = before
+
+
+if __name__ == "__main__":
+    tvm.testing.main()