You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/11 00:56:03 UTC
[tvm] branch main updated: [BYOC-DNNL] support more post-ops (#12002)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 927620e20f [BYOC-DNNL] support more post-ops (#12002)
927620e20f is described below
commit 927620e20fe226578fa0c32b9706d27874791b83
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Mon Jul 11 08:55:57 2022 +0800
[BYOC-DNNL] support more post-ops (#12002)
* support post-op swish
* support post-op clip
* enhance get_shape and get_dtype in dnnl.py to support efficientnet
* add checks for with_eltwise whether in supported list
* fix lint
* fix test
---
python/tvm/relay/op/contrib/dnnl.py | 22 ++++-
src/relay/backend/contrib/dnnl/codegen.cc | 9 ++
src/relay/backend/utils.h | 4 +
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 12 ++-
tests/python/contrib/test_dnnl.py | 126 ++++++------------------
tests/python/relay/test_pass_partition_graph.py | 26 ++++-
6 files changed, 95 insertions(+), 104 deletions(-)
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index b3ef478f20..9b6b45240a 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -51,6 +51,7 @@ from .register import register_pattern_table
logger = logging.getLogger("DNNL")
+supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
def _register_external_op_helper(op_name, supported=True):
@@ -120,6 +121,8 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
conv_out : CallPattern
Call node sequence.
"""
+ if with_eltwise not in supported_post_elts:
+ raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise)
data = wildcard()
weight = wildcard()
bias = wildcard()
@@ -128,8 +131,11 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
conv_out = is_op("add")(conv, bias)
else:
conv_out = conv
- if with_eltwise:
- return is_op(with_eltwise)(conv_out)
+ if with_eltwise == "swish":
+ sig_out = is_op("sigmoid")(conv_out)
+ conv_out = is_op("multiply")(conv_out, sig_out)
+ elif with_eltwise:
+ conv_out = is_op(with_eltwise)(conv_out)
return conv_out
@@ -147,6 +153,8 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
dense_out : CallPattern
Call node sequence.
"""
+ if with_eltwise not in supported_post_elts:
+ raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise)
data = wildcard()
weight = wildcard()
bias = wildcard()
@@ -165,6 +173,9 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
added_erf_val = is_op("add")(erf_val, const2)
mul_val = is_op("multiply")(dense_out, added_erf_val)
dense_out = is_op("multiply")(mul_val, const3)
+ elif with_eltwise == "swish":
+ sig_out = is_op("sigmoid")(dense_out)
+ dense_out = is_op("multiply")(dense_out, sig_out)
elif with_eltwise:
dense_out = is_op(with_eltwise)(dense_out)
return dense_out
@@ -191,6 +202,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::]
pat_name += "_bias" if with_bias else ""
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
+ pat_name = pat_name.replace("_swish", "_sigmoid_mul")
if "conv" in op_name:
dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise))
elif op_name == "nn.dense":
@@ -282,7 +294,7 @@ def pattern_table():
dnnl_patterns.append(make_qnn_conv2d_pattern())
dnnl_patterns.append(make_qnn_dense_pattern())
- elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None]
+ elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
@@ -380,6 +392,8 @@ def get_shape(tensor):
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].shape
if isinstance(tensor, relay.expr.Call):
+ if tensor.op.name == "multiply":
+ return tensor.type_args[0].shape
return tensor.checked_type.shape
raise TypeError("Unsupport data type: %s" % type(tensor))
@@ -395,6 +409,8 @@ def get_dtype(tensor):
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].dtype
if isinstance(tensor, relay.expr.Call):
+ if tensor.op.name == "multiply":
+ return tensor.type_args[0].dtype
return tensor.checked_type.dtype
raise TypeError("Unsupport data type: %s" % type(tensor))
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index 2f47c23a7c..4abfc9d9b1 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -470,6 +470,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
{"relu", "nn.relu"},
{"tanh", "tanh"},
{"sigmoid", "sigmoid"},
+ {"clip", "clip"},
+ {"mul", "multiply"},
{"nn.deconv2d", "nn.conv2d_transpose"},
{"nn.deconv3d", "nn.conv3d_transpose"},
};
@@ -566,6 +568,13 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);
SetCallNodeAttribute(node, call);
+ // If has post-op `clip`. Assume the last op is clip, add clip's attrs to the pattern attrs.
+ if (name.find("_clip") != std::string::npos) {
+ auto clip_call = cn->op.as<FunctionNode>()->body.as<CallNode>();
+ ICHECK(IsOp(clip_call, "clip"));
+ SetCallNodeAttribute(node, clip_call);
+ }
+ // For QNN.
for (const auto& kvp : extra_attrs) node->SetAttr(kvp.first, kvp.second);
return AddNode(node, GetRef<Expr>(cn));
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index d6fae8c72b..57c0661311 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -470,6 +470,10 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
current_call->args[valid_node_idx].as<VarNode>()) {
valid_node_idx++;
}
+ while (valid_node_idx < current_call->args.size() &&
+ !(IsOp(current_call->args[valid_node_idx].as<CallNode>(), expected_op_names[depth - 1]))) {
+ valid_node_idx++;
+ }
const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
return GetRootCall(next_call, depth - 1, expected_op_names);
}
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index a46f170fea..6c0fd64066 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -189,6 +189,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
+ std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");
// Parsing post-ops.
@@ -199,8 +200,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (std::regex_match(op_name, tanh_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
}
+ if (std::regex_match(op_name, clip_pat)) {
+ float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
+ float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
+ ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
+ }
if (std::regex_match(op_name, sigmoid_pat)) {
- ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
+ if (op_name.find("_sigmoid_mul") != std::string::npos) {
+ ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
+ } else {
+ ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
+ }
}
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 078483798c..6c7034741a 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -192,7 +192,6 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te
if use_dnnl:
processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
check_dnnl_used(processed_mod)
-
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=processed_mod, device=dev, target=target
@@ -237,6 +236,23 @@ def run_and_verify_func(
)
+def add_activation(activation, out, dic, param_lst):
+ if activation == "relu":
+ return relay.nn.relu(out), dic, param_lst
+ elif activation == "tanh":
+ return relay.tanh(out), dic, param_lst
+ elif activation == "sigmoid":
+ return relay.sigmoid(out), dic, param_lst
+ elif activation == "clip":
+ return relay.clip(out, 0.0, 6.0), dic, param_lst
+ elif activation == "swish":
+ sig_out = relay.sigmoid(out)
+ out = relay.multiply(out, sig_out)
+ return out, dic, param_lst
+ else:
+ return out, dic, param_lst
+
+
def get_conv1d(
x_shape=((1, 3, 224)),
k_shape=(16, 3, 3),
@@ -262,15 +278,7 @@ def get_conv1d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"):
@@ -279,15 +287,7 @@ def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dt
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"):
@@ -334,15 +334,7 @@ def get_conv2d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv2d_transpose(
@@ -367,15 +359,7 @@ def get_conv2d_transpose(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv2d_weights_const(
@@ -412,15 +396,7 @@ def get_conv2d_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv2d_transpose_bias(
@@ -431,15 +407,7 @@ def get_conv2d_transpose_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[1],)
param_lst += ["bias"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
@@ -503,15 +471,7 @@ def get_conv3d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv3d_transpose(
@@ -542,15 +502,7 @@ def get_conv3d_transpose(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv3d_bias(
@@ -561,15 +513,7 @@ def get_conv3d_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def get_conv3d_transpose_bias(
@@ -580,15 +524,7 @@ def get_conv3d_transpose_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[1],)
param_lst += ["bias"]
-
- if activation == "relu":
- return relay.nn.relu(out), dic, param_lst
- elif activation == "tanh":
- return relay.tanh(out), dic, param_lst
- elif activation == "sigmoid":
- return relay.sigmoid(out), dic, param_lst
- else:
- return out, dic, param_lst
+ return add_activation(activation, out, dic, param_lst)
def gelu_helper(data):
@@ -797,7 +733,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"):
def test_conv2d_pattern(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)
- activation_lst = [None, "relu", "tanh", "sigmoid"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
@@ -839,7 +775,7 @@ def test_conv2d_transpose(run_module, dtype="float32"):
def test_conv2d_transpose_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
@@ -872,7 +808,7 @@ def test_conv3d(run_module, dtype="float32"):
def test_conv3d_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)
@@ -905,7 +841,7 @@ def test_conv3d_transpose(run_module, dtype="float32"):
def test_conv3d_transpose_pattern(run_module, dtype="float32"):
- activation_lst = [None, "relu", "tanh", "sigmoid"]
+ activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 58b41189a0..4b7ac92136 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -919,6 +919,7 @@ def test_mixed_single_multiple_outputs():
def test_dnnl_fuse():
dnnl_patterns = get_pattern_table("dnnl")
+ dnnl_pat_dic = dict(dnnl_patterns)
(
conv2d_bias_relu_pat,
conv2d_bias_sigmoid_pat,
@@ -926,11 +927,26 @@ def test_dnnl_fuse():
conv2d_relu_pat,
conv2d_sigmoid_pat,
) = (
- dnnl_patterns[3],
- dnnl_patterns[15],
- dnnl_patterns[22],
- dnnl_patterns[28],
- dnnl_patterns[40],
+ (
+ "dnnl.conv2d_bias_relu",
+ dnnl_pat_dic["dnnl.conv2d_bias_relu"],
+ ),
+ (
+ "dnnl.conv2d_bias_sigmoid",
+ dnnl_pat_dic["dnnl.conv2d_bias_sigmoid"],
+ ),
+ (
+ "dnnl.conv2d_bias",
+ dnnl_pat_dic["dnnl.conv2d_bias"],
+ ),
+ (
+ "dnnl.conv2d_relu",
+ dnnl_pat_dic["dnnl.conv2d_relu"],
+ ),
+ (
+ "dnnl.conv2d_sigmoid",
+ dnnl_pat_dic["dnnl.conv2d_sigmoid"],
+ ),
)
def get_blocks(