You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/30 22:07:33 UTC

[GitHub] anirudh2290 closed pull request #12019: [MXNET-753] Fallback when using non-MKLDNN supported operators

anirudh2290 closed pull request #12019: [MXNET-753] Fallback when using non-MKLDNN supported operators
URL: https://github.com/apache/incubator-mxnet/pull/12019
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index c011c1d9ce0..0e415ef5112 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -159,6 +159,9 @@ class StatefulComputeExExecutor : public OpExecutor {
     op_ctx.run_ctx = rctx;
 #if MXNET_USE_MKLDNN == 1
     InvalidateOutputs(out_array, req);
+    CreateDefaultInputs(in_array, &in_array_fallback);
+    fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
+    return;
 #endif
     fcompute_(state_, op_ctx, in_array, req, out_array);
   }
@@ -226,6 +229,13 @@ class FComputeExExecutor : public OpExecutor {
     op_ctx.run_ctx = rctx;
 #if MXNET_USE_MKLDNN == 1
     InvalidateOutputs(out_array, req);
+    // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
+    const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
+    if (!is_mkldnn.get(attrs_.op, false)) {
+      CreateDefaultInputs(in_array, &in_array_fallback);
+      fcompute_(attrs_, op_ctx, in_array_fallback, req, out_array);
+      return;
+    }
 #endif
     fcompute_(attrs_, op_ctx, in_array, req, out_array);
   }
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index cd1db0ac194..52f7c790c77 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -86,6 +86,10 @@ class OpExecutor {
   virtual OpStatePtr state() const {
     return OpStatePtr();
   }
+
+  // TODO(alexzai): (MXNET-856) Remove instance member after subgraph feature added
+ protected:
+  std::vector<NDArray> in_array_fallback;
 };
 
 /*!
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index b8c2045fba1..ba44ebd4ed4 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -155,6 +155,7 @@ The following activation functions are supported:
 })
 .set_attr<FCompute>("FCompute<cpu>", ActivationCompute<cpu>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ActivationComputeExCPU)
 #endif
 .set_attr<nnvm::FGradient>("FGradient", ActivationGrad{"_backward_Activation"})
@@ -184,6 +185,7 @@ NNVM_REGISTER_OP(_backward_Activation)
 #endif
 .set_attr_parser(ParamParser<ActivationParam>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ActivationGradComputeExCPU)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", ActivationGradCompute<cpu>);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index b15f84e107e..4ea494d64e4 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -601,6 +601,7 @@ the sparse tensors will fallback.
 #endif
 .set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
@@ -633,6 +634,7 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
 #endif
 .set_attr_parser(ParamParser<BatchNormParam>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormGradComputeExCPU)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", BatchNormGradCompute<cpu>);
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 9df459e9224..ac8a814ce70 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -367,6 +367,7 @@ Example::
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
+.set_attr<bool>("TIsMKLDNN", true)
 #endif
 CONCAT_FORWARD_ATTRS
 .set_attr<nnvm::FInferShape>("FInferShape", ConcatShape)
@@ -387,6 +388,7 @@ NNVM_REGISTER_OP(_backward_Concat)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 8f25cf0dcbb..d5abe629123 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -484,6 +484,7 @@ There are other options to tune the performance.
 #endif
 .set_attr<FCompute>("FCompute<cpu>", ConvolutionCompute<cpu>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ConvolutionComputeExCPU)
 #endif
 .set_attr<nnvm::FGradient>("FGradient", ConvolutionGrad{"_backward_Convolution"})
@@ -509,6 +510,7 @@ NNVM_REGISTER_OP(_backward_Convolution)
 })
 .set_attr_parser(ConvolutionParamParser)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ConvolutionGradComputeExCPU)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", ConvolutionGradCompute<cpu>);
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index a4be1a0c56a..1ab391d92b0 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -413,6 +413,7 @@ NNVM_REGISTER_OP(Deconvolution)
 })
 .set_attr<FCompute>("FCompute<cpu>", DeconvolutionCompute<cpu>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", DeconvolutionComputeExCPU)
 #endif
 .set_attr<nnvm::FGradient>("FGradient", DeconvolutionGrad{"_backward_Deconvolution"})
@@ -436,6 +437,7 @@ NNVM_REGISTER_OP(_backward_Deconvolution)
 })
 .set_attr_parser(DeconvolutionParamParser)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", DeconvolutionGradComputeExCPU)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", DeconvolutionGradCompute<cpu>);
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index eb881d29abd..d8a32f0ae96 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -290,6 +290,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
     return std::vector<std::string>{"output"};
 })
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
@@ -322,6 +323,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
 .set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
 .set_attr_parser(ParamParser<FullyConnectedParam>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradCompute<cpu>);
diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc
index 587cf930920..a428eb1e4fa 100644
--- a/src/operator/nn/lrn.cc
+++ b/src/operator/nn/lrn.cc
@@ -180,6 +180,7 @@ number of kernels in the layer.
 })
 .set_attr<FCompute>("FCompute<cpu>", LRNCompute<cpu>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", LRNComputeExCPU)
 #endif
 .set_attr<nnvm::FGradient>("FGradient", LRNGrad{"_backward_LRN"})
@@ -194,6 +195,7 @@ NNVM_REGISTER_OP(_backward_LRN)
 #endif
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", LRNGradComputeExCPU)
 // Native compute requires norm while MKLDNN does not so cannot be compared in debug mode
 .set_attr<bool>("TExcludeMKLDNNDebug", true)
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 273afcd32dc..6eb90f845d3 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -356,6 +356,18 @@ static inline void InvalidateOutputs(const std::vector<NDArray> &arrs,
   }
 }
 
+// TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added
+static inline void CreateDefaultInputs(const std::vector<NDArray> &arrs,
+                                       std::vector<NDArray> *out_arrs) {
+  out_arrs->clear();
+  for (size_t i = 0; i < arrs.size(); ++i) {
+    if (arrs[i].IsMKLDNNData())
+      out_arrs->push_back(arrs[i].Reorder2Default());
+    else
+      out_arrs->push_back(arrs[i]);
+  }
+}
+
 const mkldnn::memory *GetWeights(const NDArray &arr,
                                  const mkldnn::memory::primitive_desc &target_pd,
                                  int num_groups);
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 2d118142bc7..c133b63623a 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -395,6 +395,7 @@ For each window ``X``, the mathematical expression for Lp pooling is:
 .set_attr<nnvm::FInferShape>("FInferShape", PoolingShape)
 .set_attr<FCompute>("FCompute<cpu>", PoolingCompute<cpu>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", PoolingComputeExCPU)
 #endif
 .set_attr<nnvm::FGradient>("FGradient",
@@ -424,6 +425,7 @@ NNVM_REGISTER_OP(_backward_Pooling)
 #endif
 .set_attr_parser(PoolingParamParser)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", PoolingGradComputeExCPU)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", PoolingGradCompute<cpu>);
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 88b7b5fc473..81e775cac52 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -98,6 +98,7 @@ Example::
 })
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd>)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
 .set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
 #endif
diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc
index 9630988165c..1666537e286 100644
--- a/src/operator/tensor/elemwise_sum.cc
+++ b/src/operator/tensor/elemwise_sum.cc
@@ -179,6 +179,9 @@ The storage type of ``add_n`` output depends on storage types of inputs
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+#endif
 .set_attr<nnvm::FInferShape>("FInferShape", ElementWiseSumShape)
 .set_attr<nnvm::FInferType>("FInferType", ElementWiseSumType)
 .set_attr<FInferStorageType>("FInferStorageType", ElementWiseSumForwardInferStorageType)
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index e09a6cccddb..eb070a41127 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -299,7 +299,11 @@ class UnaryOp : public OpBase {
         }
         break;
       case kWriteInplace:
+// cannot check if ptrs are the same for MKLDNN because we may have
+// created copies of input when reordering. WriteInPlace will still write to original array
+#if MXNET_USE_MKLDNN == 0
         CHECK_EQ(inputs[0].dptr_, outputs[0].dptr_);
+#endif
         break;
       case kNullOp:
         break;
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index f7f21f9076a..c3e9c2dc91d 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -206,6 +206,7 @@ MXNET_OPERATOR_REGISTER_UNARY(_copy)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
+.set_attr<bool>("TIsMKLDNN", true)
 #endif
 .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
   [](const NodeAttrs& attrs){
@@ -225,6 +226,7 @@ NNVM_REGISTER_OP(_backward_copy)
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", CopyEx)
 #if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index ba4cf3f0116..e597d0f5fc5 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -381,6 +381,50 @@ def check_fullyconnected_training(stype):
     for stype in stypes:
         check_fullyconnected_training(stype)
 
+@with_seed()
+def test_non_mkldnn_fcomputeex():
+    # test special case where MKLDNN formatted NDArray feeds into non-mkldnn fcomputeex operator
+    # conv is example where MKLDNN NDArray is created from regular NDArrays
+    # CustomOps is example of non-mkldnn fcomputeex operator
+
+    @mx.operator.register("custom")
+    class CustomProp(mx.operator.CustomOpProp):
+        def __int__(self):
+            super(CustomProp, self).__init__(need_top_grad=False)
+
+        def list_arguments(self):
+            return ['data']
+
+        def list_outputs(self):
+            return ['output']
+
+        def infer_shape(self, in_shape):
+            data_shape = in_shape[0]
+            output_shape = in_shape[0]
+            return [data_shape], [output_shape], []
+
+        def infer_type(self, in_type):
+            dtype = in_type[0]
+            return [dtype], [dtype], []
+
+        def create_operator(self, ctx, shapes, dtypes):
+            return Custom()
+
+
+    class Custom(mx.operator.CustomOp):
+        def forward(self, is_train, req, in_data, out_data, aux):
+            print(in_data[0])
+            self.assign(out_data[0], req[0], in_data[0])
+
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            self.assign(in_grad[0], req[0], out_grad)
+
+    data = mx.symbol.Variable('data')
+    conv = mx.sym.Convolution(data=data, kernel=(5, 5), pad=(1, 1), stride=(1,1), num_filter=8, name="conv", no_bias=True)
+    custom = mx.symbol.Custom(name='custom', data=conv, op_type='custom')
+    exec1 = custom.bind(mx.cpu(), args={'data': mx.nd.ones([10,3,96,96]), 'conv_weight': mx.nd.ones([8,3,5,5])})
+    exec1.forward()[0].wait_to_read()
+
 
 if __name__ == '__main__':
     install.test_mkldnn_install()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services