You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/24 18:45:26 UTC

[GitHub] [tvm] csullivan commented on a change in pull request #7515: fuse constant padding into conv kernels

csullivan commented on a change in pull request #7515:
URL: https://github.com/apache/tvm/pull/7515#discussion_r582188392



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -82,6 +82,101 @@ class SimplifyReshape : public SimplifyPattern {
   DFPattern x_;
 };
 
+/*!
+ * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
+ * with a pad attribute and merges the padding into the kernel.
+ */
+class SimplifyConvPad : public SimplifyPattern {
+ public:
+  SimplifyConvPad() {
+    x_ = IsWildcard();
+    w_ = IsWildcard();
+    pad_ = IsOp("nn.pad")({x_});
+    conv1d_ = IsOp("nn.conv1d");
+    conv2d_ = IsOp("nn.conv2d");
+    conv3d_ = IsOp("nn.conv3d");
+    conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
+    pattern_ = conv_;
+  }
+  template <typename T>
+  Attrs MakeConvAttrs(const Attrs& attrs, const Array<PrimExpr> padding) const {
+    const T* old_attrs = attrs.as<T>();
+    ICHECK(old_attrs);
+    auto new_attrs = make_object<T>();
+    Array<PrimExpr> combined_padding;
+    ICHECK(padding.size() == old_attrs->padding.size());

Review comment:
       Consider adding an error message, 
   ```suggestion
       ICHECK(padding.size() == old_attrs->padding.size()) << "Number of dimensions to pad and convolution padding attributes should have the same extent";
   ```

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -82,6 +82,101 @@ class SimplifyReshape : public SimplifyPattern {
   DFPattern x_;
 };
 
+/*!
+ * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
+ * with a pad attribute and merges the padding into the kernel.
+ */
+class SimplifyConvPad : public SimplifyPattern {
+ public:
+  SimplifyConvPad() {
+    x_ = IsWildcard();
+    w_ = IsWildcard();
+    pad_ = IsOp("nn.pad")({x_});
+    conv1d_ = IsOp("nn.conv1d");
+    conv2d_ = IsOp("nn.conv2d");
+    conv3d_ = IsOp("nn.conv3d");
+    conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
+    pattern_ = conv_;
+  }
+  template <typename T>
+  Attrs MakeConvAttrs(const Attrs& attrs, const Array<PrimExpr> padding) const {
+    const T* old_attrs = attrs.as<T>();
+    ICHECK(old_attrs);
+    auto new_attrs = make_object<T>();
+    Array<PrimExpr> combined_padding;
+    ICHECK(padding.size() == old_attrs->padding.size());
+    for (size_t i = 0; i < padding.size(); ++i) {
+      combined_padding.push_back(padding[i] + old_attrs->padding[i]);
+    }
+    new_attrs->strides = old_attrs->strides;
+    new_attrs->padding = combined_padding;
+    new_attrs->dilation = old_attrs->dilation;
+    new_attrs->groups = old_attrs->groups;
+    new_attrs->channels = old_attrs->channels;
+    new_attrs->kernel_size = old_attrs->kernel_size;
+    new_attrs->data_layout = old_attrs->data_layout;
+    new_attrs->kernel_layout = old_attrs->kernel_layout;
+    new_attrs->out_layout = old_attrs->out_layout;
+    new_attrs->out_dtype = old_attrs->out_dtype;
+    return Attrs(new_attrs);
+  }
+  Expr callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    const CallNode* call_node = post.as<CallNode>();
+    ICHECK(call_node);
+    auto pad = node_map[pad_][0];
+    const CallNode* pad_node = pad.as<CallNode>();
+    ICHECK(pad_node);
+    const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
+    ICHECK(param);
+    if (param->pad_mode == "constant" && param->pad_value == 0.0) {
+      for (size_t i = 0; i < 2; ++i) {

Review comment:
       Should we avoid assuming a tuple length of 2?
   
   ```suggestion
         for (size_t i = 0; i < param->pad_width.size(); ++i) {
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org