You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/19 07:17:26 UTC

[GitHub] [incubator-tvm] jackwish commented on a change in pull request #4351: [QNN] Lowering for Depthwise Convolution.

jackwish commented on a change in pull request #4351: [QNN] Lowering for Depthwise Convolution.
URL: https://github.com/apache/incubator-tvm/pull/4351#discussion_r347756284
 
 

 ##########
 File path: src/relay/qnn/op/convolution.cc
 ##########
 @@ -417,23 +565,33 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
         param->kernel_layout == "HWOI")
       << "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout.";
 
-  int batch_size, in_channels, out_channels, kernel_h, kernel_w;
-  std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
+  int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier;
+  std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
       GetWorkload(arg_types, param);
 
-  // Fallback to int32 conv if there is dilation or depthwise conv2d
+  // Fallback to int32 conv if there is dilation or grouped conv2d
+
   CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
   auto dilation_h = get_const_int(param->dilation[0]);
   auto dilation_w = get_const_int(param->dilation[1]);
-  if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
+  if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) {
     return Conv2DFallBack(data, weight, param);
+  } else if (is_depthwise(param)) {
+    CHECK_NE(channel_multiplier, -1);
+    auto padded_data = Conv2DPadInput(data, param);
+    auto term1 = Conv2DFirstTerm(padded_data, weight, param);
+    auto term2 =
+        DepthwiseConv2DSecondTerm(padded_data, param, kernel_h, kernel_w, channel_multiplier);
+    auto term3 = DepthwiseConv2DThirdTerm(weight, param, out_channels, channel_multiplier);
+    auto term4 = DepthwiseConv2DFourthTerm(param, kernel_h, kernel_w);
+    return Conv2DCombineTerms(term1, term2, term3, term4, param);
   }
 
   auto padded_data = Conv2DPadInput(data, param);
   auto term1 = Conv2DFirstTerm(padded_data, weight, param);
   auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
-  auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
-  auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
+  auto term3 = Conv2DThirdTerm(weight, param, out_channels);
+  auto term4 = Conv2DFourthTerm(param, in_channels, kernel_h, kernel_w);
 
 Review comment:
   That is interesting, why removing batch semantic?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services