You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/01 01:32:57 UTC
[tvm] branch main updated: [BYOC-DNNL] add post_sum pattern (#12151)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c07d77f99c [BYOC-DNNL] add post_sum pattern (#12151)
c07d77f99c is described below
commit c07d77f99c024b9e2c162b574482dbbbd71d4680
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Mon Aug 1 09:32:52 2022 +0800
[BYOC-DNNL] add post_sum pattern (#12151)
* add post_sum pattern
* add checkers for sum pattern
* fix lint
* fix error in test_pass_partition_graph
* fix lint error
---
python/tvm/relay/op/contrib/dnnl.py | 106 +++++++++++++++++++++++-
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 14 +++-
tests/python/contrib/test_dnnl.py | 42 ++++++++++
tests/python/relay/test_pass_partition_graph.py | 6 +-
4 files changed, 164 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index fa98ed002c..f17b325dce 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -33,8 +33,10 @@ it is supported. For example:
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
import logging
+from functools import reduce
import tvm.ir
+from tvm.ir import Op
from tvm import relay
from tvm.relay import transform
from tvm.relay.expr import GlobalVar
@@ -44,7 +46,7 @@ from tvm.relay.expr import const
from tvm.relay.analysis import analysis as _analysis
from tvm.relay import expr as _expr
-
+from tvm.relay.expr import Call, TupleGetItem
from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table
@@ -167,6 +169,94 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
return append_eltwise_ops(conv_out, with_eltwise)
+def make_conv_bias_sum_relu_pattern(conv_type, has_relu=True):
+ """Create patterns with sum op.
+
+ Parameters
+ ----------
+ conv_type : str
+ Should be nn.conv1d / nn.conv2d / nn.conv3d.
+ has_relu : bool
+ Whether attach relu.
+ Returns
+ -------
+ out : CallPattern
+ Call node sequence.
+ """
+ data1 = wildcard()
+ weight = wildcard()
+ bias = wildcard()
+ data2 = wildcard()
+ out = is_op(conv_type)(data1, weight)
+ out = is_op("add")(out, bias)
+ out = is_op("add")(out, data2)
+ if has_relu:
+ out = is_op("nn.relu")(out)
+ return out
+
+
+def get_op_name(expr):
+ """Get the operator name from an expression."""
+ if isinstance(expr, Op):
+ return expr.name
+ if isinstance(expr, Call):
+ return get_op_name(expr.op)
+ if isinstance(expr, TupleGetItem):
+ return get_op_name(expr.tuple_value)
+ if isinstance(expr, relay.Tuple):
+ return get_op_name(expr.fields[0])
+ return ""
+
+
+def get_args(expr):
+ """Get the arguments from an expression."""
+ if isinstance(expr, Call):
+ return expr.args
+ if isinstance(expr, TupleGetItem):
+ return get_args(expr.tuple_value)
+ if isinstance(expr, relay.Tuple):
+ return [arg for args in map(get_args, expr.fields) for arg in args]
+ return []
+
+
+def get_attrs(expr):
+ """Get the attributes from an expression."""
+ if isinstance(expr, Call):
+ return expr.attrs
+ if isinstance(expr, TupleGetItem):
+ return get_attrs(expr.tuple_value)
+ return {}
+
+
+def make_predicate(checker):
+ """Check whether the conv_bias_add_sum pattern is as expected."""
+
+ def predicate(expr):
+ if get_op_name(expr) == "nn.relu":
+ expr = expr.args[0]
+ for e, op_name in zip([expr, expr.args[0]], ["sum", "bias_add"]):
+ args = get_args(e)
+ attrs = get_attrs(e.args[0])
+ if not checker(attrs, args, op_name):
+ return False
+ return True
+
+ return predicate
+
+
+def add_checker(attrs, args, op_name):
+ """Check if add is supported by DNNL."""
+ if op_name == "sum":
+ if tuple(get_shape(args[0])) != tuple(get_shape(args[1])):
+ return False
+ if op_name == "bias_add":
+ channel = dict(attrs)["channels"]
+ const_shape = get_shape(args[1])
+ if channel != reduce(lambda x, y: x * y, const_shape):
+ return False
+ return True
+
+
def make_dense_pattern(with_bias=True, with_eltwise=None):
"""Create patterns related to nn.dense.
@@ -306,6 +396,20 @@ def pattern_table():
dnnl_patterns = list()
dnnl_patterns.append(make_qnn_conv2d_pattern())
dnnl_patterns.append(make_qnn_dense_pattern())
+ dnnl_patterns.append(
+ (
+ "dnnl.conv2d_bias_sum_relu",
+ make_conv_bias_sum_relu_pattern("nn.conv2d"),
+ make_predicate(add_checker),
+ )
+ )
+ dnnl_patterns.append(
+ (
+ "dnnl.conv2d_bias_sum",
+ make_conv_bias_sum_relu_pattern("nn.conv2d", False),
+ make_predicate(add_checker),
+ )
+ )
elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
for with_bias in [True, False]:
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index dcf1a86785..1fe8fccc77 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -182,8 +182,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;
- // parsing of name to extract attributes
- auto op_name = nodes_[nid].GetOpName();
// Define RegExp.
std::regex bias_add_pat(".*_bias.*");
std::regex relu_pat(".*_relu.*");
@@ -192,9 +190,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");
std::regex swish_pat(".*_swish.*");
+ std::regex sum_pat(".*_sum.*");
+
+ // parsing of name to extract attributes
+ auto op_name = nodes_[nid].GetOpName();
// Parsing post-ops.
dnnl::post_ops ops;
+ if (std::regex_match(op_name, sum_pat)) {
+ ops.append_sum(1.f);
+ }
if (std::regex_match(op_name, relu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
}
@@ -280,6 +285,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
void Convolution(const size_t& nid) {
auto node = nodes_[nid];
+ auto op_name = nodes_[nid].GetOpName();
// Setup attributes.
auto src_tr = GetInput(nid, 0);
@@ -361,6 +367,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// TODO(@apeskov): Simulation of inplace primitive. just as PoC.
auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout);
+ if (op_name.find("_sum") != std::string::npos) {
+ sum_in_tr = GetInput(nid, node.GetInputs().size() - 1);
+ sum_in_tr = sum_in_tr.TreatAs(dst_layout);
+ }
Submit(dnnl::convolution_forward(conv_prim_desc),
{{DNNL_ARG_SRC, src_tr},
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index e744cab6e9..74d0da1238 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -788,6 +788,48 @@ def test_conv2d_pattern(run_module, dtype="float32"):
run_and_verify_func(config, run_module=run_module, dtype=dtype)
+def test_conv2d_bias_sum_relu(run_module, dtype="float32"):
+ x_shape = (1, 32, 8, 8)
+ k_shape = (16, 32, 3, 3)
+
+ def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"):
+ out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
+ beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
+ gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
+ moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype))
+ moving_var = relay.const(np.ones(k_shape[0]).astype(dtype))
+ out, _, _ = relay.nn.batch_norm(
+ out,
+ gamma=gamma,
+ beta=beta,
+ moving_mean=moving_mean,
+ moving_var=moving_var,
+ axis=1,
+ center=True,
+ scale=True,
+ epsilon=1e-5,
+ )
+ sum_data = relay.var("data1", shape=sum_shape, dtype=dtype)
+ out = relay.add(out, sum_data)
+ dic["data1"] = sum_shape
+ param_lst += ["data1"]
+ return relay.nn.relu(out), dic, param_lst
+
+ conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
+ x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype
+ )
+ conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
+ config = conv2d_bn_sum_relu, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+ conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
+ x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype
+ )
+ conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
+ config = conv2d_bn_sum_relu, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
def test_conv2d_transpose(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]:
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 31aa4e4fe2..e796f60a9c 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -919,7 +919,11 @@ def test_mixed_single_multiple_outputs():
def test_dnnl_fuse():
dnnl_patterns = get_pattern_table("dnnl")
- dnnl_pat_dic = dict(dnnl_patterns)
+ valid_pats = list()
+ for pattern in dnnl_patterns:
+ if len(pattern) == 2:
+ valid_pats.append(pattern)
+ dnnl_pat_dic = dict(valid_pats)
(
conv2d_bias_relu_pat,
conv2d_bias_sigmoid_pat,