You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/02 06:57:10 UTC
[tvm] branch main updated: fuse constant padding into conv kernels
(#7515)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 633ee11 fuse constant padding into conv kernels (#7515)
633ee11 is described below
commit 633ee118efecd04efb4be9bf6053deae6e8fac3b
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Mar 1 23:56:50 2021 -0700
fuse constant padding into conv kernels (#7515)
* fuse constant padding into conv kernels
* change the kernel to support other layouts
* add channel-last test
* add a comment about bailing early
---
src/relay/transforms/simplify_expr.cc | 116 ++++++++++++++++++++++++++
tests/python/relay/test_pass_simplify_expr.py | 78 +++++++++++++++++
2 files changed, 194 insertions(+)
diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index 74e48dc..bfe04e1 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -83,6 +83,121 @@ class SimplifyReshape : public SimplifyPattern {
};
/*!
+ * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
+ * with a pad attribute and merges the padding into the kernel.
+ */
+class SimplifyConvPad : public SimplifyPattern {
+ public:
+ SimplifyConvPad() {
+ x_ = IsWildcard();
+ w_ = IsWildcard();
+ pad_ = IsOp("nn.pad")({x_});
+ conv1d_ = IsOp("nn.conv1d");
+ conv2d_ = IsOp("nn.conv2d");
+ conv3d_ = IsOp("nn.conv3d");
+ conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
+ pattern_ = conv_;
+ }
+ template <typename T>
+ Attrs MakeConvAttrs(const T* old_attrs, const Array<PrimExpr> padding) const {
+ ICHECK(old_attrs);
+ ICHECK(padding.size() == old_attrs->padding.size())
+ << "Number of dimensions to pad and convolution padding attributes should have the same "
+ "extent";
+
+ auto new_attrs = make_object<T>();
+ Array<PrimExpr> combined_padding;
+ for (size_t i = 0; i < padding.size(); ++i) {
+ combined_padding.push_back(padding[i] + old_attrs->padding[i]);
+ }
+ new_attrs->strides = old_attrs->strides;
+ new_attrs->padding = combined_padding;
+ new_attrs->dilation = old_attrs->dilation;
+ new_attrs->groups = old_attrs->groups;
+ new_attrs->channels = old_attrs->channels;
+ new_attrs->kernel_size = old_attrs->kernel_size;
+ new_attrs->data_layout = old_attrs->data_layout;
+ new_attrs->kernel_layout = old_attrs->kernel_layout;
+ new_attrs->out_layout = old_attrs->out_layout;
+ new_attrs->out_dtype = old_attrs->out_dtype;
+ return Attrs(new_attrs);
+ }
+ template <typename T>
+ Attrs GetAttrs(const PadAttrs* param, const T* attrs) const {
+ ICHECK(param);
+ ICHECK(attrs);
+ ICHECK(attrs->data_layout.size() == param->pad_width.size())
+ << "Data Layout and padding attributes should have the same extent";
+
+ std::string data_layout = attrs->data_layout;
+ std::set<char> image_dims({'H', 'W', 'D'});
+ Array<PrimExpr> padding;
+ // If we're padding a non-spatial dimension, don't simplify
+ // Convolution can only pad on spatial axes
+ for (size_t i = 0; i < param->pad_width.size(); ++i) {
+ if (!image_dims.count(data_layout[i])) {
+ for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
+ if (param->pad_width[i][j] != 0) {
+ return Attrs();
+ }
+ }
+ }
+ }
+ for (size_t j = 0; j < param->pad_width[0].size(); ++j) {
+ for (size_t i = 0; i < param->pad_width.size(); ++i) {
+ if (image_dims.count(data_layout[i])) {
+ padding.push_back(param->pad_width[i][j]);
+ }
+ }
+ }
+
+ return MakeConvAttrs(attrs, padding);
+ }
+ Expr callback(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const override {
+ const CallNode* call_node = post.as<CallNode>();
+ ICHECK(call_node);
+ auto pad = node_map[pad_][0];
+ const CallNode* pad_node = pad.as<CallNode>();
+ ICHECK(pad_node);
+ const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
+ ICHECK(param);
+ if (param->pad_mode == "constant" && param->pad_value == 0.0) {
+ Attrs attrs;
+ if (node_map.count(conv1d_)) {
+ attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
+ } else if (node_map.count(conv2d_)) {
+ attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
+ } else if (node_map.count(conv3d_)) {
+ attrs = GetAttrs(param, call_node->attrs.as<Conv3DAttrs>());
+ } else {
+ return post;
+ }
+ if (!attrs.defined()) {
+ return post;
+ }
+ auto x = node_map[x_][0];
+ auto w = node_map[w_][0];
+ return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
+ }
+ return post;
+ }
+
+ private:
+ /*! \brief Pattern input */
+ DFPattern x_;
+ /*! \brief Pattern input weight */
+ DFPattern w_;
+ /*! \brief Pattern pad */
+ DFPattern pad_;
+ /*! \brief Pattern conv */
+ DFPattern conv_;
+ DFPattern conv1d_;
+ DFPattern conv2d_;
+ DFPattern conv3d_;
+};
+
+/*!
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
*/
class FullElementwise : public SimplifyPattern {
@@ -163,6 +278,7 @@ class ExprSimplifier {
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
CreateCallback(SimplifyReshape());
CreateCallback(FullElementwise());
+ CreateCallback(SimplifyConvPad());
}
template <typename T>
void CreateCallback(const T& pattern) {
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index 423f0a4..e3e497e 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -19,6 +19,8 @@ from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass
+import numpy as np
+
def test_simplify_reshape():
def before():
@@ -122,6 +124,82 @@ def test_simplify_full_elementwise():
validate(shape, value, dtype)
+def test_simplify_conv_pad():
+ convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d]
+
+ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout):
+ if layout[1] == "C":
+ shape = [1, 3] + [10] * ndim
+ wshape = [8, 3] + [3] * ndim
+ elif layout[-1] == "C":
+ shape = [1] + [10] * ndim + [3]
+ wshape = [8] + [3] * ndim + [3]
+ else:
+ raise ValueError("This test only supports NC* and N*C")
+
+ x = relay.var("x", shape=shape, dtype="float32")
+ w = relay.var("w", shape=wshape, dtype="float32")
+ pad = relay.nn.pad(x, pad_width, pad_value, pad_mode)
+ if layout[1] == "C":
+ conv = convs[ndim - 1](pad, w, padding=orig_padding)
+ else:
+ conv = convs[ndim - 1](
+ pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
+ )
+
+ if pad_mode == "constant" and pad_value == 0:
+ new_padding = []
+ for j in range(2):
+ for i in range(len(pad_width)):
+ if layout[i] in ["D", "H", "W"]:
+ new_padding.append(pad_width[i][j])
+ for i in range(len(new_padding)):
+ new_padding[i] += orig_padding[i]
+ if layout[1] == "C":
+ after = convs[ndim - 1](x, w, padding=new_padding)
+ else:
+ after = convs[ndim - 1](
+ x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
+ )
+ else:
+ after = conv
+
+ zz = run_opt_pass(conv, transform.SimplifyExpr())
+ expected = run_opt_pass(after, transform.InferType())
+ assert tvm.ir.structural_equal(zz, expected)
+
+ mod1 = tvm.IRModule.from_expr(conv)
+ mod2 = tvm.IRModule.from_expr(zz)
+
+ with tvm.transform.PassContext(disabled_pass="SimplifyExpr"):
+ ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm")
+ ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm")
+ x_np = np.random.rand(*shape).astype("float32")
+ w_np = np.random.rand(*wshape).astype("float32")
+ result1 = ex1.evaluate()(x_np, w_np)
+ result2 = ex2.evaluate()(x_np, w_np)
+
+ tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy())
+
+ for orig_pad in [[0, 0], [2, 0], [0, 2]]:
+ for i_pad in [[0, 0], [1, 1], [1, 0]]:
+ for ndim in [1, 2, 3]:
+ for channels_last in [0, 1]:
+ if channels_last:
+ layout = "NDHWC"
+ layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:]
+ padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]]
+ else:
+ layout = "NCDHW"
+ layout = layout[0:2] + layout[5 - ndim :]
+ padding = [[0, 0]] * 2 + [i_pad] * ndim
+
+ validate(ndim, padding, 0, "constant", orig_pad * ndim, layout)
+ ndim = 2
+ validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW")
+ validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW")
+
+
if __name__ == "__main__":
test_simplify_reshape()
test_simplify_full_elementwise()