You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/03/07 10:46:21 UTC

[tvm] branch unity updated: [Unity][BYOC] Add dynamic shape support to CUTLASS matmul (#14216)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1bece4bd28 [Unity][BYOC] Add dynamic shape support to CUTLASS matmul (#14216)
1bece4bd28 is described below

commit 1bece4bd282a1c9a63b6d6afb8f74d5dc199ef08
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Tue Mar 7 05:46:06 2023 -0500

    [Unity][BYOC] Add dynamic shape support to CUTLASS matmul (#14216)
    
    Add symbolic shape support for matmul and batch matmul in cutlass BYOC
---
 python/tvm/contrib/cutlass/build.py         |  35 +++++--
 python/tvm/contrib/cutlass/gen_tensor_op.py |  32 ++++--
 python/tvm/relax/backend/contrib/cutlass.py |  20 ++--
 tests/python/relax/test_codegen_cutlass.py  | 152 ++++++++++++++++++----------
 4 files changed, 156 insertions(+), 83 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py
index 0e8d419bae..95d9363bc6 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -22,7 +22,7 @@ import multiprocessing
 import operator
 import os
 from functools import reduce
-from typing import Optional, Tuple
+from typing import Optional, Sequence
 
 import tvm
 from tvm import relax, relay, runtime
@@ -549,7 +549,10 @@ def _extract_relax_function_signature(f):
         signature["arg%d_dtype" % i] = sinfo.dtype
 
     ret_sinfo = f.ret_struct_info
-    signature["ret_shape"] = list(ret_sinfo.shape)
+    if ret_sinfo.shape is not None:
+        signature["ret_shape"] = list(ret_sinfo.shape)
+    else:
+        signature["ret_shape"] = None
     signature["ret_dtype"] = ret_sinfo.dtype
 
     return signature
@@ -574,7 +577,10 @@ def _extract_arg_idx(pattern_name, f):
     return arg_idx
 
 
-def is_valid_for_cutlass_matmul(lhs_shape: Tuple[int], rhs_shape: Tuple[int]) -> bool:
+def is_shape_valid_for_cutlass_matmul(
+    lhs_shape: Sequence[tvm.ir.PrimExpr],
+    rhs_shape: Sequence[tvm.ir.PrimExpr],
+) -> bool:
     """
     Check whether the shape of inputs can be handled by CUTLASS GEMM.
 
@@ -584,19 +590,26 @@ def is_valid_for_cutlass_matmul(lhs_shape: Tuple[int], rhs_shape: Tuple[int]) ->
     as well as ND x 2D and 2D x ND. For example, it cannot handle matmul with shape
     (2, 1, 4, 8) x (2, 3, 8, 16), because the batch stride of lhs is not constant.
     """
+    if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)):
+        # Reduction axis must be constant
+        return False
+
     lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
     rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
     if lhs_batches == 1 or rhs_batches == 1:
         # This could be regular matmul or batch matmul with shape ND x 2D or 2D x ND
         return True
 
+    analyzer = tvm.arith.Analyzer()
     # If one side has less dimensions, use 1 to fill the gap
