You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/07/21 14:19:40 UTC

[tvm] branch unity updated: [Unity][CUTLASS] Support `out_dtype = "float32"` for FasterTransformer kernel (#15377)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e38eb9003a [Unity][CUTLASS] Support `out_dtype = "float32"` for FasterTransformer kernel (#15377)
e38eb9003a is described below

commit e38eb9003a492abf1a2e8a3a079c060ea9329951
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Jul 21 23:19:33 2023 +0900

    [Unity][CUTLASS] Support `out_dtype = "float32"` for FasterTransformer kernel (#15377)
    
    Support out_dtype = "float32" for FasterTransformer kernel
---
 python/tvm/relax/backend/contrib/cutlass.py |  8 +++++-
 tests/python/relax/test_codegen_cutlass.py  | 41 ++++++++++++++++++++++++++---
 2 files changed, 45 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index cdce3225b5..fef6a1ec03 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -237,7 +237,7 @@ def _check_decode_matmul(ctx):
     if not _check_residual(root, ctx):
         return False
 
-    # out_dtype = "float32" not supported.
+    # out_dtype = "float32" not supported unless matmul is followed by cast to fp16.
     if root.struct_info.dtype == "float32":
         return False
 
@@ -299,6 +299,9 @@ def decode_matmul_patterns():
         )
         matmul = is_op("relax.matmul")(x, w)
 
+        if "cast" in name:
+            matmul = is_op("relax.astype")(matmul)
+
         annotations = {
             "root": matmul,
             "lhs": x,
@@ -321,7 +324,10 @@ def decode_matmul_patterns():
     return [
         _decode_matmul_pattern("cutlass.decode_matmul"),
         _decode_matmul_pattern("cutlass.decode_matmul_bias"),
+        _decode_matmul_pattern("cutlass.decode_matmul_cast"),
+        _decode_matmul_pattern("cutlass.decode_matmul_cast_bias"),
         _decode_matmul_pattern("cutlass.decode_matmul_bias_gelu"),
+        _decode_matmul_pattern("cutlass.decode_matmul_cast_bias_gelu"),
     ]
 
 
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index 02f15ad3d7..30286a597a 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1412,6 +1412,38 @@ def test_fp16A_int4B_gemm():
                 R.output(lv2_1)
             return lv2_1
 
+        @R.function
+        def main_cast_bias(
+            x: R.Tensor((64, 64), dtype="float16"),
+            y: R.Tensor((128, 64), dtype="float16"),
+            bias: R.Tensor((1, 128), dtype="float16"),
+        ) -> R.Tensor((64, 128), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")],
+                )
+                lv1 = lv[0]
+                lv2 = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight",
+                    lv1,
+                    80,
+                    True,
+                    sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+                )
+                lv3: R.Tensor((128,), dtype="float16") = lv[1]
+                lv6 = R.call_tir(
+                    cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), dtype="float16")
+                )
+                lv1_1: R.Tensor((64, 128), dtype="float32") = R.matmul(x, lv6, out_dtype="float32")
+                cast: R.Tensor((64, 128), dtype="float16") = R.astype(lv1_1, dtype="float16")
+                lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(cast, bias)
+                R.output(lv2_1)
+            return lv2_1
+
         @R.function
         def main_residual(
             x: R.Tensor((64, 64), dtype="float16"),
@@ -1452,10 +1484,11 @@ def test_fp16A_int4B_gemm():
     func_names = [name.name_hint for (name, _) in mod.functions.items()]
     assert "fused_decode_relax_matmul_relax_add_cutlass" in func_names
     assert "fused_decode_relax_matmul_relax_add_relax_add_cutlass" in func_names
+    assert "fused_decode_relax_matmul_relax_astype_relax_add_cutlass" in func_names
 
     mod = relax.transform.RunCodegen(
         {"cutlass": {"sm": 80, "find_first_valid": False}},
-        entry_functions=["main_bias", "main_residual"],
+        entry_functions=["main_bias", "main_residual", "main_cast_bias"],
     )(mod)
 
     x = np.random.randn(*x_shape).astype("float16")
@@ -1483,13 +1516,15 @@ def test_fp16A_int4B_gemm():
     residual_nd = tvm.nd.array(residual, dev)
     params = (packed_weight.copyto(dev), scales.copyto(dev), bias_trans.copyto(dev))
 
-    for with_residual in [False, True]:
+    for f_name in ["main_bias", "main_cast_bias", "main_residual"]:
+        with_residual = "residual" in f_name
+
         if with_residual:
             inp = [x_nd, residual_nd, params]
         else:
             inp = [x_nd, params]
 
-        out = vm["main_residual" if with_residual else "main_bias"](*inp).numpy()
+        out = vm[f_name](*inp).numpy()
 
         ref = np.dot(x, y.transpose()) + bias