You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2023/09/27 23:02:26 UTC

[tvm] branch unity updated: [Unity][BYOC] Support offloading multi-query attention by Flash Attention (#15831)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7494bc4580 [Unity][BYOC] Support offloading multi-query attention by Flash Attention (#15831)
7494bc4580 is described below

commit 7494bc4580054e0a783d74c3d60fb0c177d1de23
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Sep 28 08:02:16 2023 +0900

    [Unity][BYOC] Support offloading multi-query attention by Flash Attention (#15831)
    
    * Squashed commit of the following:
    
    commit 99c2a59a1226f372c50c347c961d0c1201680a3e
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Wed Sep 27 09:57:21 2023 +0900
    
        Revert "Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (#15578)""
    
        This reverts commit 0a6a617e1315f3bc1550e5dc0e4630495e7fe70d.
    
    commit 9a3ca64cfa2152628f5704a68383d36949403900
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Wed Sep 27 09:55:02 2023 +0900
    
        wip
    
    commit be01900d59db94bbccc3d8142d95c302dade7ca2
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Tue Sep 26 19:55:29 2023 +0900
    
        fix test
    
    commit a026b650002b07833808d078ede41243796f9a95
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Thu Aug 31 22:24:38 2023 +0000
    
        wip
    
    commit 233d2d0fa7bb1a981f792645e8394d95e8d31cb4
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Tue Aug 29 17:42:11 2023 +0000
    
        wip
    
    commit 0a6a617e1315f3bc1550e5dc0e4630495e7fe70d
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Tue Aug 29 17:28:25 2023 +0000
    
        Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (#15578)"
    
        This reverts commit 567848e3a08a3bcb1ed69344050bc648a101d9b9.
    
    commit 6c5a4355e4cd487434c598909b7338159161624e
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Tue Aug 29 17:28:16 2023 +0000
    
        wip
    
    commit 7926cbc9d890c0c07376176a26144b8603bb9732
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Aug 28 06:17:01 2023 +0000
    
        wip
    
    commit 9828698ca3d808da8a77a432686c8be5dd4dab38
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Aug 28 15:11:47 2023 +0900
    
        wip
    
    commit 5d01fd1310fd5df98bf5fd56986056e20352ad3d
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Aug 28 06:05:56 2023 +0000
    
        wip
    
    commit ae657b7aed678fa7f7727aebcc9221940e45de26
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Aug 28 14:49:21 2023 +0900
    
        wip
    
    commit ddcab3887fef5c851689714f2f3924201165591d
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Aug 28 05:42:41 2023 +0000
    
        wip
    
    commit ab3572d852e21af3d4b349afd999654f491dcee8
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Aug 28 10:40:34 2023 +0900
    
        wip
    
    commit 690b88ef2380fc3ab5e3e02fed61cdf2936e0811
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Aug 28 10:25:33 2023 +0900
    
        update rev
    
    * black
    
    * clean
    
    * add doc
    
    * update rev
    
    * update test
    
    * fix
    
    ---------
    
    Co-authored-by: Masahiro Masuda <ma...@MasahironoMacBook-Pro.local>
---
 3rdparty/libflash_attn                            |  2 +-
 python/tvm/contrib/cutlass/attention_operation.py | 28 ++++++-------
 python/tvm/contrib/cutlass/build.py               |  7 ++--
 python/tvm/contrib/cutlass/gen_tensor_op.py       | 23 +++++++---
 python/tvm/relax/backend/contrib/cutlass.py       | 13 +++++-
 python/tvm/relax/backend/patterns.py              | 15 +++++--
 src/relax/op/nn/attention.cc                      | 13 +++++-
 tests/python/relax/test_codegen_cutlass.py        | 51 ++++++++++++++++++++++-
 8 files changed, 120 insertions(+), 32 deletions(-)

diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn
index 58b343e575..63cce0ca8f 160000
--- a/3rdparty/libflash_attn
+++ b/3rdparty/libflash_attn
@@ -1 +1 @@
-Subproject commit 58b343e57571fe5e0a5b43b5eb721acef8b35dff
+Subproject commit 63cce0ca8fa6bfca1982b342588273641cc5b86b
diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py
index 67a68df442..e59dbf032e 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -169,10 +169,10 @@ def instantiate_flash_attention_template(attrs):
     int k_head_stride = ${head_dim};
     int v_head_stride = ${head_dim};
     int o_head_stride = ${head_dim};
-    int q_row_stride = q_head_stride * ${num_heads};
-    int k_row_stride = k_head_stride * ${num_heads};
-    int v_row_stride = v_head_stride * ${num_heads};
-    int o_row_stride = o_head_stride * ${num_heads};
+    int q_row_stride = q_head_stride * ${num_q_heads};
+    int k_row_stride = k_head_stride * ${num_kv_heads};
+    int v_row_stride = v_head_stride * ${num_kv_heads};
+    int o_row_stride = o_head_stride * ${num_q_heads};
     int q_batch_stride = q_row_stride * ${num_queries};
     int k_batch_stride = k_row_stride * ${num_keys};
     int v_batch_stride = v_row_stride * ${num_keys};
@@ -190,8 +190,8 @@ def instantiate_flash_attention_template(attrs):
     			    ${num_batches},
     			    ${num_queries},
     			    ${num_keys},
-    			    ${num_heads},
-    			    ${num_heads},
+    			    ${num_q_heads},
+    			    ${num_kv_heads},
     			    ${head_dim},
     			    q_batch_stride,
     			    k_batch_stride,
@@ -215,13 +215,13 @@ def instantiate_flash_attention_template(attrs):
     int k_head_stride = ${head_dim};
     int v_head_stride = ${head_dim};
     int o_head_stride = ${head_dim};
-    int row_stride = q_head_stride * ${num_heads} +
-                     k_head_stride * ${num_heads} +
-                     v_head_stride * ${num_heads};
+    int row_stride = q_head_stride * ${num_q_heads} +
+                     k_head_stride * ${num_kv_heads} +
+                     v_head_stride * ${num_kv_heads};
     int q_row_stride = row_stride;
     int k_row_stride = row_stride;
     int v_row_stride = row_stride;
-    int o_row_stride = o_head_stride * ${num_heads};
+    int o_row_stride = o_head_stride * ${num_q_heads};
 
     int q_batch_stride = q_row_stride * ${num_queries};
     int k_batch_stride = k_row_stride * ${num_keys};
@@ -234,14 +234,14 @@ def instantiate_flash_attention_template(attrs):
 
     flash_attn::flash_attention_forward(
                             static_cast<const cutlass::half_t*>(${qkv}->data),
-    			    static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads},
-    			    static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads} * 2,
+    			    static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_q_heads},
+    			    static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads}),
     			    static_cast<cutlass::half_t*>(out0->data),
     			    ${num_batches},
     			    ${num_queries},
     			    ${num_keys},
-    			    ${num_heads},
-    			    ${num_heads},
+    			    ${num_q_heads},
+    			    ${num_kv_heads},
     			    ${head_dim},
     			    q_batch_stride,
     			    k_batch_stride,
diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py
index 0c57c4750e..b97fc20008 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -909,8 +909,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
 
         out_shape = signature["ret_shape"]
         out_dtype = signature["ret_dtype"]
-        num_batches, num_queries, num_heads, head_dim = q_shape
-        _, num_keys, _, _ = k_shape
+        num_batches, num_queries, num_q_heads, head_dim = q_shape
+        _, num_keys, num_kv_heads, _ = k_shape
         _, _, _, head_dim_value = v_shape
         scale = op_attrs.scale
 
@@ -931,7 +931,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
                 "num_batches": num_batches,
                 "num_queries": num_queries,
                 "num_keys": num_keys,
-                "num_heads": num_heads,
+                "num_q_heads": num_q_heads,
+                "num_kv_heads": num_kv_heads,
                 "head_dim": head_dim,
                 "head_dim_value": head_dim_value,
                 "scale": scale,
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 58bc91863d..62e64549c2 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -745,7 +745,6 @@ def instantiate_template(func_name, annotations, func_args):
 
         attrs["data_type"] = DataTypeTag[data_type]
         attrs["num_batches"] = b = annotations["num_batches"]
-        attrs["num_heads"] = n = annotations["num_heads"]
         attrs["head_dim"] = h = annotations["head_dim"]
         attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
         attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"]))
@@ -753,26 +752,40 @@ def instantiate_template(func_name, annotations, func_args):
             float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
         )
 
+        is_mqa = annotations["num_q_heads"] != annotations["num_kv_heads"]
+
         use_flash = (
             annotations["ret_dtype"] == "float16"
             and "bias" not in attrs
             and int(attrs["head_dim"]) <= 256
             and int(attrs["head_dim"]) % 8 == 0
             and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
-            # We have not thoroughly validated flash with causal mask yet, so for now we support
-            # only non-causal cases.
-            and int(annotations["custom_mask_type"]) == 0
+            # For the causal case (custom mask = "BottomRight"), only use flash for multi-query
+            # attention workloads. Otherwise, CUTLASS fMHA seems faster for causal attention
+            # with a single query.
+            and (
+                int(annotations["custom_mask_type"]) == 0
+                or (int(annotations["custom_mask_type"]) == 2 and is_mqa)
+            )
             # Flash v2 is currently not supported for sm < 80
             and int(annotations["arch"]) >= 80
         )
 
         if use_flash:
             headers.append("flash.h")
-            attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0
+            attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
+            attrs["num_q_heads"] = annotations["num_q_heads"]
+            attrs["num_kv_heads"] = annotations["num_kv_heads"]
             code = instantiate_flash_attention_template(attrs)
         else:
             headers.append("kernel_forward.h")
 
+            assert (
+                not is_mqa
+            ), "The number of query and KV heads need to be the same for CUTLASS fMHA."
+
+            attrs["num_heads"] = n = annotations["num_q_heads"]
+
             data_type_size = DataTypeSize[data_type]
             if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
                 attrs["kIsAligned"] = True
diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index fef6a1ec03..9efea3a0dc 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -576,7 +576,7 @@ def annotate_workspace(mod, _):
     return mod
 
 
-def partition_for_cutlass(mod, annotate_codegen=True):
+def partition_for_cutlass(mod, annotate_codegen=True, use_flash_mqa=True):
     """
     Partition the input module into CUTLASS-supported subgraphs.
 
@@ -590,6 +590,10 @@ def partition_for_cutlass(mod, annotate_codegen=True):
         body consists only of a call to the composite function. See the doc of FuseOpsByPattern
         for more detail.
 
+    use_flash_mqa: bool
+        Whether to consider a rewrite pattern for multi-query attention, which is supported by
+        the Flash Attention kernel.
+
     Returns
     -------
     mod: tvm.IRModule
@@ -598,8 +602,15 @@ def partition_for_cutlass(mod, annotate_codegen=True):
     """
     for func_name, func in mod.functions.items():
         if isinstance(func, Function):
+            if use_flash_mqa:
+                mqa_pattern, rewriter = make_attention_rewrite_pattern(
+                    "BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True
+                )
+                func = rewrite_call(mqa_pattern, rewriter, func)
+
             for pattern, rewriter in _REWRITE_PATTERNS:
                 func = rewrite_call(pattern, rewriter, func)
+
         mod[func_name] = func
 
     patterns = get_patterns_with_prefix("cutlass")
diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py
index 24edd0e7c9..10a075647b 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -318,7 +318,7 @@ def make_rms_norm_pattern():
 
 
 def make_attention_rewrite_pattern(
-    qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool
+    qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False
 ):
     """
     Create pattern for implicit fused multi head attention rewriting.
@@ -338,6 +338,10 @@ def make_attention_rewrite_pattern(
         Whether or not rewriting is intended to be applied to a module after the FP16 conversion
         pass.
 
+    with_kv_repeat: bool
+        Whether or not to include the Relax repeat op in the pattern, which is typically used
+        in a Relax module to support multi-query attention.
+
     Returns
     -------
     pattern: DFPattern
@@ -350,7 +354,10 @@ def make_attention_rewrite_pattern(
     """
 
     # pylint: disable=invalid-name
-    def handle_input(tensor, layout, transpose):
+    def handle_input(tensor, layout, transpose, repeat=False):
+        if repeat:
+            tensor = is_op("relax.repeat")(tensor)
+
         if layout == "BSNH":
             permuted = is_op("relax.permute_dims")(tensor)
             shape = wildcard()
@@ -434,8 +441,8 @@ def make_attention_rewrite_pattern(
 
     q_raw, k_raw, v_raw = wildcard(), wildcard(), wildcard()
     q, q_rewriter = handle_input(q_raw, qkv_layout, False)
-    k, k_rewriter = handle_input(k_raw, qkv_layout, True)
-    v, v_rewriter = handle_input(v_raw, qkv_layout, False)
+    k, k_rewriter = handle_input(k_raw, qkv_layout, True, repeat=with_kv_repeat)
+    v, v_rewriter = handle_input(v_raw, qkv_layout, False, repeat=with_kv_repeat)
     matmul_1 = is_op("relax.matmul")(q, k)
     scale = is_const()
 
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
index 4f37e3a33c..484137fecc 100644
--- a/src/relax/op/nn/attention.cc
+++ b/src/relax/op/nn/attention.cc
@@ -77,10 +77,19 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
                        << v1 << " while the " << dim << " of " << m2 << " is " << v2);
     }
   };
+  auto multiple_of = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) {
+    if (analyzer->CanProve(indexmod(v1, v2) != 0)) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "The " << m1 << " " << dim << " should be a multiple of " << m2 << " "
+                       << dim << ". However, the " << dim << " of " << m1 << " is " << v1
+                       << " while the " << dim << " of " << m2 << " is " << v2);
+    }
+  };
+
   diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size");
   diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size");
-  diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
-  diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads");
+  multiple_of(num_heads, k_shape->values[2], "query", "key", "number of heads");
+  multiple_of(num_heads, v_shape->values[2], "query", "value", "number of heads");
   diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length");
   diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads");
 
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index e8d4e83521..83936ef9c9 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -746,7 +746,7 @@ def attention_causal(request):
 def test_attention_causal_offload(attention_causal_size, attention_causal):
     b, (s, s_kv), n, (h, h_v), bias_shape = attention_causal_size
     q, k, v, bias, ref = get_numpy_attention_ref(
-        b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32"
+        b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float16"
     )
 
     q_shape = (b, s, n, h)
@@ -757,10 +757,11 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
         q_shape,
         k_shape,
         v_shape,
-        dtype="float32",
+        dtype="float16",
         bias_shape=bias_shape,
         causal_mask=attention_causal,
     )
+
     if bias is None:
         out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
     else:
@@ -1945,5 +1946,51 @@ def test_fp16A_int8B_gemm_batched():
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_attention_rewrite_multi_query():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            q: R.Tensor((4, 16, 32, 16), dtype="float16"),
+            k_single: R.Tensor((4, 16, 1, 16), dtype="float16"),
+            v_single: R.Tensor((4, 16, 1, 16), dtype="float16"),
+        ) -> R.Tensor((4, 16, 32, 8), dtype="float16"):
+            with R.dataflow():
+                k = R.repeat(k_single, 32, axis=2)
+                v = R.repeat(v_single, 32, axis=2)
+
+                lv = R.permute_dims(q, axes=[0, 2, 1, 3])
+                lv1 = R.reshape(lv, R.shape([128, 16, 16]))
+                lv2 = R.permute_dims(k, axes=[0, 2, 1, 3])
+                lv3 = R.reshape(lv2, R.shape([128, 16, 16]))
+                lv4 = R.permute_dims(v, axes=[0, 2, 1, 3])
+                lv5 = R.reshape(lv4, R.shape([128, 16, 16]))
+
+                lv6 = R.permute_dims(lv3, axes=[0, 2, 1])
+                lv7 = R.matmul(lv1, lv6, out_dtype="float16")
+                lv3_1 = R.astype(R.const(0.25, "float32"), "float16")
+                lv8 = R.multiply(lv7, lv3_1)
+                lv11 = R.astype(R.nn.softmax(R.astype(lv8, "float32"), axis=2), "float16")
+                lv12 = R.matmul(lv11, lv5, out_dtype="float16")
+                lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16]))
+                lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3])
+                R.output(lv6_1)
+            return lv6_1
+
+    q_np = np.random.randn(4, 16, 32, 16).astype("float16")
+    k_np = np.random.randn(4, 16, 1, 16).astype("float16")
+    v_np = np.random.randn(4, 16, 1, 16).astype("float16")
+    args = [q_np, k_np, v_np]
+    ref = build_and_run(Module, args, "llvm", legalize=True)
+
+    mod = partition_for_cutlass(Module, use_flash_mqa=True)
+    codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}})
+    mod = codegen_pass(mod)
+
+    out = build_and_run(mod, args, "cuda")
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()