You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/04/18 08:23:24 UTC
[tvm] branch unity updated: [Unity][BYOC] Fuse attention pattern with `strided_slice` (#14649)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ccca0f5ecf [Unity][BYOC] Fuse attention pattern with `strided_slice` (#14649)
ccca0f5ecf is described below
commit ccca0f5ecf75f3bbc38c38d4b559dc8f9ac3fc48
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Tue Apr 18 01:23:17 2023 -0700
[Unity][BYOC] Fuse attention pattern with `strided_slice` (#14649)
* [Unity][BYOC] Fuse attention pattern with `strided_slice`
This PR expands the support for fused stacked attention patterns strating with `strided_slice`. Initially, we only support fused stacked attention pattern starting with `split` in #14608. But with the help of #14583, we may have similar patterns starting with `strided_slice` as well.
* remove useless code
---
python/tvm/relax/backend/contrib/cutlass.py | 12 ++++++--
python/tvm/relax/backend/patterns.py | 23 +++++++++++----
tests/python/relax/test_codegen_cutlass.py | 44 +++++++++++++++++++++++------
3 files changed, 64 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index 4515118f58..06edd9febf 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -247,11 +247,19 @@ def attention_patterns():
),
(
"cutlass.stacked_attention",
- *make_stacked_attention_pattern(),
+ *make_stacked_attention_pattern(start_op="split"),
),
(
"cutlass.stacked_attention",
- *make_stacked_attention_pattern(with_bias=True),
+ *make_stacked_attention_pattern(start_op="split", with_bias=True),
+ ),
+ (
+ "cutlass.stacked_attention",
+ *make_stacked_attention_pattern(start_op="strided_slice"),
+ ),
+ (
+ "cutlass.stacked_attention",
+ *make_stacked_attention_pattern(start_op="strided_slice", with_bias=True),
),
]
diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py
index 9e34b0c964..6197fe44ca 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -197,12 +197,15 @@ def make_attention_pattern(with_bias: bool = False):
return out, annotations
-def make_stacked_attention_pattern(with_bias: bool = False):
+def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
"""
Create pattern for fused multi head attention with stacked input.
Parameters
----------
+ start_op: str
+ The starting op for pattern, i.e. `R.split` or `R.strided_slice`.
+
with_bias: bool
Whether or not to include bias addition
@@ -217,13 +220,23 @@ def make_stacked_attention_pattern(with_bias: bool = False):
check function and codegen.
"""
stacked_qkv = wildcard()
- qkv_tuple = is_op("relax.split")(stacked_qkv)
+ if start_op == "split":
+ qkv_tuple = is_op("relax.split")(stacked_qkv)
+ query_raw = is_tuple_get_item(qkv_tuple, 0)
+ key_raw = is_tuple_get_item(qkv_tuple, 1)
+ value_raw = is_tuple_get_item(qkv_tuple, 2)
+ elif start_op == "strided_slice":
+ query_raw = is_op("relax.strided_slice")(stacked_qkv)
+ key_raw = is_op("relax.strided_slice")(stacked_qkv)
+ value_raw = is_op("relax.strided_slice")(stacked_qkv)
+ else:
+ raise NotImplementedError()
query_reshape_list = wildcard()
key_reshape_list = wildcard()
value_reshape_list = wildcard()
- query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), query_reshape_list)
- key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), key_reshape_list)
- value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), value_reshape_list)
+ query = is_op("relax.reshape")(query_raw, query_reshape_list)
+ key = is_op("relax.reshape")(key_raw, key_reshape_list)
+ value = is_op("relax.reshape")(value_raw, value_reshape_list)
annotations = {
"stacked_qkv": stacked_qkv,
"query_reshape_list": query_reshape_list,
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index 4309627bf0..db8abf34c2 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -660,7 +660,7 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, bias_reshape, q
return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
-def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale=None):
+def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None):
dtype = str(qkv.dtype)
from tvm.script.ir_builder import IRBuilder
@@ -676,10 +676,22 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale
if bias is not None:
bias = R.arg("bias", R.Tensor(bias.shape, dtype))
with R.dataflow() as frame:
- qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
- q = R.reshape(qkv_tuple[0], [b, s, n, h])
- k = R.reshape(qkv_tuple[1], [b, s, n, h])
- v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
+ if op == "split":
+ qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
+ q = R.reshape(qkv_tuple[0], [b, s, n, h])
+ k = R.reshape(qkv_tuple[1], [b, s, n, h])
+ v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
+ elif op == "strided_slice":
+ q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), [b, s, n, h])
+ k = R.reshape(
+ R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), [b, s, n, h]
+ )
+ v = R.reshape(
+ R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n * h_v], [1]),
+ [b, s, n, h_v],
+ )
+ else:
+ raise NotImplementedError()
result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
R.output(result)
@@ -700,15 +712,31 @@ def stacked_attention_size(request):
return request.param
-def test_stacked_attention_offload(stacked_attention_size):
+def test_stacked_attention_split_offload(stacked_attention_size):
+ b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
+ qkv, bias, ref = get_numpy_stacked_attention_ref(
+ b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
+ )
+ if scale == "none":
+ mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias)
+ else:
+ mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias, scale)
+ if bias is None:
+ out = get_result_with_relax_cutlass_offload(mod, qkv)
+ else:
+ out = get_result_with_relax_cutlass_offload(mod, qkv, bias)
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+def test_stacked_attention_strided_slice_offload(stacked_attention_size):
b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
qkv, bias, ref = get_numpy_stacked_attention_ref(
b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
)
if scale == "none":
- mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias)
+ mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias)
else:
- mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias, scale)
+ mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias, scale)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, qkv)
else: