You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/03/17 23:07:27 UTC

[tvm] branch unity updated: [Unity][BYOC] Improve expressiveness of the pattern check function in FuseOpsByPattern (#14310)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 5513095b5a [Unity][BYOC] Improve expressiveness of the pattern check function in FuseOpsByPattern (#14310)
5513095b5a is described below

commit 5513095b5a28481ea4aea4979177b16bdc78dac6
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Fri Mar 17 19:07:19 2023 -0400

    [Unity][BYOC] Improve expressiveness of the pattern check function in FuseOpsByPattern (#14310)
    
    * Change the input of FuseOpsByPattern and add check for result dependency in cutlass conv2d residual block
    
    * Rename FuseOpsPattern to FusionPattern and PatternCheckFunctionInput to PatternCheckContext
---
 include/tvm/relax/transform.h                      | 103 ++++++++++++++++--
 python/tvm/contrib/cutlass/build.py                |   8 +-
 python/tvm/relax/backend/contrib/cutlass.py        |  80 +++++++-------
 python/tvm/relax/backend/pattern_registry.py       |  75 ++-----------
 python/tvm/relax/backend/patterns.py               |  49 +++++----
 python/tvm/relax/transform/transform.py            | 118 ++++++++++++++++-----
 src/relax/backend/pattern_registry.cc              |  39 ++-----
 src/relax/backend/pattern_registry.h               |  59 +----------
 src/relax/transform/fuse_ops.cc                    |  87 ++++++++++++---
 tests/python/relax/test_codegen_cutlass.py         |  39 ++++++-
 .../relax/test_transform_fuse_ops_by_pattern.py    |  30 ++++--
 11 files changed, 414 insertions(+), 273 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 3ff863dd09..e0fe226e83 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -215,17 +215,108 @@ TVM_DLL Pass AnnotateTIROpPattern();
  */
 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
 
+/*!
+ * \brief The pattern object used as the input of FuseOpsByPattern. For bindings to be
+ * fused, it needs to be matched with `pattern` and the `check` function needs to return
+ * true.
+ */
+class FusionPatternNode : public Object {
+ public:
+  /*!
+   * \brief The name of pattern. It becomes the value of the kComposite attribute
+   * of a fused function after successful matching
+   */
+  String name;
+
+  /*!
+   * \brief The dataflow pattern that will be used to match expression in the DataflowBlock.
+   * All the call nodes covered by the pattern will be extracted into the fused function.
+   */
+  DFPattern pattern;
+
+  /*!
+   * \brief The map which is used to extract important expressions from the pattern match
+   * result. All DFPattern in this map should be part of the `pattern`.
+   */
+  Map<String, DFPattern> annotation_patterns;
+
+  /*!
+   * \brief The function to determine whether the match result is accepted. This can be
+   * NullOpt if check function is not necessary for this pattern.
+   *
+   * It should have signature
+   * bool(const PatternCheckContext& context)
+   */
+  Optional<PackedFunc> check;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("pattern", &pattern);
+    v->Visit("annotation_patterns", &annotation_patterns);
+    v->Visit("check", &check);
+  }
+
+  static constexpr const char* _type_key = "relax.transform.FusionPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object);
+};
+
+class FusionPattern : public ObjectRef {
+ public:
+  FusionPattern(String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns,
+                Optional<PackedFunc> check);
+
+  FusionPattern(String name, DFPattern pattern) : FusionPattern(name, pattern, {}, NullOpt) {}
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode);
+};
+
+/*!
+ * \brief The input of FusionPattern::check.
+ */
+class PatternCheckContextNode : public Object {
+ public:
+  /*!
+   * \brief A map which contains all expressions matched by the sub patterns in
+   * FusionPattern::annotation_patterns.
+   */
+  Map<String, Expr> annotated_expr;
+
+  /*!
+   * \brief A map mapping variable definitions to a set of uses.
+   */
+  Map<Var, Array<Var>> var_usages;
+
+  /*!
+   * \brief Map from value to its bound variable.
+   */
+  Map<Expr, Var> value_to_bound_var;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("annotated_expr", &annotated_expr);
+    v->Visit("var_usages", &var_usages);
+    v->Visit("value_to_bound_var", &value_to_bound_var);
+  }
+
+  static constexpr const char* _type_key = "relax.transform.PatternCheckContext";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object);
+};
+
+class PatternCheckContext : public ObjectRef {
+ public:
+  PatternCheckContext(Map<String, Expr> annotated_expr, Map<Var, Array<Var>> var_usages,
+                      Map<Expr, Var> value_to_bound_var);
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef,
+                                            PatternCheckContextNode);
+};
+
 /*!
  * \brief Apply pattern matching to each function in the given module, and group matched
  * expressions into a new function. The end result is similar to FuseOps, but fusion is driven
  * completely by the provided patterns.
  *
- * \param pattern_names The name of each pattern. It becomes the value of the kComposite attribute
- * of a fused function after successful matching.
  * \param patterns The patterns to detect. The order of the patterns determines the order
  * of priority in which they are matched. Higher-priority patterns should come earlier in the list.
- * \param checks The callback functions with type (Map<DFPattern, Expr>, Expr) -> bool. It takes a
- * match result and returns a boolean value to indicate whether the match result is accepted.
  * \param bind_constants Whether or not to keep bound constants of the grouped function.
  * \param annotate_codegen If true, wrap each created composite function with another function,
  * whose body consists only of a call to the composite function, and annotate the outer function
@@ -235,9 +326,7 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
  * an external backend without using the MergeCompositeFunctions pass.
  * \return The Pass.
  */
-TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
-                              const tvm::Array<DFPattern>& patterns,
-                              const tvm::Array<PackedFunc>& checks, bool bind_constants = true,
+TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
                               bool annotate_codegen = false);
 
 /*!
diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py
index 7e92e6a887..47bdcaa790 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -568,11 +568,11 @@ def _extract_arg_idx(pattern_name, f):
     func_args = list(f.params)
 
     arg_idx = {}
-    for arg_name, arg_pattern in pattern_entry.arg_patterns.items():
-        arg_expr = matched_expr[arg_pattern]
+    for name, annotation_pattern in pattern_entry.annotation_patterns.items():
+        arg_expr = matched_expr[annotation_pattern]
         if arg_expr not in func_args:
-            raise ValueError(f"Cannot find arg {arg_name} in the fused function parameters")
-        arg_idx[arg_name] = func_args.index(arg_expr)
+            continue
+        arg_idx[name] = func_args.index(arg_expr)
 
     return arg_idx
 
diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py
index e1b9226d68..4d539928cf 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -17,12 +17,12 @@
 
 """Pattern table for CUTLASS backend"""
 
-from typing import Mapping, Optional, Tuple
+from typing import Mapping, Optional, Sequence, Tuple
 
 import tvm
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import Call, Expr, ShapeExpr, transform
-from tvm.relax.dpl import CallPattern, DFPattern
+from tvm.relax import ShapeExpr, Var, transform
+from tvm.relax.transform import PatternCheckContext
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
 from ..patterns import (
@@ -52,33 +52,27 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype):
     )
 
 
-def _find_call(op_name: str, match_result: Mapping[DFPattern, Expr]) -> Optional[Expr]:
-    result = None
+def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var, Sequence[Var]]):
+    if from_var == to_var:
+        return True
 
-    for pattern, expr in match_result.items():
-        if (
-            isinstance(expr, Call)
-            and isinstance(pattern, CallPattern)
-            and isinstance(expr.op, tvm.ir.Op)
-            and expr.op.name == op_name
-        ):
-            if result is not None:
-                raise ValueError(f"Found multiple matched call node for {op_name}")
-            result = expr
+    checked = set()
+    vars_to_check = [to_var]
+    while vars_to_check:
+        current_var = vars_to_check.pop()
+        for user in var_usages.get(current_var, []):
+            if user == from_var:
+                return True
+            if user not in checked:
+                checked.add(user)
+                vars_to_check.append(user)
 
-    return result
+    return False
 
 
-def _check_conv2d(
-    match_result: Mapping[DFPattern, Expr],
-    _: Expr,
-):
+def _check_conv2d(context: PatternCheckContext) -> bool:
     """Check if the given conv2d workload can be offloaded to CUTLASS."""
-
-    conv2d_call = _find_call("relax.nn.conv2d", match_result)
-    if conv2d_call is None:
-        return False
-
+    conv2d_call = context.annotated_expr["root"]
     data_layout = conv2d_call.attrs.data_layout
     kernel_layout = conv2d_call.attrs.kernel_layout
     data, weight, *_ = conv2d_call.args
@@ -89,6 +83,15 @@ def _check_conv2d(
     ):
         return False
 
+    if "residual" in context.annotated_expr:
+        residual = context.annotated_expr["residual"]
+        if not isinstance(residual, Var):
+            residual = context.value_to_bound_var[residual]
+        conv2d_var = context.value_to_bound_var[conv2d_call]
+        if _has_dependency(from_var=residual, to_var=conv2d_var, var_usages=context.var_usages):
+            # If residual depends on the result of conv2d, this cannot be handled by cutlass.
+            return False
+
     # pylint: disable=invalid-name
     IC = data.struct_info.shape.values[3]
     OC = weight.struct_info.shape.values[0]
@@ -96,17 +99,10 @@ def _check_conv2d(
     return not IC == OC == conv2d_call.attrs.groups
 
 
-def _check_matmul(
-    match_result: Mapping[DFPattern, Expr],
-    _: Expr,
-) -> bool:
+def _check_matmul(context: PatternCheckContext) -> bool:
     """Check if the given matmul workload can be offloaded to CUTLASS."""
-
-    matmul_call: Call = _find_call("relax.matmul", match_result)
-    if matmul_call is None:
-        return False
-
-    lhs, rhs, *_ = matmul_call.args
+    lhs = context.annotated_expr["lhs"]
+    rhs = context.annotated_expr["rhs"]
 
     lhs_dtype = lhs.struct_info.dtype
     rhs_dtype = rhs.struct_info.dtype
@@ -244,7 +240,7 @@ register_patterns(
 )
 
 
-def partition_for_cutlass(mod):
+def partition_for_cutlass(mod, annotate_codegen=True):
     """
     Partition the input module into CUTLASS-supported subgraphs.
 
@@ -253,6 +249,11 @@ def partition_for_cutlass(mod):
     mod: tvm.IRModule
         The IRModule to be partitioned.
 
+    annotate_codegen: bool
+        Whether to wrap each created composite function with another function, whose
+        body consists only of a call to the composite function. See the doc of FuseOpsByPattern
+        for more detail.
+
     Returns
     -------
     mod: tvm.IRModule
@@ -260,6 +261,7 @@ def partition_for_cutlass(mod):
         compiled by the CUTLASS backend.
     """
 
-    cutlass_pattern_entries = get_patterns_with_prefix("cutlass")
-    patterns = [(e.name, e.pattern, e.check) for e in cutlass_pattern_entries]
-    return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod)
+    patterns = get_patterns_with_prefix("cutlass")
+    return transform.FuseOpsByPattern(
+        patterns, bind_constants=False, annotate_codegen=annotate_codegen
+    )(mod)
diff --git a/python/tvm/relax/backend/pattern_registry.py b/python/tvm/relax/backend/pattern_registry.py
index 5a35eba03d..5ec57164eb 100644
--- a/python/tvm/relax/backend/pattern_registry.py
+++ b/python/tvm/relax/backend/pattern_registry.py
@@ -20,55 +20,12 @@
 import atexit
 from typing import Callable, List, Mapping, Optional, Set, Tuple, Union
 
-import tvm
 from tvm.relax.dpl import DFPattern
-from tvm.runtime import Object
+from tvm.relax.transform import FusionPattern
 
 from ..expr import Expr
 from . import _ffi_api
 
-
-@tvm._ffi.register_object("relax.backend.PatternRegistryEntry")
-class PatternRegistryEntry(Object):
-    """
-    An entry in the pattern registry. This represents a single pattern that
-    can be used to identify expressions that can be handled by external
-    backends, like CUTLASS and TensorRT.
-
-    Parameters
-    ----------
-    name: str
-        The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'.
-
-    pattern: DFPattern
-        The dataflow pattern that will be used to match expressions that can be handled
-        by external backends.
-
-    arg_patterns: Mapping[str, DFPattern]
-        The mapping from arg name to its pattern. It can be used to extract arg expression
-        from match result. All DFPattern in this map should be part of the `pattern`.
-
-    check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
-        The function to check whether the match result is accepted.
-    """
-
-    name: str
-    pattern: DFPattern
-    arg_patterns: Mapping[str, DFPattern]
-    check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
-
-    def __init__(
-        self,
-        name: str,
-        pattern: DFPattern,
-        arg_patterns: Mapping[str, DFPattern],
-        check: Callable[[Mapping[DFPattern, Expr], Expr], bool],
-    ):
-        self.__init_handle_by_constructor__(
-            _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check  # type: ignore
-        )
-
-
 _REGISTERED_PATTERN_NAMES: Set[str] = set()
 
 
@@ -96,7 +53,7 @@ def _ensure_cleanup_function_registered():
 
 CheckFunc = Callable[[Mapping[DFPattern, Expr], Expr], bool]
 Pattern = Union[
-    PatternRegistryEntry,
+    FusionPattern,
     Tuple[str, DFPattern],
     Tuple[str, DFPattern, Mapping[str, DFPattern]],
     Tuple[str, DFPattern, Mapping[str, DFPattern], CheckFunc],
@@ -118,29 +75,17 @@ def register_patterns(patterns: List[Pattern]):
 
     entries = []
     for item in patterns:
-        if isinstance(item, PatternRegistryEntry):
+        if isinstance(item, FusionPattern):
             entries.append(item)
         elif isinstance(item, tuple):
-            name, pattern, *rest = item
-
-            if len(rest) > 0:
-                arg_patterns = rest[0]
-            else:
-                arg_patterns = {}
-
-            if len(rest) > 1:
-                check = rest[1]
-            else:
-                check = lambda *_: True
-
-            entries.append(PatternRegistryEntry(name, pattern, arg_patterns, check))
-            _REGISTERED_PATTERN_NAMES.add(name)
+            entries.append(FusionPattern(*item))
+            _REGISTERED_PATTERN_NAMES.add(item[0])
         else:
-            raise TypeError(f"Cannot register type {type(pattern)} as pattern")
+            raise TypeError(f"Cannot register type {type(item)} as pattern")
     _ffi_api.RegisterPatterns(entries)
 
 
-def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]:
+def get_patterns_with_prefix(prefix: str) -> List[FusionPattern]:
     """
     Get a list of patterns whose names startwith `prefix`.
 
@@ -151,13 +96,13 @@ def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]:
 
     Returns
     -------
-    patterns: PatternRegistryEntry
+    patterns: FusionPattern
         Matched patterns, ordered by priority from high to low.
     """
     return _ffi_api.GetPatternsWithPrefix(prefix)
 
 
-def get_pattern(name: str) -> Optional[PatternRegistryEntry]:
+def get_pattern(name: str) -> Optional[FusionPattern]:
     """
     Find the pattern with a particular name.
 
@@ -168,7 +113,7 @@ def get_pattern(name: str) -> Optional[PatternRegistryEntry]:
 
     Returns
     -------
-    pattern: Optional[PatternRegistryEntry]
+    pattern: Optional[FusionPattern]
         The matched pattern. Returns None if such pattern is not found.
     """
     return _ffi_api.GetPattern(name)
diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py
index d770cc6faf..e27b91b3ea 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -24,18 +24,18 @@ from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard
 
 def _with_bias_activation_pattern(
     out: DFPattern,
-    args: Dict[str, DFPattern],
+    annotations: Dict[str, DFPattern],
     with_bias: bool = False,
     activation: str = None,
 ) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
     if with_bias:
-        args["bias"] = bias = wildcard()
+        annotations["bias"] = bias = wildcard()
         out = is_op("relax.add")(out, bias)
 
     if activation:
         out = is_op(activation)(out)
 
-    return out, args
+    return out, annotations
 
 
 def make_fused_bias_activation_pattern(
@@ -62,16 +62,17 @@ def make_fused_bias_activation_pattern(
     pattern: DFPattern
         The resulting pattern describing a fused operation
 
-    args: Mapping[str, DFPattern]
-        The mapping from arg name to its pattern. It can be used to extract
-        arg expression from match result.
+    annotations: Mapping[str, DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
     """
     lhs = wildcard()
     rhs = wildcard()
-    args = {"lhs": lhs, "rhs": rhs}
     out = is_op(op_name)(lhs, rhs)
+    annotations = {"lhs": lhs, "rhs": rhs, "root": out}
 
-    return _with_bias_activation_pattern(out, args, with_bias, activation)
+    return _with_bias_activation_pattern(out, annotations, with_bias, activation)
 
 
 def make_residual_block_pattern(
@@ -99,9 +100,10 @@ def make_residual_block_pattern(
     pattern: DFPattern
         The resulting pattern describing a matrix multiplication.
 
-    args: Mapping[str, DFPattern]
-        The mapping from arg name to its pattern. It can be used to extract
-        arg expression from match result.
+    annotations: Mapping[str, DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
     """
 
     if isinstance(node_output, tuple):
@@ -143,21 +145,23 @@ def make_matmul_pattern(
     pattern: DFPattern
         The resulting pattern describing a matrix multiplication.
 
-    args: Mapping[str, DFPattern]
-        The mapping from arg name to its pattern. It can be used to extract
-        arg expression from match result.
+    annotations: Mapping[str, DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
     """
 
     lhs = wildcard()
     rhs = wildcard()
-    args = {"lhs": lhs, "rhs": rhs}
+    annotations = {"lhs": lhs, "rhs": rhs}
 
     if transposed_rhs:
         rhs = is_op("relax.permute_dims")(rhs)
 
     out = is_op("relax.matmul")(lhs, rhs)
+    annotations["root"] = out
 
-    return _with_bias_activation_pattern(out, args, with_bias, activation)
+    return _with_bias_activation_pattern(out, annotations, with_bias, activation)
 
 
 def make_attention_pattern(with_bias: bool = False):
@@ -169,19 +173,20 @@ def make_attention_pattern(with_bias: bool = False):
     pattern: DFPattern
         The resulting pattern describing a fused multi head attention.
 
-    args: Mapping[str, DFPattern]
-        The mapping from arg name to its pattern. It can be used to extract
-        arg expression from match result.
+    annotations: Mapping[str, DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
     """
     query = wildcard()
     key = wildcard()
     value = wildcard()
-    args = {"query": query, "key": key, "value": value}
+    annotations = {"query": query, "key": key, "value": value}
     if with_bias:
         bias = wildcard()
-        args["bias"] = bias
+        annotations["bias"] = bias
         out = is_op("relax.nn.attention_bias")(query, key, value, bias)
     else:
         out = is_op("relax.nn.attention")(query, key, value)
 
-    return out, args
+    return out, annotations
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index c59104ca58..0df29dc093 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -19,11 +19,16 @@
 import functools
 import inspect
 import types
-from typing import Callable, Dict, Union, Optional, List, Tuple
-from tvm.tir import PrimFunc, IndexMap
+from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
+
 import numpy as np  # type: ignore
+
 import tvm.ir
-from tvm.runtime import NDArray
+from tvm.relax import Expr, Var
+from tvm.relax.dpl import DFPattern
+from tvm.runtime import NDArray, Object
+from tvm.tir import IndexMap, PrimFunc
+
 from . import _ffi_api
 from .legalize_ops.common import LegalizeFunc
 
@@ -283,8 +288,75 @@ def FuseTIR() -> tvm.ir.transform.Pass:
     return _ffi_api.FuseTIR()  # type: ignore
 
 
+@tvm._ffi.register_object("relax.transform.PatternCheckContext")
+class PatternCheckContext(Object):
+    """
+    The input of check function `FusionPattern.check`.
+
+    Parameters
+    ----------
+    annotated_expr: Mapping[str, Expr]
+        A map which contains all expressions matched by the sub patterns in
+        FusionPattern.annotation_patterns.
+
+    var_usages: Mapping[Var, Sequence[Var]]
+        A map mapping variable definitions to a set of uses.
+
+    value_to_bound_var: Mapping[Expr, Var]
+        Map from value to its bound variable.
+    """
+
+    annotated_expr: Mapping[str, Expr]
+    var_usages: Mapping[Var, Sequence[Var]]
+    value_to_bound_var: Mapping[Expr, Var]
+
+
+@tvm._ffi.register_object("relax.transform.FusionPattern")
+class FusionPattern(Object):
+    """
+    The pattern used by `FuseOpsByPattern`. It's mainly DFPattern but with other
+    information to help during the fusion pass.
+
+    Parameters
+    ----------
+    name: str
+        The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'.
+
+    pattern: DFPattern
+        The dataflow pattern that will be used to match expressions that can be handled
+        by external backends.
+
+    annotation_patterns: Mapping[str, DFPattern]
+        The map which is used to extract important expressions from the pattern match
+        result. All DFPattern in this map should be part of the `pattern`.
+
+    check: Callable[[PatternCheckContext], bool]
+        The function to check whether the match result is accepted.
+    """
+
+    name: str
+    pattern: DFPattern
+    annotation_patterns: Mapping[str, DFPattern]
+    check: Callable[[PatternCheckContext], bool]
+
+    def __init__(
+        self,
+        name: str,
+        pattern: DFPattern,
+        annotation_patterns: Optional[Mapping[str, DFPattern]] = None,
+        check: Optional[Callable[[Mapping[str, Expr]], bool]] = None,
+    ):
+        if annotation_patterns is None:
+            annotation_patterns = {}
+        self.__init_handle_by_constructor__(
+            _ffi_api.FusionPattern, name, pattern, annotation_patterns, check  # type: ignore
+        )
+
+
 def FuseOpsByPattern(
-    patterns: List[Tuple], bind_constants: bool = True, annotate_codegen: bool = False
+    patterns: List[Union[FusionPattern, Tuple]],
+    bind_constants: bool = True,
+    annotate_codegen: bool = False,
 ) -> tvm.ir.transform.Pass:
     """Apply pattern matching to each function in the given module, and group matched expressions
     into a new function.
@@ -293,15 +365,12 @@ def FuseOpsByPattern(
 
     Parameters
     ----------
-    patterns : List[Union[Tuple[str, DFPattern], Tuple[str, DFPattern, Callable]]]
-        A list of tuple of (name, pattern) or (name, pattern, predicate) to be matched.
-        The predicate is a function with type (Map<DFPattern, Expr>, Expr) -> bool. It takes a
-        match result and returns a boolean value to indicate whether the match result is accepted.
+    patterns : List[Union[FusionPattern, Tuple]]
+        A list of patterns to be matched. The order of the patterns determines the order of priority
+        in which they are matched. Higher-priority patterns should come earlier in the list.
 
-        The patterns to detect. The order of the patterns determines the order of priority in which
-        they are matched. Higher-priority patterns should come earlier in the list.
-        The string is the name of the corresponding pattern. It becomes the value of the kComposite
-        attribute of a fused function after a successful matching.
+        In addition to FusionPattern, a tuple can be passed as item of this list. The pattern
+        will be constructed through FusionPattern(*item)
 
     bind_constants : bool
         Whether or not to keep bound constants in the grouped function.
@@ -321,22 +390,19 @@ def FuseOpsByPattern(
         The registered pass for pattern-based fusion.
 
     """
-    pattern_names = []
-    df_patterns = []
-    checks = []
-    for tup in patterns:
-        if len(tup) == 2:
-            pattern_names.append(tup[0])
-            df_patterns.append(tup[1])
-            checks.append(lambda *_: True)
-        elif len(tup) == 3:
-            pattern_names.append(tup[0])
-            df_patterns.append(tup[1])
-            checks.append(tup[2])
+    converted_patterns = []
+    for pattern in patterns:
+        if isinstance(pattern, tuple):
+            converted_patterns.append(FusionPattern(*pattern))
+        elif isinstance(pattern, FusionPattern):
+            converted_patterns.append(pattern)
         else:
-            raise ValueError("Invalid pattern: {}".format(tup))
+            raise ValueError(f"Invalid pattern: {pattern}")
+
     return _ffi_api.FuseOpsByPattern(
-        pattern_names, df_patterns, checks, bind_constants, annotate_codegen
+        converted_patterns,
+        bind_constants,
+        annotate_codegen,
     )  # type: ignore
 
 
diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc
index 553018d690..34ebb4d6dd 100644
--- a/src/relax/backend/pattern_registry.cc
+++ b/src/relax/backend/pattern_registry.cc
@@ -24,25 +24,12 @@
 namespace tvm {
 namespace relax {
 namespace backend {
-
-PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern,
-                                           Map<String, DFPattern> arg_patterns, PackedFunc check) {
-  ObjectPtr<PatternRegistryEntryNode> n = make_object<PatternRegistryEntryNode>();
-  n->name = std::move(name);
-  n->pattern = std::move(pattern);
-  n->arg_patterns = std::move(arg_patterns);
-  n->check = check;
-  data_ = std::move(n);
-}
-
-TVM_REGISTER_NODE_TYPE(PatternRegistryEntryNode);
-
-static std::vector<PatternRegistryEntry>* GetRegistryTable() {
-  static std::vector<PatternRegistryEntry> table;
+static std::vector<FusionPattern>* GetRegistryTable() {
+  static std::vector<FusionPattern> table;
   return &table;
 }
 
-void RegisterPatterns(Array<PatternRegistryEntry> entries) {
+void RegisterPatterns(Array<FusionPattern> entries) {
   auto* table = GetRegistryTable();
   for (const auto& entry : entries) {
     table->push_back(entry);
@@ -53,16 +40,15 @@ void RemovePatterns(Array<String> names) {
   std::unordered_set<String> name_set{names.begin(), names.end()};
 
   auto* table = GetRegistryTable();
-  table->erase(std::remove_if(table->begin(), table->end(),
-                              [&](const PatternRegistryEntry& entry) {
-                                return name_set.count(entry->name) > 0;
-                              }),
-               table->end());
+  table->erase(
+      std::remove_if(table->begin(), table->end(),
+                     [&](const FusionPattern& entry) { return name_set.count(entry->name) > 0; }),
+      table->end());
 }
 
-Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) {
+Array<FusionPattern> GetPatternsWithPrefix(const String& prefix) {
   auto* table = GetRegistryTable();
-  Array<PatternRegistryEntry> result;
+  Array<FusionPattern> result;
   for (auto it = table->rbegin(); it != table->rend(); ++it) {
     if (support::StartsWith((*it)->name, prefix.data())) {
       result.push_back(*it);
@@ -71,7 +57,7 @@ Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) {
   return result;
 }
 
-Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) {
+Optional<FusionPattern> GetPattern(const String& pattern_name) {
   auto* table = GetRegistryTable();
   for (auto it = table->rbegin(); it != table->rend(); ++it) {
     if ((*it)->name == pattern_name) {
@@ -81,11 +67,6 @@ Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) {
   return NullOpt;
 }
 
-TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry")
-    .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> arg_patterns,
-                       PackedFunc check) {
-      return PatternRegistryEntry(name, pattern, arg_patterns, check);
-    });
 TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns);
 TVM_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns);
 TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix);
diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h
index e765f56b4e..72eea1238d 100644
--- a/src/relax/backend/pattern_registry.h
+++ b/src/relax/backend/pattern_registry.h
@@ -28,6 +28,7 @@
 
 #include <tvm/relax/dataflow_pattern.h>
 #include <tvm/relax/expr.h>
+#include <tvm/relax/transform.h>
 #include <tvm/runtime/container/optional.h>
 #include <tvm/runtime/object.h>
 
@@ -35,57 +36,7 @@ namespace tvm {
 namespace relax {
 namespace backend {
 
-/*!
- * \brief An entry in the pattern registry. This represents a single pattern that
- * can be used to identify expressions that can be handled by external
- * backends, like CUTLASS and TensorRT.
- */
-class PatternRegistryEntryNode : public Object {
- public:
-  /*!
-   * \brief The name of pattern. Usually it starts with the name of backend, like
-   * 'cutlass.matmul'.
-   */
-  String name;
-  /*!
-   * \brief The dataflow pattern that will be used to match expressions that can
-   * be handled by external backends.
-   */
-  DFPattern pattern;
-  /*!
-   * \brief The mapping from arg name to its pattern. It can be used to extract
-   * arg expression from match result. All DFPattern in this map should be part of
-   * the `pattern`.
-   */
-  Map<String, DFPattern> arg_patterns;
-
-  /*!
-   * \brief The function to check whether the match result is accepted.
-   *
-   * It should have signature
-   * bool(const Map<DFPattern, Expr>& match_result, const Expr& matched_expr)
-   */
-  PackedFunc check;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("pattern", &pattern);
-    v->Visit("arg_patterns", &arg_patterns);
-    v->Visit("check", &check);
-  }
-
-  static constexpr const char* _type_key = "relax.backend.PatternRegistryEntry";
-  TVM_DECLARE_FINAL_OBJECT_INFO(PatternRegistryEntryNode, Object);
-};
-
-class PatternRegistryEntry : public ObjectRef {
- public:
-  PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern> arg_patterns,
-                       PackedFunc check);
-
-  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef,
-                                            PatternRegistryEntryNode);
-};
+using transform::FusionPattern;
 
 /*!
  * \brief Register patterns which will be used to partition the DataflowBlock
@@ -93,7 +44,7 @@ class PatternRegistryEntry : public ObjectRef {
  * \param patterns Patterns to be registered. Patterns that appear later in the list have
  *        higher priority when partitioning DataflowBlock.
  */
-void RegisterPatterns(Array<PatternRegistryEntry> entries);
+void RegisterPatterns(Array<FusionPattern> patterns);
 
 /*!
  * \brief Remove patterns from the registry by their name.
@@ -106,14 +57,14 @@ void RemovePatterns(Array<String> names);
  * \param prefx The pattern name prefix.
  * \return Matched patterns, ordered by priority from high to low.
  */
-Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix);
+Array<FusionPattern> GetPatternsWithPrefix(const String& prefix);
 
 /*!
  * \brief Find the pattern with a particular name.
  * \param name The pattern name.
  * \return The matched pattern. NullOpt if not found.
  */
-Optional<PatternRegistryEntry> GetPattern(const String& name);
+Optional<FusionPattern> GetPattern(const String& name);
 
 }  // namespace backend
 }  // namespace relax
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 6d7c278d80..76f53eebc5 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -40,6 +40,7 @@
 
 #include "../../relay/analysis/graph_partitioner.h"
 #include "../../support/arena.h"
+#include "tvm/relax/expr.h"
 
 namespace tvm {
 namespace relax {
@@ -905,18 +906,30 @@ class PatternBasedPartitioner : ExprVisitor {
   using Group = GraphPartitioner::Group;
   using GroupMap = OperatorFusor::GroupMap;
   using ExprVisitor::VisitExpr_;
-  using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern, Expr>&, const Expr&)>;
+  using FCheckMatch = runtime::TypedPackedFunc<bool(const transform::PatternCheckContext&)>;
 
-  static GroupMap Run(String pattern_name, DFPattern pattern, FCheckMatch check, Expr expr,
+  static GroupMap Run(String pattern_name, DFPattern pattern,
+                      Map<String, DFPattern> annotation_patterns, FCheckMatch check, Expr expr,
                       support::Arena* arena) {
-    PatternBasedPartitioner part(pattern_name, pattern, check, arena);
+    PatternBasedPartitioner part(pattern_name, pattern, annotation_patterns, check, arena);
     part.VisitExpr(expr);
     return part.group_map_;
   }
 
-  PatternBasedPartitioner(String pattern_name, DFPattern pattern, FCheckMatch check,
+  PatternBasedPartitioner(String pattern_name, DFPattern pattern,
+                          Map<String, DFPattern> annotation_patterns, FCheckMatch check,
                           support::Arena* arena)
-      : pat_name_(pattern_name), pat_(pattern), check_(check), arena_(arena) {}
+      : pat_name_(pattern_name),
+        pat_(pattern),
+        annotation_pat_(annotation_patterns),
+        check_(check),
+        arena_(arena) {}
+
+  void VisitBindingBlock_(const DataflowBlockNode* block) final {
+    current_block_use_def_ = DataflowBlockUseDef(GetRef<DataflowBlock>(block));
+    ExprVisitor::VisitBindingBlock_(block);
+    current_block_use_def_ = {};
+  }
 
   void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make<Group>(); }
 
@@ -931,7 +944,9 @@ class PatternBasedPartitioner : ExprVisitor {
   void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final {
     VisitVarDef(binding->var);
     if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call), bindings_)) {
-      if (!check_(matches_opt.value(), GetRef<Call>(call))) {
+      if (check_ != nullptr &&
+          !check_(transform::PatternCheckContext(GetAnnotatedExpr(matches_opt.value()),
+                                                 current_block_use_def_, value_to_bound_var_))) {
         return;
       }
       // If a match is found, put all matching expressions into the same group.
@@ -975,12 +990,24 @@ class PatternBasedPartitioner : ExprVisitor {
     return group_map_[bound_var.get()]->FindRoot();
   }
 
+  Map<String, Expr> GetAnnotatedExpr(const Map<DFPattern, Expr> matched_result) {
+    Map<String, Expr> annotated_expr;
+    for (const auto& it : annotation_pat_) {
+      if (matched_result.count(it.second)) {
+        annotated_expr.Set(it.first, matched_result[it.second]);
+      }
+    }
+    return annotated_expr;
+  }
+
   String pat_name_;
   DFPattern pat_;
+  Map<String, DFPattern> annotation_pat_;
   FCheckMatch check_;
   support::Arena* arena_;
   Map<Var, Expr> bindings_;
   Map<Expr, Var> value_to_bound_var_;
+  Map<Var, Array<Var>> current_block_use_def_;
   GroupMap group_map_;
 };
 
@@ -1054,19 +1081,18 @@ class CompositeFunctionAnnotator : public ExprMutator {
   std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
 };
 
-IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
-                          const tvm::Array<DFPattern>& patterns,
-                          const tvm::Array<runtime::PackedFunc>& checks, IRModule mod,
+IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns, IRModule mod,
                           bool bind_constants, bool annotate_codegen) {
   support::Arena arena;
-  for (size_t i = 0; i < pattern_names.size(); ++i) {
+  for (const auto& pattern : patterns) {
     OperatorFusor::GroupMap group_map;
     for (const auto& entry : mod->functions) {
       if (entry.second->IsInstance<tir::PrimFuncNode>()) {
         continue;
       }
-      auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], checks[i],
-                                              entry.second, &arena);
+      auto map = PatternBasedPartitioner::Run(
+          pattern->name, pattern->pattern, pattern->annotation_patterns,
+          pattern->check.value_or(nullptr), entry.second, &arena);
       group_map.insert(map.begin(), map.end());
     }
     mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants);
@@ -1079,6 +1105,36 @@ IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
 
 namespace transform {
 
+FusionPattern::FusionPattern(String name, DFPattern pattern,
+                             Map<String, DFPattern> annotation_patterns,
+                             Optional<PackedFunc> check) {
+  ObjectPtr<FusionPatternNode> n = make_object<FusionPatternNode>();
+  n->name = std::move(name);
+  n->pattern = std::move(pattern);
+  n->annotation_patterns = std::move(annotation_patterns);
+  n->check = check;
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(FusionPatternNode);
+TVM_REGISTER_GLOBAL("relax.transform.FusionPattern")
+    .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns,
+                       Optional<PackedFunc> check) {
+      return FusionPattern(name, pattern, annotation_patterns, check);
+    });
+
+PatternCheckContext::PatternCheckContext(Map<String, Expr> annotated_expr,
+                                         Map<Var, Array<Var>> var_usages,
+                                         Map<Expr, Var> value_to_bound_var) {
+  ObjectPtr<PatternCheckContextNode> n = make_object<PatternCheckContextNode>();
+  n->annotated_expr = std::move(annotated_expr);
+  n->var_usages = std::move(var_usages);
+  n->value_to_bound_var = std::move(value_to_bound_var);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PatternCheckContextNode);
+
 Pass FuseOps(int fuse_opt_level) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
       [=](IRModule m, PassContext pc) {
@@ -1094,14 +1150,11 @@ Pass FuseOps(int fuse_opt_level) {
 
 TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
 
-Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names,
-                      const tvm::Array<DFPattern>& patterns,
-                      const tvm::Array<runtime::PackedFunc>& checks, bool bind_constants,
+Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants,
                       bool annotate_codegen) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
       [=](IRModule m, PassContext pc) {
-        return relax::FuseOpsByPattern(pattern_names, patterns, checks, m, bind_constants,
-                                       annotate_codegen);
+        return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen);
       };
   return CreateModulePass(/*pass_function=*/pass_func,       //
                           /*opt_level=*/0,                   //
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index de15f7083a..0bae6801ca 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -26,7 +26,6 @@ from tvm.contrib.pickle_memoize import memoize
 from tvm.relax.backend import get_patterns_with_prefix
 from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
 from tvm.script import relax as R
-from tvm.script import tir as T
 from tvm.script.ir_builder import IRBuilder
 from tvm.script.ir_builder import relax as relax_builder
 
@@ -296,6 +295,43 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, epilogue, residual_bloc
     tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
 
 
+def test_cutlass_partition_conv2d_residual_blocked():
+    @tvm.script.ir_module
+    class Conv2dReLU:
+        """
+        This conv2d should not be fused as conv2d residual block, because both lhs and rhs of
+        the last R.add depends on the result of conv2d.
+        """
+
+        @R.function
+        def main(
+            data: R.Tensor((32, 3, 3, 16), "float32"),
+            weight: R.Tensor((16, 3, 3, 16), "float32"),
+            bias: R.Tensor((1, 1, 1, 16), "float32"),
+        ):
+            with R.dataflow():
+                conv1 = R.nn.conv2d(
+                    data,
+                    weight,
+                    padding=(1, 1),
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                )
+                out = R.nn.relu(conv1 + bias)
+                # residual depends on conv result, which cannot be handled in cutlass
+                result = out + out
+                R.output(result)
+
+            return result
+
+    mod = partition_for_cutlass(Conv2dReLU, annotate_codegen=False)
+    for f_var in mod.functions:
+        func = mod[f_var]
+        if func.attrs and "Composite" in func.attrs:
+            # verify that the function is not fused as residual block
+            assert func.attrs["Composite"] == "cutlass.conv2d_bias_relu"
+
+
 @pytest.mark.parametrize(
     "x_shape, y_shape, transpose_y, epilogue, residual_block",
     [
@@ -451,6 +487,7 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype):
     mod = get_relax_matmul_module(
         x_shape, y_shape, dtype, with_bias=False, transposed_y=transpose_y
     )
+    mod = partition_for_cutlass(mod)
 
     assert len(mod.functions) == 1
 
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 3816e11bc5..2f3e2d479f 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -14,14 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import pytest
 import numpy as np
+import pytest
 
 import tvm
-
 from tvm import relax
-from tvm.script import relax as R, tir as T, ir as I
-from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op, wildcard
+from tvm.relax.dpl.pattern import is_op, make_fused_bias_activation_pattern, wildcard
+from tvm.relax.transform import PatternCheckContext
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 @tvm.script.ir_module
@@ -600,13 +602,23 @@ def test_unused():
 
 
 def test_check_pattern():
-    pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None)
-
-    def pred(match, expr):
+    lhs = wildcard()
+    rhs = wildcard()
+    out = is_op("relax.nn.conv2d")(lhs, rhs)
+    annotation_patterns = {"root": out, "lhs": lhs, "rhs": rhs}
+
+    def pred(context: PatternCheckContext):
+        lhs = context.annotated_expr["lhs"]
+        rhs = context.annotated_expr["rhs"]
+        expr = context.annotated_expr["root"]
+        assert isinstance(lhs, relax.expr.Var) and lhs.name_hint == "data"
+        assert isinstance(rhs, relax.expr.Var) and rhs.name_hint == "weight1"
         assert isinstance(expr, relax.expr.Call) and expr.op.name == "relax.nn.conv2d"
-        return expr.struct_info.dtype == "float32"
+        return False
 
-    check(Conv2dx2, [("cutlass.conv2d", pat, pred)], Conv2dx2)  # expect no partitioning
+    check(
+        Conv2dReLU, [("cutlass.conv2d", out, annotation_patterns, pred)], Conv2dReLU
+    )  # expect no partitioning
 
 
 def test_bind_constants():