You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/02 06:57:10 UTC

[tvm] branch main updated: fuse constant padding into conv kernels (#7515)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 633ee11  fuse constant padding into conv kernels (#7515)
633ee11 is described below

commit 633ee118efecd04efb4be9bf6053deae6e8fac3b
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Mar 1 23:56:50 2021 -0700

    fuse constant padding into conv kernels (#7515)
    
    * fuse constant padding into conv kernels
    
    * change the kernel to support other layouts
    
    * add channel-last test
    
    * add a comment about bailing early
---
 src/relay/transforms/simplify_expr.cc         | 116 ++++++++++++++++++++++++++
 tests/python/relay/test_pass_simplify_expr.py |  78 +++++++++++++++++
 2 files changed, 194 insertions(+)

diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index 74e48dc..bfe04e1 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -83,6 +83,121 @@ class SimplifyReshape : public SimplifyPattern {
 };
 
 /*!
+ * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
+ * with a pad attribute and merges the padding into the kernel.
+ */
+class SimplifyConvPad : public SimplifyPattern {
+ public:
+  SimplifyConvPad() {
+    x_ = IsWildcard();
+    w_ = IsWildcard();
+    pad_ = IsOp("nn.pad")({x_});
+    conv1d_ = IsOp("nn.conv1d");
+    conv2d_ = IsOp("nn.conv2d");
+    conv3d_ = IsOp("nn.conv3d");
+    conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
+    pattern_ = conv_;
+  }
+  template <typename T>
+  Attrs MakeConvAttrs(const T* old_attrs, const Array<PrimExpr> padding) const {
+    ICHECK(old_attrs);
+    ICHECK(padding.size() == old_attrs->padding.size())
+        << "Number of dimensions to pad and convolution padding attributes should have the same "
+           "extent";
+
+    auto new_attrs = make_object<T>();
+    Array<PrimExpr> combined_padding;
+    for (size_t i = 0; i < padding.size(); ++i) {
+      combined_padding.push_back(padding[i] + old_attrs->padding[i]);
+    }
+    new_attrs->strides = old_attrs->strides;
+    new_attrs->padding = combined_padding;
+    new_attrs->dilation = old_attrs->dilation;
+    new_attrs->groups = old_attrs->groups;
+    new_attrs->channels = old_attrs->channels;
+    new_attrs->kernel_size = old_attrs->kernel_size;
+    new_attrs->data_layout = old_attrs->data_layout;
+    new_attrs->kernel_layout = old_attrs->kernel_layout;
+    new_attrs->out_layout = old_attrs->out_layout;
+    new_attrs->out_dtype = old_attrs->out_dtype;
+    return Attrs(new_attrs);
+  }
+  template <typename T>
+  Attrs GetAttrs(const PadAttrs* param, const T* attrs) const {
+    ICHECK(param);
+    ICHECK(attrs);
+    ICHECK(attrs->data_layout.size() == param->pad_width.size())
+        << "Data Layout and padding attributes should have the same extent";
+
+    std::string data_layout = attrs->data_layout;
+    std::set<char> image_dims({'H', 'W', 'D'});
+    Array<PrimExpr> padding;
+    // If we're padding a non-spatial dimension, don't simplify
+    // Convolution can only pad on spatial axes
+    for (size_t i = 0; i < param->pad_width.size(); ++i) {
+      if (!image_dims.count(data_layout[i])) {
+        for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
+          if (param->pad_width[i][j] != 0) {
+            return Attrs();
+          }
+        }
+      }
+    }
+    for (size_t j = 0; j < param->pad_width[0].size(); ++j) {
+      for (size_t i = 0; i < param->pad_width.size(); ++i) {
+        if (image_dims.count(data_layout[i])) {
+          padding.push_back(param->pad_width[i][j]);
+        }
+      }
+    }
+
+    return MakeConvAttrs(attrs, padding);
+  }
+  Expr callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    const CallNode* call_node = post.as<CallNode>();
+    ICHECK(call_node);
+    auto pad = node_map[pad_][0];
+    const CallNode* pad_node = pad.as<CallNode>();
+    ICHECK(pad_node);
+    const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
+    ICHECK(param);
+    if (param->pad_mode == "constant" && param->pad_value == 0.0) {
+      Attrs attrs;
+      if (node_map.count(conv1d_)) {
+        attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
+      } else if (node_map.count(conv2d_)) {
+        attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
+      } else if (node_map.count(conv3d_)) {
+        attrs = GetAttrs(param, call_node->attrs.as<Conv3DAttrs>());
+      } else {
+        return post;
+      }
+      if (!attrs.defined()) {
+        return post;
+      }
+      auto x = node_map[x_][0];
+      auto w = node_map[w_][0];
+      return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
+    }
+    return post;
+  }
+
+ private:
+  /*! \brief Pattern input */
+  DFPattern x_;
+  /*! \brief Pattern input weight */
+  DFPattern w_;
+  /*! \brief Pattern pad */
+  DFPattern pad_;
+  /*! \brief Pattern conv */
+  DFPattern conv_;
+  DFPattern conv1d_;
+  DFPattern conv2d_;
+  DFPattern conv3d_;
+};
+
+/*!
  * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
  */
 class FullElementwise : public SimplifyPattern {
@@ -163,6 +278,7 @@ class ExprSimplifier {
   explicit ExprSimplifier(IRModule mod) : mod_(mod) {
     CreateCallback(SimplifyReshape());
     CreateCallback(FullElementwise());
+    CreateCallback(SimplifyConvPad());
   }
   template <typename T>
   void CreateCallback(const T& pattern) {
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index 423f0a4..e3e497e 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -19,6 +19,8 @@ from tvm import relay
 from tvm.relay import transform
 from tvm.relay.testing import run_opt_pass
 
+import numpy as np
+
 
 def test_simplify_reshape():
     def before():
@@ -122,6 +124,82 @@ def test_simplify_full_elementwise():
                 validate(shape, value, dtype)
 
 
+def test_simplify_conv_pad():
+    convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d]
+
+    def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout):
+        if layout[1] == "C":
+            shape = [1, 3] + [10] * ndim
+            wshape = [8, 3] + [3] * ndim
+        elif layout[-1] == "C":
+            shape = [1] + [10] * ndim + [3]
+            wshape = [8] + [3] * ndim + [3]
+        else:
+            raise ValueError("This test only supports NC* and N*C")
+
+        x = relay.var("x", shape=shape, dtype="float32")
+        w = relay.var("w", shape=wshape, dtype="float32")
+        pad = relay.nn.pad(x, pad_width, pad_value, pad_mode)
+        if layout[1] == "C":
+            conv = convs[ndim - 1](pad, w, padding=orig_padding)
+        else:
+            conv = convs[ndim - 1](
+                pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
+            )
+
+        if pad_mode == "constant" and pad_value == 0:
+            new_padding = []
+            for j in range(2):
+                for i in range(len(pad_width)):
+                    if layout[i] in ["D", "H", "W"]:
+                        new_padding.append(pad_width[i][j])
+            for i in range(len(new_padding)):
+                new_padding[i] += orig_padding[i]
+            if layout[1] == "C":
+                after = convs[ndim - 1](x, w, padding=new_padding)
+            else:
+                after = convs[ndim - 1](
+                    x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
+                )
+        else:
+            after = conv
+
+        zz = run_opt_pass(conv, transform.SimplifyExpr())
+        expected = run_opt_pass(after, transform.InferType())
+        assert tvm.ir.structural_equal(zz, expected)
+
+        mod1 = tvm.IRModule.from_expr(conv)
+        mod2 = tvm.IRModule.from_expr(zz)
+
+        with tvm.transform.PassContext(disabled_pass="SimplifyExpr"):
+            ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm")
+        ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm")
+        x_np = np.random.rand(*shape).astype("float32")
+        w_np = np.random.rand(*wshape).astype("float32")
+        result1 = ex1.evaluate()(x_np, w_np)
+        result2 = ex2.evaluate()(x_np, w_np)
+
+        tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy())
+
+    for orig_pad in [[0, 0], [2, 0], [0, 2]]:
+        for i_pad in [[0, 0], [1, 1], [1, 0]]:
+            for ndim in [1, 2, 3]:
+                for channels_last in [0, 1]:
+                    if channels_last:
+                        layout = "NDHWC"
+                        layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:]
+                        padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]]
+                    else:
+                        layout = "NCDHW"
+                        layout = layout[0:2] + layout[5 - ndim :]
+                        padding = [[0, 0]] * 2 + [i_pad] * ndim
+
+                    validate(ndim, padding, 0, "constant", orig_pad * ndim, layout)
+    ndim = 2
+    validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW")
+    validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW")
+
+
 if __name__ == "__main__":
     test_simplify_reshape()
     test_simplify_full_elementwise()