You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/07/21 14:19:40 UTC
[tvm] branch unity updated: [Unity][CUTLASS] Support `out_dtype = "float32"` for FasterTransformer kernel (#15377)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e38eb9003a [Unity][CUTLASS] Support `out_dtype = "float32"` for FasterTransformer kernel (#15377)
e38eb9003a is described below
commit e38eb9003a492abf1a2e8a3a079c060ea9329951
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Jul 21 23:19:33 2023 +0900
[Unity][CUTLASS] Support `out_dtype = "float32"` for FasterTransformer kernel (#15377)
Support out_dtype = "float32" for FasterTransformer kernel
---
python/tvm/relax/backend/contrib/cutlass.py | 8 +++++-
tests/python/relax/test_codegen_cutlass.py | 41 ++++++++++++++++++++++++++---
2 files changed, 45 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index cdce3225b5..fef6a1ec03 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -237,7 +237,7 @@ def _check_decode_matmul(ctx):
if not _check_residual(root, ctx):
return False
- # out_dtype = "float32" not supported.
+ # out_dtype = "float32" not supported unless matmul is followed by cast to fp16.
if root.struct_info.dtype == "float32":
return False
@@ -299,6 +299,9 @@ def decode_matmul_patterns():
)
matmul = is_op("relax.matmul")(x, w)
+ if "cast" in name:
+ matmul = is_op("relax.astype")(matmul)
+
annotations = {
"root": matmul,
"lhs": x,
@@ -321,7 +324,10 @@ def decode_matmul_patterns():
return [
_decode_matmul_pattern("cutlass.decode_matmul"),
_decode_matmul_pattern("cutlass.decode_matmul_bias"),
+ _decode_matmul_pattern("cutlass.decode_matmul_cast"),
+ _decode_matmul_pattern("cutlass.decode_matmul_cast_bias"),
_decode_matmul_pattern("cutlass.decode_matmul_bias_gelu"),
+ _decode_matmul_pattern("cutlass.decode_matmul_cast_bias_gelu"),
]
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index 02f15ad3d7..30286a597a 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1412,6 +1412,38 @@ def test_fp16A_int4B_gemm():
R.output(lv2_1)
return lv2_1
+ @R.function
+ def main_cast_bias(
+ x: R.Tensor((64, 64), dtype="float16"),
+ y: R.Tensor((128, 64), dtype="float16"),
+ bias: R.Tensor((1, 128), dtype="float16"),
+ ) -> R.Tensor((64, 128), dtype="float16"):
+ R.func_attr({"num_input": 1})
+ cls = Module
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.encode,
+ (y,),
+ out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")],
+ )
+ lv1 = lv[0]
+ lv2 = R.call_pure_packed(
+ "cutlass.ft_preprocess_weight",
+ lv1,
+ 80,
+ True,
+ sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+ )
+ lv3: R.Tensor((128,), dtype="float16") = lv[1]
+ lv6 = R.call_tir(
+ cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), dtype="float16")
+ )
+ lv1_1: R.Tensor((64, 128), dtype="float32") = R.matmul(x, lv6, out_dtype="float32")
+ cast: R.Tensor((64, 128), dtype="float16") = R.astype(lv1_1, dtype="float16")
+ lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(cast, bias)
+ R.output(lv2_1)
+ return lv2_1
+
@R.function
def main_residual(
x: R.Tensor((64, 64), dtype="float16"),
@@ -1452,10 +1484,11 @@ def test_fp16A_int4B_gemm():
func_names = [name.name_hint for (name, _) in mod.functions.items()]
assert "fused_decode_relax_matmul_relax_add_cutlass" in func_names
assert "fused_decode_relax_matmul_relax_add_relax_add_cutlass" in func_names
+ assert "fused_decode_relax_matmul_relax_astype_relax_add_cutlass" in func_names
mod = relax.transform.RunCodegen(
{"cutlass": {"sm": 80, "find_first_valid": False}},
- entry_functions=["main_bias", "main_residual"],
+ entry_functions=["main_bias", "main_residual", "main_cast_bias"],
)(mod)
x = np.random.randn(*x_shape).astype("float16")
@@ -1483,13 +1516,15 @@ def test_fp16A_int4B_gemm():
residual_nd = tvm.nd.array(residual, dev)
params = (packed_weight.copyto(dev), scales.copyto(dev), bias_trans.copyto(dev))
- for with_residual in [False, True]:
+ for f_name in ["main_bias", "main_cast_bias", "main_residual"]:
+ with_residual = "residual" in f_name
+
if with_residual:
inp = [x_nd, residual_nd, params]
else:
inp = [x_nd, params]
- out = vm["main_residual" if with_residual else "main_bias"](*inp).numpy()
+ out = vm[f_name](*inp).numpy()
ref = np.dot(x, y.transpose()) + bias