You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/31 06:53:08 UTC

[GitHub] zheng-da commented on a change in pull request #11778: [MXNET-483] C++ tests for mkldnn convolution/deconvolution operator

zheng-da commented on a change in pull request #11778: [MXNET-483] C++ tests for mkldnn convolution/deconvolution operator
URL: https://github.com/apache/incubator-mxnet/pull/11778#discussion_r214258145
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_convolution.cc
 ##########
 @@ -290,18 +300,32 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
   TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
   const std::vector<NDArray> &in_grad = outputs;
   const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
+
+  auto data = inputs[conv::kData + 1];
+  if (data.IsView() && data.IsMKLDNNData())
+    data = data.Reorder2Default();
+
+  auto weight = inputs[conv::kWeight + 1];
+  if (weight.IsView() && weight.IsMKLDNNData())
+    weight = weight.Reorder2Default();
+
+  const NDArray* bias = param.no_bias ? nullptr : &inputs[conv::kBias + 1];
+
+  auto out_grad = inputs[conv::kOut];
+  if (out_grad.IsView() && out_grad.IsMKLDNNData())
+    out_grad = out_grad.Reorder2Default();
+
   mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(param, ctx.is_train,
-      inputs[conv::kData + 1], inputs[conv::kWeight + 1],
-      param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]);
+      data, weight, param.no_bias ? nullptr : bias, out_grad);
 
 Review comment:
   why do you still need to test param.no_bias?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services