You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/03/17 23:07:27 UTC
[tvm] branch unity updated: [Unity][BYOC] Improve expressiveness of the pattern check function in FuseOpsByPattern (#14310)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5513095b5a [Unity][BYOC] Improve expressiveness of the pattern check function in FuseOpsByPattern (#14310)
5513095b5a is described below
commit 5513095b5a28481ea4aea4979177b16bdc78dac6
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Fri Mar 17 19:07:19 2023 -0400
[Unity][BYOC] Improve expressiveness of the pattern check function in FuseOpsByPattern (#14310)
* Change the input of FuseOpsByPattern and add check for result dependency in cutlass conv2d residual block
* Rename FuseOpsPattern to FusionPattern and PatternCheckFunctionInput to PatternCheckContext
---
include/tvm/relax/transform.h | 103 ++++++++++++++++--
python/tvm/contrib/cutlass/build.py | 8 +-
python/tvm/relax/backend/contrib/cutlass.py | 80 +++++++-------
python/tvm/relax/backend/pattern_registry.py | 75 ++-----------
python/tvm/relax/backend/patterns.py | 49 +++++----
python/tvm/relax/transform/transform.py | 118 ++++++++++++++++-----
src/relax/backend/pattern_registry.cc | 39 ++-----
src/relax/backend/pattern_registry.h | 59 +----------
src/relax/transform/fuse_ops.cc | 87 ++++++++++++---
tests/python/relax/test_codegen_cutlass.py | 39 ++++++-
.../relax/test_transform_fuse_ops_by_pattern.py | 30 ++++--
11 files changed, 414 insertions(+), 273 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 3ff863dd09..e0fe226e83 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -215,17 +215,108 @@ TVM_DLL Pass AnnotateTIROpPattern();
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
+/*!
+ * \brief The pattern object used as the input of FuseOpsByPattern. For bindings to be
+ * fused, it needs to be matched with `pattern` and the `check` function needs to return
+ * true.
+ */
+class FusionPatternNode : public Object {
+ public:
+ /*!
+ * \brief The name of pattern. It becomes the value of the kComposite attribute
+ * of a fused function after successful matching
+ */
+ String name;
+
+ /*!
+ * \brief The dataflow pattern that will be used to match expression in the DataflowBlock.
+ * All the call nodes covered by the pattern will be extracted into the fused function.
+ */
+ DFPattern pattern;
+
+ /*!
+ * \brief The map which is used to extract important expressions from the pattern match
+ * result. All DFPattern in this map should be part of the `pattern`.
+ */
+ Map<String, DFPattern> annotation_patterns;
+
+ /*!
+ * \brief The function to determine whether the match result is accepted. This can be
+ * NullOpt if check function is not necessary for this pattern.
+ *
+ * It should have signature
+ * bool(const PatternCheckContext& context)
+ */
+ Optional<PackedFunc> check;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("pattern", &pattern);
+ v->Visit("annotation_patterns", &annotation_patterns);
+ v->Visit("check", &check);
+ }
+
+ static constexpr const char* _type_key = "relax.transform.FusionPattern";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object);
+};
+
+class FusionPattern : public ObjectRef {
+ public:
+ FusionPattern(String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns,
+ Optional<PackedFunc> check);
+
+ FusionPattern(String name, DFPattern pattern) : FusionPattern(name, pattern, {}, NullOpt) {}
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode);
+};
+
+/*!
+ * \brief The input of FusionPattern::check.
+ */
+class PatternCheckContextNode : public Object {
+ public:
+ /*!
+ * \brief A map which contains all expressions matched by the sub patterns in
+ * FusionPattern::annotation_patterns.
+ */
+ Map<String, Expr> annotated_expr;
+
+ /*!
+ * \brief A map mapping variable definitions to a set of uses.
+ */
+ Map<Var, Array<Var>> var_usages;
+
+ /*!
+ * \brief Map from value to its bound variable.
+ */
+ Map<Expr, Var> value_to_bound_var;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("annotated_expr", &annotated_expr);
+ v->Visit("var_usages", &var_usages);
+ v->Visit("value_to_bound_var", &value_to_bound_var);
+ }
+
+ static constexpr const char* _type_key = "relax.transform.PatternCheckContext";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object);
+};
+
+class PatternCheckContext : public ObjectRef {
+ public:
+ PatternCheckContext(Map<String, Expr> annotated_expr, Map<Var, Array<Var>> var_usages,
+ Map<Expr, Var> value_to_bound_var);
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef,
+ PatternCheckContextNode);
+};
+
/*!
* \brief Apply pattern matching to each function in the given module, and group matched
* expressions into a new function. The end result is similar to FuseOps, but fusion is driven
* completely by the provided patterns.
*
- * \param pattern_names The name of each pattern. It becomes the value of the kComposite attribute
- * of a fused function after successful matching.
* \param patterns The patterns to detect. The order of the patterns determines the order
* of priority in which they are matched. Higher-priority patterns should come earlier in the list.
- * \param checks The callback functions with type (Map<DFPattern, Expr>, Expr) -> bool. It takes a
- * match result and returns a boolean value to indicate whether the match result is accepted.
* \param bind_constants Whether or not to keep bound constants of the grouped function.
* \param annotate_codegen If true, wrap each created composite function with another function,
* whose body consists only of a call to the composite function, and annotate the outer function
@@ -235,9 +326,7 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
* an external backend without using the MergeCompositeFunctions pass.
* \return The Pass.
*/
-TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
- const tvm::Array<DFPattern>& patterns,
- const tvm::Array<PackedFunc>& checks, bool bind_constants = true,
+TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
bool annotate_codegen = false);
/*!
diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py
index 7e92e6a887..47bdcaa790 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -568,11 +568,11 @@ def _extract_arg_idx(pattern_name, f):
func_args = list(f.params)
arg_idx = {}
- for arg_name, arg_pattern in pattern_entry.arg_patterns.items():
- arg_expr = matched_expr[arg_pattern]
+ for name, annotation_pattern in pattern_entry.annotation_patterns.items():
+ arg_expr = matched_expr[annotation_pattern]
if arg_expr not in func_args:
- raise ValueError(f"Cannot find arg {arg_name} in the fused function parameters")
- arg_idx[arg_name] = func_args.index(arg_expr)
+ continue
+ arg_idx[name] = func_args.index(arg_expr)
return arg_idx
diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index e1b9226d68..4d539928cf 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -17,12 +17,12 @@
"""Pattern table for CUTLASS backend"""
-from typing import Mapping, Optional, Tuple
+from typing import Mapping, Optional, Sequence, Tuple
import tvm
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import Call, Expr, ShapeExpr, transform
-from tvm.relax.dpl import CallPattern, DFPattern
+from tvm.relax import ShapeExpr, Var, transform
+from tvm.relax.transform import PatternCheckContext
from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import (
@@ -52,33 +52,27 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype):
)
-def _find_call(op_name: str, match_result: Mapping[DFPattern, Expr]) -> Optional[Expr]:
- result = None
+def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var, Sequence[Var]]):
+ if from_var == to_var:
+ return True
- for pattern, expr in match_result.items():
- if (
- isinstance(expr, Call)
- and isinstance(pattern, CallPattern)
- and isinstance(expr.op, tvm.ir.Op)
- and expr.op.name == op_name
- ):
- if result is not None:
- raise ValueError(f"Found multiple matched call node for {op_name}")
- result = expr
+ checked = set()
+ vars_to_check = [to_var]
+ while vars_to_check:
+ current_var = vars_to_check.pop()
+ for user in var_usages.get(current_var, []):
+ if user == from_var:
+ return True
+ if user not in checked:
+ checked.add(user)
+ vars_to_check.append(user)
- return result
+ return False
-def _check_conv2d(
- match_result: Mapping[DFPattern, Expr],
- _: Expr,
-):
+def _check_conv2d(context: PatternCheckContext) -> bool:
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
-
- conv2d_call = _find_call("relax.nn.conv2d", match_result)
- if conv2d_call is None:
- return False
-
+ conv2d_call = context.annotated_expr["root"]
data_layout = conv2d_call.attrs.data_layout
kernel_layout = conv2d_call.attrs.kernel_layout
data, weight, *_ = conv2d_call.args
@@ -89,6 +83,15 @@ def _check_conv2d(
):
return False
+ if "residual" in context.annotated_expr:
+ residual = context.annotated_expr["residual"]
+ if not isinstance(residual, Var):
+ residual = context.value_to_bound_var[residual]
+ conv2d_var = context.value_to_bound_var[conv2d_call]
+ if _has_dependency(from_var=residual, to_var=conv2d_var, var_usages=context.var_usages):
+ # If residual depends on the result of conv2d, this cannot be handled by cutlass.
+ return False
+
# pylint: disable=invalid-name
IC = data.struct_info.shape.values[3]
OC = weight.struct_info.shape.values[0]
@@ -96,17 +99,10 @@ def _check_conv2d(
return not IC == OC == conv2d_call.attrs.groups
-def _check_matmul(
- match_result: Mapping[DFPattern, Expr],
- _: Expr,
-) -> bool:
+def _check_matmul(context: PatternCheckContext) -> bool:
"""Check if the given matmul workload can be offloaded to CUTLASS."""
-
- matmul_call: Call = _find_call("relax.matmul", match_result)
- if matmul_call is None:
- return False
-
- lhs, rhs, *_ = matmul_call.args
+ lhs = context.annotated_expr["lhs"]
+ rhs = context.annotated_expr["rhs"]
lhs_dtype = lhs.struct_info.dtype
rhs_dtype = rhs.struct_info.dtype
@@ -244,7 +240,7 @@ register_patterns(
)
-def partition_for_cutlass(mod):
+def partition_for_cutlass(mod, annotate_codegen=True):
"""
Partition the input module into CUTLASS-supported subgraphs.
@@ -253,6 +249,11 @@ def partition_for_cutlass(mod):
mod: tvm.IRModule
The IRModule to be partitioned.
+ annotate_codegen: bool
+ Whether to wrap each created composite function with another function, whose
+ body consists only of a call to the composite function. See the doc of FuseOpsByPattern
+ for more detail.
+
Returns
-------
mod: tvm.IRModule
@@ -260,6 +261,7 @@ def partition_for_cutlass(mod):
compiled by the CUTLASS backend.
"""
- cutlass_pattern_entries = get_patterns_with_prefix("cutlass")
- patterns = [(e.name, e.pattern, e.check) for e in cutlass_pattern_entries]
- return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod)
+ patterns = get_patterns_with_prefix("cutlass")
+ return transform.FuseOpsByPattern(
+ patterns, bind_constants=False, annotate_codegen=annotate_codegen
+ )(mod)
diff --git a/python/tvm/relax/backend/pattern_registry.py b/python/tvm/relax/backend/pattern_registry.py
index 5a35eba03d..5ec57164eb 100644
--- a/python/tvm/relax/backend/pattern_registry.py
+++ b/python/tvm/relax/backend/pattern_registry.py
@@ -20,55 +20,12 @@
import atexit
from typing import Callable, List, Mapping, Optional, Set, Tuple, Union
-import tvm
from tvm.relax.dpl import DFPattern
-from tvm.runtime import Object
+from tvm.relax.transform import FusionPattern
from ..expr import Expr
from . import _ffi_api
-
-@tvm._ffi.register_object("relax.backend.PatternRegistryEntry")
-class PatternRegistryEntry(Object):
- """
- An entry in the pattern registry. This represents a single pattern that
- can be used to identify expressions that can be handled by external
- backends, like CUTLASS and TensorRT.
-
- Parameters
- ----------
- name: str
- The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'.
-
- pattern: DFPattern
- The dataflow pattern that will be used to match expressions that can be handled
- by external backends.
-
- arg_patterns: Mapping[str, DFPattern]
- The mapping from arg name to its pattern. It can be used to extract arg expression
- from match result. All DFPattern in this map should be part of the `pattern`.
-
- check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
- The function to check whether the match result is accepted.
- """
-
- name: str
- pattern: DFPattern
- arg_patterns: Mapping[str, DFPattern]
- check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
-
- def __init__(
- self,
- name: str,
- pattern: DFPattern,
- arg_patterns: Mapping[str, DFPattern],
- check: Callable[[Mapping[DFPattern, Expr], Expr], bool],
- ):
- self.__init_handle_by_constructor__(
- _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check # type: ignore
- )
-
-
_REGISTERED_PATTERN_NAMES: Set[str] = set()
@@ -96,7 +53,7 @@ def _ensure_cleanup_function_registered():
CheckFunc = Callable[[Mapping[DFPattern, Expr], Expr], bool]
Pattern = Union[
- PatternRegistryEntry,
+ FusionPattern,
Tuple[str, DFPattern],
Tuple[str, DFPattern, Mapping[str, DFPattern]],
Tuple[str, DFPattern, Mapping[str, DFPattern], CheckFunc],
@@ -118,29 +75,17 @@ def register_patterns(patterns: List[Pattern]):
entries = []
for item in patterns:
- if isinstance(item, PatternRegistryEntry):
+ if isinstance(item, FusionPattern):
entries.append(item)
elif isinstance(item, tuple):
- name, pattern, *rest = item
-
- if len(rest) > 0:
- arg_patterns = rest[0]
- else:
- arg_patterns = {}
-
- if len(rest) > 1:
- check = rest[1]
- else:
- check = lambda *_: True
-
- entries.append(PatternRegistryEntry(name, pattern, arg_patterns, check))
- _REGISTERED_PATTERN_NAMES.add(name)
+ entries.append(FusionPattern(*item))
+ _REGISTERED_PATTERN_NAMES.add(item[0])
else:
- raise TypeError(f"Cannot register type {type(pattern)} as pattern")
+ raise TypeError(f"Cannot register type {type(item)} as pattern")
_ffi_api.RegisterPatterns(entries)
-def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]:
+def get_patterns_with_prefix(prefix: str) -> List[FusionPattern]:
"""
Get a list of patterns whose names startwith `prefix`.
@@ -151,13 +96,13 @@ def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]:
Returns
-------
- patterns: PatternRegistryEntry
+ patterns: FusionPattern
Matched patterns, ordered by priority from high to low.
"""
return _ffi_api.GetPatternsWithPrefix(prefix)
-def get_pattern(name: str) -> Optional[PatternRegistryEntry]:
+def get_pattern(name: str) -> Optional[FusionPattern]:
"""
Find the pattern with a particular name.
@@ -168,7 +113,7 @@ def get_pattern(name: str) -> Optional[PatternRegistryEntry]:
Returns
-------
- pattern: Optional[PatternRegistryEntry]
+ pattern: Optional[FusionPattern]
The matched pattern. Returns None if such pattern is not found.
"""
return _ffi_api.GetPattern(name)
diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py
index d770cc6faf..e27b91b3ea 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -24,18 +24,18 @@ from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard
def _with_bias_activation_pattern(
out: DFPattern,
- args: Dict[str, DFPattern],
+ annotations: Dict[str, DFPattern],
with_bias: bool = False,
activation: str = None,
) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
if with_bias:
- args["bias"] = bias = wildcard()
+ annotations["bias"] = bias = wildcard()
out = is_op("relax.add")(out, bias)
if activation:
out = is_op(activation)(out)
- return out, args
+ return out, annotations
def make_fused_bias_activation_pattern(
@@ -62,16 +62,17 @@ def make_fused_bias_activation_pattern(
pattern: DFPattern
The resulting pattern describing a fused operation
- args: Mapping[str, DFPattern]
- The mapping from arg name to its pattern. It can be used to extract
- arg expression from match result.
+ annotations: Mapping[str, DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
"""
lhs = wildcard()
rhs = wildcard()
- args = {"lhs": lhs, "rhs": rhs}
out = is_op(op_name)(lhs, rhs)
+ annotations = {"lhs": lhs, "rhs": rhs, "root": out}
- return _with_bias_activation_pattern(out, args, with_bias, activation)
+ return _with_bias_activation_pattern(out, annotations, with_bias, activation)
def make_residual_block_pattern(
@@ -99,9 +100,10 @@ def make_residual_block_pattern(
pattern: DFPattern
The resulting pattern describing a matrix multiplication.
- args: Mapping[str, DFPattern]
- The mapping from arg name to its pattern. It can be used to extract
- arg expression from match result.
+ annotations: Mapping[str, DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
"""
if isinstance(node_output, tuple):
@@ -143,21 +145,23 @@ def make_matmul_pattern(
pattern: DFPattern
The resulting pattern describing a matrix multiplication.
- args: Mapping[str, DFPattern]
- The mapping from arg name to its pattern. It can be used to extract
- arg expression from match result.
+ annotations: Mapping[str, DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
"""
lhs = wildcard()
rhs = wildcard()
- args = {"lhs": lhs, "rhs": rhs}
+ annotations = {"lhs": lhs, "rhs": rhs}
if transposed_rhs:
rhs = is_op("relax.permute_dims")(rhs)
out = is_op("relax.matmul")(lhs, rhs)
+ annotations["root"] = out
- return _with_bias_activation_pattern(out, args, with_bias, activation)
+ return _with_bias_activation_pattern(out, annotations, with_bias, activation)
def make_attention_pattern(with_bias: bool = False):
@@ -169,19 +173,20 @@ def make_attention_pattern(with_bias: bool = False):
pattern: DFPattern
The resulting pattern describing a fused multi head attention.
- args: Mapping[str, DFPattern]
- The mapping from arg name to its pattern. It can be used to extract
- arg expression from match result.
+ annotations: Mapping[str, DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
+ important expressions from match result, to power the partition
+ check function and codegen.
"""
query = wildcard()
key = wildcard()
value = wildcard()
- args = {"query": query, "key": key, "value": value}
+ annotations = {"query": query, "key": key, "value": value}
if with_bias:
bias = wildcard()
- args["bias"] = bias
+ annotations["bias"] = bias
out = is_op("relax.nn.attention_bias")(query, key, value, bias)
else:
out = is_op("relax.nn.attention")(query, key, value)
- return out, args
+ return out, annotations
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index c59104ca58..0df29dc093 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -19,11 +19,16 @@
import functools
import inspect
import types
-from typing import Callable, Dict, Union, Optional, List, Tuple
-from tvm.tir import PrimFunc, IndexMap
+from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
+
import numpy as np # type: ignore
+
import tvm.ir
-from tvm.runtime import NDArray
+from tvm.relax import Expr, Var
+from tvm.relax.dpl import DFPattern
+from tvm.runtime import NDArray, Object
+from tvm.tir import IndexMap, PrimFunc
+
from . import _ffi_api
from .legalize_ops.common import LegalizeFunc
@@ -283,8 +288,75 @@ def FuseTIR() -> tvm.ir.transform.Pass:
return _ffi_api.FuseTIR() # type: ignore
+@tvm._ffi.register_object("relax.transform.PatternCheckContext")
+class PatternCheckContext(Object):
+ """
+ The input of check function `FusionPattern.check`.
+
+ Parameters
+ ----------
+ annotated_expr: Mapping[str, Expr]
+ A map which contains all expressions matched by the sub patterns in
+ FusionPattern.annotation_patterns.
+
+ var_usages: Mapping[Var, Sequence[Var]]
+ A map mapping variable definitions to a set of uses.
+
+ value_to_bound_var: Mapping[Expr, Var]
+ Map from value to its bound variable.
+ """
+
+ annotated_expr: Mapping[str, Expr]
+ var_usages: Mapping[Var, Sequence[Var]]
+ value_to_bound_var: Mapping[Expr, Var]
+
+
+@tvm._ffi.register_object("relax.transform.FusionPattern")
+class FusionPattern(Object):
+ """
+ The pattern used by `FuseOpsByPattern`. It's mainly DFPattern but with other
+ information to help during the fusion pass.
+
+ Parameters
+ ----------
+ name: str
+ The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'.
+
+ pattern: DFPattern
+ The dataflow pattern that will be used to match expressions that can be handled
+ by external backends.
+
+ annotation_patterns: Mapping[str, DFPattern]
+ The map which is used to extract important expressions from the pattern match
+ result. All DFPattern in this map should be part of the `pattern`.
+
+ check: Callable[[PatternCheckContext], bool]
+ The function to check whether the match result is accepted.
+ """
+
+ name: str
+ pattern: DFPattern
+ annotation_patterns: Mapping[str, DFPattern]
+ check: Callable[[PatternCheckContext], bool]
+
+ def __init__(
+ self,
+ name: str,
+ pattern: DFPattern,
+ annotation_patterns: Optional[Mapping[str, DFPattern]] = None,
+ check: Optional[Callable[[Mapping[str, Expr]], bool]] = None,
+ ):
+ if annotation_patterns is None:
+ annotation_patterns = {}
+ self.__init_handle_by_constructor__(
+ _ffi_api.FusionPattern, name, pattern, annotation_patterns, check # type: ignore
+ )
+
+
def FuseOpsByPattern(
- patterns: List[Tuple], bind_constants: bool = True, annotate_codegen: bool = False
+ patterns: List[Union[FusionPattern, Tuple]],
+ bind_constants: bool = True,
+ annotate_codegen: bool = False,
) -> tvm.ir.transform.Pass:
"""Apply pattern matching to each function in the given module, and group matched expressions
into a new function.
@@ -293,15 +365,12 @@ def FuseOpsByPattern(
Parameters
----------
- patterns : List[Union[Tuple[str, DFPattern], Tuple[str, DFPattern, Callable]]]
- A list of tuple of (name, pattern) or (name, pattern, predicate) to be matched.
- The predicate is a function with type (Map<DFPattern, Expr>, Expr) -> bool. It takes a
- match result and returns a boolean value to indicate whether the match result is accepted.
+ patterns : List[Union[FusionPattern, Tuple]]
+ A list of patterns to be matched. The order of the patterns determines the order of priority
+ in which they are matched. Higher-priority patterns should come earlier in the list.
- The patterns to detect. The order of the patterns determines the order of priority in which
- they are matched. Higher-priority patterns should come earlier in the list.
- The string is the name of the corresponding pattern. It becomes the value of the kComposite
- attribute of a fused function after a successful matching.
+ In addition to FusionPattern, a tuple can be passed as item of this list. The pattern
+ will be constructed through FusionPattern(*item)
bind_constants : bool
Whether or not to keep bound constants in the grouped function.
@@ -321,22 +390,19 @@ def FuseOpsByPattern(
The registered pass for pattern-based fusion.
"""
- pattern_names = []
- df_patterns = []
- checks = []
- for tup in patterns:
- if len(tup) == 2:
- pattern_names.append(tup[0])
- df_patterns.append(tup[1])
- checks.append(lambda *_: True)
- elif len(tup) == 3:
- pattern_names.append(tup[0])
- df_patterns.append(tup[1])
- checks.append(tup[2])
+ converted_patterns = []
+ for pattern in patterns:
+ if isinstance(pattern, tuple):
+ converted_patterns.append(FusionPattern(*pattern))
+ elif isinstance(pattern, FusionPattern):
+ converted_patterns.append(pattern)
else:
- raise ValueError("Invalid pattern: {}".format(tup))
+ raise ValueError(f"Invalid pattern: {pattern}")
+
return _ffi_api.FuseOpsByPattern(
- pattern_names, df_patterns, checks, bind_constants, annotate_codegen
+ converted_patterns,
+ bind_constants,
+ annotate_codegen,
) # type: ignore
diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc
index 553018d690..34ebb4d6dd 100644
--- a/src/relax/backend/pattern_registry.cc
+++ b/src/relax/backend/pattern_registry.cc
@@ -24,25 +24,12 @@
namespace tvm {
namespace relax {
namespace backend {
-
-PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern,
- Map<String, DFPattern> arg_patterns, PackedFunc check) {
- ObjectPtr<PatternRegistryEntryNode> n = make_object<PatternRegistryEntryNode>();
- n->name = std::move(name);
- n->pattern = std::move(pattern);
- n->arg_patterns = std::move(arg_patterns);
- n->check = check;
- data_ = std::move(n);
-}
-
-TVM_REGISTER_NODE_TYPE(PatternRegistryEntryNode);
-
-static std::vector<PatternRegistryEntry>* GetRegistryTable() {
- static std::vector<PatternRegistryEntry> table;
+static std::vector<FusionPattern>* GetRegistryTable() {
+ static std::vector<FusionPattern> table;
return &table;
}
-void RegisterPatterns(Array<PatternRegistryEntry> entries) {
+void RegisterPatterns(Array<FusionPattern> entries) {
auto* table = GetRegistryTable();
for (const auto& entry : entries) {
table->push_back(entry);
@@ -53,16 +40,15 @@ void RemovePatterns(Array<String> names) {
std::unordered_set<String> name_set{names.begin(), names.end()};
auto* table = GetRegistryTable();
- table->erase(std::remove_if(table->begin(), table->end(),
- [&](const PatternRegistryEntry& entry) {
- return name_set.count(entry->name) > 0;
- }),
- table->end());
+ table->erase(
+ std::remove_if(table->begin(), table->end(),
+ [&](const FusionPattern& entry) { return name_set.count(entry->name) > 0; }),
+ table->end());
}
-Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) {
+Array<FusionPattern> GetPatternsWithPrefix(const String& prefix) {
auto* table = GetRegistryTable();
- Array<PatternRegistryEntry> result;
+ Array<FusionPattern> result;
for (auto it = table->rbegin(); it != table->rend(); ++it) {
if (support::StartsWith((*it)->name, prefix.data())) {
result.push_back(*it);
@@ -71,7 +57,7 @@ Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) {
return result;
}
-Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) {
+Optional<FusionPattern> GetPattern(const String& pattern_name) {
auto* table = GetRegistryTable();
for (auto it = table->rbegin(); it != table->rend(); ++it) {
if ((*it)->name == pattern_name) {
@@ -81,11 +67,6 @@ Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) {
return NullOpt;
}
-TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry")
- .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> arg_patterns,
- PackedFunc check) {
- return PatternRegistryEntry(name, pattern, arg_patterns, check);
- });
TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns);
TVM_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns);
TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix);
diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h
index e765f56b4e..72eea1238d 100644
--- a/src/relax/backend/pattern_registry.h
+++ b/src/relax/backend/pattern_registry.h
@@ -28,6 +28,7 @@
#include <tvm/relax/dataflow_pattern.h>
#include <tvm/relax/expr.h>
+#include <tvm/relax/transform.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/object.h>
@@ -35,57 +36,7 @@ namespace tvm {
namespace relax {
namespace backend {
-/*!
- * \brief An entry in the pattern registry. This represents a single pattern that
- * can be used to identify expressions that can be handled by external
- * backends, like CUTLASS and TensorRT.
- */
-class PatternRegistryEntryNode : public Object {
- public:
- /*!
- * \brief The name of pattern. Usually it starts with the name of backend, like
- * 'cutlass.matmul'.
- */
- String name;
- /*!
- * \brief The dataflow pattern that will be used to match expressions that can
- * be handled by external backends.
- */
- DFPattern pattern;
- /*!
- * \brief The mapping from arg name to its pattern. It can be used to extract
- * arg expression from match result. All DFPattern in this map should be part of
- * the `pattern`.
- */
- Map<String, DFPattern> arg_patterns;
-
- /*!
- * \brief The function to check whether the match result is accepted.
- *
- * It should have signature
- * bool(const Map<DFPattern, Expr>& match_result, const Expr& matched_expr)
- */
- PackedFunc check;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("pattern", &pattern);
- v->Visit("arg_patterns", &arg_patterns);
- v->Visit("check", &check);
- }
-
- static constexpr const char* _type_key = "relax.backend.PatternRegistryEntry";
- TVM_DECLARE_FINAL_OBJECT_INFO(PatternRegistryEntryNode, Object);
-};
-
-class PatternRegistryEntry : public ObjectRef {
- public:
- PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern> arg_patterns,
- PackedFunc check);
-
- TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef,
- PatternRegistryEntryNode);
-};
+using transform::FusionPattern;
/*!
* \brief Register patterns which will be used to partition the DataflowBlock
@@ -93,7 +44,7 @@ class PatternRegistryEntry : public ObjectRef {
* \param patterns Patterns to be registered. Patterns that appear later in the list have
* higher priority when partitioning DataflowBlock.
*/
-void RegisterPatterns(Array<PatternRegistryEntry> entries);
+void RegisterPatterns(Array<FusionPattern> patterns);
/*!
* \brief Remove patterns from the registry by their name.
@@ -106,14 +57,14 @@ void RemovePatterns(Array<String> names);
* \param prefx The pattern name prefix.
* \return Matched patterns, ordered by priority from high to low.
*/
-Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix);
+Array<FusionPattern> GetPatternsWithPrefix(const String& prefix);
/*!
* \brief Find the pattern with a particular name.
* \param name The pattern name.
* \return The matched pattern. NullOpt if not found.
*/
-Optional<PatternRegistryEntry> GetPattern(const String& name);
+Optional<FusionPattern> GetPattern(const String& name);
} // namespace backend
} // namespace relax
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 6d7c278d80..76f53eebc5 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -40,6 +40,7 @@
#include "../../relay/analysis/graph_partitioner.h"
#include "../../support/arena.h"
+#include "tvm/relax/expr.h"
namespace tvm {
namespace relax {
@@ -905,18 +906,30 @@ class PatternBasedPartitioner : ExprVisitor {
using Group = GraphPartitioner::Group;
using GroupMap = OperatorFusor::GroupMap;
using ExprVisitor::VisitExpr_;
- using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern, Expr>&, const Expr&)>;
+ using FCheckMatch = runtime::TypedPackedFunc<bool(const transform::PatternCheckContext&)>;
- static GroupMap Run(String pattern_name, DFPattern pattern, FCheckMatch check, Expr expr,
+ static GroupMap Run(String pattern_name, DFPattern pattern,
+ Map<String, DFPattern> annotation_patterns, FCheckMatch check, Expr expr,
support::Arena* arena) {
- PatternBasedPartitioner part(pattern_name, pattern, check, arena);
+ PatternBasedPartitioner part(pattern_name, pattern, annotation_patterns, check, arena);
part.VisitExpr(expr);
return part.group_map_;
}
- PatternBasedPartitioner(String pattern_name, DFPattern pattern, FCheckMatch check,
+ PatternBasedPartitioner(String pattern_name, DFPattern pattern,
+ Map<String, DFPattern> annotation_patterns, FCheckMatch check,
support::Arena* arena)
- : pat_name_(pattern_name), pat_(pattern), check_(check), arena_(arena) {}
+ : pat_name_(pattern_name),
+ pat_(pattern),
+ annotation_pat_(annotation_patterns),
+ check_(check),
+ arena_(arena) {}
+
+ void VisitBindingBlock_(const DataflowBlockNode* block) final {
+ current_block_use_def_ = DataflowBlockUseDef(GetRef<DataflowBlock>(block));
+ ExprVisitor::VisitBindingBlock_(block);
+ current_block_use_def_ = {};
+ }
void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make<Group>(); }
@@ -931,7 +944,9 @@ class PatternBasedPartitioner : ExprVisitor {
void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final {
VisitVarDef(binding->var);
if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call), bindings_)) {
- if (!check_(matches_opt.value(), GetRef<Call>(call))) {
+ if (check_ != nullptr &&
+ !check_(transform::PatternCheckContext(GetAnnotatedExpr(matches_opt.value()),
+ current_block_use_def_, value_to_bound_var_))) {
return;
}
// If a match is found, put all matching expressions into the same group.
@@ -975,12 +990,24 @@ class PatternBasedPartitioner : ExprVisitor {
return group_map_[bound_var.get()]->FindRoot();
}
+ Map<String, Expr> GetAnnotatedExpr(const Map<DFPattern, Expr> matched_result) {
+ Map<String, Expr> annotated_expr;
+ for (const auto& it : annotation_pat_) {
+ if (matched_result.count(it.second)) {
+ annotated_expr.Set(it.first, matched_result[it.second]);
+ }
+ }
+ return annotated_expr;
+ }
+
String pat_name_;
DFPattern pat_;
+ Map<String, DFPattern> annotation_pat_;
FCheckMatch check_;
support::Arena* arena_;
Map<Var, Expr> bindings_;
Map<Expr, Var> value_to_bound_var_;
+ Map<Var, Array<Var>> current_block_use_def_;
GroupMap group_map_;
};
@@ -1054,19 +1081,18 @@ class CompositeFunctionAnnotator : public ExprMutator {
std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
};
-IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
- const tvm::Array<DFPattern>& patterns,
- const tvm::Array<runtime::PackedFunc>& checks, IRModule mod,
+IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns, IRModule mod,
bool bind_constants, bool annotate_codegen) {
support::Arena arena;
- for (size_t i = 0; i < pattern_names.size(); ++i) {
+ for (const auto& pattern : patterns) {
OperatorFusor::GroupMap group_map;
for (const auto& entry : mod->functions) {
if (entry.second->IsInstance<tir::PrimFuncNode>()) {
continue;
}
- auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], checks[i],
- entry.second, &arena);
+ auto map = PatternBasedPartitioner::Run(
+ pattern->name, pattern->pattern, pattern->annotation_patterns,
+ pattern->check.value_or(nullptr), entry.second, &arena);
group_map.insert(map.begin(), map.end());
}
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants);
@@ -1079,6 +1105,36 @@ IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
namespace transform {
+FusionPattern::FusionPattern(String name, DFPattern pattern,
+ Map<String, DFPattern> annotation_patterns,
+ Optional<PackedFunc> check) {
+ ObjectPtr<FusionPatternNode> n = make_object<FusionPatternNode>();
+ n->name = std::move(name);
+ n->pattern = std::move(pattern);
+ n->annotation_patterns = std::move(annotation_patterns);
+ n->check = check;
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(FusionPatternNode);
+TVM_REGISTER_GLOBAL("relax.transform.FusionPattern")
+ .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns,
+ Optional<PackedFunc> check) {
+ return FusionPattern(name, pattern, annotation_patterns, check);
+ });
+
+PatternCheckContext::PatternCheckContext(Map<String, Expr> annotated_expr,
+ Map<Var, Array<Var>> var_usages,
+ Map<Expr, Var> value_to_bound_var) {
+ ObjectPtr<PatternCheckContextNode> n = make_object<PatternCheckContextNode>();
+ n->annotated_expr = std::move(annotated_expr);
+ n->var_usages = std::move(var_usages);
+ n->value_to_bound_var = std::move(value_to_bound_var);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PatternCheckContextNode);
+
Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
@@ -1094,14 +1150,11 @@ Pass FuseOps(int fuse_opt_level) {
TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
-Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names,
- const tvm::Array<DFPattern>& patterns,
- const tvm::Array<runtime::PackedFunc>& checks, bool bind_constants,
+Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants,
bool annotate_codegen) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
- return relax::FuseOpsByPattern(pattern_names, patterns, checks, m, bind_constants,
- annotate_codegen);
+ return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen);
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index de15f7083a..0bae6801ca 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -26,7 +26,6 @@ from tvm.contrib.pickle_memoize import memoize
from tvm.relax.backend import get_patterns_with_prefix
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.script import relax as R
-from tvm.script import tir as T
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
@@ -296,6 +295,43 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, epilogue, residual_bloc
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
+def test_cutlass_partition_conv2d_residual_blocked():
+ @tvm.script.ir_module
+ class Conv2dReLU:
+ """
+ This conv2d should not be fused as conv2d residual block, because both lhs and rhs of
+ the last R.add depends on the result of conv2d.
+ """
+
+ @R.function
+ def main(
+ data: R.Tensor((32, 3, 3, 16), "float32"),
+ weight: R.Tensor((16, 3, 3, 16), "float32"),
+ bias: R.Tensor((1, 1, 1, 16), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.conv2d(
+ data,
+ weight,
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ )
+ out = R.nn.relu(conv1 + bias)
+ # residual depends on conv result, which cannot be handled in cutlass
+ result = out + out
+ R.output(result)
+
+ return result
+
+ mod = partition_for_cutlass(Conv2dReLU, annotate_codegen=False)
+ for f_var in mod.functions:
+ func = mod[f_var]
+ if func.attrs and "Composite" in func.attrs:
+ # verify that the function is not fused as residual block
+ assert func.attrs["Composite"] == "cutlass.conv2d_bias_relu"
+
+
@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, epilogue, residual_block",
[
@@ -451,6 +487,7 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype):
mod = get_relax_matmul_module(
x_shape, y_shape, dtype, with_bias=False, transposed_y=transpose_y
)
+ mod = partition_for_cutlass(mod)
assert len(mod.functions) == 1
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 3816e11bc5..2f3e2d479f 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -14,14 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import pytest
import numpy as np
+import pytest
import tvm
-
from tvm import relax
-from tvm.script import relax as R, tir as T, ir as I
-from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op, wildcard
+from tvm.relax.dpl.pattern import is_op, make_fused_bias_activation_pattern, wildcard
+from tvm.relax.transform import PatternCheckContext
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
@tvm.script.ir_module
@@ -600,13 +602,23 @@ def test_unused():
def test_check_pattern():
- pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None)
-
- def pred(match, expr):
+ lhs = wildcard()
+ rhs = wildcard()
+ out = is_op("relax.nn.conv2d")(lhs, rhs)
+ annotation_patterns = {"root": out, "lhs": lhs, "rhs": rhs}
+
+ def pred(context: PatternCheckContext):
+ lhs = context.annotated_expr["lhs"]
+ rhs = context.annotated_expr["rhs"]
+ expr = context.annotated_expr["root"]
+ assert isinstance(lhs, relax.expr.Var) and lhs.name_hint == "data"
+ assert isinstance(rhs, relax.expr.Var) and rhs.name_hint == "weight1"
assert isinstance(expr, relax.expr.Call) and expr.op.name == "relax.nn.conv2d"
- return expr.struct_info.dtype == "float32"
+ return False
- check(Conv2dx2, [("cutlass.conv2d", pat, pred)], Conv2dx2) # expect no partitioning
+ check(
+ Conv2dReLU, [("cutlass.conv2d", out, annotation_patterns, pred)], Conv2dReLU
+ ) # expect no partitioning
def test_bind_constants():