You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/08/30 22:07:43 UTC
[incubator-mxnet] branch master updated: [MXNET-753] Fallback when
using non-MKLDNN supported operators (#12019)
This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 32c9ca7 [MXNET-753] Fallback when using non-MKLDNN supported operators (#12019)
32c9ca7 is described below
commit 32c9ca74839ae4d275bcf9a027ea0a711373be81
Author: Alexander Zai <az...@gmail.com>
AuthorDate: Thu Aug 30 15:07:32 2018 -0700
[MXNET-753] Fallback when using non-MKLDNN supported operators (#12019)
* add fallback test
* wait to read throws error
* add TIsMKLDNN attr
* invalidate inputs if fcomputeex unsupported
* keep ptr to newly created default arrays
* add flag to all mkldnn operators
* update method name to CreateDefaultInputs
* remove dup attrs
* create new instance var to store copy
* only reorder if mkldnn
* add mkldnn flag to batch norm
* do not check input / output ptr for mkldnn as copied is made
* fix lint
* update macro
* update custom update name
* add todo for fallback
* fix lint
* rename opexecutor name
* add fallback to opexecutor class
* fix lint
* add todos
* create fallback arrays in place
* revert in place diff
* create copy of arrays for fallback
* empty array
---
src/executor/attach_op_execs_pass.cc | 10 ++++++
src/executor/exec_pass.h | 4 +++
src/operator/nn/activation.cc | 2 ++
src/operator/nn/batch_norm.cc | 2 ++
src/operator/nn/concat.cc | 2 ++
src/operator/nn/convolution.cc | 2 ++
src/operator/nn/deconvolution.cc | 2 ++
src/operator/nn/fully_connected.cc | 2 ++
src/operator/nn/lrn.cc | 2 ++
src/operator/nn/mkldnn/mkldnn_base-inl.h | 12 +++++++
src/operator/nn/pooling.cc | 2 ++
src/operator/nn/softmax.cc | 1 +
src/operator/tensor/elemwise_sum.cc | 3 ++
src/operator/tensor/elemwise_unary_op.h | 4 +++
src/operator/tensor/elemwise_unary_op_basic.cc | 2 ++
tests/python/mkl/test_mkldnn.py | 44 ++++++++++++++++++++++++++
16 files changed, 96 insertions(+)
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index c011c1d..0e415ef 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -159,6 +159,9 @@ class StatefulComputeExExecutor : public OpExecutor {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
+ CreateDefaultInputs(in_array, &in_array_fallback);
+ fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
+ return;
#endif
fcompute_(state_, op_ctx, in_array, req, out_array);
}
@@ -226,6 +229,13 @@ class FComputeExExecutor : public OpExecutor {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
+ // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
+ const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
+ if (!is_mkldnn.get(attrs_.op, false)) {
+ CreateDefaultInputs(in_array, &in_array_fallback);
+ fcompute_(attrs_, op_ctx, in_array_fallback, req, out_array);
+ return;
+ }
#endif
fcompute_(attrs_, op_ctx, in_array, req, out_array);
}
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index cd1db0a..52f7c79 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -86,6 +86,10 @@ class OpExecutor {
virtual OpStatePtr state() const {
return OpStatePtr();
}
+
+ // TODO(alexzai): (MXNET-856) Remove instance member after subgraph feature added
+ protected:
+ std::vector<NDArray> in_array_fallback;
};
/*!
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index b8c2045..ba44ebd 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -155,6 +155,7 @@ The following activation functions are supported:
})
.set_attr<FCompute>("FCompute<cpu>", ActivationCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ActivationComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", ActivationGrad{"_backward_Activation"})
@@ -184,6 +185,7 @@ NNVM_REGISTER_OP(_backward_Activation)
#endif
.set_attr_parser(ParamParser<ActivationParam>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ActivationGradComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", ActivationGradCompute<cpu>);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index b15f84e..4ea494d 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -601,6 +601,7 @@ the sparse tensors will fallback.
#endif
.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
@@ -633,6 +634,7 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
#endif
.set_attr_parser(ParamParser<BatchNormParam>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormGradComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", BatchNormGradCompute<cpu>);
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 9df459e..ac8a814 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -367,6 +367,7 @@ Example::
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
+.set_attr<bool>("TIsMKLDNN", true)
#endif
CONCAT_FORWARD_ATTRS
.set_attr<nnvm::FInferShape>("FInferShape", ConcatShape)
@@ -387,6 +388,7 @@ NNVM_REGISTER_OP(_backward_Concat)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 8f25cf0..d5abe62 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -484,6 +484,7 @@ There are other options to tune the performance.
#endif
.set_attr<FCompute>("FCompute<cpu>", ConvolutionCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConvolutionComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", ConvolutionGrad{"_backward_Convolution"})
@@ -509,6 +510,7 @@ NNVM_REGISTER_OP(_backward_Convolution)
})
.set_attr_parser(ConvolutionParamParser)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConvolutionGradComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", ConvolutionGradCompute<cpu>);
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index a4be1a0..1ab391d 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -413,6 +413,7 @@ NNVM_REGISTER_OP(Deconvolution)
})
.set_attr<FCompute>("FCompute<cpu>", DeconvolutionCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", DeconvolutionComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", DeconvolutionGrad{"_backward_Deconvolution"})
@@ -436,6 +437,7 @@ NNVM_REGISTER_OP(_backward_Deconvolution)
})
.set_attr_parser(DeconvolutionParamParser)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", DeconvolutionGradComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", DeconvolutionGradCompute<cpu>);
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index eb881d2..d8a32f0 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -290,6 +290,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
return std::vector<std::string>{"output"};
})
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
@@ -322,6 +323,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
.set_attr_parser(ParamParser<FullyConnectedParam>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradCompute<cpu>);
diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc
index 587cf93..a428eb1 100644
--- a/src/operator/nn/lrn.cc
+++ b/src/operator/nn/lrn.cc
@@ -180,6 +180,7 @@ number of kernels in the layer.
})
.set_attr<FCompute>("FCompute<cpu>", LRNCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", LRNGrad{"_backward_LRN"})
@@ -194,6 +195,7 @@ NNVM_REGISTER_OP(_backward_LRN)
#endif
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNGradComputeExCPU)
// Native compute requires norm while MKLDNN does not so cannot be compared in debug mode
.set_attr<bool>("TExcludeMKLDNNDebug", true)
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 273afcd..6eb90f8 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -356,6 +356,18 @@ static inline void InvalidateOutputs(const std::vector<NDArray> &arrs,
}
}
+// TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added
+static inline void CreateDefaultInputs(const std::vector<NDArray> &arrs,
+ std::vector<NDArray> *out_arrs) {
+ out_arrs->clear();
+ for (size_t i = 0; i < arrs.size(); ++i) {
+ if (arrs[i].IsMKLDNNData())
+ out_arrs->push_back(arrs[i].Reorder2Default());
+ else
+ out_arrs->push_back(arrs[i]);
+ }
+}
+
const mkldnn::memory *GetWeights(const NDArray &arr,
const mkldnn::memory::primitive_desc &target_pd,
int num_groups);
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 2d11814..c133b63 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -395,6 +395,7 @@ For each window ``X``, the mathematical expression for Lp pooling is:
.set_attr<nnvm::FInferShape>("FInferShape", PoolingShape)
.set_attr<FCompute>("FCompute<cpu>", PoolingCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", PoolingComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient",
@@ -424,6 +425,7 @@ NNVM_REGISTER_OP(_backward_Pooling)
#endif
.set_attr_parser(PoolingParamParser)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", PoolingGradComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", PoolingGradCompute<cpu>);
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 88b7b5f..81e775c 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -98,6 +98,7 @@ Example::
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd>)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
#endif
diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc
index 9630988..1666537 100644
--- a/src/operator/tensor/elemwise_sum.cc
+++ b/src/operator/tensor/elemwise_sum.cc
@@ -179,6 +179,9 @@ The storage type of ``add_n`` output depends on storage types of inputs
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+#endif
.set_attr<nnvm::FInferShape>("FInferShape", ElementWiseSumShape)
.set_attr<nnvm::FInferType>("FInferType", ElementWiseSumType)
.set_attr<FInferStorageType>("FInferStorageType", ElementWiseSumForwardInferStorageType)
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index e09a6cc..eb070a4 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -299,7 +299,11 @@ class UnaryOp : public OpBase {
}
break;
case kWriteInplace:
+// cannot check if ptrs are the same for MKLDNN because we may have
+// created copies of input when reordering. WriteInPlace will still write to original array
+#if MXNET_USE_MKLDNN == 0
CHECK_EQ(inputs[0].dptr_, outputs[0].dptr_);
+#endif
break;
case kNullOp:
break;
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index f7f21f9..c3e9c2d 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -206,6 +206,7 @@ MXNET_OPERATOR_REGISTER_UNARY(_copy)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
+.set_attr<bool>("TIsMKLDNN", true)
#endif
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
[](const NodeAttrs& attrs){
@@ -225,6 +226,7 @@ NNVM_REGISTER_OP(_backward_copy)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", CopyEx)
#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index ba4cf3f..e597d0f5 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -381,6 +381,50 @@ def test_fullyconnected():
for stype in stypes:
check_fullyconnected_training(stype)
+@with_seed()
+def test_non_mkldnn_fcomputeex():
+ # test special case where MKLDNN formatted NDArray feeds into non-mkldnn fcomputeex operator
+ # conv is example where MKLDNN NDArray is created from regular NDArrays
+ # CustomOps is example of non-mkldnn fcomputeex operator
+
+ @mx.operator.register("custom")
+ class CustomProp(mx.operator.CustomOpProp):
+ def __int__(self):
+ super(CustomProp, self).__init__(need_top_grad=False)
+
+ def list_arguments(self):
+ return ['data']
+
+ def list_outputs(self):
+ return ['output']
+
+ def infer_shape(self, in_shape):
+ data_shape = in_shape[0]
+ output_shape = in_shape[0]
+ return [data_shape], [output_shape], []
+
+ def infer_type(self, in_type):
+ dtype = in_type[0]
+ return [dtype], [dtype], []
+
+ def create_operator(self, ctx, shapes, dtypes):
+ return Custom()
+
+
+ class Custom(mx.operator.CustomOp):
+ def forward(self, is_train, req, in_data, out_data, aux):
+ print(in_data[0])
+ self.assign(out_data[0], req[0], in_data[0])
+
+ def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+ self.assign(in_grad[0], req[0], out_grad)
+
+ data = mx.symbol.Variable('data')
+ conv = mx.sym.Convolution(data=data, kernel=(5, 5), pad=(1, 1), stride=(1,1), num_filter=8, name="conv", no_bias=True)
+ custom = mx.symbol.Custom(name='custom', data=conv, op_type='custom')
+ exec1 = custom.bind(mx.cpu(), args={'data': mx.nd.ones([10,3,96,96]), 'conv_weight': mx.nd.ones([8,3,5,5])})
+ exec1.forward()[0].wait_to_read()
+
if __name__ == '__main__':
install.test_mkldnn_install()