You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/07/02 17:54:08 UTC
[tvm] branch unity updated: [Unity][Dlight] Matmul Rules (#15191)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 918fc4ecf7 [Unity][Dlight] Matmul Rules (#15191)
918fc4ecf7 is described below
commit 918fc4ecf7803bdf89466f82e0b634951a47f64c
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Mon Jul 3 01:54:02 2023 +0800
[Unity][Dlight] Matmul Rules (#15191)
This PR introduces default schedule rules for matmul kernels. Note that
we skip GEMV-liked kernels as it would be a separate rule.
---
python/tvm/dlight/base/analysis.py | 40 +++
python/tvm/dlight/base/transform.py | 2 +-
python/tvm/dlight/gpu/__init__.py | 1 +
python/tvm/dlight/gpu/fallback.py | 14 +-
python/tvm/dlight/gpu/matmul.py | 366 +++++++++++++++++++++
.../schedule/primitive/layout_transformation.cc | 11 +-
tests/python/dlight/test_gpu_matmul.py | 252 ++++++++++++++
7 files changed, 670 insertions(+), 16 deletions(-)
diff --git a/python/tvm/dlight/base/analysis.py b/python/tvm/dlight/base/analysis.py
index a5d70c6c0c..d11e29a8ad 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -21,6 +21,9 @@ from typing_extensions import Literal
from tvm import tir
from tvm._ffi import get_global_func
+from tvm.target.target import Target
+from tvm.tir import Schedule
+from tvm.tir.schedule import BlockRV
class IterInfo:
@@ -146,3 +149,40 @@ def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
)
)
return blocks
+
+
+def _assert_gpu_target(target: Target):
+ if "gpu" not in target.keys:
+ raise ValueError(f"Expect a GPU target, but got {target}")
+
+
+def get_max_threads_per_block(target: Target) -> int:
+ _assert_gpu_target(target)
+ max_threads_per_block = None
+ for name in ["max_threads_per_block", "max_num_threads"]:
+ if max_threads_per_block is None:
+ max_threads_per_block = target.attrs.get(name, None)
+ if max_threads_per_block is None:
+ max_threads_per_block = 64
+ return int(max_threads_per_block)
+
+
+def get_max_shared_memory_per_block(target: Target) -> int:
+ _assert_gpu_target(target)
+ max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None)
+ if max_shared_memory_per_block is None:
+ raise ValueError(
+ f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually"
+ )
+ return int(max_shared_memory_per_block)
+
+
+def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
+ try:
+ block = sch.mod[func_name].body.block
+ except:
+ raise ValueError(
+ f"The function body is expected to be the root block, but got:\n"
+ f"{sch.mod[func_name].body}"
+ )
+ return sch.get_block(block.name_hint)
diff --git a/python/tvm/dlight/base/transform.py b/python/tvm/dlight/base/transform.py
index c11a4ae060..dc02c3dc0f 100644
--- a/python/tvm/dlight/base/transform.py
+++ b/python/tvm/dlight/base/transform.py
@@ -60,7 +60,7 @@ class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods
target = Target.current(allow_none=False)
updated_functions = {}
for g_var, func in mod.functions.items():
- if not _is_scheduled(func):
+ if isinstance(func, tir.PrimFunc) and not _is_scheduled(func):
sch = _apply_rules(func, target, self.rules, tunable=False)
if sch is not None:
assert len(sch) == 1
diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py
index b689bef381..79090d400b 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -21,3 +21,4 @@ For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/meta
from .fallback import Fallback
from .decode_gemv import DecodeGEMV
from .reduction import Reduction
+from .matmul import Matmul
diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py
index 63033aa7c7..6b120b1648 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -21,17 +21,7 @@ from typing import List
from tvm import tir
from tvm.target import Target
-from ..base import ScheduleRule, normalize_prim_func, try_inline
-
-
-def _max_threads_per_block(target: Target) -> int:
- max_threads_per_block = None
- for name in ["max_threads_per_block", "max_num_threads"]:
- if max_threads_per_block is None:
- max_threads_per_block = target.attrs.get(name, None)
- if max_threads_per_block is None:
- max_threads_per_block = 64
- return int(max_threads_per_block)
+from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline
class Fallback(ScheduleRule):
@@ -46,7 +36,7 @@ class Fallback(ScheduleRule):
target: Target,
_: bool,
) -> tir.Schedule:
- max_threads_per_block = _max_threads_per_block(target)
+ max_threads_per_block = analysis.get_max_threads_per_block(target)
sch = tir.Schedule(func)
block_infos = try_inline(sch, normalize_prim_func(sch))
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
new file mode 100644
index 0000000000..e66eaa3222
--- /dev/null
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -0,0 +1,366 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name
+"""A GEMM schedule rule for GPU operators."""
+from enum import Enum
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Set, Tuple
+
+from tvm import tir
+from tvm.ir import Range
+from tvm.target import Target
+from tvm.tir import PrimExpr, Var, IterVar
+from tvm.tir.analysis import undefined_vars
+from tvm.tir.schedule.schedule import BlockRV
+
+from ..base import ScheduleRule, analysis
+
+
+def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
+ result = []
+ for producer in sch.get_producers(block):
+ result.append(producer)
+ result.extend(_collect_producers(sch, producer))
+ return result
+
+
+def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
+ result = []
+ for consumer in sch.get_consumers(block):
+ result.append(consumer)
+ result.extend(_collect_consumers(sch, consumer))
+ return result
+
+
+def auto_inline_producers(
+ sch: tir.Schedule,
+ block: tir.schedule.BlockRV,
+):
+ while True:
+ inlined_cnt = 0
+ producers = _collect_producers(sch, block)
+ for producer in producers:
+ try:
+ sch.compute_inline(producer)
+ inlined_cnt += 1
+ except: # pylint: disable=bare-except
+ continue
+ if inlined_cnt == 0:
+ return
+
+
+def auto_inline_consumers(
+ sch: tir.Schedule,
+ block: tir.schedule.BlockRV,
+):
+ while True:
+ inlined_cnt = 0
+ consumers = _collect_consumers(sch, block)
+ for consumer in consumers:
+ try:
+ sch.compute_inline(consumer)
+ inlined_cnt += 1
+ except: # pylint: disable=bare-except
+ continue
+ for consumer in consumers:
+ try:
+ sch.reverse_compute_inline(consumer)
+ inlined_cnt += 1
+ except: # pylint: disable=bare-except
+ continue
+ if inlined_cnt == 0:
+ return
+
+
+class IterKind(Enum):
+ """Iter kinds for GEMM-liked programs.
+ We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K],
+ where `I, J, K` are fundamental axes for gemm and `S` represents all
+ other spatial axes (e.g. batches)
+ kIter_S: spatial axes
+ kIter_I: I axes
+ kIter_J: J axes
+ kIter_K: K axes
+ kIter_T: trivial axes (i.e. with extent 1)
+ """
+
+ kIter_S = 0
+ kIter_I = 1
+ kIter_J = 2
+ kIter_K = 3
+ kIter_T = 4
+
+
+@dataclass
+class IterTrait:
+ kind: IterKind
+ extent: PrimExpr
+
+
+def _is_one(x: PrimExpr) -> bool:
+ return isinstance(x, tir.IntImm) and x.value == 1
+
+
+def make_iter_fusion_index_map(
+ traits: List[IterTrait],
+ kind_order: List[IterKind],
+) -> tir.IndexMap:
+ fused_iters: Dict[IterKind, PrimExpr] = {}
+ input_iters: List[tir.Var] = []
+ for i, trait in enumerate(traits):
+ v_i = tir.Var(f"i{i}", "int64")
+ input_iters.append(v_i)
+ if trait.kind == IterKind.kIter_T:
+ continue
+ if trait.kind not in kind_order:
+ raise ValueError(f"Unknown iter kind {trait.kind}")
+ if trait.kind in fused_iters:
+ fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i
+ else:
+ fused_iters[trait.kind] = v_i
+
+ final_indices: List[tir.PrimExpr] = [
+ fused_iters.get(kind, tir.IntImm("int64", 0)) for kind in kind_order
+ ]
+
+ return tir.IndexMap(input_iters, final_indices, None)
+
+
+def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
+ """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
+
+ Parameters
+ ----------
+ block : tir.Block
+ The block to be analyzed
+
+ Returns
+ -------
+ traits : Optional[Tuple[List[IterTrait]]]
+ The detected iter traits for axes in A, B and C. None if the block
+ does not match the pattern.
+
+ """
+
+ if len(block.reads) != 2 or len(block.writes) != 1:
+ return None
+
+ def get_access_axes(region: List[Range]) -> Set[Var]:
+ axes: Set[Var] = set()
+ for r in region:
+ if not _is_one(r.extent):
+ raise ValueError("Expect elemwise block access")
+ axes = axes.union(set(undefined_vars(r.min)))
+ return axes
+
+ try:
+ A_axes = get_access_axes(block.reads[0].region)
+ B_axes = get_access_axes(block.reads[1].region)
+ C_axes = get_access_axes(block.writes[0].region)
+ except ValueError:
+ return None
+
+ traits: Dict[Var, IterTrait] = {}
+ for iter_var in block.iter_vars:
+ var = iter_var.var
+ kind: IterKind
+ if _is_one(iter_var.dom.extent):
+ kind = IterKind.kIter_T
+ elif iter_var.iter_type == iter_var.DataPar:
+ if var in A_axes and var in B_axes and var in C_axes:
+ kind = IterKind.kIter_S
+ elif var in A_axes and var in C_axes:
+ kind = IterKind.kIter_I
+ elif var in B_axes and var in C_axes:
+ kind = IterKind.kIter_J
+ else:
+ return None
+ elif iter_var.iter_type == tir.IterVar.CommReduce:
+ if var in A_axes and var in B_axes and var not in C_axes:
+ kind = IterKind.kIter_K
+ else:
+ return None
+ else:
+ return None
+ traits[var] = IterTrait(kind, iter_var.dom.extent)
+
+ # A Gemm-kernel requires have I, J and K axes
+ gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K}
+ if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits:
+ return None
+
+ A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes]
+ B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes]
+ C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes]
+ block_traits = [traits[i.var] for i in block.iter_vars]
+ return A_traits, B_traits, C_traits, block_traits
+
+
+def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]:
+ """Get index maps for the block
+
+ Parameters
+ ----------
+ block : tir.Block
+ The block to be analyzed
+
+ Returns
+ -------
+ index_maps : Optional[Tuple[tir.IndexMap]]
+ The index maps for the block, or None if the block is not a gemm-liked kernel
+ """
+ traits = detect_iter_traits(block)
+ if traits is None:
+ return None
+ A_traits, B_traits, C_traits, block_traits = traits
+
+ A_index_map = make_iter_fusion_index_map(
+ A_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_K]
+ )
+ B_index_map = make_iter_fusion_index_map(
+ B_traits, [IterKind.kIter_S, IterKind.kIter_J, IterKind.kIter_K]
+ )
+ C_index_map = make_iter_fusion_index_map(
+ C_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J]
+ )
+ matmul_index_map = make_iter_fusion_index_map(
+ block_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K]
+ )
+
+ return (
+ matmul_index_map,
+ A_index_map,
+ B_index_map,
+ C_index_map,
+ )
+
+
+class Matmul(ScheduleRule):
+ """The schedule rule for matmul-like computation"""
+
+ def apply( # pylint: disable=too-many-locals,missing-docstring
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> Optional[tir.Schedule]:
+ sch = tir.Schedule(func)
+ root_block = analysis.get_root_block(sch)
+ blocks = sch.get_child_blocks(root_block)
+
+ # Get the main computation block
+ def is_reduction(block: BlockRV) -> bool:
+ block_stmt = sch.get(block)
+ iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+ return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+
+ def is_spatial(block: BlockRV) -> bool:
+ block_stmt = sch.get(block)
+ iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+ return iter_types == {IterVar.DataPar}
+
+ # NOTE: We assume there is only one reduction block in the function
+ # all blocks are required to be spatial or reduction
+ if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
+ return None
+
+ # There is only one reduction block
+ reduction_blocks = [block for block in blocks if is_reduction(block)]
+ if len(reduction_blocks) != 1:
+ return None
+
+ main_block = reduction_blocks[0]
+ block_stmt = sch.get(main_block)
+ index_maps = get_index_map(block_stmt)
+ if index_maps is None:
+ return None
+ matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+ # Start Schedule
+ # Step 0. Get schedule config.
+ # NOTE: we can analyze the config by the hardware spec in the future
+ block_size_x = 8
+ block_size_y = 16
+ vthread_x = 1
+ vthread_y = 1
+ micro_size_x = 2
+ micro_size_y = 4
+ micro_size_k = 16
+ vector_size = 2
+
+ # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
+ block = sch.reindex(main_block, ("read", 0))
+ sch.transform_layout(block, ("write", 0), a_index_map)
+ block = sch.reindex(main_block, ("read", 1))
+ sch.transform_layout(block, ("write", 0), b_index_map)
+ block = sch.reindex(main_block, ("write", 0))
+ sch.transform_layout(block, ("read", 0), c_index_map)
+ sch.transform_block_layout(main_block, matmul_index_map)
+
+ # Step 2. Padding for dynamic shape kernels
+ sch.pad_einsum(
+ main_block,
+ [
+ 1,
+ vthread_x * block_size_x * micro_size_x,
+ vthread_y * block_size_y * micro_size_y,
+ micro_size_k,
+ ],
+ )
+
+ # Step 3. Schedule matmul
+ batch, x, y, k = sch.get_loops(main_block)
+ bx, vx, tx, xi = sch.split(x, [None, vthread_x, block_size_x, micro_size_x])
+ by, vy, ty, yi = sch.split(y, [None, vthread_y, block_size_y, micro_size_y])
+ ko, ki = sch.split(k, factors=[None, micro_size_k])
+ sch.reorder(bx, by, vy, vx, ty, tx, ko, ki, yi, xi)
+ sch.bind(batch, "blockIdx.z")
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(by, "blockIdx.y")
+ sch.bind(vy, "vthread.y")
+ sch.bind(vx, "vthread.x")
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=256)
+ sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1)
+
+ l2g = sch.cache_write(main_block, 0, "local")
+ sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True)
+
+ def _cooperative_fetch(index, vec_len):
+ block = sch.cache_read(main_block, index, "shared")
+ num_loops = len(sch.get_loops(block))
+ sch.compute_at(block, ko, preserve_unit_loops=True)
+ loops = sch.get_loops(block)[-num_loops:]
+ _, ty, tx, vec = sch.split(
+ sch.fuse(*loops),
+ factors=[None, block_size_y, block_size_x, vec_len],
+ )
+ sch.vectorize(vec)
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.storage_align(block, 0, axis=1, factor=32, offset=vec_len)
+ return block
+
+ a_g2s = _cooperative_fetch(0, vec_len=vector_size)
+ b_g2s = _cooperative_fetch(1, vec_len=vector_size)
+
+ auto_inline_producers(sch, a_g2s)
+ auto_inline_producers(sch, b_g2s)
+ auto_inline_consumers(sch, l2g)
+ sch.decompose_reduction(main_block, ko)
+ return sch
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index bb2abc559d..c6b9ea6a3e 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -1215,7 +1215,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
}
/*!
- * \brief Detect the block iter type assoicated with the expression
+ * \brief Detect the block iter type associated with the expression
*
* This function collects block iters in the expression and check if the block iters have the same
* iter type. The detected iter type is the iter type of the block iters in the expression
@@ -1387,13 +1387,18 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
for (size_t i = 0; i < transformed_block_iters.size(); ++i) {
Var new_block_var{"v" + std::to_string(i), transformed_block_iters[i]->dtype};
new_block_vars.push_back(new_block_var);
- IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type);
+ IterVarType iter_type;
+ if (is_one(new_block_iter_range[i])) {
+ iter_type = kDataPar;
+ } else {
+ iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type);
+ }
if (iter_type == kOpaque) {
throw OpaqueNewIterTypeError(self->mod, GetRef<Block>(block_ptr), transformed_block_iters[i]);
}
auto dtype = new_block_var.dtype();
new_block_iters.push_back(IterVar(
- /*dom=*/Range::FromMinExtent(make_zero(dtype), new_block_iter_range[i]),
+ /*dom=*/Range::FromMinExtent(make_zero(dtype), cast(dtype, new_block_iter_range[i])),
/*var=*/std::move(new_block_var), /*iter_type=*/iter_type));
}
diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py
new file mode 100644
index 0000000000..38ecbfee94
--- /dev/null
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import pytest
+
+import tvm.testing
+from tvm import dlight as dl
+from tvm.ir import assert_structural_equal
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+ @pytest.fixture
+ def transform(self):
+ def transform(mod):
+ with Target("nvidia/geforce-rtx-3090-ti"):
+ return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+
+ return transform
+
+
+class TestMatmul(BaseBeforeAfter):
+ # fmt: off
+ @T.prim_func
+ def before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
+ m = T.int64()
+ inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
+ matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
+ for i0, i1, i2, k in T.grid(T.int64(1), m, T.int64(4096), T.int64(4096)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ with T.init():
+ matmul[v_i0, v_i1, v_i2] = T.float32(0)
+ matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2]
+
+ @T.prim_func
+ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
+ T.func_attr({"tir.is_scheduled": 1})
+ m = T.int64()
+ inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
+ matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
+ # with T.block("root"):
+ matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local")
+ inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="shared")
+ inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
+ for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+ for ax1_0 in T.thread_binding((m + T.int64(15)) // T.int64(16), thread="blockIdx.x"):
+ for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
+ for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+ for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
+ for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(2)):
+ with T.block("matmul_init"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init)
+ v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+ T.reads()
+ T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+ matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
+ for ax3_0 in range(T.int64(256)):
+ for ax0_ax1_ax2_fused_0 in range(T.int64(1)):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+ with T.block("inp0_reindex_pad_shared"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(inp0[v0, v1, v2])
+ T.writes(inp0_reindex_pad_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+ inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
+ for ax0_ax1_ax2_fused_0 in range(T.int64(4)):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+ with T.block("inp1_reindex_shared"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(inp1[v2, v1])
+ T.writes(inp1_reindex_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+ inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
+ for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(2)):
+ with T.block("matmul_update"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3)
+ v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+ v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
+ T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3])
+ T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+ matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3]
+ for ax0_1, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(4)):
+ with T.block("matmul_reindex_pad_local"):
+ v0 = T.axis.spatial(T.int64(1), ax0_1)
+ v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_2 * T.int64(2) + ax1)
+ v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2)
+ T.reads(matmul_reindex_pad_local[v0, v1, v2])
+ T.writes(matmul[T.int64(0), v1, v2])
+ if v1 < m:
+ matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
+ # fmt: on
+
+
+class TestFusedMatmul(BaseBeforeAfter):
+ # fmt: off
+
+ @T.prim_func
+ def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")):
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)))
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)))
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(W[v_i // T.int64(8), v_j], S[v_i // T.int64(32), v_j])
+ T.writes(var_decode_intermediate[v_i, v_j])
+ var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(4096), T.int64(4096)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(A[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+ with T.init():
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(4096)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+ T.writes(Out[v_ax0, v_ax1, v_ax2])
+ Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+ @T.prim_func
+ def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")):
+ T.func_attr({"tir.is_scheduled": 1})
+ var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local")
+ A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared")
+ var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
+ for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+ for ax1_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"):
+ for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
+ for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+ for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
+ for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(2)):
+ with T.block("matmul_init"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init)
+ v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+ T.reads()
+ T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
+ var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
+ for ax3_0 in range(T.int64(256)):
+ for ax0_ax1_ax2_fused_0 in range(T.int64(1)):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+ with T.block("A_reindex_shared"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(A[v0, v1, v2])
+ T.writes(A_reindex_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+ A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
+ for ax0_ax1_ax2_fused_0 in range(T.int64(4)):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
+ with T.block("var_decode_intermediate_reindex_shared"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(W[v2 // T.int64(8), v1], S[v2 // T.int64(32), v1])
+ T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+ var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.ui [...]
+ for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(2)):
+ with T.block("matmul_update"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3)
+ v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+ v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
+ T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], A_reindex_shared[T.int64(0), v1, v3], var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
+ T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
+ var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + A_reindex_shared[T.int64(0), v1, v3] * var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
+ for ax0_1, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(4)):
+ with T.block("var_matmul_intermediate_reindex_local"):
+ v0 = T.axis.spatial(T.int64(1), ax0_1)
+ v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + ax1_2 * T.int64(2) + ax1)
+ v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2)
+ T.reads(C[T.int64(0), v1, v2], var_matmul_intermediate_reindex_local[v0, v1, v2])
+ T.writes(Out[T.int64(0), v1, v2])
+ Out[T.int64(0), v1, v2] = C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
+ # fmt: on
+
+
+class TestSkipGEMV(BaseBeforeAfter):
+ # fmt: off
+
+ @T.prim_func
+ def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)))
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)))
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(W[v_i // T.int64(8), v_j], S[v_i // T.int64(32), v_j])
+ T.writes(var_decode_intermediate[v_i, v_j])
+ var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(A[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+ with T.init():
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+ T.writes(Out[v_ax0, v_ax1, v_ax2])
+ Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+ # fmt: on
+
+ expected = before
+
+
+if __name__ == "__main__":
+ tvm.testing.main()