You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/15 07:39:40 UTC

[tvm] branch main updated: Enable conv family fused with gelu (#12106)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fa057213d4 Enable conv family fused with gelu (#12106)
fa057213d4 is described below

commit fa057213d44271c09bf444bf4d6388c414f218c0
Author: billishyahao <ya...@intel.com>
AuthorDate: Fri Jul 15 15:39:33 2022 +0800

    Enable conv family fused with gelu (#12106)
---
 python/tvm/relay/op/contrib/dnnl.py           | 58 ++++++++++++++++-----------
 src/relay/backend/contrib/dnnl/codegen.cc     | 51 +++--------------------
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 10 ++---
 tests/python/contrib/test_dnnl.py             | 11 +++--
 4 files changed, 51 insertions(+), 79 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index 05416bb9a3..228619e0ef 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -108,6 +108,37 @@ _register_external_op_helper("nn.layer_norm")
 _register_external_op_helper("nn.batch_matmul")
 
 
+def append_eltwise_ops(op, eltwise):
+    """Append element-wise post-ops to conv / conv_transpose / dense
+
+    Parameters
+    ----------
+    op : str
+        The op name to be attached with element-wise post-op.
+    eltwise : str
+        The attached elementwise post-op name.
+    Returns
+    -------
+    pattern : CallPattern
+        Call node sequence.
+    """
+    if eltwise == "gelu":
+        const1 = wildcard()
+        const2 = wildcard()
+        const3 = wildcard()
+        div = is_op("divide")(op, const1)
+        erf_val = is_op("erf")(div)
+        added_erf_val = is_op("add")(erf_val, const2)
+        mul_val = is_op("multiply")(op, added_erf_val)
+        op = is_op("multiply")(mul_val, const3)
+    elif eltwise == "swish":
+        sig_out = is_op("sigmoid")(op)
+        op = is_op("multiply")(op, sig_out)
+    elif eltwise:
+        op = is_op(eltwise)(op)
+    return op
+
+
 def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
     """Create patterns related to conv and conv_transpose.
 
@@ -132,12 +163,7 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
         conv_out = is_op("add")(conv, bias)
     else:
         conv_out = conv
-    if with_eltwise == "swish":
-        sig_out = is_op("sigmoid")(conv_out)
-        conv_out = is_op("multiply")(conv_out, sig_out)
-    elif with_eltwise:
-        conv_out = is_op(with_eltwise)(conv_out)
-    return conv_out
+    return append_eltwise_ops(conv_out, with_eltwise)
 
 
 def make_dense_pattern(with_bias=True, with_eltwise=None):
@@ -165,21 +191,7 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
         dense_out = is_op("add")(dense, bias)
     else:
         dense_out = dense
-    if with_eltwise == "gelu":
-        const1 = wildcard()
-        const2 = wildcard()
-        const3 = wildcard()
-        div = is_op("divide")(dense_out, const1)
-        erf_val = is_op("erf")(div)
-        added_erf_val = is_op("add")(erf_val, const2)
-        mul_val = is_op("multiply")(dense_out, added_erf_val)
-        dense_out = is_op("multiply")(mul_val, const3)
-    elif with_eltwise == "swish":
-        sig_out = is_op("sigmoid")(dense_out)
-        dense_out = is_op("multiply")(dense_out, sig_out)
-    elif with_eltwise:
-        dense_out = is_op(with_eltwise)(dense_out)
-    return dense_out
+    return append_eltwise_ops(dense_out, with_eltwise)
 
 
 def make_dnnl_pattern(op_name, with_bias, with_eltwise):
@@ -203,7 +215,6 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
         pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::]
     pat_name += "_bias" if with_bias else ""
     pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
-    pat_name = pat_name.replace("_swish", "_sigmoid_mul")
     if "conv" in op_name:
         dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise))
     elif op_name == "nn.dense":
@@ -307,8 +318,7 @@ def pattern_table():
                 "nn.conv2d_transpose",
                 "nn.conv3d_transpose",
             ]:
-                if elt != "gelu":
-                    dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
+                dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
             dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt))
     return dnnl_patterns
 
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index 4abfc9d9b1..cbd11b4542 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -465,42 +465,6 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
   using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
   using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
 
-  std::map<std::string, std::string> op_map{
-      {"bias", "add"},
-      {"relu", "nn.relu"},
-      {"tanh", "tanh"},
-      {"sigmoid", "sigmoid"},
-      {"clip", "clip"},
-      {"mul", "multiply"},
-      {"nn.deconv2d", "nn.conv2d_transpose"},
-      {"nn.deconv3d", "nn.conv3d_transpose"},
-  };
-
-  std::vector<std::string> ParsingOpList(const std::string& pattern_name,
-                                         std::string interval = "_") {
-    ICHECK_NE(pattern_name, "");
-    std::vector<std::string> op_list;
-    size_t pos = 0, start = 0;
-
-    while ((pos = pattern_name.find(interval, start)) != std::string::npos) {
-      std::string op_name = pattern_name.substr(start, pos - start);
-      if (op_name.find("dnnl") != std::string::npos) {
-        op_name.replace(op_name.find("dnnl"), 4, "nn");
-        if (op_name.find("deconv") != std::string::npos) {
-          op_name = op_map[op_name];
-        }
-      } else {
-        op_name = op_map[op_name];
-      }
-      if (pos > start) op_list.push_back(op_name);
-      start = pos + interval.size();
-    }
-    if (pattern_name.size() > start) {
-      op_list.push_back(op_map[pattern_name.substr(start)]);
-    }
-    return op_list;
-  }
-
  public:
   DNNLJSONSerializer(const std::string& symbol, const Expr& expr)
       : JSONSerializer("dnnl_" + symbol, expr) {}
@@ -521,24 +485,19 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
       name = comp.value();
 
       if (name.find("dnnl.deconv2d") != std::string::npos) {
-        std::vector<std::string> op_list = ParsingOpList(name);
-        call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+        call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv2d_transpose");
         ICHECK(call->op.as<OpNode>()) << "Not op node";
       } else if (name.find("dnnl.deconv3d") != std::string::npos) {
-        std::vector<std::string> op_list = ParsingOpList(name);
-        call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+        call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv3d_transpose");
         ICHECK(call->op.as<OpNode>()) << "Not op node";
       } else if (name.find("dnnl.conv1d") != std::string::npos) {
-        std::vector<std::string> op_list = ParsingOpList(name);
-        call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+        call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv1d");
         ICHECK(call->op.as<OpNode>()) << "Not op node";
       } else if (name.find("dnnl.conv2d") != std::string::npos) {
-        std::vector<std::string> op_list = ParsingOpList(name);
-        call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+        call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv2d");
         ICHECK(call->op.as<OpNode>()) << "Not op node";
       } else if (name.find("dnnl.conv3d") != std::string::npos) {
-        std::vector<std::string> op_list = ParsingOpList(name);
-        call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+        call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv3d");
         ICHECK(call->op.as<OpNode>()) << "Not op node";
       } else if (name.find("dnnl.dense") != std::string::npos) {
         call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.dense");
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index c6e50eafea..93c53dda16 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -191,6 +191,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     std::regex sigmoid_pat(".*_sigmoid.*");
     std::regex clip_pat(".*_clip.*");
     std::regex gelu_pat(".*_gelu.*");
+    std::regex swish_pat(".*_swish.*");
 
     // Parsing post-ops.
     dnnl::post_ops ops;
@@ -206,11 +207,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
       ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
     }
     if (std::regex_match(op_name, sigmoid_pat)) {
-      if (op_name.find("_sigmoid_mul") != std::string::npos) {
-        ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
-      } else {
-        ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
-      }
+      ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
+    }
+    if (std::regex_match(op_name, swish_pat)) {
+      ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
     }
     if (std::regex_match(op_name, gelu_pat)) {
       ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index dfe1b7265d..1bf8068b2e 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -249,6 +249,9 @@ def add_activation(activation, out, dic, param_lst):
         sig_out = relay.sigmoid(out)
         out = relay.multiply(out, sig_out)
         return out, dic, param_lst
+    elif activation == "gelu":
+        out = gelu_helper(out)
+        return out, dic, param_lst
     else:
         return out, dic, param_lst
 
@@ -762,7 +765,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"):
 def test_conv2d_pattern(run_module, dtype="float32"):
     x_shape = (1, 32, 8, 8)
     k_shape = (16, 32, 3, 3)
-    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
     for a in activation_lst:
         conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
         conv2d = tvm.IRModule.from_expr(conv2d)
@@ -804,7 +807,7 @@ def test_conv2d_transpose(run_module, dtype="float32"):
 
 
 def test_conv2d_transpose_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
     for a in activation_lst:
         conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype)
         conv2d = tvm.IRModule.from_expr(conv2d)
@@ -837,7 +840,7 @@ def test_conv3d(run_module, dtype="float32"):
 
 
 def test_conv3d_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
     for a in activation_lst:
         conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype)
         conv3d = tvm.IRModule.from_expr(conv3d)
@@ -870,7 +873,7 @@ def test_conv3d_transpose(run_module, dtype="float32"):
 
 
 def test_conv3d_transpose_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
     for a in activation_lst:
         conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype)
         conv3d = tvm.IRModule.from_expr(conv3d)