You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/10 01:11:31 UTC

[incubator-mxnet] branch master updated: Improve mkldnn fallback. (#12663)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 443ded4  Improve mkldnn fallback. (#12663)
443ded4 is described below

commit 443ded4f8ab455a4c4ec0b9d431564b8ccc785ea
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Wed Oct 10 09:11:18 2018 +0800

    Improve mkldnn fallback. (#12663)
---
 src/executor/attach_op_execs_pass.cc               | 22 +++++++++++++++-------
 src/operator/quantization/dequantize.cc            |  1 +
 .../mkldnn/mkldnn_quantized_pooling.cc             |  1 +
 src/operator/quantization/quantize.cc              |  1 +
 src/operator/quantization/requantize.cc            |  1 +
 5 files changed, 19 insertions(+), 7 deletions(-)

diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 0e415ef..a0176fa 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -159,9 +159,13 @@ class StatefulComputeExExecutor : public OpExecutor {
     op_ctx.run_ctx = rctx;
 #if MXNET_USE_MKLDNN == 1
     InvalidateOutputs(out_array, req);
-    CreateDefaultInputs(in_array, &in_array_fallback);
-    fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
-    return;
+    // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
+    const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
+    if (!is_mkldnn.get(attrs_.op, false)) {
+      CreateDefaultInputs(in_array, &in_array_fallback);
+      fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
+      return;
+    }
 #endif
     fcompute_(state_, op_ctx, in_array, req, out_array);
   }
@@ -180,12 +184,14 @@ class StatefulComputeExExecutor : public OpExecutor {
     return state_;
   }
 
-  explicit StatefulComputeExExecutor(const OpStatePtr& state,
+  explicit StatefulComputeExExecutor(const NodeAttrs& attrs,
+                                     const OpStatePtr& state,
                                      const FStatefulComputeEx& fcompute,
                                      ExecType exec_type)
-      : state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
+      : attrs_(attrs), state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
+  NodeAttrs attrs_;
   OpStatePtr state_;
   FStatefulComputeEx fcompute_;
   ExecType exec_type_;
@@ -302,7 +308,8 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
         op, "FStatefulComputeEx", vctx[i]);
     // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
     if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-      ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
+      ret[i] = std::make_shared<StatefulComputeExExecutor>(inode.source->attrs, state,
+                                                           fcompute_ex, exec_type);
     } else {
       FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
           op, "FStatefulCompute", vctx[i]);
@@ -322,7 +329,8 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
     // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
     if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
       ret[i] = std::make_shared<StatefulComputeExExecutor>(
-          ret[fwd_id].get()->state(), fcompute_ex, exec_type);
+          inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex,
+          exec_type);
     } else {
       FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
           op, "FStatefulCompute", vctx[i]);
diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc
index bbd7941..e20bc17 100644
--- a/src/operator/quantization/dequantize.cc
+++ b/src/operator/quantization/dequantize.cc
@@ -72,6 +72,7 @@ by keep zero centered for the quantized value:
 .set_attr<nnvm::FInferType>("FInferType", DequantizeType)
 .set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNDequantizeCompute)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", DequantizeCompute<cpu>)
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
index b81881a..07e1441 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
@@ -46,6 +46,7 @@ static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const Op
 }
 
 NNVM_REGISTER_OP(_contrib_quantized_pooling)
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedPoolingForward);
 
 }  // namespace op
diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc
index 25fb19d..5227751 100644
--- a/src/operator/quantization/quantize.cc
+++ b/src/operator/quantization/quantize.cc
@@ -83,6 +83,7 @@ where
 .set_attr<nnvm::FInferType>("FInferType", QuantizeType)
 .set_attr<FInferStorageType>("FInferStorageType", QuantizeStorageType)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeCompute)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", QuantizeCompute<cpu>)
diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc
index 5ce0ff0..68b1b65 100644
--- a/src/operator/quantization/requantize.cc
+++ b/src/operator/quantization/requantize.cc
@@ -65,6 +65,7 @@ inference accuracy.
 .set_attr<nnvm::FInferType>("FInferType", RequantizeType)
 .set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNRequantizeForward)
 #else
 .set_attr<FCompute>("FCompute<cpu>", RequantizeForward<cpu>)