You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/06/01 23:40:48 UTC

[incubator-mxnet] branch master updated: Fix mkldnn backend when using naive engine (#15089)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cdb7b72  Fix mkldnn backend when using naive engine (#15089)
cdb7b72 is described below

commit cdb7b7272f0bf4d8f70fb64be17053ef4126cbf0
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Sat Jun 1 18:40:10 2019 -0500

    Fix mkldnn backend when using naive engine (#15089)
    
    * Fix mkldnn backend when using naive engine
    
    * Rerun CI
    
    * Rerun CI
    
    * Rerun CI
    
    * Add comment
---
 src/engine/naive_engine.cc                                         | 4 ++++
 src/operator/nn/mkldnn/mkldnn_convolution.cc                       | 5 +++--
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc                     | 7 ++++---
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc                   | 5 ++++-
 src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc          | 7 ++++---
 .../quantization/mkldnn/mkldnn_quantized_fully_connected.cc        | 5 ++++-
 6 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 93853c4..7291f46 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -142,6 +142,10 @@ class NaiveEngine final : public Engine {
       opr->opr_name);
   }
 
+/*!
+ * \brief NaiveEngine's PushAsync was intentionally synchronous. 
+ * User should not make any assumption about execution order when using async interface of any engine.
+ */
   void PushAsync(AsyncFn exec_fun,
                  Context exec_ctx,
                  std::vector<VarHandle> const& const_vars,
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index a394ede..6a91ae0 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -411,11 +411,12 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
     // For inference, we want to reorder the weight array so we don't need to
     // reorder data every time.
     if (weight.IsDefaultData()) {
-      weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(),
-                              param.conv_param.num_group);
       // We also need to modify the layout on the original weight array. The
       // data conversion happens after the weight array is used.
       weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
+      weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(),
+                              param.conv_param.num_group);
+
     } else {
       weight_mem = weight.GetMKLDNNData();
       CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 9d00cf8..557901c 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -262,10 +262,11 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
     // For inference, we want to reorder the weight array so we don't need to
     // reorder data every time.
     if (weight.IsDefaultData()) {
-      weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
-      // We also need to modify the layout on the original weight array. The
-      // data conversion happens after the weight array is used.
+      // We also need to modify the layout on the original weight array.
+      // Don't switch below sequence because naive engine will executes
+      // pushAsync synchronously.
       const_cast<NDArray&>(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
+      weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
     } else {
       weight_mem = weight.GetMKLDNNData();
       CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index 03d7e62..1dfd2a9 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -253,8 +253,11 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
     weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
   } else {
     if (weight.IsDefaultData()) {
-      weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
+      // We also need to modify the layout on the original weight array.
+      // Don't switch below sequence because naive engine will executes
+      // pushAsync synchronously.
       weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
+      weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
     } else {
       weight_mem = weight.GetMKLDNNData();
       CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
index 55028d8..f810717 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
@@ -52,10 +52,11 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
   // For inference, we want to reorder the weight array so we don't need to
   // reorder data every time.
   if (weight.IsDefaultData()) {
-    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
-    // We also need to modify the layout on the original weight array. The
-    // data conversion happens after the weight array is used.
+    // We also need to modify the layout on the original weight array.
+    // Don't switch below sequence because naive engine will executes
+    // pushAsync synchronously.
     weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
+    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
   } else {
     weight_mem = weight.GetMKLDNNData();
     CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
index e8abab2..aca129a 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
@@ -93,8 +93,11 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
   const mkldnn::memory *weight_mem = nullptr;
 
   if (weight.IsDefaultData()) {
-    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
+    // We also need to modify the layout on the original weight array.
+    // Don't switch below sequence because naive engine will executes
+    // pushAsync synchronously.
     weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
+    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
   } else {
     weight_mem = weight.GetMKLDNNData();
     CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());