You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/29 07:34:17 UTC

[tvm] branch main updated: DNNL-BYOC enhancement (#9797)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 75cd670  DNNL-BYOC enhancement (#9797)
75cd670 is described below

commit 75cd670ae40451b8dc86536bcd05c5666bc3e954
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Wed Dec 29 15:33:55 2021 +0800

    DNNL-BYOC enhancement (#9797)
    
    * add unit test for byoc-dnnl
    
    * add byoc-dnnl pattern and their test cases
---
 python/tvm/relay/op/contrib/dnnl.py             | 140 +++++++++-
 src/relay/backend/contrib/dnnl/codegen.cc       |  18 ++
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc   |  99 ++++---
 tests/python/contrib/test_dnnl.py               | 350 ++++++++++++++++++++++++
 tests/python/relay/test_pass_partition_graph.py |  67 ++++-
 5 files changed, 622 insertions(+), 52 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index a2fdc19..05b5880 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -32,10 +32,17 @@ it is supported. For example:
 - The other way is to implement the function by themselves to
 check the attributes of the op and decide if it should be offloaded to DNNL.
 """
+import logging
+
 import tvm.ir
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+
 from ...dataflow_pattern import wildcard, is_op
 from .register import register_pattern_table
 
+logger = logging.getLogger("DNNL")
+
 
 def _register_external_op_helper(op_name, supported=True):
     """The helper function to indicate that a given operator can be supported
@@ -63,11 +70,26 @@ _register_external_op_helper("nn.batch_norm")
 _register_external_op_helper("nn.conv2d")
 _register_external_op_helper("nn.dense")
 _register_external_op_helper("nn.relu")
+_register_external_op_helper("tanh")
+_register_external_op_helper("sigmoid")
 _register_external_op_helper("add")
 _register_external_op_helper("multiply")
 
 
-def make_pattern(with_bias=True):
+def make_conv_pattern(with_bias=True, with_eltwise=None):
+    """Create patterns related to nn.conv2d.
+
+    Parameters
+    ----------
+    with_bias : bool
+        Whether attach `bias_add` to `nn.conv2d`.
+    with_eltwise : str
+        The attached elementwise post-op name.
+    Returns
+    -------
+    conv_out : CallPattern
+        Call node sequence.
+    """
     data = wildcard()
     weight = wildcard()
     bias = wildcard()
@@ -76,12 +98,120 @@ def make_pattern(with_bias=True):
         conv_out = is_op("add")(conv, bias)
     else:
         conv_out = conv
-    return is_op("nn.relu")(conv_out)
+    if with_eltwise:
+        return is_op(with_eltwise)(conv_out)
+    return conv_out
+
+
+def make_dense_pattern(with_bias=True, with_eltwise=None):
+    """Create patterns related to nn.dense.
+
+    Parameters
+    ----------
+    with_bias : bool
+        Whether attach `bias_add` to `nn.dense`.
+    with_eltwise : str
+        The attached elementwise post-op name.
+    Returns
+    -------
+    dense_out : CallPattern
+        Call node sequence.
+    """
+    data = wildcard()
+    weight = wildcard()
+    bias = wildcard()
+    dense = is_op("nn.dense")(data, weight)
+    if with_bias:
+        dense_out = is_op("add")(dense, bias)
+    else:
+        dense_out = dense
+    if with_eltwise:
+        dense_out = is_op(with_eltwise)(dense_out)
+    return dense_out
+
+
+def make_dnnl_pattern(op, with_bias, with_eltwise):
+    """Create dnnl patterns.
+
+    Parameters
+    ----------
+    op : str
+        The first call node's op name.
+    with_bias : bool
+        Whether attach `bias_add` to `nn.dense`.
+    with_eltwise : str
+        The attached elementwise post-op name.
+    Returns
+    -------
+    pattern : Tuple(pattern_name, CallPattern)
+        Created pattern name, along with its CallPattern.
+    """
+    pat_name = "dnnl." + op
+    pat_name += "_bias" if with_bias else ""
+    pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
+    if op == "conv2d":
+        dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise))
+    elif op == "dense":
+        dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
+    else:
+        logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op)
+        dnnl_pattern = ()
+    return dnnl_pattern
 
 
 @register_pattern_table("dnnl")
 def pattern_table():
