You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/17 00:01:21 UTC

[tvm] 03/18: dnnl pattern matching

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/rebase-09162022-autotensorization
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f6911e91091c276ee0edfc91e290ae794f58963c
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Thu Sep 1 21:46:53 2022 -0700

    dnnl pattern matching
---
 python/tvm/relay/op/contrib/dnnl.py | 64 +++++++++++++++++++++++++++----------
 1 file changed, 47 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index f7752e41b0..e27449ac43 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -36,22 +36,18 @@ import logging
 from functools import reduce
 
 import tvm.ir
-from tvm.ir import Op
 from tvm import relay
+from tvm.ir import Op
+from tvm.relay import expr as _expr
 from tvm.relay import transform
-from tvm.relay.expr import GlobalVar
-from tvm.relay.expr_functor import ExprMutator, ExprVisitor
-from tvm.relay.expr import const
-
 from tvm.relay.analysis import analysis as _analysis
-from tvm.relay import expr as _expr
+from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const
+from tvm.relay.expr_functor import ExprMutator, ExprVisitor
 
-from tvm.relay.expr import Call, TupleGetItem
 from ... import _ffi_api
-from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
+from ...dataflow_pattern import DFPatternCallback, is_constant, is_expr, is_op, rewrite, wildcard
 from .register import register_pattern_table
 
-
 logger = logging.getLogger("DNNL")
 supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None]
 
@@ -809,7 +805,7 @@ def prune_dnnl_subgraphs(mod):
     return new_mod
 
 
-class LayerNormRewrite(DFPatternCallback):
+class LayerNormRewritePattern1(DFPatternCallback):
     """
     A callback to rewrite the following operators into a single layer normalization operator.
 
@@ -826,7 +822,42 @@ class LayerNormRewrite(DFPatternCallback):
             /* ty=Tensor[(1, 3136, 64), float32] */;
     10   %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
             /* ty=Tensor[(1, 3136, 64), float32] */;
+    """
+
+    def __init__(self):
+        super(LayerNormRewritePattern1, self).__init__()
+        self.data = wildcard()
+        self.gamma = wildcard()
+        self.beta = wildcard()
+        mu = is_op("mean")(self.data)
+        diff = is_op("subtract")(self.data, mu)
+        cdiff = is_op("cast")(diff)
+        const_two = (
+            is_expr(relay.const(2))
+            | is_expr(relay.const(2.0))
+            | is_expr(relay.const(2.0, dtype="float16"))
+        )
+        p1 = is_op("power")(cdiff, const_two)
+        mp1 = is_op("mean")(p1)
+        eps = is_constant()  # TODO: check epsilon is something reasonable
+        added_eps = is_op("add")(mp1, eps)
+        deno = is_op("sqrt")(added_eps)
+        div_out = is_op("divide")(diff, deno)
+        div_out2 = diff * is_op("rsqrt")(added_eps)
+        weighted = is_op("multiply")(div_out | div_out2, self.gamma)
+        added_bias = is_op("add")(weighted, self.beta)
+        self.pattern = added_bias
 
+    def callback(self, pre, post, node_map):
+        data = node_map[self.data][0]
+        gamma = node_map[self.gamma][0]
+        beta = node_map[self.beta][0]
+        return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)
+
+
+class LayerNormRewritePattern2(DFPatternCallback):
+    """
+    A callback to rewrite the following operators into a single layer normalization operator.
     Pattern #2:
     1   %0 = mean(%input, axis=[-1], keepdims=True);
     2   %1 = variance(%input, %0, axis=[-1], keepdims=True);
@@ -842,19 +873,16 @@ class LayerNormRewrite(DFPatternCallback):
     """
 
     def __init__(self):
-        super(LayerNormRewrite, self).__init__()
+        super(LayerNormRewritePattern2, self).__init__()
         self.data = wildcard()
         self.gamma = wildcard()
         self.beta = wildcard()
         mu = is_op("mean")(self.data)
-        diff = is_op("subtract")(self.data, mu)
-        cdiff = diff | is_op("cast")(diff)
-        const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
-        p1 = is_op("power")(cdiff, const_two)
-        mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
+        mp1 = is_op("variance")(self.data, mu)
         eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6))
         added_eps = is_op("add")(mp1, eps)
         deno = is_op("sqrt")(added_eps)
+        diff = is_op("subtract")(self.data, mu)
         div_out = is_op("divide")(diff, deno)
         div_out2 = diff * is_op("rsqrt")(added_eps)
         weighted = is_op("multiply")(div_out | div_out2, self.gamma)
@@ -872,7 +900,9 @@ def rewrite_layer_norm(mod):
     """Rewrite the input graph to replace multiple operators with a TVM native layer normalization
     operator so that we can offload them to dnnl layer normalization byoc part.
     """
-    mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
+    mod["main"] = rewrite(LayerNormRewritePattern1(), mod["main"])
+    mod["main"] = rewrite(LayerNormRewritePattern2(), mod["main"])
+
     return mod