You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/01/18 00:28:44 UTC

[GitHub] [incubator-tvm] apivovarov commented on a change in pull request #4478: [TOPI] implement pool3d op

apivovarov commented on a change in pull request #4478: [TOPI] implement pool3d op
URL: https://github.com/apache/incubator-tvm/pull/4478#discussion_r368187999
 
 

 ##########
 File path: src/relay/op/nn/pooling.cc
 ##########
 @@ -720,5 +740,220 @@ RELAY_REGISTER_OP("nn.avg_pool2d_grad")
 .set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
 
 
+// relay.nn.max_pool3d & relay.nn.avg_pool3d
+TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs);
+TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs);
+
+template <typename AttrType>
+bool Pool3DRel(const Array<Type>& types,
+               int num_inputs,
+               const Attrs& attrs,
+               const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+
+  if (data == nullptr) return false;
+
+  const auto dshape = data->shape;
+  CHECK_GE(dshape.size(), 3U)
+      << "Pool3D only support input >= 3-D: input must have depth, height and width";
+  const auto param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+
+  Layout layout(param->layout);
+  CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) &&
+        layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) &&
+        !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
+    << "Invalid layout " << layout
+    << ". Pool3D layout must have D, H and W, which cannot be split";
+
+  const auto didx = layout.IndexOf(LayoutAxis::Get('D'));
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
+
+  IndexExpr pad_d, pad_h, pad_w;
+  if (param->padding.size() == 1) {
+    pad_d = param->padding[0] * 2;
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[0] * 2;
+  } else if (param->padding.size() == 3) {
+    // (front, top, left)
+    pad_d = param->padding[0] * 2;
+    pad_h = param->padding[1] * 2;
+    pad_w = param->padding[2] * 2;
+  } else if (param->padding.size() == 6) {
+    // (front, top, left, back, bottom, right)
+    pad_d = param->padding[0] + param->padding[3];
+    pad_h = param->padding[1] + param->padding[4];
 
 Review comment:
   `pad_d` and `pad_h` are not used. The fix is here https://github.com/apache/incubator-tvm/pull/4738 @optima2005  @masahi 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services