-    conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
-    conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
-    dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
+    """Create dnnl patterns.
+
+    Returns
+    -------
+    dnnl_patterns : List[dnnl_pattern]
+        Created patterns.
+    """
+    elt_list = ["nn.relu", "tanh", "sigmoid", None]
+    dnnl_patterns = []
+    for with_bias in [True, False]:
+        for elt in elt_list:
+            if not with_bias and not elt:
+                return dnnl_patterns
+            dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt))
+            dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt))
     return dnnl_patterns
+
+
+def partition_for_dnnl(mod, params=None):
+    """Partition the graph greedily offloading supported operators to DNNL.
+
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : Optional[Dict[str, NDArray]]
+        Constant input parameters.
+    Returns
+    -------
+    mod : Module
+        Annotated and partitioned module.
+    """
+
+    if params:
+        mod["main"] = bind_params_by_name(mod["main"], params)
+    seq = tvm.transform.Sequential(
+        [
+            transform.CanonicalizeOps(),
+            transform.InferType(),
+            transform.SimplifyInference(),
+            transform.FoldConstant(),
+            transform.FoldScaleAxis(),
+            # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
+            transform.SimplifyExpr(),
+            transform.FoldConstant(),
+            transform.MergeComposite(pattern_table()),
+            transform.AnnotateTarget("dnnl"),
+            transform.MergeCompilerRegions(),
+            transform.PartitionGraph(),
+        ]
+    )
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    return mod
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index fa1dbc6..b1b2f58 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -455,9 +455,27 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
 
       if (name == "dnnl.conv2d_bias_relu") {
         call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
+      } else if (name == "dnnl.conv2d_bias_tanh") {
+        call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "tanh"});
+        ICHECK(call->op.as<OpNode>()) << "Not op node";
+      } else if (name == "dnnl.conv2d_bias_sigmoid") {
+        call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "sigmoid"});
+        ICHECK(call->op.as<OpNode>()) << "Not op node";
+      } else if (name == "dnnl.conv2d_bias") {
+        call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "add"});
+        ICHECK(call->op.as<OpNode>()) << "Not op node";
       } else if (name == "dnnl.conv2d_relu") {
         call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
         ICHECK(call->op.as<OpNode>()) << "Not op node";
+      } else if (name == "dnnl.conv2d_tanh") {
+        call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "tanh"});
+        ICHECK(call->op.as<OpNode>()) << "Not op node";
+      } else if (name == "dnnl.conv2d_sigmoid") {
+        call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "sigmoid"});
+        ICHECK(call->op.as<OpNode>()) << "Not op node";
+      } else if (name == "dnnl.dense_bias") {
+        call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.dense", "add"});
+        ICHECK(call->op.as<OpNode>()) << "Not op node";
       } else {
         LOG(FATAL) << "Unrecognized DNNL pattern: " << name;
       }
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index b32d137..f9f1961 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -103,15 +103,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
         if ("nn.conv2d" == op_name) {
           Conv2d(nid);
         } else if ("dnnl.conv2d_relu" == op_name) {
-          Conv2d(nid, true, false);
+          Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu);
+        } else if ("dnnl.conv2d_tanh" == op_name) {
+          Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh);
+        } else if ("dnnl.conv2d_sigmoid" == op_name) {
+          Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic);
+        } else if ("dnnl.conv2d_bias" == op_name) {
+          Conv2d(nid, false, true);
         } else if ("dnnl.conv2d_bias_relu" == op_name) {
-          Conv2d(nid, true, true);
+          Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu);
+        } else if ("dnnl.conv2d_bias_tanh" == op_name) {
+          Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh);
+        } else if ("dnnl.conv2d_bias_sigmoid" == op_name) {
+          Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic);
         } else if ("nn.dense" == op_name) {
           Dense(nid);
+        } else if ("dnnl.dense_bias" == op_name) {
+          Dense(nid, true);
         } else if ("nn.batch_norm" == op_name) {
           BatchNorm(nid);
         } else if ("nn.relu" == op_name) {
-          Relu(nid);
+          Eltwise(nid, dnnl::algorithm::eltwise_relu);
+        } else if ("tanh" == op_name) {
+          Eltwise(nid, dnnl::algorithm::eltwise_tanh);
+        } else if ("sigmoid" == op_name) {
+          Eltwise(nid, dnnl::algorithm::eltwise_logistic);
         } else if ("add" == op_name) {
           Binary(nid, dnnl::algorithm::binary_add);
         } else if ("multiply" == op_name) {
@@ -150,7 +166,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     return entry_out_mem_[eid].first;
   }
 
-  void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) {
+  void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false,
+              dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) {
     auto node = nodes_[nid];
 
     // Setup attributes.
@@ -159,24 +176,29 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
     dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
     std::vector<std::string> str_strides = node.GetAttr<std::vector<std::string>>("strides");
+    std::vector<std::string> str_dilates = node.GetAttr<std::vector<std::string>>("dilation");
     std::vector<std::string> str_padding = node.GetAttr<std::vector<std::string>>("padding");
     dnnl::memory::dim groups = std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
 
-    dnnl::memory::dim N = input_shape[0],       // batch size
-        IC = input_shape[1],                    // input channels
-        IH = input_shape[2],                    // input height
-        IW = input_shape[3],                    // input width
-        OC = weight_shape[0],                   // output channels
-        KH = weight_shape[2],                   // weight height
-        KW = weight_shape[3],                   // weight width
-        PW_L = std::stoi(str_padding[1]),       // width padding: left
-        PW_R = std::stoi(str_padding[3]),       // width padding: right
-        PH_L = std::stoi(str_padding[0]),       // height padding: top
-        PH_R = std::stoi(str_padding[2]),       // height padding: bottom
-        SH = std::stoi(str_strides[0]),         // height-wise stride
-        SW = std::stoi(str_strides[1]),         // weight-wise stride
-        OH = (IH - KH + PH_L + PH_R) / SH + 1,  // output height
-        OW = (IW - KW + PW_L + PW_R) / SW + 1;  // output width
+    dnnl::memory::dim N = input_shape[0],        // batch size
+        IC = input_shape[1],                     // input channels
+        IH = input_shape[2],                     // input height
+        IW = input_shape[3],                     // input width
+        OC = weight_shape[0],                    // output channels
+        KH = weight_shape[2],                    // weight height
+        KW = weight_shape[3],                    // weight width
+        PW_L = std::stoi(str_padding[1]),        // width padding: left
+        PW_R = std::stoi(str_padding[3]),        // width padding: right
+        PH_L = std::stoi(str_padding[0]),        // height padding: top
+        PH_R = std::stoi(str_padding[2]),        // height padding: bottom
+        SH = std::stoi(str_strides[0]),          // height-wise stride
+        SW = std::stoi(str_strides[1]),          // weight-wise stride
+        DH = std::stoi(str_dilates[0]) - 1,      // height-wise dilate
+        DW = std::stoi(str_dilates[1]) - 1,      // weight-wise dilate
+        DKH = 1 + (KH - 1) * (DH + 1),           // dilated weight height
+        DKW = 1 + (KW - 1) * (DW + 1),           // dilated weight width
+        OH = (IH - DKH + PH_L + PH_R) / SH + 1,  // output height
+        OW = (IW - DKW + PW_L + PW_R) / SW + 1;  // output width
 
     // Memory shapes.
     dnnl::memory::dims src_dims = {N, IC, IH, IW};
@@ -187,6 +209,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     dnnl::memory::dims bias_dims = {OC};
     dnnl::memory::dims dst_dims = {N, OC, OH, OW};
     dnnl::memory::dims strides_dims = {SH, SW};
+    dnnl::memory::dims dilates_dims = {DH, DW};
     dnnl::memory::dims padding_dims_l = {PH_L, PW_L};
     dnnl::memory::dims padding_dims_r = {PH_R, PW_R};
 
@@ -199,13 +222,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     // Covn2d description.
     auto conv_desc = dnnl::convolution_forward::desc(
         dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
-        conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, padding_dims_l, padding_dims_r);
+        conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l,
+        padding_dims_r);
 
-    // Enable ReLU
+    // Enable elementwise post-ops
     dnnl::primitive_attr attr;
-    if (has_relu) {
+    if (has_elt) {
       dnnl::post_ops ops;
-      ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
+      ops.append_eltwise(1.f, algo, 0.f, 0.f);
       attr.set_post_ops(ops);
     }
 
@@ -245,7 +269,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
                          {DNNL_ARG_DST, conv2d_dst_memory}});
   }
 
-  void Dense(const size_t& nid) {
+  void Dense(const size_t& nid, const bool has_bias = false) {
     auto node = nodes_[nid];
 
     // Setup attributes.
@@ -281,9 +305,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     // Memories.
     auto data_memory = BindDNNLMemory(data_entry, data_md);
     auto weight_memory = BindDNNLMemory(weight_entry, weight_md);
+
+    // Bias memory.
     auto bias_memory = dnnl::memory(bias_md, engine_);
-    float bias[OC] = {0};
-    write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
+    if (has_bias) {
+      auto bias_entry = node.GetInputs()[2];
+      BindDNNLMemory(bias_entry, bias_memory);
+    } else {
+      float bias[OC] = {0};
+      write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
+    }
+
+    // Output memory.
     JSONGraphNodeEntry out_entry(nid, 0);
     auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc());
 
@@ -335,20 +368,20 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
                          {DNNL_ARG_VARIANCE, variance_memory}});
   }
 
-  void Relu(const size_t& nid) {
+  void Eltwise(const size_t& nid, dnnl::algorithm algo) {
     auto node = nodes_[nid];
 
     auto data_entry = node.GetInputs()[0];
     dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
     dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);
 
-    auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference,
-                                                 dnnl::algorithm::eltwise_relu, data_md, 0);
-    auto relu_prim_desc = dnnl::eltwise_forward::primitive_desc(relu_desc, engine_);
-    ICHECK(data_md == relu_prim_desc.dst_desc());
+    auto elt_desc =
+        dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0);
+    auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_);
+    ICHECK(data_md == elt_prim_desc.dst_desc());
 
-    auto relu = dnnl::eltwise_forward(relu_prim_desc);
-    net_.push_back(relu);
+    auto elt = dnnl::eltwise_forward(elt_prim_desc);
+    net_.push_back(elt);
 
     auto data_memory = BindDNNLMemory(data_entry, data_md);
     JSONGraphNodeEntry out_entry(nid, 0);
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
new file mode 100755
index 0000000..7adf3e4
--- /dev/null
+++ b/tests/python/contrib/test_dnnl.py
@@ -0,0 +1,350 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+import itertools
+import tvm
+import tvm.relay.testing
+from tvm import relay
+from tvm.relay.op.contrib import dnnl
+import tvm.testing
+
+has_dnnl_codegen = pytest.mark.skipif(
+    not tvm.get_global_func("relay.ext.dnnl", True), reason="DNNL codegen not available"
+)
+
+run_module = tvm.testing.parameter(
+    pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]),
+    pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]),
+    ids=["compile", "run"],
+)
+
+
+def vmobj_to_list(o):
+    if isinstance(o, tvm.nd.NDArray):
+        return [o.numpy()]
+    elif isinstance(o, tvm.runtime.container.ADT) or isinstance(o, list):
+        return [vmobj_to_list(f) for f in o]
+    else:
+        raise RuntimeError("Unknown object type: %s" % type(o))
+
+
+def assert_result_dict_holds(result_dict):
+    for k1, k2 in itertools.combinations(result_dict, 2):
+        res1 = vmobj_to_list(result_dict[k1])
+        res2 = vmobj_to_list(result_dict[k2])
+        for r1, r2 in zip(res1, res2):
+            tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
+
+
+def run_and_verify(mod, input, params, target, run_module):
+    def check_dnnl_used(mod):
+        num_dnnl_subgraphs = sum(
+            [1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()]
+        )
+        assert num_dnnl_subgraphs >= 1
+
+    dev = tvm.cpu()
+    result_dict = dict()
+    for mode in ["graph", "vm"]:
+        for use_dnnl in [False, True]:
+            result_key = mode + ("_dnnl" if use_dnnl else "")
+            if use_dnnl:
+                mod = dnnl.partition_for_dnnl(mod, params)
+            with tvm.transform.PassContext(opt_level=3):
+                func = relay.create_executor(mode, mod=mod, device=dev, target=target).evaluate()
+            if run_module:
+                if isinstance(input, dict):
+                    result_dict[result_key] = func(**input, **params)
+                else:
+                    result_dict[result_key] = func(input, **params)
+
+    if run_module:
+        assert_result_dict_holds(result_dict)
+
+
+def run_and_verify_func(config, run_module, target="llvm", dtype="float32"):
+    """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs.
+
+    Parameters
+    ----------
+    config : Tuple[relay.Function, Dict[str, NDArray], List[str]]
+        A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and
+        3) A list of which vars should be considered params.
+
+    run_module: bool
+        If True, the built module will be run after being compiled.
+    """
+    f, input_shapes, is_param = config
+    params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype) for x in is_param}
+    input_dict = {
+        k: np.random.uniform(-1, 1, v).astype(dtype)
+        for k, v in input_shapes.items()
+        if k not in is_param
+    }
+    run_and_verify(f, input_dict, params, target, run_module)
+
+
+def get_conv2d(
+    x_shape=(1, 32, 8, 8),
+    k_shape=(16, 32, 3, 3),
+    groups=1,
+    padding=(0, 0),
+    strides=(1, 1),
+    dilation=(1, 1),
+    activation=None,
+    dtype="float32",
+):
+    x = relay.var("x", shape=(x_shape), dtype=dtype)
+    kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
+    out = relay.nn.conv2d(
+        x,
+        kernel,
+        kernel_size=k_shape[2:4],
+        groups=groups,
+        padding=padding,
+        strides=strides,
+        dilation=dilation,
+        channels=k_shape[0],
+    )
+    dic = {"x": x_shape, "kernel": k_shape}
+    param_lst = ["kernel"]
+
+    if activation == "relu":
+        return relay.nn.relu(out), dic, param_lst
+    elif activation == "tanh":
+        return relay.tanh(out), dic, param_lst
+    elif activation == "sigmoid":
+        return relay.sigmoid(out), dic, param_lst
+    else:
+        return out, dic, param_lst
+
+
+def get_conv2d_weights_const(
+    x_shape=(1, 32, 8, 8),
+    k_shape=(16, 32, 3, 3),
+    groups=1,
+    padding=(0, 0),
+    strides=(1, 1),
+    dilation=(1, 1),
+    dtype="float32",
+):
+    x = relay.var("x", shape=(x_shape), dtype=dtype)
+    kernel = relay.const(np.ones(k_shape).astype(dtype))
+    out = relay.nn.conv2d(
+        x,
+        kernel,
+        channels=k_shape[0],
+        kernel_size=k_shape[2:4],
+        groups=groups,
+        padding=padding,
+        strides=strides,
+        dilation=dilation,
+    )
+    dic = {"x": x_shape}
+    param_lst = []
+    return out, dic, param_lst
+
+
+def get_conv2d_bias(
+    x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), activation=None, dtype="float32"
+):
+    conv, dic, param_lst = get_conv2d(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
+    bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype)
+    out = relay.nn.bias_add(conv, bias)
+    dic["bias"] = (k_shape[0],)
+    param_lst += ["bias"]
+
+    if activation == "relu":
+        return relay.nn.relu(out), dic, param_lst
+    elif activation == "tanh":
+        return relay.tanh(out), dic, param_lst
+    elif activation == "sigmoid":
+        return relay.sigmoid(out), dic, param_lst
+    else:
+        return out, dic, param_lst
+
+
+def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
+    conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype)
+    beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
+    gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
+    moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype))
+    moving_var = relay.const(np.ones(k_shape[0]).astype(dtype))
+    conv2d_bias_bn, _, _ = relay.nn.batch_norm(
+        conv2d_bias,
+        gamma=gamma,
+        beta=beta,
+        moving_mean=moving_mean,
+        moving_var=moving_var,
+        axis=1,
+        center=True,
+        scale=True,
+        epsilon=1e-5,
+    )
+    return relay.nn.relu(conv2d_bias_bn), dic, param_lst
+
+
+def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"):
+    x = relay.var("x", shape=(x_shape), dtype=dtype)
+    kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
+    out = relay.nn.dense(x, kernel, units=k_shape[0])
+    dic = {"x": x_shape, "kernel": k_shape}
+    param_lst = ["kernel"]
+    return out, dic, param_lst
+
+
+def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"):
+    dense, dic, param_lst = get_dense(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
+    bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype)
+    out = relay.nn.bias_add(dense, bias)
+    dic["bias"] = (k_shape[0],)
+    param_lst += ["bias"]
+    return out, dic, param_lst
+
+
+def test_dnnl_not_compatible(run_module, target="llvm", dtype="float32"):
+    xshape = (1, 32, 14, 14)
+    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
+
+    x = relay.var("x", shape=(xshape), dtype=dtype)
+    y = relay.add(x, x)
+    z = relay.cast(relay.cast(y, "int32"), "float32")
+    out = relay.nn.relu(z)
+    f = relay.Function([x], out)
+    mod = tvm.IRModule()
+    mod["main"] = f
+    mod = dnnl.partition_for_dnnl(mod)
+    for mode in ["graph", "vm"]:
+        with tvm.transform.PassContext(opt_level=3):
+            func = relay.create_executor(mode, mod=mod, device=tvm.cpu(0), target=target).evaluate()
+            if run_module:
+                results = func(x_data)
+
+
+def test_multiple_outputs(run_module, dtype="float32"):
+    def get_graph():
+        x = relay.var("x", shape=(1, 3), dtype=dtype)
+        y = relay.var("y", shape=(1, 3), dtype=dtype)
+        z = relay.add(x, y)
+        w = relay.add(z, y)
+        out = relay.Tuple((z, w))
+        f = tvm.IRModule.from_expr(out)
+        return f, {"x": (1, 3), "y": (1, 3)}, []
+
+    run_and_verify_func(get_graph(), run_module=run_module, dtype=dtype)
+
+
+def test_unary(run_module):
+    def get_graph(op, x_shape=(1, 8, 3, 3)):
+        x = relay.var("x", shape=(x_shape), dtype="float32")
+        out = op(x)
+        f = tvm.IRModule.from_expr(out)
+        return f, {"x": x_shape}, []
+
+    for op in [
+        relay.nn.relu,
+        relay.tanh,
+        relay.sigmoid,
+    ]:
+        run_and_verify_func(get_graph(op), run_module=run_module)
+
+
+def test_conv2d(run_module, dtype="float32"):
+    x_shape = (1, 32, 8, 8)
+    for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]:
+        for padding in [(0, 0), (1, 1)]:
+            for strides in [(1, 1), (2, 2)]:
+                for dilation in [(1, 1), (2, 2)]:
+                    conv2d, dic, param_lst = get_conv2d(
+                        x_shape=x_shape,
+                        k_shape=k_shape,
+                        groups=groups,
+                        padding=padding,
+                        strides=strides,
+                        dilation=dilation,
+                        dtype=dtype,
+                    )
+                    conv2d = tvm.IRModule.from_expr(conv2d)
+                    config = conv2d, dic, param_lst
+                    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+def test_conv2d_weights_const(run_module, dtype="float32"):
+    x_shape = (1, 32, 8, 8)
+    k_shape = (16, 32, 3, 3)
+    conv2d, dic, param_lst = get_conv2d_weights_const(x_shape, k_shape, dtype=dtype)
+    conv2d = tvm.IRModule.from_expr(conv2d)
+    config = conv2d, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+def test_conv2d_pattern(run_module, dtype="float32"):
+    x_shape = (1, 32, 8, 8)
+    k_shape = (16, 32, 3, 3)
+    activation_lst = [None, "relu", "tanh", "sigmoid"]
+    for a in activation_lst:
+        conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
+        conv2d = tvm.IRModule.from_expr(conv2d)
+        config = conv2d, dic, param_lst
+        run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+        conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, activation=a, dtype=dtype)
+        conv2d_bias = tvm.IRModule.from_expr(conv2d_bias)
+        config = conv2d_bias, dic, param_lst
+        run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+    conv2d_bias_bn_relu, dic, param_lst = get_conv2d_bias_bn_relu(x_shape, k_shape, dtype=dtype)
+    conv2d_bias_bn_relu = tvm.IRModule.from_expr(conv2d_bias_bn_relu)
+    config = conv2d_bias_bn_relu, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+def test_dense(run_module, dtype="float32"):
+    x_shape = (1, 16)
+    k_shape = (32, 16)
+
+    dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype)
+    dense = tvm.IRModule.from_expr(dense)
+    config = dense, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+    dense, dic, param_lst = get_dense(x_shape, k_shape=(1, 16), dtype=dtype)
+    dense = tvm.IRModule.from_expr(dense)
+    config = dense, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+def test_dense_pattern(run_module, dtype="float32"):
+    x_shape = (1, 16)
+    k_shape = (32, 16)
+
+    dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype)
+    dense = tvm.IRModule.from_expr(dense)
+    config = dense, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+    dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, dtype=dtype)
+    dense_bias = tvm.IRModule.from_expr(dense_bias)
+    config = dense_bias, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+if __name__ == "__main__":
+    import sys
+
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 80fb2e0..736ece2 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -919,10 +919,25 @@ def test_mixed_single_multiple_outputs():
 
 def test_dnnl_fuse():
     dnnl_patterns = get_pattern_table("dnnl")
-    conv2d_bias_relu_pat, conv2d_relu_pat = dnnl_patterns
-
-    def get_blocks(prefix, data, in_channel, out_channel, include_bn=True, include_sigmoid=False):
+    (
+        conv2d_bias_relu_pat,
+        conv2d_bias_sigmoid_pat,
+        conv2d_bias_pat,
+        conv2d_relu_pat,
+        conv2d_sigmoid_pat,
+    ) = (dnnl_patterns[0], dnnl_patterns[4], dnnl_patterns[6], dnnl_patterns[8], dnnl_patterns[12])
+
+    def get_blocks(
+        prefix,
+        data,
+        in_channel,
+        out_channel,
+        include_bias_add=True,
+        include_bn=True,
+        include_sigmoid=False,
+    ):
         weight = relay.var(prefix + "weight")
+        bias = relay.var(prefix + "bias")
         bn_gamma = relay.var(prefix + "bn_gamma")
         bn_beta = relay.var(prefix + "bn_beta")
         bn_mmean = relay.var(prefix + "bn_mean")
@@ -931,6 +946,8 @@ def test_dnnl_fuse():
         layer = relay.nn.conv2d(
             data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1)
         )
+        if include_bias_add:
+            layer = relay.nn.bias_add(layer, bias)
         if include_bn:
             bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar)
             layer = bn_output[0]
@@ -940,11 +957,11 @@ def test_dnnl_fuse():
         layer = relay.nn.relu(layer)
         return layer
 
-    def get_net(include_bn=True, include_sigmoid=False):
+    def get_net(include_bias_add=True, include_bn=True, include_sigmoid=False):
         data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
-        block1 = get_blocks("block1_", data, 3, 8, include_bn, include_sigmoid)
+        block1 = get_blocks("block1_", data, 3, 8, include_bias_add, include_bn, include_sigmoid)
         # The second block is always conv + relu, to make it more interesting
-        block2 = get_blocks("block2_", block1, 8, 8, False, include_sigmoid)
+        block2 = get_blocks("block2_", block1, 8, 8, False, False, include_sigmoid)
         return relay.Function(relay.analysis.free_vars(block2), block2)
 
     def get_partitoned_mod(mod, params, pattern_table):
@@ -959,9 +976,18 @@ def test_dnnl_fuse():
                 transform.FoldScaleAxis(),
             ]
         )
+        # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
+        remove_linear_pass = tvm.transform.Sequential(
+            [
+                transform.SimplifyExpr(),
+                transform.FoldConstant(),
+            ]
+        )
         composite_partition = tvm.transform.Sequential(
             [
+                transform.CanonicalizeOps(),
                 remove_bn_pass,
+                remove_linear_pass,
                 transform.MergeComposite(pattern_table),
                 transform.AnnotateTarget("dnnl"),
                 transform.PartitionGraph(),
@@ -971,25 +997,38 @@ def test_dnnl_fuse():
         with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
             return composite_partition(mod)
 
-    def test_detect_pattern(pattern_table, include_bn, include_sigmoid, num_expected_partition):
-        net = get_net(include_bn, include_sigmoid)
+    def test_detect_pattern(
+        pattern_table, include_bias_add, include_bn, include_sigmoid, num_expected_partition
+    ):
+        net = get_net(include_bias_add, include_bn, include_sigmoid)
         mod, params = tvm.relay.testing.create_workload(net)
         mod = get_partitoned_mod(mod, params, pattern_table)
         assert len(mod.functions) - 1 == num_expected_partition  # -1 for main
 
     def test_partition():
         # conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
-        test_detect_pattern([conv2d_bias_relu_pat], True, False, 3)
+        test_detect_pattern([conv2d_bias_relu_pat], False, True, False, 3)
         # conv + bn + relu, conv + relu -> conv, bias, relu, and fused conv_relu
-        test_detect_pattern([conv2d_relu_pat], True, False, 4)
+        test_detect_pattern([conv2d_relu_pat], False, True, False, 4)
         # conv + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
-        test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, False, 2)
+        test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], False, True, False, 2)
+        # conv + bias_add + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
+        test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, False, 2)
         # conv + relu, conv + relu -> two fused conv_relu
-        test_detect_pattern([conv2d_relu_pat], False, False, 2)
+        test_detect_pattern([conv2d_relu_pat], False, False, False, 2)
         # conv + relu, conv + relu -> no fusion, 4 partition each with a single op
-        test_detect_pattern([conv2d_bias_relu_pat], False, False, 4)
+        test_detect_pattern([conv2d_bias_relu_pat], False, False, False, 4)
         # conv + bn + sigmoid + relu, conv + sigmoid + relu -> no fusion
-        test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, 5)
+        test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], False, True, True, 7)
+        # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias
+        # and single op sigmoid, relu, conv, sigmoid, relu
+        test_detect_pattern([conv2d_bias_pat, conv2d_relu_pat], True, True, True, 6)
+        # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias_sigmoid
+        # and single op relu, conv, sigmoid, relu
+        test_detect_pattern([conv2d_bias_sigmoid_pat, conv2d_relu_pat], True, True, True, 5)
+        # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias_sigmoid,
+        # fused conv_sigmoid and single op relu, relu
+        test_detect_pattern([conv2d_bias_sigmoid_pat, conv2d_sigmoid_pat], True, True, True, 4)
 
     def test_partition_mobilenet():
         mod, params = relay.testing.mobilenet.get_workload()