You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/10 01:11:31 UTC
[incubator-mxnet] branch master updated: Improve mkldnn fallback.
(#12663)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 443ded4 Improve mkldnn fallback. (#12663)
443ded4 is described below
commit 443ded4f8ab455a4c4ec0b9d431564b8ccc785ea
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Wed Oct 10 09:11:18 2018 +0800
Improve mkldnn fallback. (#12663)
---
src/executor/attach_op_execs_pass.cc | 22 +++++++++++++++-------
src/operator/quantization/dequantize.cc | 1 +
.../mkldnn/mkldnn_quantized_pooling.cc | 1 +
src/operator/quantization/quantize.cc | 1 +
src/operator/quantization/requantize.cc | 1 +
5 files changed, 19 insertions(+), 7 deletions(-)
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 0e415ef..a0176fa 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -159,9 +159,13 @@ class StatefulComputeExExecutor : public OpExecutor {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
- CreateDefaultInputs(in_array, &in_array_fallback);
- fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
- return;
+ // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
+ const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
+ if (!is_mkldnn.get(attrs_.op, false)) {
+ CreateDefaultInputs(in_array, &in_array_fallback);
+ fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
+ return;
+ }
#endif
fcompute_(state_, op_ctx, in_array, req, out_array);
}
@@ -180,12 +184,14 @@ class StatefulComputeExExecutor : public OpExecutor {
return state_;
}
- explicit StatefulComputeExExecutor(const OpStatePtr& state,
+ explicit StatefulComputeExExecutor(const NodeAttrs& attrs,
+ const OpStatePtr& state,
const FStatefulComputeEx& fcompute,
ExecType exec_type)
- : state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
+ : attrs_(attrs), state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
private:
+ NodeAttrs attrs_;
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
@@ -302,7 +308,8 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
+ ret[i] = std::make_shared<StatefulComputeExExecutor>(inode.source->attrs, state,
+ fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
@@ -322,7 +329,8 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
- ret[fwd_id].get()->state(), fcompute_ex, exec_type);
+ inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex,
+ exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc
index bbd7941..e20bc17 100644
--- a/src/operator/quantization/dequantize.cc
+++ b/src/operator/quantization/dequantize.cc
@@ -72,6 +72,7 @@ by keep zero centered for the quantized value:
.set_attr<nnvm::FInferType>("FInferType", DequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNDequantizeCompute)
#endif
.set_attr<FCompute>("FCompute<cpu>", DequantizeCompute<cpu>)
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
index b81881a..07e1441 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
@@ -46,6 +46,7 @@ static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const Op
}
NNVM_REGISTER_OP(_contrib_quantized_pooling)
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedPoolingForward);
} // namespace op
diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc
index 25fb19d..5227751 100644
--- a/src/operator/quantization/quantize.cc
+++ b/src/operator/quantization/quantize.cc
@@ -83,6 +83,7 @@ where
.set_attr<nnvm::FInferType>("FInferType", QuantizeType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizeStorageType)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeCompute)
#endif
.set_attr<FCompute>("FCompute<cpu>", QuantizeCompute<cpu>)
diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc
index 5ce0ff0..68b1b65 100644
--- a/src/operator/quantization/requantize.cc
+++ b/src/operator/quantization/requantize.cc
@@ -65,6 +65,7 @@ inference accuracy.
.set_attr<nnvm::FInferType>("FInferType", RequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNRequantizeForward)
#else
.set_attr<FCompute>("FCompute<cpu>", RequantizeForward<cpu>)