You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/01 01:32:57 UTC

[tvm] branch main updated: [BYOC-DNNL] add post_sum pattern (#12151)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c07d77f99c [BYOC-DNNL] add post_sum pattern (#12151)
c07d77f99c is described below

commit c07d77f99c024b9e2c162b574482dbbbd71d4680
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Mon Aug 1 09:32:52 2022 +0800

    [BYOC-DNNL] add post_sum pattern (#12151)
    
    * add post_sum pattern
    
    * add checkers for sum pattern
    
    * fix lint
    
    * fix error in test_pass_partition_graph
    
    * fix lint error
---
 python/tvm/relay/op/contrib/dnnl.py             | 106 +++++++++++++++++++++++-
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc   |  14 +++-
 tests/python/contrib/test_dnnl.py               |  42 ++++++++++
 tests/python/relay/test_pass_partition_graph.py |   6 +-
 4 files changed, 164 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index fa98ed002c..f17b325dce 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -33,8 +33,10 @@ it is supported. For example:
 check the attributes of the op and decide if it should be offloaded to DNNL.
 """
 import logging
+from functools import reduce
 
 import tvm.ir
+from tvm.ir import Op
 from tvm import relay
 from tvm.relay import transform
 from tvm.relay.expr import GlobalVar
@@ -44,7 +46,7 @@ from tvm.relay.expr import const
 from tvm.relay.analysis import analysis as _analysis
 from tvm.relay import expr as _expr
 
-
+from tvm.relay.expr import Call, TupleGetItem
 from ... import _ffi_api
 from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
 from .register import register_pattern_table
@@ -167,6 +169,94 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
     return append_eltwise_ops(conv_out, with_eltwise)
 
 
+def make_conv_bias_sum_relu_pattern(conv_type, has_relu=True):
+    """Create patterns with sum op.
+
+    Parameters
+    ----------
+    conv_type : str
+        Should be nn.conv1d / nn.conv2d / nn.conv3d.
+    has_relu : bool
+        Whether attach relu.
+    Returns
+    -------
+    out : CallPattern
+        Call node sequence.
+    """
+    data1 = wildcard()
+    weight = wildcard()
+    bias = wildcard()
+    data2 = wildcard()
+    out = is_op(conv_type)(data1, weight)
+    out = is_op("add")(out, bias)
+    out = is_op("add")(out, data2)
+    if has_relu:
+        out = is_op("nn.relu")(out)
+    return out
+
+
+def get_op_name(expr):
+    """Get the operator name from an expression."""
+    if isinstance(expr, Op):
+        return expr.name
+    if isinstance(expr, Call):
+        return get_op_name(expr.op)
+    if isinstance(expr, TupleGetItem):
+        return get_op_name(expr.tuple_value)
+    if isinstance(expr, relay.Tuple):
+        return get_op_name(expr.fields[0])
+    return ""
+
+
+def get_args(expr):
+    """Get the arguments from an expression."""
+    if isinstance(expr, Call):
+        return expr.args
+    if isinstance(expr, TupleGetItem):
+        return get_args(expr.tuple_value)
+    if isinstance(expr, relay.Tuple):
+        return [arg for args in map(get_args, expr.fields) for arg in args]
+    return []
+
+
+def get_attrs(expr):
+    """Get the attributes from an expression."""
+    if isinstance(expr, Call):
+        return expr.attrs
+    if isinstance(expr, TupleGetItem):
+        return get_attrs(expr.tuple_value)
+    return {}
+
+
+def make_predicate(checker):
+    """Check whether the conv_bias_add_sum pattern is as expected."""
+
+    def predicate(expr):
+        if get_op_name(expr) == "nn.relu":
+            expr = expr.args[0]
+        for e, op_name in zip([expr, expr.args[0]], ["sum", "bias_add"]):
+            args = get_args(e)
+            attrs = get_attrs(e.args[0])
+            if not checker(attrs, args, op_name):
+                return False
+        return True
+
+    return predicate
+
+
+def add_checker(attrs, args, op_name):
+    """Check if add is supported by DNNL."""
+    if op_name == "sum":
+        if tuple(get_shape(args[0])) != tuple(get_shape(args[1])):
+            return False
+    if op_name == "bias_add":
+        channel = dict(attrs)["channels"]
+        const_shape = get_shape(args[1])
+        if channel != reduce(lambda x, y: x * y, const_shape):
+            return False
+    return True
+
+
 def make_dense_pattern(with_bias=True, with_eltwise=None):
     """Create patterns related to nn.dense.
 
@@ -306,6 +396,20 @@ def pattern_table():
     dnnl_patterns = list()
     dnnl_patterns.append(make_qnn_conv2d_pattern())
     dnnl_patterns.append(make_qnn_dense_pattern())
+    dnnl_patterns.append(
+        (
+            "dnnl.conv2d_bias_sum_relu",
+            make_conv_bias_sum_relu_pattern("nn.conv2d"),
+            make_predicate(add_checker),
+        )
+    )
+    dnnl_patterns.append(
+        (
+            "dnnl.conv2d_bias_sum",
+            make_conv_bias_sum_relu_pattern("nn.conv2d", False),
+            make_predicate(add_checker),
+        )
+    )
 
     elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
     for with_bias in [True, False]:
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index dcf1a86785..1fe8fccc77 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -182,8 +182,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
 
     if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;
 
-    // parsing of name to extract attributes
-    auto op_name = nodes_[nid].GetOpName();
     // Define RegExp.
     std::regex bias_add_pat(".*_bias.*");
     std::regex relu_pat(".*_relu.*");
@@ -192,9 +190,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     std::regex clip_pat(".*_clip.*");
     std::regex gelu_pat(".*_gelu.*");
     std::regex swish_pat(".*_swish.*");
+    std::regex sum_pat(".*_sum.*");
+
+    // parsing of name to extract attributes
+    auto op_name = nodes_[nid].GetOpName();
 
     // Parsing post-ops.
     dnnl::post_ops ops;
+    if (std::regex_match(op_name, sum_pat)) {
+      ops.append_sum(1.f);
+    }
     if (std::regex_match(op_name, relu_pat)) {
       ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
     }
@@ -280,6 +285,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
 
   void Convolution(const size_t& nid) {
     auto node = nodes_[nid];
+    auto op_name = nodes_[nid].GetOpName();
 
     // Setup attributes.
     auto src_tr = GetInput(nid, 0);
@@ -361,6 +367,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
 
     // TODO(@apeskov): Simulation of inplace primitive. just as PoC.
     auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout);
+    if (op_name.find("_sum") != std::string::npos) {
+      sum_in_tr = GetInput(nid, node.GetInputs().size() - 1);
+      sum_in_tr = sum_in_tr.TreatAs(dst_layout);
+    }
 
     Submit(dnnl::convolution_forward(conv_prim_desc),
            {{DNNL_ARG_SRC, src_tr},
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index e744cab6e9..74d0da1238 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -788,6 +788,48 @@ def test_conv2d_pattern(run_module, dtype="float32"):
     run_and_verify_func(config, run_module=run_module, dtype=dtype)
 
 
+def test_conv2d_bias_sum_relu(run_module, dtype="float32"):
+    x_shape = (1, 32, 8, 8)
+    k_shape = (16, 32, 3, 3)
+
+    def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"):
+        out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
+        beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
+        gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
+        moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype))
+        moving_var = relay.const(np.ones(k_shape[0]).astype(dtype))
+        out, _, _ = relay.nn.batch_norm(
+            out,
+            gamma=gamma,
+            beta=beta,
+            moving_mean=moving_mean,
+            moving_var=moving_var,
+            axis=1,
+            center=True,
+            scale=True,
+            epsilon=1e-5,
+        )
+        sum_data = relay.var("data1", shape=sum_shape, dtype=dtype)
+        out = relay.add(out, sum_data)
+        dic["data1"] = sum_shape
+        param_lst += ["data1"]
+        return relay.nn.relu(out), dic, param_lst
+
+    conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
+        x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype
+    )
+    conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
+    config = conv2d_bn_sum_relu, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+    conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
+        x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype
+    )
+    conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
+    config = conv2d_bn_sum_relu, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
 def test_conv2d_transpose(run_module, dtype="float32"):
     x_shape = (1, 32, 8, 8)
     for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]:
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 31aa4e4fe2..e796f60a9c 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -919,7 +919,11 @@ def test_mixed_single_multiple_outputs():
 
 def test_dnnl_fuse():
     dnnl_patterns = get_pattern_table("dnnl")
-    dnnl_pat_dic = dict(dnnl_patterns)
+    valid_pats = list()
+    for pattern in dnnl_patterns:
+        if len(pattern) == 2:
+            valid_pats.append(pattern)
+    dnnl_pat_dic = dict(valid_pats)
     (
         conv2d_bias_relu_pat,
         conv2d_bias_sigmoid_pat,