You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/11 00:56:03 UTC

[tvm] branch main updated: [BYOC-DNNL] support more post-ops (#12002)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 927620e20f [BYOC-DNNL] support more post-ops (#12002)
927620e20f is described below

commit 927620e20fe226578fa0c32b9706d27874791b83
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Mon Jul 11 08:55:57 2022 +0800

    [BYOC-DNNL] support more post-ops (#12002)
    
    * support post-op swish
    
    * support post-op clip
    
    * enhance get_shape and get_dtype in dnnl.py to support efficientnet
    
    * add checks for with_eltwise whether in supported list
    
    * fix lint
    
    * fix test
---
 python/tvm/relay/op/contrib/dnnl.py             |  22 ++++-
 src/relay/backend/contrib/dnnl/codegen.cc       |   9 ++
 src/relay/backend/utils.h                       |   4 +
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc   |  12 ++-
 tests/python/contrib/test_dnnl.py               | 126 ++++++------------------
 tests/python/relay/test_pass_partition_graph.py |  26 ++++-
 6 files changed, 95 insertions(+), 104 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index b3ef478f20..9b6b45240a 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -51,6 +51,7 @@ from .register import register_pattern_table
 
 
 logger = logging.getLogger("DNNL")
+supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
 
 
 def _register_external_op_helper(op_name, supported=True):
@@ -120,6 +121,8 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
     conv_out : CallPattern
         Call node sequence.
     """
+    if with_eltwise not in supported_post_elts:
+        raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise)
     data = wildcard()
     weight = wildcard()
     bias = wildcard()
@@ -128,8 +131,11 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
         conv_out = is_op("add")(conv, bias)
     else:
         conv_out = conv
-    if with_eltwise:
-        return is_op(with_eltwise)(conv_out)
+    if with_eltwise == "swish":
+        sig_out = is_op("sigmoid")(conv_out)
+        conv_out = is_op("multiply")(conv_out, sig_out)
+    elif with_eltwise:
+        conv_out = is_op(with_eltwise)(conv_out)
     return conv_out
 
 
@@ -147,6 +153,8 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
     dense_out : CallPattern
         Call node sequence.
     """
+    if with_eltwise not in supported_post_elts:
+        raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise)
     data = wildcard()
     weight = wildcard()
     bias = wildcard()
@@ -165,6 +173,9 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
         added_erf_val = is_op("add")(erf_val, const2)
         mul_val = is_op("multiply")(dense_out, added_erf_val)
         dense_out = is_op("multiply")(mul_val, const3)
+    elif with_eltwise == "swish":
+        sig_out = is_op("sigmoid")(dense_out)
+        dense_out = is_op("multiply")(dense_out, sig_out)
     elif with_eltwise:
         dense_out = is_op(with_eltwise)(dense_out)
     return dense_out
@@ -191,6 +202,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
         pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::]
     pat_name += "_bias" if with_bias else ""
     pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
+    pat_name = pat_name.replace("_swish", "_sigmoid_mul")
     if "conv" in op_name:
         dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise))
     elif op_name == "nn.dense":
@@ -282,7 +294,7 @@ def pattern_table():
     dnnl_patterns.append(make_qnn_conv2d_pattern())
     dnnl_patterns.append(make_qnn_dense_pattern())
 
-    elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None]
+    elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
     for with_bias in [True, False]:
         for elt in elt_list:
             if not with_bias and not elt:
@@ -380,6 +392,8 @@ def get_shape(tensor):
     if isinstance(tensor, tvm.ir.container.Array):
         return tensor[-1].shape
     if isinstance(tensor, relay.expr.Call):
+        if tensor.op.name == "multiply":
+            return tensor.type_args[0].shape
         return tensor.checked_type.shape
     raise TypeError("Unsupport data type: %s" % type(tensor))
 
@@ -395,6 +409,8 @@ def get_dtype(tensor):
     if isinstance(tensor, tvm.ir.container.Array):
         return tensor[-1].dtype
     if isinstance(tensor, relay.expr.Call):
+        if tensor.op.name == "multiply":
+            return tensor.type_args[0].dtype
         return tensor.checked_type.dtype
     raise TypeError("Unsupport data type: %s" % type(tensor))
 
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index 2f47c23a7c..4abfc9d9b1 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -470,6 +470,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
       {"relu", "nn.relu"},
       {"tanh", "tanh"},
       {"sigmoid", "sigmoid"},
+      {"clip", "clip"},
+      {"mul", "multiply"},
       {"nn.deconv2d", "nn.conv2d_transpose"},
       {"nn.deconv3d", "nn.conv3d_transpose"},
   };
@@ -566,6 +568,13 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
                                                 "kernel", /* op_type_ */
                                                 inputs, 1 /* num_outputs_ */);
     SetCallNodeAttribute(node, call);
+    // If has post-op `clip`. Assume the last op is clip, add clip's attrs to the pattern attrs.
+    if (name.find("_clip") != std::string::npos) {
+      auto clip_call = cn->op.as<FunctionNode>()->body.as<CallNode>();
+      ICHECK(IsOp(clip_call, "clip"));
+      SetCallNodeAttribute(node, clip_call);
+    }
+    // For QNN.
     for (const auto& kvp : extra_attrs) node->SetAttr(kvp.first, kvp.second);
 
     return AddNode(node, GetRef<Expr>(cn));
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index d6fae8c72b..57c0661311 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -470,6 +470,10 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
          current_call->args[valid_node_idx].as<VarNode>()) {
     valid_node_idx++;
   }
+  while (valid_node_idx < current_call->args.size() &&
+         !(IsOp(current_call->args[valid_node_idx].as<CallNode>(), expected_op_names[depth - 1]))) {
+    valid_node_idx++;
+  }
   const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
   return GetRootCall(next_call, depth - 1, expected_op_names);
 }
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index a46f170fea..6c0fd64066 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -189,6 +189,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     std::regex relu_pat(".*_relu.*");
     std::regex tanh_pat(".*_tanh.*");
     std::regex sigmoid_pat(".*_sigmoid.*");
+    std::regex clip_pat(".*_clip.*");
     std::regex gelu_pat(".*_gelu.*");
 
     // Parsing post-ops.
@@ -199,8 +200,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     if (std::regex_match(op_name, tanh_pat)) {
       ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
     }
+    if (std::regex_match(op_name, clip_pat)) {
+      float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
+      float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
+      ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
+    }
     if (std::regex_match(op_name, sigmoid_pat)) {
-      ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
+      if (op_name.find("_sigmoid_mul") != std::string::npos) {
+        ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
+      } else {
+        ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
+      }
     }
     if (std::regex_match(op_name, gelu_pat)) {
       ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 078483798c..6c7034741a 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -192,7 +192,6 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te
             if use_dnnl:
                 processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
                 check_dnnl_used(processed_mod)
-
             with tvm.transform.PassContext(opt_level=3):
                 func = relay.create_executor(
                     mode, mod=processed_mod, device=dev, target=target
@@ -237,6 +236,23 @@ def run_and_verify_func(
     )
 
 
+def add_activation(activation, out, dic, param_lst):
+    if activation == "relu":
+        return relay.nn.relu(out), dic, param_lst
+    elif activation == "tanh":
+        return relay.tanh(out), dic, param_lst
+    elif activation == "sigmoid":
+        return relay.sigmoid(out), dic, param_lst
+    elif activation == "clip":
+        return relay.clip(out, 0.0, 6.0), dic, param_lst
+    elif activation == "swish":
+        sig_out = relay.sigmoid(out)
+        out = relay.multiply(out, sig_out)
+        return out, dic, param_lst
+    else:
+        return out, dic, param_lst
+
+
 def get_conv1d(
     x_shape=((1, 3, 224)),
     k_shape=(16, 3, 3),
@@ -262,15 +278,7 @@ def get_conv1d(
     )
     dic = {"x": x_shape, "kernel": k_shape}
     param_lst = ["kernel"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"):
@@ -279,15 +287,7 @@ def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dt
     out = relay.nn.bias_add(conv, bias)
     dic["bias"] = (k_shape[0],)
     param_lst += ["bias"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"):
@@ -334,15 +334,7 @@ def get_conv2d(
     )
     dic = {"x": x_shape, "kernel": k_shape}
     param_lst = ["kernel"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv2d_transpose(
@@ -367,15 +359,7 @@ def get_conv2d_transpose(
     )
     dic = {"x": x_shape, "kernel": k_shape}
     param_lst = ["kernel"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv2d_weights_const(
@@ -412,15 +396,7 @@ def get_conv2d_bias(
     out = relay.nn.bias_add(conv, bias)
     dic["bias"] = (k_shape[0],)
     param_lst += ["bias"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv2d_transpose_bias(
@@ -431,15 +407,7 @@ def get_conv2d_transpose_bias(
     out = relay.nn.bias_add(conv, bias)
     dic["bias"] = (k_shape[1],)
     param_lst += ["bias"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
@@ -503,15 +471,7 @@ def get_conv3d(
     )
     dic = {"x": x_shape, "kernel": k_shape}
     param_lst = ["kernel"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv3d_transpose(
@@ -542,15 +502,7 @@ def get_conv3d_transpose(
     )
     dic = {"x": x_shape, "kernel": k_shape}
     param_lst = ["kernel"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv3d_bias(
@@ -561,15 +513,7 @@ def get_conv3d_bias(
     out = relay.nn.bias_add(conv, bias)
     dic["bias"] = (k_shape[0],)
     param_lst += ["bias"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def get_conv3d_transpose_bias(
@@ -580,15 +524,7 @@ def get_conv3d_transpose_bias(
     out = relay.nn.bias_add(conv, bias)
     dic["bias"] = (k_shape[1],)
     param_lst += ["bias"]
-
-    if activation == "relu":
-        return relay.nn.relu(out), dic, param_lst
-    elif activation == "tanh":
-        return relay.tanh(out), dic, param_lst
-    elif activation == "sigmoid":
-        return relay.sigmoid(out), dic, param_lst
-    else:
-        return out, dic, param_lst
+    return add_activation(activation, out, dic, param_lst)
 
 
 def gelu_helper(data):
@@ -797,7 +733,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"):
 def test_conv2d_pattern(run_module, dtype="float32"):
     x_shape = (1, 32, 8, 8)
     k_shape = (16, 32, 3, 3)
-    activation_lst = [None, "relu", "tanh", "sigmoid"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
     for a in activation_lst:
         conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
         conv2d = tvm.IRModule.from_expr(conv2d)
@@ -839,7 +775,7 @@ def test_conv2d_transpose(run_module, dtype="float32"):
 
 
 def test_conv2d_transpose_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
     for a in activation_lst:
         conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype)
         conv2d = tvm.IRModule.from_expr(conv2d)
@@ -872,7 +808,7 @@ def test_conv3d(run_module, dtype="float32"):
 
 
 def test_conv3d_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
     for a in activation_lst:
         conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype)
         conv3d = tvm.IRModule.from_expr(conv3d)
@@ -905,7 +841,7 @@ def test_conv3d_transpose(run_module, dtype="float32"):
 
 
 def test_conv3d_transpose_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
     for a in activation_lst:
         conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype)
         conv3d = tvm.IRModule.from_expr(conv3d)
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 58b41189a0..4b7ac92136 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -919,6 +919,7 @@ def test_mixed_single_multiple_outputs():
 
 def test_dnnl_fuse():
     dnnl_patterns = get_pattern_table("dnnl")
+    dnnl_pat_dic = dict(dnnl_patterns)
     (
         conv2d_bias_relu_pat,
         conv2d_bias_sigmoid_pat,
@@ -926,11 +927,26 @@ def test_dnnl_fuse():
         conv2d_relu_pat,
         conv2d_sigmoid_pat,
     ) = (
-        dnnl_patterns[3],
-        dnnl_patterns[15],
-        dnnl_patterns[22],
-        dnnl_patterns[28],
-        dnnl_patterns[40],
+        (
+            "dnnl.conv2d_bias_relu",
+            dnnl_pat_dic["dnnl.conv2d_bias_relu"],
+        ),
+        (
+            "dnnl.conv2d_bias_sigmoid",
+            dnnl_pat_dic["dnnl.conv2d_bias_sigmoid"],
+        ),
+        (
+            "dnnl.conv2d_bias",
+            dnnl_pat_dic["dnnl.conv2d_bias"],
+        ),
+        (
+            "dnnl.conv2d_relu",
+            dnnl_pat_dic["dnnl.conv2d_relu"],
+        ),
+        (
+            "dnnl.conv2d_sigmoid",
+            dnnl_pat_dic["dnnl.conv2d_sigmoid"],
+        ),
     )
 
     def get_blocks(