You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/12/13 06:03:24 UTC

[GitHub] [incubator-tvm] optima2005 commented on a change in pull request #4476: Implement 1d deconvolution

optima2005 commented on a change in pull request #4476: Implement 1d deconvolution
URL: https://github.com/apache/incubator-tvm/pull/4476#discussion_r357496355
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -328,6 +328,160 @@ v            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
 .add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
 
 
+// relay.nn.conv1d_transpose
+TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);
+
+bool Conv1DTransposeRel(const Array<Type>& types,
+                        int num_inputs,
+                        const Attrs& attrs,
+                        const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* weight = types[1].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  static const Layout kNCW("NCW");
+  static const Layout kOIW("OIW");
+
+  const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->kernel_layout);
+
+  const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW);
+  CHECK(trans_in_layout.defined())
+    << "Conv only support input layouts that are convertible from NCW."
+    << " But got " << in_layout;
+
+  const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW);
+  CHECK(trans_kernel_layout.defined())
+    << "Conv only support kernel layouts that are convertible from OIW."
+    << " But got "<< kernel_layout;
+
+  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+  const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW);
+  CHECK(trans_out_layout.defined())
+    << "Conv only support output layouts that are convertible from NCW."
+    << " But got " << out_layout;
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+
+  auto dshape_ncw = trans_in_layout.ForwardShape(data->shape);
+
+  // infer weight if the kernel_size and channels are defined
+  if (param->kernel_size.defined() && param->channels.defined()) {
+    CHECK_EQ(param->kernel_size.size(), 1);
+    CHECK_EQ(param->dilation.size(), 1);
+
+    Array<IndexExpr> wshape({dshape_ncw[1],
+            indexdiv(param->channels, param->groups),
+            param->kernel_size[0]});
+
+    wshape = trans_kernel_layout.BackwardShape(wshape);
+    dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+    channels = param->channels;
+
+    // assign result to reporter
+    reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
+  } else {
+    // use weight to infer the conv shape.
+    if (weight == nullptr) return false;
+    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
+    if (param->kernel_size.defined()) {
+      CHECK_EQ(param->kernel_size.size(), 1);
+      // check the size
+      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
+          << "Conv1D: shape of weight is inconsistent with kernel_size, "
+          << " kernel_size=" << param->kernel_size
+          << " wshape=" << Array<IndexExpr>(wshape);
+    }
+    if (param->channels.defined()) {
+      CHECK(reporter->AssertEQ(param->channels, wshape[1]))
+          << "Conv1D: shape of weight is inconsistent with channels, "
+          << " channels=" << param->channels
+          << " wshape=" << Array<IndexExpr>(wshape);
+    }
+    CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
+    channels = wshape[1];
+    dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
+  }
+  // dilation
+  Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
+  oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x -
+                 2 * param->padding[0] + param->output_padding[0]));
 
 Review comment:
   Please use asymmetric padding. (head_padding + tail_padding) see #4511 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services