You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/29 07:34:17 UTC
[tvm] branch main updated: DNNL-BYOC enhancement (#9797)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 75cd670 DNNL-BYOC enhancement (#9797)
75cd670 is described below
commit 75cd670ae40451b8dc86536bcd05c5666bc3e954
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Wed Dec 29 15:33:55 2021 +0800
DNNL-BYOC enhancement (#9797)
* add unit test for byoc-dnnl
* add byoc-dnnl pattern and their test cases
---
python/tvm/relay/op/contrib/dnnl.py | 140 +++++++++-
src/relay/backend/contrib/dnnl/codegen.cc | 18 ++
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 99 ++++---
tests/python/contrib/test_dnnl.py | 350 ++++++++++++++++++++++++
tests/python/relay/test_pass_partition_graph.py | 67 ++++-
5 files changed, 622 insertions(+), 52 deletions(-)
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index a2fdc19..05b5880 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -32,10 +32,17 @@ it is supported. For example:
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
+import logging
+
import tvm.ir
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+
from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table
+logger = logging.getLogger("DNNL")
+
def _register_external_op_helper(op_name, supported=True):
"""The helper function to indicate that a given operator can be supported
@@ -63,11 +70,26 @@ _register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
+_register_external_op_helper("tanh")
+_register_external_op_helper("sigmoid")
_register_external_op_helper("add")
_register_external_op_helper("multiply")
-def make_pattern(with_bias=True):
+def make_conv_pattern(with_bias=True, with_eltwise=None):
+ """Create patterns related to nn.conv2d.
+
+ Parameters
+ ----------
+ with_bias : bool
+ Whether attach `bias_add` to `nn.conv2d`.
+ with_eltwise : str
+ The attached elementwise post-op name.
+ Returns
+ -------
+ conv_out : CallPattern
+ Call node sequence.
+ """
data = wildcard()
weight = wildcard()
bias = wildcard()
@@ -76,12 +98,120 @@ def make_pattern(with_bias=True):
conv_out = is_op("add")(conv, bias)
else:
conv_out = conv
- return is_op("nn.relu")(conv_out)
+ if with_eltwise:
+ return is_op(with_eltwise)(conv_out)
+ return conv_out
+
+
+def make_dense_pattern(with_bias=True, with_eltwise=None):
+ """Create patterns related to nn.dense.
+
+ Parameters
+ ----------
+ with_bias : bool
+ Whether attach `bias_add` to `nn.dense`.
+ with_eltwise : str
+ The attached elementwise post-op name.
+ Returns
+ -------
+ dense_out : CallPattern
+ Call node sequence.
+ """
+ data = wildcard()
+ weight = wildcard()
+ bias = wildcard()
+ dense = is_op("nn.dense")(data, weight)
+ if with_bias:
+ dense_out = is_op("add")(dense, bias)
+ else:
+ dense_out = dense
+ if with_eltwise:
+ dense_out = is_op(with_eltwise)(dense_out)
+ return dense_out
+
+
+def make_dnnl_pattern(op, with_bias, with_eltwise):
+ """Create dnnl patterns.
+
+ Parameters
+ ----------
+ op : str
+ The first call node's op name.
+ with_bias : bool
+ Whether attach `bias_add` to `nn.dense`.
+ with_eltwise : str
+ The attached elementwise post-op name.
+ Returns
+ -------
+ pattern : Tuple(pattern_name, CallPattern)
+ Created pattern name, along with its CallPattern.
+ """
+ pat_name = "dnnl." + op
+ pat_name += "_bias" if with_bias else ""
+ pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
+ if op == "conv2d":
+ dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise))
+ elif op == "dense":
+ dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
+ else:
+ logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op)
+ dnnl_pattern = ()
+ return dnnl_pattern
@register_pattern_table("dnnl")
def pattern_table():
- conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
- conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
- dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
+ """Create dnnl patterns.
+
+ Returns
+ -------
+ dnnl_patterns : List[dnnl_pattern]
+ Created patterns.
+ """
+ elt_list = ["nn.relu", "tanh", "sigmoid", None]
+ dnnl_patterns = []
+ for with_bias in [True, False]:
+ for elt in elt_list:
+ if not with_bias and not elt:
+ return dnnl_patterns
+ dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt))
+ dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt))
return dnnl_patterns
+
+
+def partition_for_dnnl(mod, params=None):
+ """Partition the graph greedily offloading supported operators to DNNL.
+
+ Parameters
+ ----------
+ mod : Module
+ The module to run passes on.
+ params : Optional[Dict[str, NDArray]]
+ Constant input parameters.
+ Returns
+ -------
+ mod : Module
+ Annotated and partitioned module.
+ """
+
+ if params:
+ mod["main"] = bind_params_by_name(mod["main"], params)
+ seq = tvm.transform.Sequential(
+ [
+ transform.CanonicalizeOps(),
+ transform.InferType(),
+ transform.SimplifyInference(),
+ transform.FoldConstant(),
+ transform.FoldScaleAxis(),
+ # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
+ transform.SimplifyExpr(),
+ transform.FoldConstant(),
+ transform.MergeComposite(pattern_table()),
+ transform.AnnotateTarget("dnnl"),
+ transform.MergeCompilerRegions(),
+ transform.PartitionGraph(),
+ ]
+ )
+ with tvm.transform.PassContext(opt_level=3):
+ mod = seq(mod)
+ return mod
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index fa1dbc6..b1b2f58 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -455,9 +455,27 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
if (name == "dnnl.conv2d_bias_relu") {
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
+ } else if (name == "dnnl.conv2d_bias_tanh") {
+ call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "tanh"});
+ ICHECK(call->op.as<OpNode>()) << "Not op node";
+ } else if (name == "dnnl.conv2d_bias_sigmoid") {
+ call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "sigmoid"});
+ ICHECK(call->op.as<OpNode>()) << "Not op node";
+ } else if (name == "dnnl.conv2d_bias") {
+ call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "add"});
+ ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_relu") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
+ } else if (name == "dnnl.conv2d_tanh") {
+ call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "tanh"});
+ ICHECK(call->op.as<OpNode>()) << "Not op node";
+ } else if (name == "dnnl.conv2d_sigmoid") {
+ call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "sigmoid"});
+ ICHECK(call->op.as<OpNode>()) << "Not op node";
+ } else if (name == "dnnl.dense_bias") {
+ call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.dense", "add"});
+ ICHECK(call->op.as<OpNode>()) << "Not op node";
} else {
LOG(FATAL) << "Unrecognized DNNL pattern: " << name;
}
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index b32d137..f9f1961 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -103,15 +103,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if ("nn.conv2d" == op_name) {
Conv2d(nid);
} else if ("dnnl.conv2d_relu" == op_name) {
- Conv2d(nid, true, false);
+ Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu);
+ } else if ("dnnl.conv2d_tanh" == op_name) {
+ Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh);
+ } else if ("dnnl.conv2d_sigmoid" == op_name) {
+ Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic);
+ } else if ("dnnl.conv2d_bias" == op_name) {
+ Conv2d(nid, false, true);
} else if ("dnnl.conv2d_bias_relu" == op_name) {
- Conv2d(nid, true, true);
+ Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu);
+ } else if ("dnnl.conv2d_bias_tanh" == op_name) {
+ Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh);
+ } else if ("dnnl.conv2d_bias_sigmoid" == op_name) {
+ Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic);
} else if ("nn.dense" == op_name) {
Dense(nid);
+ } else if ("dnnl.dense_bias" == op_name) {
+ Dense(nid, true);
} else if ("nn.batch_norm" == op_name) {
BatchNorm(nid);
} else if ("nn.relu" == op_name) {
- Relu(nid);
+ Eltwise(nid, dnnl::algorithm::eltwise_relu);
+ } else if ("tanh" == op_name) {
+ Eltwise(nid, dnnl::algorithm::eltwise_tanh);
+ } else if ("sigmoid" == op_name) {
+ Eltwise(nid, dnnl::algorithm::eltwise_logistic);
} else if ("add" == op_name) {
Binary(nid, dnnl::algorithm::binary_add);
} else if ("multiply" == op_name) {
@@ -150,7 +166,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
return entry_out_mem_[eid].first;
}
- void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) {
+ void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false,
+ dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) {
auto node = nodes_[nid];
// Setup attributes.
@@ -159,24 +176,29 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
std::vector<std::string> str_strides = node.GetAttr<std::vector<std::string>>("strides");
+ std::vector<std::string> str_dilates = node.GetAttr<std::vector<std::string>>("dilation");
std::vector<std::string> str_padding = node.GetAttr<std::vector<std::string>>("padding");
dnnl::memory::dim groups = std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
- dnnl::memory::dim N = input_shape[0], // batch size
- IC = input_shape[1], // input channels
- IH = input_shape[2], // input height
- IW = input_shape[3], // input width
- OC = weight_shape[0], // output channels
- KH = weight_shape[2], // weight height
- KW = weight_shape[3], // weight width
- PW_L = std::stoi(str_padding[1]), // width padding: left
- PW_R = std::stoi(str_padding[3]), // width padding: right
- PH_L = std::stoi(str_padding[0]), // height padding: top
- PH_R = std::stoi(str_padding[2]), // height padding: bottom
- SH = std::stoi(str_strides[0]), // height-wise stride
- SW = std::stoi(str_strides[1]), // weight-wise stride
- OH = (IH - KH + PH_L + PH_R) / SH + 1, // output height
- OW = (IW - KW + PW_L + PW_R) / SW + 1; // output width
+ dnnl::memory::dim N = input_shape[0], // batch size
+ IC = input_shape[1], // input channels
+ IH = input_shape[2], // input height
+ IW = input_shape[3], // input width
+ OC = weight_shape[0], // output channels
+ KH = weight_shape[2], // weight height
+ KW = weight_shape[3], // weight width
+ PW_L = std::stoi(str_padding[1]), // width padding: left
+ PW_R = std::stoi(str_padding[3]), // width padding: right
+ PH_L = std::stoi(str_padding[0]), // height padding: top
+ PH_R = std::stoi(str_padding[2]), // height padding: bottom
+ SH = std::stoi(str_strides[0]), // height-wise stride
+ SW = std::stoi(str_strides[1]), // weight-wise stride
+ DH = std::stoi(str_dilates[0]) - 1, // height-wise dilate
+ DW = std::stoi(str_dilates[1]) - 1, // weight-wise dilate
+ DKH = 1 + (KH - 1) * (DH + 1), // dilated weight height
+ DKW = 1 + (KW - 1) * (DW + 1), // dilated weight width
+ OH = (IH - DKH + PH_L + PH_R) / SH + 1, // output height
+ OW = (IW - DKW + PW_L + PW_R) / SW + 1; // output width
// Memory shapes.
dnnl::memory::dims src_dims = {N, IC, IH, IW};
@@ -187,6 +209,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
dnnl::memory::dims bias_dims = {OC};
dnnl::memory::dims dst_dims = {N, OC, OH, OW};
dnnl::memory::dims strides_dims = {SH, SW};
+ dnnl::memory::dims dilates_dims = {DH, DW};
dnnl::memory::dims padding_dims_l = {PH_L, PW_L};
dnnl::memory::dims padding_dims_r = {PH_R, PW_R};
@@ -199,13 +222,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Covn2d description.
auto conv_desc = dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
- conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, padding_dims_l, padding_dims_r);
+ conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l,
+ padding_dims_r);
- // Enable ReLU
+ // Enable elementwise post-ops
dnnl::primitive_attr attr;
- if (has_relu) {
+ if (has_elt) {
dnnl::post_ops ops;
- ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
+ ops.append_eltwise(1.f, algo, 0.f, 0.f);
attr.set_post_ops(ops);
}
@@ -245,7 +269,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{DNNL_ARG_DST, conv2d_dst_memory}});
}
- void Dense(const size_t& nid) {
+ void Dense(const size_t& nid, const bool has_bias = false) {
auto node = nodes_[nid];
// Setup attributes.
@@ -281,9 +305,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Memories.
auto data_memory = BindDNNLMemory(data_entry, data_md);
auto weight_memory = BindDNNLMemory(weight_entry, weight_md);
+
+ // Bias memory.
auto bias_memory = dnnl::memory(bias_md, engine_);
- float bias[OC] = {0};
- write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
+ if (has_bias) {
+ auto bias_entry = node.GetInputs()[2];
+ BindDNNLMemory(bias_entry, bias_memory);
+ } else {
+ float bias[OC] = {0};
+ write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
+ }
+
+ // Output memory.
JSONGraphNodeEntry out_entry(nid, 0);
auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc());
@@ -335,20 +368,20 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{DNNL_ARG_VARIANCE, variance_memory}});
}
- void Relu(const size_t& nid) {
+ void Eltwise(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];
auto data_entry = node.GetInputs()[0];
dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);
- auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference,
- dnnl::algorithm::eltwise_relu, data_md, 0);
- auto relu_prim_desc = dnnl::eltwise_forward::primitive_desc(relu_desc, engine_);
- ICHECK(data_md == relu_prim_desc.dst_desc());
+ auto elt_desc =
+ dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0);
+ auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_);
+ ICHECK(data_md == elt_prim_desc.dst_desc());
- auto relu = dnnl::eltwise_forward(relu_prim_desc);
- net_.push_back(relu);
+ auto elt = dnnl::eltwise_forward(elt_prim_desc);
+ net_.push_back(elt);
auto data_memory = BindDNNLMemory(data_entry, data_md);
JSONGraphNodeEntry out_entry(nid, 0);
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
new file mode 100755
index 0000000..7adf3e4
--- /dev/null
+++ b/tests/python/contrib/test_dnnl.py
@@ -0,0 +1,350 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+import itertools
+import tvm
+import tvm.relay.testing
+from tvm import relay
+from tvm.relay.op.contrib import dnnl
+import tvm.testing
+
+has_dnnl_codegen = pytest.mark.skipif(
+ not tvm.get_global_func("relay.ext.dnnl", True), reason="DNNL codegen not available"
+)
+
+run_module = tvm.testing.parameter(
+ pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]),
+ pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]),
+ ids=["compile", "run"],
+)
+
+
+def vmobj_to_list(o):
+ if isinstance(o, tvm.nd.NDArray):
+ return [o.numpy()]
+ elif isinstance(o, tvm.runtime.container.ADT) or isinstance(o, list):
+ return [vmobj_to_list(f) for f in o]
+ else:
+ raise RuntimeError("Unknown object type: %s" % type(o))
+
+
+def assert_result_dict_holds(result_dict):
+ for k1, k2 in itertools.combinations(result_dict, 2):
+ res1 = vmobj_to_list(result_dict[k1])
+ res2 = vmobj_to_list(result_dict[k2])
+ for r1, r2 in zip(res1, res2):
+ tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
+
+
+def run_and_verify(mod, input, params, target, run_module):
+ def check_dnnl_used(mod):
+ num_dnnl_subgraphs = sum(
+ [1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()]
+ )
+ assert num_dnnl_subgraphs >= 1
+
+ dev = tvm.cpu()
+ result_dict = dict()
+ for mode in ["graph", "vm"]:
+ for use_dnnl in [False, True]:
+ result_key = mode + ("_dnnl" if use_dnnl else "")
+ if use_dnnl:
+ mod = dnnl.partition_for_dnnl(mod, params)
+ with tvm.transform.PassContext(opt_level=3):
+ func = relay.create_executor(mode, mod=mod, device=dev, target=target).evaluate()
+ if run_module:
+ if isinstance(input, dict):
+ result_dict[result_key] = func(**input, **params)
+ else:
+ result_dict[result_key] = func(input, **params)
+
+ if run_module:
+ assert_result_dict_holds(result_dict)
+
+
+def run_and_verify_func(config, run_module, target="llvm", dtype="float32"):
+ """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs.
+
+ Parameters
+ ----------
+ config : Tuple[relay.Function, Dict[str, NDArray], List[str]]
+ A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and
+ 3) A list of which vars should be considered params.
+
+ run_module: bool
+ If True, the built module will be run after being compiled.
+ """
+ f, input_shapes, is_param = config
+ params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype) for x in is_param}
+ input_dict = {
+ k: np.random.uniform(-1, 1, v).astype(dtype)
+ for k, v in input_shapes.items()
+ if k not in is_param
+ }
+ run_and_verify(f, input_dict, params, target, run_module)
+
+
+def get_conv2d(
+ x_shape=(1, 32, 8, 8),
+ k_shape=(16, 32, 3, 3),
+ groups=1,
+ padding=(0, 0),
+ strides=(1, 1),
+ dilation=(1, 1),
+ activation=None,
+ dtype="float32",
+):
+ x = relay.var("x", shape=(x_shape), dtype=dtype)
+ kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
+ out = relay.nn.conv2d(
+ x,
+ kernel,
+ kernel_size=k_shape[2:4],
+ groups=groups,
+ padding=padding,
+ strides=strides,
+ dilation=dilation,
+ channels=k_shape[0],
+ )
+ dic = {"x": x_shape, "kernel": k_shape}
+ param_lst = ["kernel"]
+
+ if activation == "relu":
+ return relay.nn.relu(out), dic, param_lst
+ elif activation == "tanh":
+ return relay.tanh(out), dic, param_lst
+ elif activation == "sigmoid":
+ return relay.sigmoid(out), dic, param_lst
+ else:
+ return out, dic, param_lst
+
+
+def get_conv2d_weights_const(
+ x_shape=(1, 32, 8, 8),
+ k_shape=(16, 32, 3, 3),
+ groups=1,
+ padding=(0, 0),
+ strides=(1, 1),
+ dilation=(1, 1),
+ dtype="float32",
+):
+ x = relay.var("x", shape=(x_shape), dtype=dtype)
+ kernel = relay.const(np.ones(k_shape).astype(dtype))
+ out = relay.nn.conv2d(
+ x,
+ kernel,
+ channels=k_shape[0],
+ kernel_size=k_shape[2:4],
+ groups=groups,
+ padding=padding,
+ strides=strides,
+ dilation=dilation,
+ )
+ dic = {"x": x_shape}
+ param_lst = []
+ return out, dic, param_lst
+
+
+def get_conv2d_bias(
+ x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), activation=None, dtype="float32"
+):
+ conv, dic, param_lst = get_conv2d(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
+ bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype)
+ out = relay.nn.bias_add(conv, bias)
+ dic["bias"] = (k_shape[0],)
+ param_lst += ["bias"]
+
+ if activation == "relu":
+ return relay.nn.relu(out), dic, param_lst
+ elif activation == "tanh":
+ return relay.tanh(out), dic, param_lst
+ elif activation == "sigmoid":
+ return relay.sigmoid(out), dic, param_lst
+ else:
+ return out, dic, param_lst
+
+
+def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
+ conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype)
+ beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
+ gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
+ moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype))
+ moving_var = relay.const(np.ones(k_shape[0]).astype(dtype))
+ conv2d_bias_bn, _, _ = relay.nn.batch_norm(
+ conv2d_bias,
+ gamma=gamma,
+ beta=beta,
+ moving_mean=moving_mean,
+ moving_var=moving_var,
+ axis=1,
+ center=True,
+ scale=True,
+ epsilon=1e-5,
+ )
+ return relay.nn.relu(conv2d_bias_bn), dic, param_lst
+
+
+def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"):
+ x = relay.var("x", shape=(x_shape), dtype=dtype)
+ kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
+ out = relay.nn.dense(x, kernel, units=k_shape[0])
+ dic = {"x": x_shape, "kernel": k_shape}
+ param_lst = ["kernel"]
+ return out, dic, param_lst
+
+
+def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"):
+ dense, dic, param_lst = get_dense(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
+ bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype)
+ out = relay.nn.bias_add(dense, bias)
+ dic["bias"] = (k_shape[0],)
+ param_lst += ["bias"]
+ return out, dic, param_lst
+
+
+def test_dnnl_not_compatible(run_module, target="llvm", dtype="float32"):
+ xshape = (1, 32, 14, 14)
+ x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
+
+ x = relay.var("x", shape=(xshape), dtype=dtype)
+ y = relay.add(x, x)
+ z = relay.cast(relay.cast(y, "int32"), "float32")
+ out = relay.nn.relu(z)
+ f = relay.Function([x], out)
+ mod = tvm.IRModule()
+ mod["main"] = f
+ mod = dnnl.partition_for_dnnl(mod)
+ for mode in ["graph", "vm"]:
+ with tvm.transform.PassContext(opt_level=3):
+ func = relay.create_executor(mode, mod=mod, device=tvm.cpu(0), target=target).evaluate()
+ if run_module:
+ results = func(x_data)
+
+
+def test_multiple_outputs(run_module, dtype="float32"):
+ def get_graph():
+ x = relay.var("x", shape=(1, 3), dtype=dtype)
+ y = relay.var("y", shape=(1, 3), dtype=dtype)
+ z = relay.add(x, y)
+ w = relay.add(z, y)
+ out = relay.Tuple((z, w))
+ f = tvm.IRModule.from_expr(out)
+ return f, {"x": (1, 3), "y": (1, 3)}, []
+
+ run_and_verify_func(get_graph(), run_module=run_module, dtype=dtype)
+
+
+def test_unary(run_module):
+ def get_graph(op, x_shape=(1, 8, 3, 3)):
+ x = relay.var("x", shape=(x_shape), dtype="float32")
+ out = op(x)
+ f = tvm.IRModule.from_expr(out)
+ return f, {"x": x_shape}, []
+
+ for op in [
+ relay.nn.relu,
+ relay.tanh,
+ relay.sigmoid,
+ ]:
+ run_and_verify_func(get_graph(op), run_module=run_module)
+
+
+def test_conv2d(run_module, dtype="float32"):
+ x_shape = (1, 32, 8, 8)
+ for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]:
+ for padding in [(0, 0), (1, 1)]:
+ for strides in [(1, 1), (2, 2)]:
+ for dilation in [(1, 1), (2, 2)]:
+ conv2d, dic, param_lst = get_conv2d(
+ x_shape=x_shape,
+ k_shape=k_shape,
+ groups=groups,
+ padding=padding,
+ strides=strides,
+ dilation=dilation,
+ dtype=dtype,
+ )
+ conv2d = tvm.IRModule.from_expr(conv2d)
+ config = conv2d, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+def test_conv2d_weights_const(run_module, dtype="float32"):
+ x_shape = (1, 32, 8, 8)
+ k_shape = (16, 32, 3, 3)
+ conv2d, dic, param_lst = get_conv2d_weights_const(x_shape, k_shape, dtype=dtype)
+ conv2d = tvm.IRModule.from_expr(conv2d)
+ config = conv2d, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+def test_conv2d_pattern(run_module, dtype="float32"):
+ x_shape = (1, 32, 8, 8)
+ k_shape = (16, 32, 3, 3)
+ activation_lst = [None, "relu", "tanh", "sigmoid"]
+ for a in activation_lst:
+ conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
+ conv2d = tvm.IRModule.from_expr(conv2d)
+ config = conv2d, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+ conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, activation=a, dtype=dtype)
+ conv2d_bias = tvm.IRModule.from_expr(conv2d_bias)
+ config = conv2d_bias, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+ conv2d_bias_bn_relu, dic, param_lst = get_conv2d_bias_bn_relu(x_shape, k_shape, dtype=dtype)
+ conv2d_bias_bn_relu = tvm.IRModule.from_expr(conv2d_bias_bn_relu)
+ config = conv2d_bias_bn_relu, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+def test_dense(run_module, dtype="float32"):
+ x_shape = (1, 16)
+ k_shape = (32, 16)
+
+ dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype)
+ dense = tvm.IRModule.from_expr(dense)
+ config = dense, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+ dense, dic, param_lst = get_dense(x_shape, k_shape=(1, 16), dtype=dtype)
+ dense = tvm.IRModule.from_expr(dense)
+ config = dense, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+def test_dense_pattern(run_module, dtype="float32"):
+ x_shape = (1, 16)
+ k_shape = (32, 16)
+
+ dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype)
+ dense = tvm.IRModule.from_expr(dense)
+ config = dense, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+ dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, dtype=dtype)
+ dense_bias = tvm.IRModule.from_expr(dense_bias)
+ config = dense_bias, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
+if __name__ == "__main__":
+ import sys
+
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 80fb2e0..736ece2 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -919,10 +919,25 @@ def test_mixed_single_multiple_outputs():
def test_dnnl_fuse():
dnnl_patterns = get_pattern_table("dnnl")
- conv2d_bias_relu_pat, conv2d_relu_pat = dnnl_patterns
-
- def get_blocks(prefix, data, in_channel, out_channel, include_bn=True, include_sigmoid=False):
+ (
+ conv2d_bias_relu_pat,
+ conv2d_bias_sigmoid_pat,
+ conv2d_bias_pat,
+ conv2d_relu_pat,
+ conv2d_sigmoid_pat,
+ ) = (dnnl_patterns[0], dnnl_patterns[4], dnnl_patterns[6], dnnl_patterns[8], dnnl_patterns[12])
+
+ def get_blocks(
+ prefix,
+ data,
+ in_channel,
+ out_channel,
+ include_bias_add=True,
+ include_bn=True,
+ include_sigmoid=False,
+ ):
weight = relay.var(prefix + "weight")
+ bias = relay.var(prefix + "bias")
bn_gamma = relay.var(prefix + "bn_gamma")
bn_beta = relay.var(prefix + "bn_beta")
bn_mmean = relay.var(prefix + "bn_mean")
@@ -931,6 +946,8 @@ def test_dnnl_fuse():
layer = relay.nn.conv2d(
data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1)
)
+ if include_bias_add:
+ layer = relay.nn.bias_add(layer, bias)
if include_bn:
bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar)
layer = bn_output[0]
@@ -940,11 +957,11 @@ def test_dnnl_fuse():
layer = relay.nn.relu(layer)
return layer
- def get_net(include_bn=True, include_sigmoid=False):
+ def get_net(include_bias_add=True, include_bn=True, include_sigmoid=False):
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
- block1 = get_blocks("block1_", data, 3, 8, include_bn, include_sigmoid)
+ block1 = get_blocks("block1_", data, 3, 8, include_bias_add, include_bn, include_sigmoid)
# The second block is always conv + relu, to make it more interesting
- block2 = get_blocks("block2_", block1, 8, 8, False, include_sigmoid)
+ block2 = get_blocks("block2_", block1, 8, 8, False, False, include_sigmoid)
return relay.Function(relay.analysis.free_vars(block2), block2)
def get_partitoned_mod(mod, params, pattern_table):
@@ -959,9 +976,18 @@ def test_dnnl_fuse():
transform.FoldScaleAxis(),
]
)
+ # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
+ remove_linear_pass = tvm.transform.Sequential(
+ [
+ transform.SimplifyExpr(),
+ transform.FoldConstant(),
+ ]
+ )
composite_partition = tvm.transform.Sequential(
[
+ transform.CanonicalizeOps(),
remove_bn_pass,
+ remove_linear_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
transform.PartitionGraph(),
@@ -971,25 +997,38 @@ def test_dnnl_fuse():
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
return composite_partition(mod)
- def test_detect_pattern(pattern_table, include_bn, include_sigmoid, num_expected_partition):
- net = get_net(include_bn, include_sigmoid)
+ def test_detect_pattern(
+ pattern_table, include_bias_add, include_bn, include_sigmoid, num_expected_partition
+ ):
+ net = get_net(include_bias_add, include_bn, include_sigmoid)
mod, params = tvm.relay.testing.create_workload(net)
mod = get_partitoned_mod(mod, params, pattern_table)
assert len(mod.functions) - 1 == num_expected_partition # -1 for main
def test_partition():
# conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
- test_detect_pattern([conv2d_bias_relu_pat], True, False, 3)
+ test_detect_pattern([conv2d_bias_relu_pat], False, True, False, 3)
# conv + bn + relu, conv + relu -> conv, bias, relu, and fused conv_relu
- test_detect_pattern([conv2d_relu_pat], True, False, 4)
+ test_detect_pattern([conv2d_relu_pat], False, True, False, 4)
# conv + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
- test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, False, 2)
+ test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], False, True, False, 2)
+ # conv + bias_add + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
+ test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, False, 2)
# conv + relu, conv + relu -> two fused conv_relu
- test_detect_pattern([conv2d_relu_pat], False, False, 2)
+ test_detect_pattern([conv2d_relu_pat], False, False, False, 2)
# conv + relu, conv + relu -> no fusion, 4 partition each with a single op
- test_detect_pattern([conv2d_bias_relu_pat], False, False, 4)
+ test_detect_pattern([conv2d_bias_relu_pat], False, False, False, 4)
# conv + bn + sigmoid + relu, conv + sigmoid + relu -> no fusion
- test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, 5)
+ test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], False, True, True, 7)
+ # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias
+ # and single op sigmoid, relu, conv, sigmoid, relu
+ test_detect_pattern([conv2d_bias_pat, conv2d_relu_pat], True, True, True, 6)
+ # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias_sigmoid
+ # and single op relu, conv, sigmoid, relu
+ test_detect_pattern([conv2d_bias_sigmoid_pat, conv2d_relu_pat], True, True, True, 5)
+ # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias_sigmoid,
+ # fused conv_sigmoid and single op relu, relu
+ test_detect_pattern([conv2d_bias_sigmoid_pat, conv2d_sigmoid_pat], True, True, True, 4)
def test_partition_mobilenet():
mod, params = relay.testing.mobilenet.get_workload()