You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2023/09/27 23:02:26 UTC
[tvm] branch unity updated: [Unity][BYOC] Support offloading multi-query attention by Flash Attention (#15831)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7494bc4580 [Unity][BYOC] Support offloading multi-query attention by Flash Attention (#15831)
7494bc4580 is described below
commit 7494bc4580054e0a783d74c3d60fb0c177d1de23
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Sep 28 08:02:16 2023 +0900
[Unity][BYOC] Support offloading multi-query attention by Flash Attention (#15831)
* Squashed commit of the following:
commit 99c2a59a1226f372c50c347c961d0c1201680a3e
Author: Masahiro Masuda <ma...@gmail.com>
Date: Wed Sep 27 09:57:21 2023 +0900
Revert "Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (#15578)""
This reverts commit 0a6a617e1315f3bc1550e5dc0e4630495e7fe70d.
commit 9a3ca64cfa2152628f5704a68383d36949403900
Author: Masahiro Masuda <ma...@gmail.com>
Date: Wed Sep 27 09:55:02 2023 +0900
wip
commit be01900d59db94bbccc3d8142d95c302dade7ca2
Author: Masahiro Masuda <ma...@gmail.com>
Date: Tue Sep 26 19:55:29 2023 +0900
fix test
commit a026b650002b07833808d078ede41243796f9a95
Author: Masahiro Masuda <ma...@gmail.com>
Date: Thu Aug 31 22:24:38 2023 +0000
wip
commit 233d2d0fa7bb1a981f792645e8394d95e8d31cb4
Author: Masahiro Masuda <ma...@gmail.com>
Date: Tue Aug 29 17:42:11 2023 +0000
wip
commit 0a6a617e1315f3bc1550e5dc0e4630495e7fe70d
Author: Masahiro Masuda <ma...@gmail.com>
Date: Tue Aug 29 17:28:25 2023 +0000
Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (#15578)"
This reverts commit 567848e3a08a3bcb1ed69344050bc648a101d9b9.
commit 6c5a4355e4cd487434c598909b7338159161624e
Author: Masahiro Masuda <ma...@gmail.com>
Date: Tue Aug 29 17:28:16 2023 +0000
wip
commit 7926cbc9d890c0c07376176a26144b8603bb9732
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Aug 28 06:17:01 2023 +0000
wip
commit 9828698ca3d808da8a77a432686c8be5dd4dab38
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Aug 28 15:11:47 2023 +0900
wip
commit 5d01fd1310fd5df98bf5fd56986056e20352ad3d
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Aug 28 06:05:56 2023 +0000
wip
commit ae657b7aed678fa7f7727aebcc9221940e45de26
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Aug 28 14:49:21 2023 +0900
wip
commit ddcab3887fef5c851689714f2f3924201165591d
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Aug 28 05:42:41 2023 +0000
wip
commit ab3572d852e21af3d4b349afd999654f491dcee8
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Aug 28 10:40:34 2023 +0900
wip
commit 690b88ef2380fc3ab5e3e02fed61cdf2936e0811
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Aug 28 10:25:33 2023 +0900
update rev
* black
* clean
* add doc
* update rev
* update test
* fix
---------
Co-authored-by: Masahiro Masuda <ma...@MasahironoMacBook-Pro.local>
---
3rdparty/libflash_attn | 2 +-
python/tvm/contrib/cutlass/attention_operation.py | 28 ++++++-------
python/tvm/contrib/cutlass/build.py | 7 ++--
python/tvm/contrib/cutlass/gen_tensor_op.py | 23 +++++++---
python/tvm/relax/backend/contrib/cutlass.py | 13 +++++-
python/tvm/relax/backend/patterns.py | 15 +++++--
src/relax/op/nn/attention.cc | 13 +++++-
tests/python/relax/test_codegen_cutlass.py | 51 ++++++++++++++++++++++-
8 files changed, 120 insertions(+), 32 deletions(-)
diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn
index 58b343e575..63cce0ca8f 160000
--- a/3rdparty/libflash_attn
+++ b/3rdparty/libflash_attn
@@ -1 +1 @@
-Subproject commit 58b343e57571fe5e0a5b43b5eb721acef8b35dff
+Subproject commit 63cce0ca8fa6bfca1982b342588273641cc5b86b
diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py
index 67a68df442..e59dbf032e 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -169,10 +169,10 @@ def instantiate_flash_attention_template(attrs):
int k_head_stride = ${head_dim};
int v_head_stride = ${head_dim};
int o_head_stride = ${head_dim};
- int q_row_stride = q_head_stride * ${num_heads};
- int k_row_stride = k_head_stride * ${num_heads};
- int v_row_stride = v_head_stride * ${num_heads};
- int o_row_stride = o_head_stride * ${num_heads};
+ int q_row_stride = q_head_stride * ${num_q_heads};
+ int k_row_stride = k_head_stride * ${num_kv_heads};
+ int v_row_stride = v_head_stride * ${num_kv_heads};
+ int o_row_stride = o_head_stride * ${num_q_heads};
int q_batch_stride = q_row_stride * ${num_queries};
int k_batch_stride = k_row_stride * ${num_keys};
int v_batch_stride = v_row_stride * ${num_keys};
@@ -190,8 +190,8 @@ def instantiate_flash_attention_template(attrs):
${num_batches},
${num_queries},
${num_keys},
- ${num_heads},
- ${num_heads},
+ ${num_q_heads},
+ ${num_kv_heads},
${head_dim},
q_batch_stride,
k_batch_stride,
@@ -215,13 +215,13 @@ def instantiate_flash_attention_template(attrs):
int k_head_stride = ${head_dim};
int v_head_stride = ${head_dim};
int o_head_stride = ${head_dim};
- int row_stride = q_head_stride * ${num_heads} +
- k_head_stride * ${num_heads} +
- v_head_stride * ${num_heads};
+ int row_stride = q_head_stride * ${num_q_heads} +
+ k_head_stride * ${num_kv_heads} +
+ v_head_stride * ${num_kv_heads};
int q_row_stride = row_stride;
int k_row_stride = row_stride;
int v_row_stride = row_stride;
- int o_row_stride = o_head_stride * ${num_heads};
+ int o_row_stride = o_head_stride * ${num_q_heads};
int q_batch_stride = q_row_stride * ${num_queries};
int k_batch_stride = k_row_stride * ${num_keys};
@@ -234,14 +234,14 @@ def instantiate_flash_attention_template(attrs):
flash_attn::flash_attention_forward(
static_cast<const cutlass::half_t*>(${qkv}->data),
- static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads},
- static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads} * 2,
+ static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_q_heads},
+ static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads}),
static_cast<cutlass::half_t*>(out0->data),
${num_batches},
${num_queries},
${num_keys},
- ${num_heads},
- ${num_heads},
+ ${num_q_heads},
+ ${num_kv_heads},
${head_dim},
q_batch_stride,
k_batch_stride,
diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py
index 0c57c4750e..b97fc20008 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -909,8 +909,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
out_shape = signature["ret_shape"]
out_dtype = signature["ret_dtype"]
- num_batches, num_queries, num_heads, head_dim = q_shape
- _, num_keys, _, _ = k_shape
+ num_batches, num_queries, num_q_heads, head_dim = q_shape
+ _, num_keys, num_kv_heads, _ = k_shape
_, _, _, head_dim_value = v_shape
scale = op_attrs.scale
@@ -931,7 +931,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
"num_batches": num_batches,
"num_queries": num_queries,
"num_keys": num_keys,
- "num_heads": num_heads,
+ "num_q_heads": num_q_heads,
+ "num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"head_dim_value": head_dim_value,
"scale": scale,
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 58bc91863d..62e64549c2 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -745,7 +745,6 @@ def instantiate_template(func_name, annotations, func_args):
attrs["data_type"] = DataTypeTag[data_type]
attrs["num_batches"] = b = annotations["num_batches"]
- attrs["num_heads"] = n = annotations["num_heads"]
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"]))
@@ -753,26 +752,40 @@ def instantiate_template(func_name, annotations, func_args):
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
)
+ is_mqa = annotations["num_q_heads"] != annotations["num_kv_heads"]
+
use_flash = (
annotations["ret_dtype"] == "float16"
and "bias" not in attrs
and int(attrs["head_dim"]) <= 256
and int(attrs["head_dim"]) % 8 == 0
and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
- # We have not thoroughly validated flash with causal mask yet, so for now we support
- # only non-causal cases.
- and int(annotations["custom_mask_type"]) == 0
+ # For the causal case (custom mask = "BottomRight"), only use flash for multi-query
+ # attention workloads. Otherwise, CUTLASS fMHA seems faster for causal attention
+ # with a single query.
+ and (
+ int(annotations["custom_mask_type"]) == 0
+ or (int(annotations["custom_mask_type"]) == 2 and is_mqa)
+ )
# Flash v2 is currently not supported for sm < 80
and int(annotations["arch"]) >= 80
)
if use_flash:
headers.append("flash.h")
- attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0
+ attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
+ attrs["num_q_heads"] = annotations["num_q_heads"]
+ attrs["num_kv_heads"] = annotations["num_kv_heads"]
code = instantiate_flash_attention_template(attrs)
else:
headers.append("kernel_forward.h")
+ assert (
+ not is_mqa
+ ), "The number of query and KV heads need to be the same for CUTLASS fMHA."
+
+ attrs["num_heads"] = n = annotations["num_q_heads"]
+
data_type_size = DataTypeSize[data_type]
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
attrs["kIsAligned"] = True
diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index fef6a1ec03..9efea3a0dc 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -576,7 +576,7 @@ def annotate_workspace(mod, _):
return mod
-def partition_for_cutlass(mod, annotate_codegen=True):
+def partition_for_cutlass(mod, annotate_codegen=True, use_flash_mqa=True):
"""
Partition the input module into CUTLASS-supported subgraphs.
@@ -590,6 +590,10 @@ def partition_for_cutlass(mod, annotate_codegen=True):
body consists only of a call to the composite function. See the doc of FuseOpsByPattern
for more detail.
+ use_flash_mqa: bool
+ Whether to consider a rewrite pattern for multi-query attention, which is supported by
+ the Flash Attention kernel.
+
Returns
-------
mod: tvm.IRModule
@@ -598,8 +602,15 @@ def partition_for_cutlass(mod, annotate_codegen=True):
"""
for func_name, func in mod.functions.items():
if isinstance(func, Function):
+ if use_flash_mqa:
+ mqa_pattern, rewriter = make_attention_rewrite_pattern(
+ "BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True
+ )
+ func = rewrite_call(mqa_pattern, rewriter, func)
+
for pattern, rewriter in _REWRITE_PATTERNS:
func = rewrite_call(pattern, rewriter, func)
+
mod[func_name] = func
patterns = get_patterns_with_prefix("cutlass")
diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py
index 24edd0e7c9..10a075647b 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -318,7 +318,7 @@ def make_rms_norm_pattern():
def make_attention_rewrite_pattern(
- qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool
+ qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False
):
"""
Create pattern for implicit fused multi head attention rewriting.
@@ -338,6 +338,10 @@ def make_attention_rewrite_pattern(
Whether or not rewriting is intended to be applied to a module after the FP16 conversion
pass.
+ with_kv_repeat: bool
+ Whether or not to include the Relax repeat op in the pattern, which is typically used
+ in a Relax module to support multi-query attention.
+
Returns
-------
pattern: DFPattern
@@ -350,7 +354,10 @@ def make_attention_rewrite_pattern(
"""
# pylint: disable=invalid-name
- def handle_input(tensor, layout, transpose):
+ def handle_input(tensor, layout, transpose, repeat=False):
+ if repeat:
+ tensor = is_op("relax.repeat")(tensor)
+
if layout == "BSNH":
permuted = is_op("relax.permute_dims")(tensor)
shape = wildcard()
@@ -434,8 +441,8 @@ def make_attention_rewrite_pattern(
q_raw, k_raw, v_raw = wildcard(), wildcard(), wildcard()
q, q_rewriter = handle_input(q_raw, qkv_layout, False)
- k, k_rewriter = handle_input(k_raw, qkv_layout, True)
- v, v_rewriter = handle_input(v_raw, qkv_layout, False)
+ k, k_rewriter = handle_input(k_raw, qkv_layout, True, repeat=with_kv_repeat)
+ v, v_rewriter = handle_input(v_raw, qkv_layout, False, repeat=with_kv_repeat)
matmul_1 = is_op("relax.matmul")(q, k)
scale = is_const()
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
index 4f37e3a33c..484137fecc 100644
--- a/src/relax/op/nn/attention.cc
+++ b/src/relax/op/nn/attention.cc
@@ -77,10 +77,19 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
<< v1 << " while the " << dim << " of " << m2 << " is " << v2);
}
};
+ auto multiple_of = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) {
+ if (analyzer->CanProve(indexmod(v1, v2) != 0)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "The " << m1 << " " << dim << " should be a multiple of " << m2 << " "
+ << dim << ". However, the " << dim << " of " << m1 << " is " << v1
+ << " while the " << dim << " of " << m2 << " is " << v2);
+ }
+ };
+
diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size");
diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size");
- diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
- diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads");
+ multiple_of(num_heads, k_shape->values[2], "query", "key", "number of heads");
+ multiple_of(num_heads, v_shape->values[2], "query", "value", "number of heads");
diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length");
diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads");
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index e8d4e83521..83936ef9c9 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -746,7 +746,7 @@ def attention_causal(request):
def test_attention_causal_offload(attention_causal_size, attention_causal):
b, (s, s_kv), n, (h, h_v), bias_shape = attention_causal_size
q, k, v, bias, ref = get_numpy_attention_ref(
- b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32"
+ b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float16"
)
q_shape = (b, s, n, h)
@@ -757,10 +757,11 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
q_shape,
k_shape,
v_shape,
- dtype="float32",
+ dtype="float16",
bias_shape=bias_shape,
causal_mask=attention_causal,
)
+
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
else:
@@ -1945,5 +1946,51 @@ def test_fp16A_int8B_gemm_batched():
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_attention_rewrite_multi_query():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ q: R.Tensor((4, 16, 32, 16), dtype="float16"),
+ k_single: R.Tensor((4, 16, 1, 16), dtype="float16"),
+ v_single: R.Tensor((4, 16, 1, 16), dtype="float16"),
+ ) -> R.Tensor((4, 16, 32, 8), dtype="float16"):
+ with R.dataflow():
+ k = R.repeat(k_single, 32, axis=2)
+ v = R.repeat(v_single, 32, axis=2)
+
+ lv = R.permute_dims(q, axes=[0, 2, 1, 3])
+ lv1 = R.reshape(lv, R.shape([128, 16, 16]))
+ lv2 = R.permute_dims(k, axes=[0, 2, 1, 3])
+ lv3 = R.reshape(lv2, R.shape([128, 16, 16]))
+ lv4 = R.permute_dims(v, axes=[0, 2, 1, 3])
+ lv5 = R.reshape(lv4, R.shape([128, 16, 16]))
+
+ lv6 = R.permute_dims(lv3, axes=[0, 2, 1])
+ lv7 = R.matmul(lv1, lv6, out_dtype="float16")
+ lv3_1 = R.astype(R.const(0.25, "float32"), "float16")
+ lv8 = R.multiply(lv7, lv3_1)
+ lv11 = R.astype(R.nn.softmax(R.astype(lv8, "float32"), axis=2), "float16")
+ lv12 = R.matmul(lv11, lv5, out_dtype="float16")
+ lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16]))
+ lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3])
+ R.output(lv6_1)
+ return lv6_1
+
+ q_np = np.random.randn(4, 16, 32, 16).astype("float16")
+ k_np = np.random.randn(4, 16, 1, 16).astype("float16")
+ v_np = np.random.randn(4, 16, 1, 16).astype("float16")
+ args = [q_np, k_np, v_np]
+ ref = build_and_run(Module, args, "llvm", legalize=True)
+
+ mod = partition_for_cutlass(Module, use_flash_mqa=True)
+ codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}})
+ mod = codegen_pass(mod)
+
+ out = build_and_run(mod, args, "cuda")
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
if __name__ == "__main__":
tvm.testing.main()