You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/01 08:56:56 UTC
[tvm] branch main updated: Enable conv family fused with mish (#12228)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a49273e050 Enable conv family fused with mish (#12228)
a49273e050 is described below
commit a49273e05092480bde8593c6a137bb251b5dee6c
Author: billishyahao <ya...@intel.com>
AuthorDate: Mon Aug 1 16:56:49 2022 +0800
Enable conv family fused with mish (#12228)
---
python/tvm/relay/op/contrib/dnnl.py | 11 +++++++++--
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 4 ++++
tests/python/contrib/test_dnnl.py | 15 +++++++++++----
3 files changed, 24 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index f17b325dce..46c20e947f 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -53,7 +53,7 @@ from .register import register_pattern_table
logger = logging.getLogger("DNNL")
-supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
+supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None]
def _register_external_op_helper(op_name, supported=True):
@@ -137,6 +137,13 @@ def append_eltwise_ops(op, eltwise):
elif eltwise == "swish":
sig_out = is_op("sigmoid")(op)
op = is_op("multiply")(op, sig_out)
+ elif eltwise == "mish":
+ const1 = wildcard()
+ exp = is_op("exp")(op)
+ add = is_op("add")(exp, const1)
+ log = is_op("log")(add)
+ tanh = is_op("tanh")(log)
+ op = is_op("multiply")(op, tanh)
elif eltwise:
op = is_op(eltwise)(op)
return op
@@ -411,7 +418,7 @@ def pattern_table():
)
)
- elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
+ elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None]
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index 1fe8fccc77..d019f4e811 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -191,6 +191,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex gelu_pat(".*_gelu.*");
std::regex swish_pat(".*_swish.*");
std::regex sum_pat(".*_sum.*");
+ std::regex mish_pat(".*_mish.*");
// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();
@@ -220,6 +221,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
+ if (std::regex_match(op_name, mish_pat)) {
+ ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f);
+ }
if (ops.len() != 0) {
attr.set_post_ops(ops);
}
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 74d0da1238..8de8bd9ce6 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -252,6 +252,13 @@ def add_activation(activation, out, dic, param_lst):
elif activation == "gelu":
out = gelu_helper(out)
return out, dic, param_lst
+ elif activation == "mish":
+ exp = relay.exp(out)
+ add = relay.add(exp, relay.const(1.0))
+ log = relay.log(add)
+ tanh = relay.tanh(log)
+ out = relay.multiply(out, tanh)
+ return out, dic, param_lst
else:
return out, dic, param_lst
@@ -765,7 +772,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"):
def test_conv2d_pattern(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)
- activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
@@ -849,7 +856,7 @@ def test_conv2d_transpose(run_module, dtype="float32"):
def test_conv2d_transpose_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
@@ -882,7 +889,7 @@ def test_conv3d(run_module, dtype="float32"):
def test_conv3d_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)
@@ -915,7 +922,7 @@ def test_conv3d_transpose(run_module, dtype="float32"):
def test_conv3d_transpose_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)