You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/16 22:47:33 UTC
[tvm] 09/20: dnnl pattern matching
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch aluo/rebase-08312022-autotensorization-fq2i-changes
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 946815850b8f7b13a02fddb46e5cc7b7be01aa58
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Thu Sep 1 21:46:53 2022 -0700
dnnl pattern matching
---
python/tvm/relay/op/contrib/dnnl.py | 64 +++++++++++++++++++++++++++----------
1 file changed, 47 insertions(+), 17 deletions(-)
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index f7752e41b0..e27449ac43 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -36,22 +36,18 @@ import logging
from functools import reduce
import tvm.ir
-from tvm.ir import Op
from tvm import relay
+from tvm.ir import Op
+from tvm.relay import expr as _expr
from tvm.relay import transform
-from tvm.relay.expr import GlobalVar
-from tvm.relay.expr_functor import ExprMutator, ExprVisitor
-from tvm.relay.expr import const
-
from tvm.relay.analysis import analysis as _analysis
-from tvm.relay import expr as _expr
+from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const
+from tvm.relay.expr_functor import ExprMutator, ExprVisitor
-from tvm.relay.expr import Call, TupleGetItem
from ... import _ffi_api
-from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
+from ...dataflow_pattern import DFPatternCallback, is_constant, is_expr, is_op, rewrite, wildcard
from .register import register_pattern_table
-
logger = logging.getLogger("DNNL")
supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None]
@@ -809,7 +805,7 @@ def prune_dnnl_subgraphs(mod):
return new_mod
-class LayerNormRewrite(DFPatternCallback):
+class LayerNormRewritePattern1(DFPatternCallback):
"""
A callback to rewrite the following operators into a single layer normalization operator.
@@ -826,7 +822,42 @@ class LayerNormRewrite(DFPatternCallback):
/* ty=Tensor[(1, 3136, 64), float32] */;
10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 3136, 64), float32] */;
+ """
+
+ def __init__(self):
+ super(LayerNormRewritePattern1, self).__init__()
+ self.data = wildcard()
+ self.gamma = wildcard()
+ self.beta = wildcard()
+ mu = is_op("mean")(self.data)
+ diff = is_op("subtract")(self.data, mu)
+ cdiff = is_op("cast")(diff)
+ const_two = (
+ is_expr(relay.const(2))
+ | is_expr(relay.const(2.0))
+ | is_expr(relay.const(2.0, dtype="float16"))
+ )
+ p1 = is_op("power")(cdiff, const_two)
+ mp1 = is_op("mean")(p1)
+ eps = is_constant() # TODO: check epsilon is something reasonable
+ added_eps = is_op("add")(mp1, eps)
+ deno = is_op("sqrt")(added_eps)
+ div_out = is_op("divide")(diff, deno)
+ div_out2 = diff * is_op("rsqrt")(added_eps)
+ weighted = is_op("multiply")(div_out | div_out2, self.gamma)
+ added_bias = is_op("add")(weighted, self.beta)
+ self.pattern = added_bias
+ def callback(self, pre, post, node_map):
+ data = node_map[self.data][0]
+ gamma = node_map[self.gamma][0]
+ beta = node_map[self.beta][0]
+ return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)
+
+
+class LayerNormRewritePattern2(DFPatternCallback):
+ """
+ A callback to rewrite the following operators into a single layer normalization operator.
Pattern #2:
1 %0 = mean(%input, axis=[-1], keepdims=True);
2 %1 = variance(%input, %0, axis=[-1], keepdims=True);
@@ -842,19 +873,16 @@ class LayerNormRewrite(DFPatternCallback):
"""
def __init__(self):
- super(LayerNormRewrite, self).__init__()
+ super(LayerNormRewritePattern2, self).__init__()
self.data = wildcard()
self.gamma = wildcard()
self.beta = wildcard()
mu = is_op("mean")(self.data)
- diff = is_op("subtract")(self.data, mu)
- cdiff = diff | is_op("cast")(diff)
- const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
- p1 = is_op("power")(cdiff, const_two)
- mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
+ mp1 = is_op("variance")(self.data, mu)
eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6))
added_eps = is_op("add")(mp1, eps)
deno = is_op("sqrt")(added_eps)
+ diff = is_op("subtract")(self.data, mu)
div_out = is_op("divide")(diff, deno)
div_out2 = diff * is_op("rsqrt")(added_eps)
weighted = is_op("multiply")(div_out | div_out2, self.gamma)
@@ -872,7 +900,9 @@ def rewrite_layer_norm(mod):
"""Rewrite the input graph to replace multiple operators with a TVM native layer normalization
operator so that we can offload them to dnnl layer normalization byoc part.
"""
- mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
+ mod["main"] = rewrite(LayerNormRewritePattern1(), mod["main"])
+ mod["main"] = rewrite(LayerNormRewritePattern2(), mod["main"])
+
return mod