-    batch_dim_pairs = itertools.zip_longest(
-        lhs_shape[-3::-1],  # Remove the last two dimensions and reverse
-        rhs_shape[-3::-1],
-        fillvalue=1,
+    batch_dim_pairs = list(
+        itertools.zip_longest(
+            list(lhs_shape)[-3::-1],  # Remove the last two dimensions and reverse
+            list(rhs_shape)[-3::-1],
+            fillvalue=1,
+        )
     )
-    return all(p[0] == p[1] for p in batch_dim_pairs)
+    return all(analyzer.can_prove_equal(p[0], p[1]) for p in batch_dim_pairs)
 
 
 @relax.expr_functor.mutator
@@ -689,7 +702,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         rhs_dtype = signature[f"{rhs_arg}_dtype"]
         out_dtype = signature["ret_dtype"]
 
-        if not is_valid_for_cutlass_matmul(lhs_shape, rhs_shape):
+        if not is_shape_valid_for_cutlass_matmul(lhs_shape, rhs_shape):
             raise ValueError(f"Cannot handle the input shapes, lhs: {lhs_shape}, rhs: {rhs_shape}")
 
         MM = lhs_shape[-2]
@@ -712,7 +725,9 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         else:
             is_batched = True
             batch_attrs = {
-                "batch": max(lhs_batches, rhs_batches),
+                # If both lhs_batches and rhs_batches are greater than 1,
+                # they must be equal. This is checked by is_shape_valid_for_cutlass_matmul.
+                "batch": lhs_batches if rhs_batches == 1 else rhs_batches,
                 "batch_stride_A": 0 if lhs_batches == 1 else MM * KK,
                 "batch_stride_B": 0 if rhs_batches == 1 else KK * NN,
                 "batch_stride_C": MM * NN,
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 78e2b489c6..62c06cabd1 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -561,20 +561,34 @@ def instantiate_template(func_name, annotations, func_args):
 
         if batched:
             headers.append("cutlass/gemm/device/gemm_batched.h")
-            attrs["batch"] = get_dim(annotations["batch"], lhs_arg, 0)
+
+            def get_batch_on_arg(arg_name, arg_shape):
+                return " * ".join(
+                    "{}->shape[{}]".format(arg_name, i) for i in range(len(arg_shape) - 2)
+                )
+
+            if isinstance(annotations["batch"], IntImm):
+                attrs["batch"] = str(int(annotations["batch"]))
+            elif annotations["batch_stride_A"] == 0:
+                # 2D x ND
+                attrs["batch"] = get_batch_on_arg(rhs_arg, rhs_shape)
+            else:
+                # ND x 2D or ND x ND
+                attrs["batch"] = get_batch_on_arg(lhs_arg, lhs_shape)
+
             attrs["batch_stride_A"] = get_batch_stride(
                 annotations["batch_stride_A"],
                 lhs_arg_idx,
                 lhs_arg_idx,
-                1,
-                2,
+                lhs_batched_offset,
+                lhs_batched_offset + 1,
             )
             attrs["batch_stride_B"] = get_batch_stride(
                 annotations["batch_stride_B"],
                 rhs_arg_idx,
                 rhs_arg_idx,
-                1,
-                2,
+                rhs_batched_offset,
+                rhs_batched_offset + 1,
             )
 
             if transposed:
@@ -582,16 +596,16 @@ def instantiate_template(func_name, annotations, func_args):
                     annotations["batch_stride_C"],
                     lhs_arg_idx,
                     rhs_arg_idx,
-                    1,
-                    1,
+                    lhs_batched_offset,
+                    rhs_batched_offset,
                 )
             else:
                 attrs["batch_stride_C"] = get_batch_stride(
                     annotations["batch_stride_C"],
                     lhs_arg_idx,
                     rhs_arg_idx,
-                    1,
-                    2,
+                    lhs_batched_offset,
+                    rhs_batched_offset + 1,
                 )
         else:
             headers.append("cutlass/gemm/device/gemm.h")
diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index 19165fa832..17af2e0597 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -20,9 +20,9 @@
 from typing import Mapping, Optional, Tuple
 
 import tvm
-from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
+from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
 from tvm.relax import Call, Expr, ShapeExpr, transform
-from tvm.relax.dpl import DFPattern
+from tvm.relax.dpl import CallPattern, DFPattern
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
 from ..patterns import (
@@ -56,9 +56,10 @@ def _check_matmul(
     _: Expr,
 ) -> bool:
     matmul_call: Call = None
-    for _, expr in match_result.items():
+    for pattern, expr in match_result.items():
         if (
             isinstance(expr, Call)
+            and isinstance(pattern, CallPattern)
             and isinstance(expr.op, tvm.ir.Op)
             and expr.op.name == "relax.matmul"
         ):
@@ -66,17 +67,16 @@ def _check_matmul(
     if matmul_call is None:
         raise ValueError("Cannot find call to matmul from match_result.")
 
-    lhs_shape = _get_static_shape(matmul_call.args[0].struct_info.shape)
-    rhs_shape = _get_static_shape(matmul_call.args[1].struct_info.shape)
-    if len(lhs_shape) < 2 or len(rhs_shape) < 2:
-        return False
+    lhs, rhs, *_ = matmul_call.args
 
-    lhs_dtype = matmul_call.args[0].struct_info.dtype
-    rhs_dtype = matmul_call.args[1].struct_info.dtype
+    lhs_dtype = lhs.struct_info.dtype
+    rhs_dtype = rhs.struct_info.dtype
     if not _is_supported_dtype(lhs_dtype, rhs_dtype):
         return False
 
-    return is_valid_for_cutlass_matmul(lhs_shape, rhs_shape)
+    lhs_shape = lhs.struct_info.shape.values
+    rhs_shape = rhs.struct_info.shape.values
+    return is_shape_valid_for_cutlass_matmul(lhs_shape, rhs_shape)
 
 
 register_patterns(
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index 36a1c4cd16..f2d2da15b5 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -16,17 +16,17 @@
 # under the License.
 import numpy as np
 import pytest
-import scipy
 
 import tvm
 import tvm.testing
 import tvm.topi.testing
+from tvm import relax
+from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
 from tvm.contrib.pickle_memoize import memoize
-from tvm import relax, relay
-from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
 from tvm.relax.backend import get_patterns_with_prefix
 from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
 from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 @pytest.fixture(autouse=True)
@@ -134,14 +134,13 @@ def test_kernel_sharing():
     np.testing.assert_equal(out, ref)
 
 
-def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False, activation=None):
-    m, k = x.shape[-2:]
+def get_relax_matmul_module(
+    x_shape, y_shape, dtype, transposed_y=False, with_bias=False, activation=None
+):
     if transposed_y:
-        n = y.shape[-2]
+        n = y_shape[-2]
     else:
-        n = y.shape[-1]
-    dtype = str(x.dtype)
-    y_shape = y.shape
+        n = y_shape[-1]
 
     from tvm.script.ir_builder import IRBuilder
     from tvm.script.ir_builder import relax as relax_builder
@@ -149,8 +148,8 @@ def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False, activatio
     with IRBuilder() as builder:
         with relax_builder.function():
             R.func_name("main")
-            x = R.arg("x", R.Tensor(x.shape, dtype))
-            y = R.arg("y", R.Tensor(y.shape, dtype))
+            x = R.arg("x", R.Tensor(x_shape, dtype))
+            y = R.arg("y", R.Tensor(y_shape, dtype))
             if with_bias:
                 bias = R.arg("bias", R.Tensor((n,), dtype))
 
@@ -171,46 +170,63 @@ def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False, activatio
     return tvm.IRModule({"main": func})
 
 
+def _to_concrete_shape(symbolic_shape, var_table):
+    result = []
+    for dim in symbolic_shape:
+        if not isinstance(dim, tvm.tir.expr.Var):
+            result.append(dim)
+            continue
+
+        if dim not in var_table:
+            var_table[dim] = np.random.randint(10, 50)
+        result.append(var_table[dim])
+
+    return tuple(result)
+
+
+_vars = {
+    "a": tvm.tir.expr.Var("a", "int64"),
+    "b": tvm.tir.expr.Var("b", "int64"),
+}
+
+
+_epilogue_table = {
+    "none": (False, None),
+    "bias": (True, None),
+    "relu": (True, R.nn.relu),
+    "gelu": (True, R.nn.gelu),
+}
+
+
 @pytest.mark.parametrize(
-    "x_shape, y_shape, transpose_y",
+    "x_shape, y_shape, transpose_y, epilogue",
     [
         # Regular
-        ((32, 6), (6, 16), False),
+        ((32, 6), (6, 16), False, "none"),
+        ((_vars["a"], 6), (6, 16), False, "bias"),
         # Transposed
-        ((4, 16), (16, 128), True),
-        ((35, 8), (8, 8), True),
+        ((4, 16), (16, 128), True, "relu"),
+        ((35, 8), (8, 8), True, "gelu"),
         # 3D x 3D
-        ((6, 32, 8), (6, 8, 10), False),
-        ((6, 32, 8), (6, 8, 10), True),
+        ((6, 32, 8), (6, 8, 10), False, "bias"),
+        ((6, 32, 8), (6, 8, 10), True, "none"),
+        ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"),
         # 3D x 2D
-        ((6, 32, 8), (8, 10), False),
-        ((10, 16, 8), (8, 10), True),
+        ((6, 32, 8), (8, 10), False, "none"),
+        ((_vars["a"], 32, 8), (8, 10), False, "bias"),
+        ((10, 16, 8), (8, 10), True, "relu"),
         # 2D x 3D
-        ((32, 8), (10, 8, 10), False),
-        ((32, 8), (10, 8, 10), True),
+        ((32, 8), (10, 8, 10), False, "relu"),
+        ((32, 8), (_vars["a"], 8, 10), True, "gelu"),
         # ND x 2D
-        ((3, 6, 32, 8), (8, 10), False),
+        ((3, 6, 32, 8), (8, 10), False, "bias"),
+        ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none"),
         # 2D x ND
-        ((32, 8), (5, 3, 8, 10), False),
+        ((32, 8), (5, 3, 8, 10), False, "gelu"),
         # ND x ND
-        ((5, 3, 32, 8), (5, 3, 8, 10), True),
-        ((3, 2, 4, 16, 15), (1, 1, 15, 2), True),
-        ((1, 1, 16, 15), (3, 2, 4, 15, 2), False),
-    ],
-)
-@pytest.mark.parametrize(
-    "with_bias, activation",
-    [
-        (True, None),
-        (False, None),
-        (True, R.nn.relu),
-        (True, R.nn.gelu),
-    ],
-    ids=[
-        "no_bias",
-        "biased",
-        "biased_relu",
-        "biased_gelu",
+        ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"),
+        ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu"),
+        ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none"),
     ],
 )
 @pytest.mark.parametrize(
@@ -223,25 +239,34 @@ def test_matmul_offload(
     x_shape,
     y_shape,
     transpose_y,
-    with_bias,
-    activation,
+    epilogue,
     dtype,
 ):
-    x = np.random.randn(*x_shape).astype(dtype)
-    y = np.random.randn(*y_shape).astype(dtype)
+    with_bias, activation = _epilogue_table[epilogue]
+    var_table = {}
+    concrete_x_shape = _to_concrete_shape(x_shape, var_table)
+    concrete_y_shape = _to_concrete_shape(y_shape, var_table)
+    x = np.random.randn(*concrete_x_shape).astype(dtype)
+    y = np.random.randn(*concrete_y_shape).astype(dtype)
 
     if transpose_y:
         y = np.swapaxes(y, -2, -1)
+        y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
 
     if with_bias:
-        bias = np.random.randn(y_shape[-1]).astype(dtype)
+        bias = np.random.randn(concrete_y_shape[-1]).astype(dtype)
         args = (x, y, bias)
     else:
         bias = None
         args = (x, y)
 
     mod = get_relax_matmul_module(
-        x, y, with_bias=with_bias, transposed_y=transpose_y, activation=activation
+        x_shape,
+        y_shape,
+        dtype,
+        with_bias=with_bias,
+        transposed_y=transpose_y,
+        activation=activation,
     )
     out = get_result_with_relax_cutlass_offload(mod, *args)
     ref = build_and_run(mod, args, "llvm", legalize=True)
@@ -256,15 +281,24 @@ def test_matmul_offload(
         ((3, 4), (4, 5), True),
         # Batch matmul without stretching
         ((3, 16, 15), (3, 15, 2), True),
+        ((_vars["a"], 16, 15), (_vars["a"], 15, 2), True),
         # Broadcast 2D to 3D
         ((3, 16, 15), (15, 2), True),
+        ((_vars["a"], 16, 15), (15, 2), True),
         ((16, 15), (3, 15, 2), True),
         # Broadcast one-length dimension
         ((1, 16, 15), (3, 15, 2), True),
         ((3, 16, 15), (1, 15, 2), True),
         ((1, 1, 16, 15), (3, 2, 4, 15, 2), True),
+        ((1, 1, 16, 15), (3, _vars["a"], 4, 15, 2), True),
         # ND x ND
         ((3, 2, 4, 16, 15), (3, 2, 4, 15, 2), True),
+        ((_vars["a"], 2, 4, 16, 15), (_vars["a"], 2, 4, 15, 2), True),
+        (
+            (_vars["a"], _vars["b"], 4, 16, 15),
+            (_vars["a"], _vars["b"], 4, 15, 2),
+            True,
+        ),
         # ND x ND with one-length dimension
         ((1, 2, 4, 16, 15), (1, 2, 4, 15, 2), True),
         ((3, 2, 1, 16, 15), (3, 2, 1, 15, 2), True),
@@ -275,10 +309,16 @@ def test_matmul_offload(
         ((3, 2, 4, 16, 15), (2, 4, 15, 2), False),
         # Different shape
         ((3, 4, 16, 15), (3, 2, 15, 2), False),
+        ((3, _vars["a"], 16, 15), (3, _vars["b"], 15, 2), False),
+        # Cannot prove that broadcast dimensions are equal
+        ((_vars["a"], 16, 15), (3, 15, 2), False),
+        ((3, _vars["a"], 1, 16, 15), (1, 1, 3, 2, 1, 15, 2), False),
+        # Reduction axis must be constant
+        ((3, _vars["a"]), (_vars["a"], 5), False),
     ],
 )
-def test_is_valid_for_cutlass_matmul(x_shape, y_shape, expected):
-    assert is_valid_for_cutlass_matmul(x_shape, y_shape) == expected
+def test_is_shape_valid_for_cutlass_matmul(x_shape, y_shape, expected):
+    assert is_shape_valid_for_cutlass_matmul(x_shape, y_shape) == expected
 
 
 @pytest.mark.parametrize(
@@ -286,20 +326,24 @@ def test_is_valid_for_cutlass_matmul(x_shape, y_shape, expected):
     [
         # Not broadcasting all dims. Cannot be computed by stride-based batch gemm
         ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False, "float16"),
+        ((3, 2, _vars["a"], 16, 15), (3, 2, 4, 15, 2), False, "float16"),
         ((1, 2, 1, 16, 15), (2, 1, 4, 15, 2), False, "float16"),
         ((3, 2, 4, 16, 15), (2, 4, 15, 2), True, "float16"),
         ((3, 16, 15), (2, 1, 3, 15, 2), True, "float16"),
+        ((3, 16, 15), (_vars["a"], 1, 3, 15, 2), True, "float16"),
+        ((_vars["a"], 1, 3, 16, 15), (_vars["b"], 1, 3, 15, 2), True, "float16"),
+        ((_vars["a"], _vars["b"], 3, 16, 15), (_vars["a"], 1, 3, 15, 2), True, "float16"),
     ],
 )
 def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype):
-    x = np.random.randn(*x_shape).astype(dtype)
-    y = np.random.randn(*y_shape).astype(dtype)
     if transpose_y:
-        y = np.swapaxes(y, -2, -1)
+        y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
 
-    mod = get_relax_matmul_module(x, y, with_bias=False, transposed_y=transpose_y)
+    mod = get_relax_matmul_module(
+        x_shape, y_shape, dtype, with_bias=False, transposed_y=transpose_y
+    )
 
-    tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod))
+    assert len(mod.functions) == 1
 
 
 @pytest.fixture(params=["float16", "float32"])