You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/15 07:39:40 UTC
[tvm] branch main updated: Enable conv family fused with gelu (#12106)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fa057213d4 Enable conv family fused with gelu (#12106)
fa057213d4 is described below
commit fa057213d44271c09bf444bf4d6388c414f218c0
Author: billishyahao <ya...@intel.com>
AuthorDate: Fri Jul 15 15:39:33 2022 +0800
Enable conv family fused with gelu (#12106)
---
python/tvm/relay/op/contrib/dnnl.py | 58 ++++++++++++++++-----------
src/relay/backend/contrib/dnnl/codegen.cc | 51 +++--------------------
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 10 ++---
tests/python/contrib/test_dnnl.py | 11 +++--
4 files changed, 51 insertions(+), 79 deletions(-)
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index 05416bb9a3..228619e0ef 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -108,6 +108,37 @@ _register_external_op_helper("nn.layer_norm")
_register_external_op_helper("nn.batch_matmul")
+def append_eltwise_ops(op, eltwise):
+ """Append element-wise post-ops to conv / conv_transpose / dense
+
+ Parameters
+ ----------
+ op : str
+ The op name to be attached with element-wise post-op.
+ eltwise : str
+ The attached elementwise post-op name.
+ Returns
+ -------
+ pattern : CallPattern
+ Call node sequence.
+ """
+ if eltwise == "gelu":
+ const1 = wildcard()
+ const2 = wildcard()
+ const3 = wildcard()
+ div = is_op("divide")(op, const1)
+ erf_val = is_op("erf")(div)
+ added_erf_val = is_op("add")(erf_val, const2)
+ mul_val = is_op("multiply")(op, added_erf_val)
+ op = is_op("multiply")(mul_val, const3)
+ elif eltwise == "swish":
+ sig_out = is_op("sigmoid")(op)
+ op = is_op("multiply")(op, sig_out)
+ elif eltwise:
+ op = is_op(eltwise)(op)
+ return op
+
+
def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
"""Create patterns related to conv and conv_transpose.
@@ -132,12 +163,7 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
conv_out = is_op("add")(conv, bias)
else:
conv_out = conv
- if with_eltwise == "swish":
- sig_out = is_op("sigmoid")(conv_out)
- conv_out = is_op("multiply")(conv_out, sig_out)
- elif with_eltwise:
- conv_out = is_op(with_eltwise)(conv_out)
- return conv_out
+ return append_eltwise_ops(conv_out, with_eltwise)
def make_dense_pattern(with_bias=True, with_eltwise=None):
@@ -165,21 +191,7 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
dense_out = is_op("add")(dense, bias)
else:
dense_out = dense
- if with_eltwise == "gelu":
- const1 = wildcard()
- const2 = wildcard()
- const3 = wildcard()
- div = is_op("divide")(dense_out, const1)
- erf_val = is_op("erf")(div)
- added_erf_val = is_op("add")(erf_val, const2)
- mul_val = is_op("multiply")(dense_out, added_erf_val)
- dense_out = is_op("multiply")(mul_val, const3)
- elif with_eltwise == "swish":
- sig_out = is_op("sigmoid")(dense_out)
- dense_out = is_op("multiply")(dense_out, sig_out)
- elif with_eltwise:
- dense_out = is_op(with_eltwise)(dense_out)
- return dense_out
+ return append_eltwise_ops(dense_out, with_eltwise)
def make_dnnl_pattern(op_name, with_bias, with_eltwise):
@@ -203,7 +215,6 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::]
pat_name += "_bias" if with_bias else ""
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
- pat_name = pat_name.replace("_swish", "_sigmoid_mul")
if "conv" in op_name:
dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise))
elif op_name == "nn.dense":
@@ -307,8 +318,7 @@ def pattern_table():
"nn.conv2d_transpose",
"nn.conv3d_transpose",
]:
- if elt != "gelu":
- dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
+ dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt))
return dnnl_patterns
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index 4abfc9d9b1..cbd11b4542 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -465,42 +465,6 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
- std::map<std::string, std::string> op_map{
- {"bias", "add"},
- {"relu", "nn.relu"},
- {"tanh", "tanh"},
- {"sigmoid", "sigmoid"},
- {"clip", "clip"},
- {"mul", "multiply"},
- {"nn.deconv2d", "nn.conv2d_transpose"},
- {"nn.deconv3d", "nn.conv3d_transpose"},
- };
-
- std::vector<std::string> ParsingOpList(const std::string& pattern_name,
- std::string interval = "_") {
- ICHECK_NE(pattern_name, "");
- std::vector<std::string> op_list;
- size_t pos = 0, start = 0;
-
- while ((pos = pattern_name.find(interval, start)) != std::string::npos) {
- std::string op_name = pattern_name.substr(start, pos - start);
- if (op_name.find("dnnl") != std::string::npos) {
- op_name.replace(op_name.find("dnnl"), 4, "nn");
- if (op_name.find("deconv") != std::string::npos) {
- op_name = op_map[op_name];
- }
- } else {
- op_name = op_map[op_name];
- }
- if (pos > start) op_list.push_back(op_name);
- start = pos + interval.size();
- }
- if (pattern_name.size() > start) {
- op_list.push_back(op_map[pattern_name.substr(start)]);
- }
- return op_list;
- }
-
public:
DNNLJSONSerializer(const std::string& symbol, const Expr& expr)
: JSONSerializer("dnnl_" + symbol, expr) {}
@@ -521,24 +485,19 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
name = comp.value();
if (name.find("dnnl.deconv2d") != std::string::npos) {
- std::vector<std::string> op_list = ParsingOpList(name);
- call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+ call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv2d_transpose");
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.deconv3d") != std::string::npos) {
- std::vector<std::string> op_list = ParsingOpList(name);
- call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+ call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv3d_transpose");
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.conv1d") != std::string::npos) {
- std::vector<std::string> op_list = ParsingOpList(name);
- call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+ call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv1d");
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.conv2d") != std::string::npos) {
- std::vector<std::string> op_list = ParsingOpList(name);
- call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+ call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv2d");
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.conv3d") != std::string::npos) {
- std::vector<std::string> op_list = ParsingOpList(name);
- call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
+ call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.conv3d");
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.dense") != std::string::npos) {
call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.dense");
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index c6e50eafea..93c53dda16 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -191,6 +191,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");
+ std::regex swish_pat(".*_swish.*");
// Parsing post-ops.
dnnl::post_ops ops;
@@ -206,11 +207,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
}
if (std::regex_match(op_name, sigmoid_pat)) {
- if (op_name.find("_sigmoid_mul") != std::string::npos) {
- ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
- } else {
- ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
- }
+ ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
+ }
+ if (std::regex_match(op_name, swish_pat)) {
+ ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
}
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index dfe1b7265d..1bf8068b2e 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -249,6 +249,9 @@ def add_activation(activation, out, dic, param_lst):
sig_out = relay.sigmoid(out)
out = relay.multiply(out, sig_out)
return out, dic, param_lst
+ elif activation == "gelu":
+ out = gelu_helper(out)
+ return out, dic, param_lst
else:
return out, dic, param_lst
@@ -762,7 +765,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"):
def test_conv2d_pattern(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)
- activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
@@ -804,7 +807,7 @@ def test_conv2d_transpose(run_module, dtype="float32"):
def test_conv2d_transpose_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
@@ -837,7 +840,7 @@ def test_conv3d(run_module, dtype="float32"):
def test_conv3d_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)
@@ -870,7 +873,7 @@ def test_conv3d_transpose(run_module, dtype="float32"):
def test_conv3d_transpose_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)