You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/01 08:56:56 UTC

[tvm] branch main updated: Enable conv family fused with mish (#12228)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a49273e050 Enable conv family fused with mish (#12228)
a49273e050 is described below

commit a49273e05092480bde8593c6a137bb251b5dee6c
Author: billishyahao <ya...@intel.com>
AuthorDate: Mon Aug 1 16:56:49 2022 +0800

    Enable conv family fused with mish (#12228)
---
 python/tvm/relay/op/contrib/dnnl.py           | 11 +++++++++--
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc |  4 ++++
 tests/python/contrib/test_dnnl.py             | 15 +++++++++++----
 3 files changed, 24 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index f17b325dce..46c20e947f 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -53,7 +53,7 @@ from .register import register_pattern_table
 
 
 logger = logging.getLogger("DNNL")
-supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
+supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None]
 
 
 def _register_external_op_helper(op_name, supported=True):
@@ -137,6 +137,13 @@ def append_eltwise_ops(op, eltwise):
     elif eltwise == "swish":
         sig_out = is_op("sigmoid")(op)
         op = is_op("multiply")(op, sig_out)
+    elif eltwise == "mish":
+        const1 = wildcard()
+        exp = is_op("exp")(op)
+        add = is_op("add")(exp, const1)
+        log = is_op("log")(add)
+        tanh = is_op("tanh")(log)
+        op = is_op("multiply")(op, tanh)
     elif eltwise:
         op = is_op(eltwise)(op)
     return op
@@ -411,7 +418,7 @@ def pattern_table():
         )
     )
 
-    elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
+    elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None]
     for with_bias in [True, False]:
         for elt in elt_list:
             if not with_bias and not elt:
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index 1fe8fccc77..d019f4e811 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -191,6 +191,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     std::regex gelu_pat(".*_gelu.*");
     std::regex swish_pat(".*_swish.*");
     std::regex sum_pat(".*_sum.*");
+    std::regex mish_pat(".*_mish.*");
 
     // parsing of name to extract attributes
     auto op_name = nodes_[nid].GetOpName();
@@ -220,6 +221,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     if (std::regex_match(op_name, gelu_pat)) {
       ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
     }
+    if (std::regex_match(op_name, mish_pat)) {
+      ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f);
+    }
     if (ops.len() != 0) {
       attr.set_post_ops(ops);
     }
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 74d0da1238..8de8bd9ce6 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -252,6 +252,13 @@ def add_activation(activation, out, dic, param_lst):
     elif activation == "gelu":
         out = gelu_helper(out)
         return out, dic, param_lst
+    elif activation == "mish":
+        exp = relay.exp(out)
+        add = relay.add(exp, relay.const(1.0))
+        log = relay.log(add)
+        tanh = relay.tanh(log)
+        out = relay.multiply(out, tanh)
+        return out, dic, param_lst
     else:
         return out, dic, param_lst
 
@@ -765,7 +772,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"):
 def test_conv2d_pattern(run_module, dtype="float32"):
     x_shape = (1, 32, 8, 8)
     k_shape = (16, 32, 3, 3)
-    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"]
     for a in activation_lst:
         conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
         conv2d = tvm.IRModule.from_expr(conv2d)
@@ -849,7 +856,7 @@ def test_conv2d_transpose(run_module, dtype="float32"):
 
 
 def test_conv2d_transpose_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"]
     for a in activation_lst:
         conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype)
         conv2d = tvm.IRModule.from_expr(conv2d)
@@ -882,7 +889,7 @@ def test_conv3d(run_module, dtype="float32"):
 
 
 def test_conv3d_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"]
     for a in activation_lst:
         conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype)
         conv3d = tvm.IRModule.from_expr(conv3d)
@@ -915,7 +922,7 @@ def test_conv3d_transpose(run_module, dtype="float32"):
 
 
 def test_conv3d_transpose_pattern(run_module, dtype="float32"):
-    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
+    activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"]
     for a in activation_lst:
         conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype)
         conv3d = tvm.IRModule.from_expr(conv3d)