You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/04/03 20:22:14 UTC

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403299155
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -662,96 +454,101 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
         ConvInferCorrectLayout<Conv2DWinogradAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_weight_transform
-TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
-
-bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
-                                      int num_inputs,
-                                      const Attrs& attrs,
-                                      const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  const Conv2DWinogradWeightTransformAttrs* param = attrs.as<Conv2DWinogradWeightTransformAttrs>();
-  CHECK(param != nullptr);
-
-  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
-
-  // each pad width element should be a pair of positive integers
-  std::vector<IndexExpr> oshape {
-      param->tile_size + data->shape[2] - 1,
-      param->tile_size + data->shape[3] - 1,
-      data->shape[0],
-      data->shape[1],
-  };
-
-  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
-                                                  data->dtype));
-  return true;
-}
-
-Expr MakeConv2DWinogradWeightTransform(Expr weight,
-                                       int tile_size) {
-  auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
-  attrs->tile_size = tile_size;
-  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
-  return Call(op, {weight}, Attrs(attrs), {});
-}
-
+TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
 
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body_typed(MakeConv2DWinogradWeightTransform);
-
+.set_body_typed([](Expr weight,
+                   int tile_size) {
+  return MakeConvWinogradWeightTransform(
+    weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
+});
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm.
+    .describe(R"code(Weight transformation of winograd fast convolution algorithm.
 
 Separate this into another operator in order to enable Precompute Pass to compute the
 weight transformation in advance.
 
 - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DWinogradWeightTransformAttrs>()
-.set_num_inputs(1)
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
+    .set_attrs_type<ConvWinogradWeightTransformAttrs>()
+    .set_num_inputs(1)
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_support_level(10)
+    .add_type_rel("Conv2DWinogradWeightTransform",
+                  Conv2DWinogradWeightTransformRel<ConvWinogradWeightTransformAttrs>);
+
+// relay.nn.contrib_conv3d_winograd_without_weight_transform
+TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform")
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   int tile_size,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   DataType out_dtype) {
+  return MakeConvWinograd<Conv3DWinogradAttrs>(
+    data, weight, tile_size, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform");
+});
+
+RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform")
+    .describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout.
 
 Review comment:
   ditto

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services