You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/04/18 08:23:24 UTC

[tvm] branch unity updated: [Unity][BYOC] Fuse attention pattern with `strided_slice` (#14649)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ccca0f5ecf [Unity][BYOC] Fuse attention pattern with `strided_slice` (#14649)
ccca0f5ecf is described below

commit ccca0f5ecf75f3bbc38c38d4b559dc8f9ac3fc48
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Tue Apr 18 01:23:17 2023 -0700

    [Unity][BYOC] Fuse attention pattern with `strided_slice` (#14649)
    
    * [Unity][BYOC] Fuse attention pattern with `strided_slice`
    
    This PR expands the support for fused stacked attention patterns strating with `strided_slice`. Initially, we only support fused stacked attention pattern starting with `split` in #14608. But with the help of #14583, we may have similar patterns starting with `strided_slice` as well.
    
    * remove useless code
---
 python/tvm/relax/backend/contrib/cutlass.py | 12 ++++++--
 python/tvm/relax/backend/patterns.py        | 23 +++++++++++----
 tests/python/relax/test_codegen_cutlass.py  | 44 +++++++++++++++++++++++------
 3 files changed, 64 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index 4515118f58..06edd9febf 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -247,11 +247,19 @@ def attention_patterns():
         ),
         (
             "cutlass.stacked_attention",
-            *make_stacked_attention_pattern(),
+            *make_stacked_attention_pattern(start_op="split"),
         ),
         (
             "cutlass.stacked_attention",
-            *make_stacked_attention_pattern(with_bias=True),
+            *make_stacked_attention_pattern(start_op="split", with_bias=True),
+        ),
+        (
+            "cutlass.stacked_attention",
+            *make_stacked_attention_pattern(start_op="strided_slice"),
+        ),
+        (
+            "cutlass.stacked_attention",
+            *make_stacked_attention_pattern(start_op="strided_slice", with_bias=True),
         ),
     ]
 
diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py
index 9e34b0c964..6197fe44ca 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -197,12 +197,15 @@ def make_attention_pattern(with_bias: bool = False):
     return out, annotations
 
 
-def make_stacked_attention_pattern(with_bias: bool = False):
+def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
     """
     Create pattern for fused multi head attention with stacked input.
 
     Parameters
     ----------
+    start_op: str
+        The starting op for pattern, i.e. `R.split` or `R.strided_slice`.
+
     with_bias: bool
         Whether or not to include bias addition
 
@@ -217,13 +220,23 @@ def make_stacked_attention_pattern(with_bias: bool = False):
         check function and codegen.
     """
     stacked_qkv = wildcard()
-    qkv_tuple = is_op("relax.split")(stacked_qkv)
+    if start_op == "split":
+        qkv_tuple = is_op("relax.split")(stacked_qkv)
+        query_raw = is_tuple_get_item(qkv_tuple, 0)
+        key_raw = is_tuple_get_item(qkv_tuple, 1)
+        value_raw = is_tuple_get_item(qkv_tuple, 2)
+    elif start_op == "strided_slice":
+        query_raw = is_op("relax.strided_slice")(stacked_qkv)
+        key_raw = is_op("relax.strided_slice")(stacked_qkv)
+        value_raw = is_op("relax.strided_slice")(stacked_qkv)
+    else:
+        raise NotImplementedError()
     query_reshape_list = wildcard()
     key_reshape_list = wildcard()
     value_reshape_list = wildcard()
-    query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), query_reshape_list)
-    key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), key_reshape_list)
-    value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), value_reshape_list)
+    query = is_op("relax.reshape")(query_raw, query_reshape_list)
+    key = is_op("relax.reshape")(key_raw, key_reshape_list)
+    value = is_op("relax.reshape")(value_raw, value_reshape_list)
     annotations = {
         "stacked_qkv": stacked_qkv,
         "query_reshape_list": query_reshape_list,
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index 4309627bf0..db8abf34c2 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -660,7 +660,7 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, bias_reshape, q
     return qkv, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
 
 
-def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale=None):
+def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None):
     dtype = str(qkv.dtype)
 
     from tvm.script.ir_builder import IRBuilder
@@ -676,10 +676,22 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale
             if bias is not None:
                 bias = R.arg("bias", R.Tensor(bias.shape, dtype))
             with R.dataflow() as frame:
-                qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
-                q = R.reshape(qkv_tuple[0], [b, s, n, h])
-                k = R.reshape(qkv_tuple[1], [b, s, n, h])
-                v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
+                if op == "split":
+                    qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
+                    q = R.reshape(qkv_tuple[0], [b, s, n, h])
+                    k = R.reshape(qkv_tuple[1], [b, s, n, h])
+                    v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
+                elif op == "strided_slice":
+                    q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), [b, s, n, h])
+                    k = R.reshape(
+                        R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), [b, s, n, h]
+                    )
+                    v = R.reshape(
+                        R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n * h_v], [1]),
+                        [b, s, n, h_v],
+                    )
+                else:
+                    raise NotImplementedError()
                 result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
                 R.output(result)
 
@@ -700,15 +712,31 @@ def stacked_attention_size(request):
     return request.param
 
 
-def test_stacked_attention_offload(stacked_attention_size):
+def test_stacked_attention_split_offload(stacked_attention_size):
+    b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
+    qkv, bias, ref = get_numpy_stacked_attention_ref(
+        b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
+    )
+    if scale == "none":
+        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias)
+    else:
+        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias, scale)
+    if bias is None:
+        out = get_result_with_relax_cutlass_offload(mod, qkv)
+    else:
+        out = get_result_with_relax_cutlass_offload(mod, qkv, bias)
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+def test_stacked_attention_strided_slice_offload(stacked_attention_size):
     b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
     qkv, bias, ref = get_numpy_stacked_attention_ref(
         b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
     )
     if scale == "none":
-        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias)
+        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias)
     else:
-        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias, scale)
+        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias, scale)
     if bias is None:
         out = get_result_with_relax_cutlass_offload(mod, qkv)
     else: