You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "cyx-6 (via GitHub)" <gi...@apache.org> on 2023/03/01 00:02:21 UTC

[GitHub] [tvm] cyx-6 opened a new pull request, #14150: [Unity][OP] Add an operator for fused multi head attention

cyx-6 opened a new pull request, #14150:
URL: https://github.com/apache/tvm/pull/14150

   This PR introduces the new relax operator `R.nn.attention` for fused multi head attention, and the support of fused multi head attention to relax cutlass BYOC. The input of the operator are query, key and value tensor, with `BSNH` layout, namely `[batch size, sequence length, number of heads, dimension of heads]`. And the output shares the same layout with all input tensor.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on PR #14150:
URL: https://github.com/apache/tvm/pull/14150#issuecomment-1452483031

   Just note that 3rdparty/cutlass is changed by this PR. Could you revert this change?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #14150:
URL: https://github.com/apache/tvm/pull/14150#issuecomment-1449109384

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @quic-sanirudh <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on PR #14150:
URL: https://github.com/apache/tvm/pull/14150#issuecomment-1449589729

   cc @hwu36


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1121918157


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -438,5 +439,79 @@ def test_matmul_transposed_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
 
 
+@pytest.fixture(
+    params=[
+        # B, S, N, H
+        (32, (4, 4), 16, (8, 8)),
+        (4, (8, 4), 32, (8, 8)),
+        (4, (8, 4), 32, (8, 16)),
+    ]
+)
+def attention_size(request):
+    return request.param
+
+
+@pytest.fixture
+def attention_q(attention_size, target_dtype):
+    b, (s, _), n, (h, _) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
+@pytest.fixture
+def attention_k(attention_size, target_dtype):
+    b, (_, s), n, (h, _) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
+@pytest.fixture
+def attention_v(attention_size, target_dtype):
+    b, (_, s), n, (_, h) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
+def get_relax_attention_module(q, k, v):
+    dtype = str(q.dtype)
+
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import relax as relax_builder
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            q = R.arg("q", R.Tensor(q.shape, dtype))
+            k = R.arg("k", R.Tensor(k.shape, dtype))
+            v = R.arg("v", R.Tensor(v.shape, dtype))
+
+            with R.dataflow() as frame:
+                result = R.emit(R.nn.attention(q, k, v))
+                R.output(result)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
+
+
+def get_numpy_attention_ref(q, k, v):

Review Comment:
   Sure! I have refactored the unittest codes and add the memoize decorator.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1120998952


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):

Review Comment:
   this only implements `softmax(Q*K.T) * V` in BSNH layout. We can support other layout, and fused QKV in the future by changing the strides



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1120974691


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):
+    """Return CUTLASS host code for fused multi head attention
+    based on a template and the provided attribute map."""
+
+    template = """
+  using T = cutlass::half_t;
+
+  CHECK(${arg0}->ndim == 4); // B, S, N, H
+  CHECK(${arg1}->ndim == 4); // B, S', N, H
+  CHECK(${arg2}->ndim == 4); // B, S', N, H'
+  CHECK(out0->ndim == 4); // B, S, N, H'
+
+  using Attention =
+      AttentionKernel<T,
+                      /*ArchTag=*/${arch},
+                      /*is_aligned=*/true,

Review Comment:
   we also need to dispatch `is_aligned` based on input shape



##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):
+    """Return CUTLASS host code for fused multi head attention
+    based on a template and the provided attribute map."""
+
+    template = """
+  using T = cutlass::half_t;
+
+  CHECK(${arg0}->ndim == 4); // B, S, N, H
+  CHECK(${arg1}->ndim == 4); // B, S', N, H
+  CHECK(${arg2}->ndim == 4); // B, S', N, H'
+  CHECK(out0->ndim == 4); // B, S, N, H'
+
+  using Attention =
+      AttentionKernel<T,
+                      /*ArchTag=*/${arch},
+                      /*is_aligned=*/true,
+                      /*queries_per_block=*/${kQueriesPerBlock},
+                      /*keys_per_block=*/${kKeysPerBlock},
+                      /*single_value_iteration=*/${kSingleValueIteration}
+      >;
+
+  typename Attention::Params p;
+
+  p.query_ptr = reinterpret_cast<T *>(${arg0}->data);
+  p.key_ptr = reinterpret_cast<T *>(${arg1}->data);
+  p.value_ptr = reinterpret_cast<T *>(${arg2}->data);
+  p.logsumexp_ptr = nullptr;
+  p.output_ptr = reinterpret_cast<T *>(out0->data);
+  static_assert(!Attention::kNeedsOutputAccumulatorBuffer);
+  p.output_accum_ptr = nullptr;
+
+  p.num_heads = ${num_heads}; // N
+  p.num_batches = ${num_batches}; // B
+  p.head_dim = ${head_dim}; // H
+  p.head_dim_value = ${head_dim_value}; // H'
+  p.num_queries = ${num_queries}; // S
+  p.num_keys = ${num_keys}; // S'
+  p.scale = 1.0f / sqrt(float(${head_dim}));
+  // p.causal = false;
+
+  // stride for N
+  p.q_strideH = p.head_dim; // H
+  p.k_strideH = p.head_dim; // H
+  p.v_strideH = p.head_dim_value; // H'
+  // p.o_strideH = p.head_dim_value; // H'
+
+  // stride for S
+  p.q_strideM = p.q_strideH * p.num_heads; // H * N
+  p.k_strideM = p.k_strideH * p.num_heads; // H * N
+  p.v_strideM = p.v_strideH * p.num_heads; // H' * N
+  p.o_strideM = p.head_dim_value * p.num_heads; // H' * N
+
+  // stride for B
+  p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
+  p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
+  p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'
+  // p.o_strideB = p.o_strideM * p.num_queries; // H'* N * S

Review Comment:
   remove if not needed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1120970995


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -438,5 +439,79 @@ def test_matmul_transposed_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
 
 
+@pytest.fixture(
+    params=[
+        # B, S, N, H
+        (32, (4, 4), 16, (8, 8)),
+        (4, (8, 4), 32, (8, 8)),
+        (4, (8, 4), 32, (8, 16)),
+    ]
+)
+def attention_size(request):
+    return request.param
+
+
+@pytest.fixture
+def attention_q(attention_size, target_dtype):
+    b, (s, _), n, (h, _) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
+@pytest.fixture
+def attention_k(attention_size, target_dtype):
+    b, (_, s), n, (h, _) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
+@pytest.fixture
+def attention_v(attention_size, target_dtype):
+    b, (_, s), n, (_, h) = attention_size
+    return np.random.randn(b, s, n, h).astype(target_dtype)
+
+
+def get_relax_attention_module(q, k, v):
+    dtype = str(q.dtype)
+
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import relax as relax_builder
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            q = R.arg("q", R.Tensor(q.shape, dtype))
+            k = R.arg("k", R.Tensor(k.shape, dtype))
+            v = R.arg("v", R.Tensor(v.shape, dtype))
+
+            with R.dataflow() as frame:
+                result = R.emit(R.nn.attention(q, k, v))
+                R.output(result)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
+
+
+def get_numpy_attention_ref(q, k, v):

Review Comment:
   use tvm.contrib.memoize so we don't need to run it multiple times



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1124879759


##########
python/tvm/relax/op/nn/nn.py:
##########
@@ -577,3 +577,41 @@ def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr:
       The computed result.
     """
     return _ffi_api.cross_entropy_with_logits(predictions, labels)  # type: ignore
+
+
+def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) -> Expr:
+    r"""Computes fused multi head attention.
+
+    All input tensors are of 4-D tensors with BSNH layout.
+
+    .. math::
+        FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V
+
+    .. note::
+        The input tensor is required to have float16 dtype
+
+    Parameters
+    ----------
+    query: relax.Expr
+        The input query to the operator. The layout of the input query should be
+        (batch_size, seq_len, num_head, head_dim).
+
+    key: relax.Expr
+        The input key to the operator. The layout of the input key should be
+        (batch_size, seq_len_kv, num_head, head_dim).
+
+    value: relax.Expr
+        The input value to the operator. The layout of the input value should be
+        (batch_size, seq_len_kv, num_head, head_dim_v).
+
+    bias: Optional[Expr]
+        The optional attention bias to the operator. The layout of the attention bias should be

Review Comment:
   it's more common to have BNSS' and it allows broadcastable shape. Let's say
   ```
   The shape of the attention bias should be able to broadcast to (batch_size, num_head, seq_len, seq_len_kv)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] hwu36 commented on pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "hwu36 (via GitHub)" <gi...@apache.org>.
hwu36 commented on PR #14150:
URL: https://github.com/apache/tvm/pull/14150#issuecomment-1450279539

   @danthe3rd @mnicely


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1123862519


##########
python/tvm/contrib/cutlass/gen_tensor_op.py:
##########
@@ -652,4 +653,25 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_
         code = instantiate_conv2d_template(attrs, func_args)
         return CodegenResult(code, headers)
 
+    elif "attention" in func_name:
+        headers.append("kernel_forward.h")
+        attrs["num_batches"] = str(int(annotations["num_batches"]))
+        attrs["num_queries"] = str(int(annotations["num_queries"]))
+        attrs["num_keys"] = str(int(annotations["num_keys"]))
+        attrs["num_heads"] = str(int(annotations["num_heads"]))
+        attrs["head_dim"] = str(int(annotations["head_dim"]))

Review Comment:
   The `substitute_template` is now updated to accept `int`, `IntImm` and `bool`. And all the usages of `str(int(...))` have been removed, including usages in conv2d and matmul.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] spectrometerHBH commented on pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "spectrometerHBH (via GitHub)" <gi...@apache.org>.
spectrometerHBH commented on PR #14150:
URL: https://github.com/apache/tvm/pull/14150#issuecomment-1449434066

   Do we consider adding it to the Fx translator since PT has MHA op?
   https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1122219014


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):

Review Comment:
   Is that the difference between Bahdanau and Luong attention? (I don't have a ton of experience with attention mechanisms)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1121920405


##########
include/tvm/relax/attrs/nn.h:
##########
@@ -184,6 +184,15 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
   }
 };  // struct DropoutAttrs
 
+/*! \brief Attributes used in fuse multi head attention operator */
+struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
+  DataType out_dtype;

Review Comment:
   Yes. The `AttentionAttrs` is redundant indeed. I have removed them.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1124041621


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):
+    """Return CUTLASS host code for fused multi head attention
+    based on a template and the provided attribute map."""
+
+    template = """
+  using T = cutlass::half_t;
+
+  CHECK(${arg0}->ndim == 4); // B, S, N, H
+  CHECK(${arg1}->ndim == 4); // B, S', N, H
+  CHECK(${arg2}->ndim == 4); // B, S', N, H'
+  CHECK(out0->ndim == 4); // B, S, N, H'
+
+  using Attention =
+      AttentionKernel<T,
+                      /*ArchTag=*/${arch},
+                      /*is_aligned=*/true,
+                      /*queries_per_block=*/${kQueriesPerBlock},
+                      /*keys_per_block=*/${kKeysPerBlock},
+                      /*single_value_iteration=*/${kSingleValueIteration}
+      >;

Review Comment:
   Thanks for your valuable advice! I have set the `false` for these two fields by default. :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1121923108


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):
+    """Return CUTLASS host code for fused multi head attention
+    based on a template and the provided attribute map."""
+
+    template = """
+  using T = cutlass::half_t;
+
+  CHECK(${arg0}->ndim == 4); // B, S, N, H
+  CHECK(${arg1}->ndim == 4); // B, S', N, H
+  CHECK(${arg2}->ndim == 4); // B, S', N, H'
+  CHECK(out0->ndim == 4); // B, S, N, H'
+
+  using Attention =
+      AttentionKernel<T,
+                      /*ArchTag=*/${arch},
+                      /*is_aligned=*/true,
+                      /*queries_per_block=*/${kQueriesPerBlock},
+                      /*keys_per_block=*/${kKeysPerBlock},
+                      /*single_value_iteration=*/${kSingleValueIteration}
+      >;
+
+  typename Attention::Params p;
+
+  p.query_ptr = reinterpret_cast<T *>(${arg0}->data);
+  p.key_ptr = reinterpret_cast<T *>(${arg1}->data);
+  p.value_ptr = reinterpret_cast<T *>(${arg2}->data);
+  p.logsumexp_ptr = nullptr;
+  p.output_ptr = reinterpret_cast<T *>(out0->data);
+  static_assert(!Attention::kNeedsOutputAccumulatorBuffer);
+  p.output_accum_ptr = nullptr;
+
+  p.num_heads = ${num_heads}; // N
+  p.num_batches = ${num_batches}; // B
+  p.head_dim = ${head_dim}; // H
+  p.head_dim_value = ${head_dim_value}; // H'
+  p.num_queries = ${num_queries}; // S
+  p.num_keys = ${num_keys}; // S'
+  p.scale = 1.0f / sqrt(float(${head_dim}));
+  // p.causal = false;
+
+  // stride for N
+  p.q_strideH = p.head_dim; // H
+  p.k_strideH = p.head_dim; // H
+  p.v_strideH = p.head_dim_value; // H'
+  // p.o_strideH = p.head_dim_value; // H'
+
+  // stride for S
+  p.q_strideM = p.q_strideH * p.num_heads; // H * N
+  p.k_strideM = p.k_strideH * p.num_heads; // H * N
+  p.v_strideM = p.v_strideH * p.num_heads; // H' * N
+  p.o_strideM = p.head_dim_value * p.num_heads; // H' * N
+
+  // stride for B
+  p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
+  p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
+  p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'
+  // p.o_strideB = p.o_strideM * p.num_queries; // H'* N * S

Review Comment:
   Sure. The commented codes have been removed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1124896803


##########
python/tvm/relax/op/nn/nn.py:
##########
@@ -577,3 +577,41 @@ def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr:
       The computed result.
     """
     return _ffi_api.cross_entropy_with_logits(predictions, labels)  # type: ignore
+
+
+def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) -> Expr:
+    r"""Computes fused multi head attention.
+
+    All input tensors are of 4-D tensors with BSNH layout.
+
+    .. math::
+        FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V
+
+    .. note::
+        The input tensor is required to have float16 dtype
+
+    Parameters
+    ----------
+    query: relax.Expr
+        The input query to the operator. The layout of the input query should be
+        (batch_size, seq_len, num_head, head_dim).
+
+    key: relax.Expr
+        The input key to the operator. The layout of the input key should be
+        (batch_size, seq_len_kv, num_head, head_dim).
+
+    value: relax.Expr
+        The input value to the operator. The layout of the input value should be
+        (batch_size, seq_len_kv, num_head, head_dim_v).
+
+    bias: Optional[Expr]
+        The optional attention bias to the operator. The layout of the attention bias should be

Review Comment:
   Sure! I have updated the layout of the bias to `BNSS'`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1124919305


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -24,12 +24,12 @@ def instantiate_attention_template(attrs, func_args):
     based on a template and the provided attribute map."""
 
     bias_template = """
-  CHECK(${arg3}->ndim == 4); // B, S, N, S'
+  CHECK(${arg3}->ndim == 4); // B, N, S, S'
 
   p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
-  p.bias_strideH = p.num_keys; // S'
-  p.bias_strideM = p.bias_strideH * p.num_heads; // S' * N
-  p.bias_strideB = p.bias_strideM * p.num_queries; // S' * N * S
+  p.bias_strideM = p.num_keys; // S'
+  p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S

Review Comment:
   we should also handle broadcast case



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on PR #14150:
URL: https://github.com/apache/tvm/pull/14150#issuecomment-1454330960

   @masahi Yes, as long as it is a Relax op, we should have it fully tested. I will add those necessary tests in the following PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on PR #14150:
URL: https://github.com/apache/tvm/pull/14150#issuecomment-1449133074

   thanks @cyx-6 , can we also add a note about existing attention operator layout in libraries(HF and PT), that would help us to write down the rationale of the operator


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] danthe3rd commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "danthe3rd (via GitHub)" <gi...@apache.org>.
danthe3rd commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1121993334


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):
+    """Return CUTLASS host code for fused multi head attention
+    based on a template and the provided attribute map."""
+
+    template = """
+  using T = cutlass::half_t;
+
+  CHECK(${arg0}->ndim == 4); // B, S, N, H
+  CHECK(${arg1}->ndim == 4); // B, S', N, H
+  CHECK(${arg2}->ndim == 4); // B, S', N, H'
+  CHECK(out0->ndim == 4); // B, S, N, H'
+
+  using Attention =
+      AttentionKernel<T,
+                      /*ArchTag=*/${arch},
+                      /*is_aligned=*/true,
+                      /*queries_per_block=*/${kQueriesPerBlock},
+                      /*keys_per_block=*/${kKeysPerBlock},
+                      /*single_value_iteration=*/${kSingleValueIteration}
+      >;

Review Comment:
   Note: if you don't need it, you should set `kSupportsDropout` and `kSupportsBias` to false for better performance
   https://github.com/NVIDIA/cutlass/blob/f396cdd15ccf873d5c92ae73b73ed680c77f4400/examples/41_fused_multi_head_attention/kernel_forward.h#L107-L108



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1123792677


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):

Review Comment:
   it's mostly different variants of mask/bias (e.g 2D, 3D, causal) apply to Q*K.T



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 merged pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 merged PR #14150:
URL: https://github.com/apache/tvm/pull/14150


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1125160077


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -24,12 +24,12 @@ def instantiate_attention_template(attrs, func_args):
     based on a template and the provided attribute map."""
 
     bias_template = """
-  CHECK(${arg3}->ndim == 4); // B, S, N, S'
+  CHECK(${arg3}->ndim == 4); // B, N, S, S'
 
   p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
-  p.bias_strideH = p.num_keys; // S'
-  p.bias_strideM = p.bias_strideH * p.num_heads; // S' * N
-  p.bias_strideB = p.bias_strideM * p.num_queries; // S' * N * S
+  p.bias_strideM = p.num_keys; // S'
+  p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S

Review Comment:
   The broadcast for bias is added. Now we accept the bias with layout of `BS'`, `BSS'` and `BNSS'`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1124035776


##########
src/relax/op/nn/attention.h:
##########
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file attention.h
+ * \brief The functions to make Relax attention operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_ATTENTION_H_
+#define TVM_RELAX_OP_NN_ATTENTION_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief fused multi head attention */
+Expr attention(Expr query, Expr key, Expr value, DataType out_dtype);

Review Comment:
   The bias for `Q @ K + bias` has been added. The kernel currently does not support the bias for `softmax(Q @ K) @ V + bias`. We may find a solution in the future if needed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1123866733


##########
python/tvm/contrib/cutlass/gen_tensor_op.py:
##########
@@ -652,4 +653,25 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_
         code = instantiate_conv2d_template(attrs, func_args)
         return CodegenResult(code, headers)
 
+    elif "attention" in func_name:
+        headers.append("kernel_forward.h")
+        attrs["num_batches"] = str(int(annotations["num_batches"]))
+        attrs["num_queries"] = str(int(annotations["num_queries"]))
+        attrs["num_keys"] = str(int(annotations["num_keys"]))
+        attrs["num_heads"] = str(int(annotations["num_heads"]))
+        attrs["head_dim"] = str(int(annotations["head_dim"]))

Review Comment:
   Thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on PR #14150:
URL: https://github.com/apache/tvm/pull/14150#issuecomment-1454327790

   The test was added only for cutlass BYOC. Shouldn't we have a standalone test for Relax op + legalize?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1123793756


##########
python/tvm/contrib/cutlass/gen_tensor_op.py:
##########
@@ -652,4 +653,25 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_
         code = instantiate_conv2d_template(attrs, func_args)
         return CodegenResult(code, headers)
 
+    elif "attention" in func_name:
+        headers.append("kernel_forward.h")
+        attrs["num_batches"] = str(int(annotations["num_batches"]))
+        attrs["num_queries"] = str(int(annotations["num_queries"]))
+        attrs["num_keys"] = str(int(annotations["num_keys"]))
+        attrs["num_heads"] = str(int(annotations["num_heads"]))
+        attrs["head_dim"] = str(int(annotations["head_dim"]))

Review Comment:
   Feel free to update `substitute_template(...)` to get rid off `str(int(...))` stuff. I can also do it later otherwise.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] sunggg commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "sunggg (via GitHub)" <gi...@apache.org>.
sunggg commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1120986730


##########
include/tvm/relax/attrs/nn.h:
##########
@@ -184,6 +184,15 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
   }
 };  // struct DropoutAttrs
 
+/*! \brief Attributes used in fuse multi head attention operator */
+struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
+  DataType out_dtype;

Review Comment:
   I'm wondering why we need this as attribute. Can we use struct_info instead?



##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):

Review Comment:
   Correct me if i'm wrong, but I heard there are various ways to implement attention. 
   Is there any specific form (maybe something standard) of attention we target? 
   Or can we also support a set of variants? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1122320266


##########
src/relax/op/nn/attention.h:
##########
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file attention.h
+ * \brief The functions to make Relax attention operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_ATTENTION_H_
+#define TVM_RELAX_OP_NN_ATTENTION_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief fused multi head attention */
+Expr attention(Expr query, Expr key, Expr value, DataType out_dtype);

Review Comment:
   yeah bias can be useful in some cases. @cyx-6 can we add an optional input `bias` for attention op?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1122321051


##########
src/relax/op/nn/attention.h:
##########
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file attention.h
+ * \brief The functions to make Relax attention operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_ATTENTION_H_
+#define TVM_RELAX_OP_NN_ATTENTION_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief fused multi head attention */
+Expr attention(Expr query, Expr key, Expr value, DataType out_dtype);

Review Comment:
   yes, I am adding it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1120971664


##########
src/relax/op/nn/attention.h:
##########
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file attention.h
+ * \brief The functions to make Relax attention operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_ATTENTION_H_
+#define TVM_RELAX_OP_NN_ATTENTION_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief fused multi head attention */
+Expr attention(Expr query, Expr key, Expr value, DataType out_dtype);

Review Comment:
   what about bias?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1123860265


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):
+    """Return CUTLASS host code for fused multi head attention
+    based on a template and the provided attribute map."""
+
+    template = """
+  using T = cutlass::half_t;
+
+  CHECK(${arg0}->ndim == 4); // B, S, N, H
+  CHECK(${arg1}->ndim == 4); // B, S', N, H
+  CHECK(${arg2}->ndim == 4); // B, S', N, H'
+  CHECK(out0->ndim == 4); // B, S, N, H'
+
+  using Attention =
+      AttentionKernel<T,
+                      /*ArchTag=*/${arch},
+                      /*is_aligned=*/true,

Review Comment:
   the dispatch for alignment has been added



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] cyx-6 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "cyx-6 (via GitHub)" <gi...@apache.org>.
cyx-6 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1123870949


##########
python/tvm/contrib/cutlass/gen_tensor_op.py:
##########
@@ -653,4 +655,34 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_
         code = instantiate_conv2d_template(attrs, func_args)
         return CodegenResult(code, headers)
 
+    elif "attention" in func_name:
+        headers.append("kernel_forward.h")
+        data_type = dtype_map[annotations["arg0_dtype"]]
+        attrs["data_type"] = DataTypeTag[data_type]
+        attrs["num_batches"] = b = annotations["num_batches"]
+        attrs["num_queries"] = s = annotations["num_queries"]
+        attrs["num_keys"] = annotations["num_keys"]
+        attrs["num_heads"] = n = annotations["num_heads"]
+        attrs["head_dim"] = h = annotations["head_dim"]
+        attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
+        data_type_size = DataTypeSize[data_type]
+        if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
+            attrs["kIsAligned"] = True
+        elif (h % 4 == 0) and (h_v % 4 == 0):
+            attrs["kIsAligned"] = False
+        else:
+            raise NotImplementedError()
+        if h_v > 64:
+            attrs["kQueriesPerBlock"] = "32"
+            attrs["kKeysPerBlock"] = "128"
+            attrs["kSingleValueIteration"] = h_v <= 128
+        else:
+            attrs["kQueriesPerBlock"] = "64"
+            attrs["kKeysPerBlock"] = "64"

Review Comment:
   Oh, I missed that. Fixed already.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1123868036


##########
python/tvm/contrib/cutlass/gen_tensor_op.py:
##########
@@ -653,4 +655,34 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_
         code = instantiate_conv2d_template(attrs, func_args)
         return CodegenResult(code, headers)
 
+    elif "attention" in func_name:
+        headers.append("kernel_forward.h")
+        data_type = dtype_map[annotations["arg0_dtype"]]
+        attrs["data_type"] = DataTypeTag[data_type]
+        attrs["num_batches"] = b = annotations["num_batches"]
+        attrs["num_queries"] = s = annotations["num_queries"]
+        attrs["num_keys"] = annotations["num_keys"]
+        attrs["num_heads"] = n = annotations["num_heads"]
+        attrs["head_dim"] = h = annotations["head_dim"]
+        attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
+        data_type_size = DataTypeSize[data_type]
+        if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
+            attrs["kIsAligned"] = True
+        elif (h % 4 == 0) and (h_v % 4 == 0):
+            attrs["kIsAligned"] = False
+        else:
+            raise NotImplementedError()
+        if h_v > 64:
+            attrs["kQueriesPerBlock"] = "32"
+            attrs["kKeysPerBlock"] = "128"
+            attrs["kSingleValueIteration"] = h_v <= 128
+        else:
+            attrs["kQueriesPerBlock"] = "64"
+            attrs["kKeysPerBlock"] = "64"

Review Comment:
   We can also remove the quotation now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14150: [Unity][OP] Add an operator for fused multi head attention

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14150:
URL: https://github.com/apache/tvm/pull/14150#discussion_r1123792677


##########
python/tvm/contrib/cutlass/attention_operation.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS attention kernels."""
+from .library import *
+
+
+def instantiate_attention_template(attrs, func_args):

Review Comment:
   it's mostly different variants of mask/bias apply to Q*K.T



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org