You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/10/11 07:08:18 UTC

[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add MKL-DNN reshape&flatten&expand_dims (#16258)

This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
     new 922b616  [mkldnn-v1.0] Add MKL-DNN reshape&flatten&expand_dims (#16258)
922b616 is described below

commit 922b6162e8b62af98fe211d1e24f1aa10716a0e6
Author: Wuxun Zhang <wu...@intel.com>
AuthorDate: Fri Oct 11 15:07:24 2019 +0800

    [mkldnn-v1.0] Add MKL-DNN reshape&flatten&expand_dims (#16258)
    
    * Add mkldnn 1.0 support for reshape/flatten/expanddims ops
    
    * improve log & modify definition location of args_map_
    
    * fix comments
    
    * rebase code
    
    * trigger CI
    
    * trigger CI
    
    * trigger CI
    
    * trigger CI
---
 src/operator/nn/mkldnn/mkldnn_base-inl.h           |   2 +-
 .../{mkldnn_flatten.cc => mkldnn_expand_dims.cc}   | 149 ++++++++++-----------
 src/operator/nn/mkldnn/mkldnn_flatten-inl.h        |   2 +-
 src/operator/nn/mkldnn/mkldnn_flatten.cc           |   6 +-
 src/operator/nn/mkldnn/mkldnn_ops-inl.h            |  28 ++--
 src/operator/nn/mkldnn/mkldnn_reshape-inl.h        |  33 ++++-
 src/operator/nn/mkldnn/mkldnn_reshape.cc           | 124 +++++++----------
 src/operator/tensor/matrix_op-inl.h                |  14 ++
 src/operator/tensor/matrix_op.cc                   |  54 ++++++--
 9 files changed, 225 insertions(+), 187 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index e4c4b98..c93cdb4 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -189,7 +189,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
 bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
 bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
 bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
-bool SupportMKLDNNReshape(const ReshapeParam &param, const NDArray &data);
+bool SupportMKLDNNReshape(const NDArray &in_data, const NDArray &out_data);
 }  // namespace op
 
 static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_expand_dims.cc
similarity index 50%
copy from src/operator/nn/mkldnn/mkldnn_flatten.cc
copy to src/operator/nn/mkldnn/mkldnn_expand_dims.cc
index 4090eb0..dcd85f1 100644
--- a/src/operator/nn/mkldnn/mkldnn_flatten.cc
+++ b/src/operator/nn/mkldnn/mkldnn_expand_dims.cc
@@ -1,79 +1,70 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file mkldnn_flatten.cc
- * \brief Implement flatten operator by using mkldnn reorder primitive
- * \author Wuxun Zhang
-*/
-
-#if MXNET_USE_MKLDNN == 1
-
-#include "mkldnn_flatten-inl.h"
-
-namespace mxnet {
-namespace op {
-
-static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req,
-                                           const NDArray &input,
-                                           const NDArray &output) {
-#if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<OpSignature,
-                                         MKLDNNFlattenFwd, OpHash> fwds;
-#else
-  static MX_THREAD_LOCAL std::unordered_map<OpSignature,
-                                            MKLDNNFlattenFwd, OpHash> fwds;
-#endif
-  OpSignature key;
-  key.AddSign(req);
-  key.AddSign(input);
-
-  auto it = fwds.find(key);
-  if (it == fwds.end()) {
-    MKLDNNFlattenFwd fwd(req, input, output);
-    it = AddToCache(&fwds, key, fwd);
-  }
-  return it->second;
-}
-
-void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
-                          const OpContext &ctx,
-                          const NDArray &input,
-                          const OpReqType &req,
-                          const NDArray &output) {
-  if (req == kNullOp) return;
-  CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
-
-  auto fwd = GetFlattenForward(req, input, output);
-  auto ws_size = fwd.GetWorkspaceSize();
-  void* ws_ptr = nullptr;
-  if (ws_size) {
-    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
-    mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
-      .get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
-    ws_ptr = reinterpret_cast<void*>(ws.dptr_);
-  }
-
-  fwd.Execute(input, output, ws_ptr);
-}
-
-}  // namespace op
-}  // namespace mxnet
-
-#endif
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_expand_dims.cc
+ * \brief Implement expand_dims operator via MKL-DNN reorder primitive
+ * \author Wuxun Zhang
+*/
+
+#if MXNET_USE_MKLDNN == 100
+
+#include "mkldnn_reshape-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class MKLDNNExpandDimsFwd : public MKLDNNReshapeFwd {
+ public:
+  explicit MKLDNNExpandDimsFwd(const OpReqType &req,
+                               const NDArray &input,
+                               const NDArray &output)
+    : MKLDNNReshapeFwd(req, input, output) {}
+};
+
+typedef ParamOpSign<ExpandDimParam> MKLDNNExpandDimsSignature;
+
+void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs,
+                             const OpContext &ctx,
+                             const NDArray &input,
+                             const OpReqType &req,
+                             const NDArray &output) {
+  const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
+  if (req == kNullOp) return;
+  CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
+
+  auto fwd = GetCachedForward<MKLDNNExpandDimsFwd, ExpandDimParam,
+                              MKLDNNExpandDimsSignature>(param, req, input, output);
+
+  auto ws_size = fwd.GetWorkspaceSize();
+  void* ws_ptr = nullptr;
+  if (ws_size) {
+    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
+      .get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
+    ws_ptr = reinterpret_cast<void*>(ws.dptr_);
+  }
+
+  fwd.Execute(input, output, req, ws_ptr);
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif
diff --git a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h
index ae890d8..89e52cc 100644
--- a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h
@@ -25,7 +25,7 @@
 
 #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include "mkldnn_reshape-inl.h"
 
diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc
index 4090eb0..4058399 100644
--- a/src/operator/nn/mkldnn/mkldnn_flatten.cc
+++ b/src/operator/nn/mkldnn/mkldnn_flatten.cc
@@ -19,11 +19,11 @@
 
 /*!
  * \file mkldnn_flatten.cc
- * \brief Implement flatten operator by using mkldnn reorder primitive
+ * \brief Implement flatten operator via using MKL-DNN reorder primitive
  * \author Wuxun Zhang
 */
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include "mkldnn_flatten-inl.h"
 
@@ -70,7 +70,7 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
     ws_ptr = reinterpret_cast<void*>(ws.dptr_);
   }
 
-  fwd.Execute(input, output, ws_ptr);
+  fwd.Execute(input, output, req, ws_ptr);
 }
 
 }  // namespace op
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 793aad7..ec97c93 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -63,18 +63,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                           const std::vector<NDArray>& inputs,
                           const std::vector<OpReqType>& req,
                           const std::vector<NDArray>& outputs);
-
-void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
-                          const OpContext &ctx,
-                          const NDArray &input,
-                          const OpReqType &req,
-                          const NDArray &output);
-
-void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
-                          const OpContext &ctx,
-                          const NDArray &input,
-                          const OpReqType &req,
-                          const NDArray &output);
 #endif
 
 #if MXNET_USE_MKLDNN == 100
@@ -142,6 +130,22 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
                             const NDArray &data,
                             const OpReqType &req,
                             const NDArray &output);
+
+void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
+                          const OpContext &ctx,
+                          const NDArray &input,
+                          const OpReqType &req,
+                          const NDArray &output);
+void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
+                          const OpContext &ctx,
+                          const NDArray &input,
+                          const OpReqType &req,
+                          const NDArray &output);
+void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs,
+                             const OpContext &ctx,
+                             const NDArray &input,
+                             const OpReqType &req,
+                             const NDArray &output);
 #endif
 
 }  // namespace op
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
index 63e367b..aa0f11c 100644
--- a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
@@ -26,7 +26,7 @@
 #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include <vector>
 #include "mkldnn_base-inl.h"
 #include "../../tensor/matrix_op-inl.h"
@@ -36,7 +36,6 @@ namespace op {
 
 class MKLDNNReshapeFwd {
  protected:
-  std::shared_ptr<mkldnn::memory> data_;
   std::shared_ptr<mkldnn::memory> out_;
   std::shared_ptr<mkldnn::memory> temp_;
   std::vector<mkldnn::primitive> prims_;
@@ -47,15 +46,39 @@ class MKLDNNReshapeFwd {
                    const NDArray &input,
                    const NDArray &output);
   int GetWorkspaceSize();
-  void SetNewMem(const NDArray &input,
-                 const NDArray &output,
-                 void* workspace = nullptr);
   void Execute(const NDArray &input,
                const NDArray &output,
+               const OpReqType &req,
                void* workspace = nullptr);
 };
 
 typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
+
+template<typename MKLDNNOpFwdType, typename ParamType, typename MKLDNNSigatureType>
+MKLDNNOpFwdType &GetCachedForward(const ParamType& param,
+                                  const OpReqType &req,
+                                  const NDArray &input,
+                                  const NDArray &output) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNSigatureType,
+                                         MKLDNNOpFwdType, OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNSigatureType,
+                                            MKLDNNOpFwdType, OpHash> fwds;
+#endif
+  MKLDNNSigatureType key(param);
+  key.AddSign(req);
+  key.AddSign(input);
+  key.AddSign(output);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    MKLDNNOpFwdType fwd(req, input, output);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
 MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
                                     const OpReqType &req,
                                     const NDArray &input,
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc
index 063c85d..d180125 100644
--- a/src/operator/nn/mkldnn/mkldnn_reshape.cc
+++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc
@@ -23,7 +23,7 @@
  * \author Tao Lv
 */
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include <mkldnn.hpp>
 #include "mkldnn_reshape-inl.h"
@@ -31,13 +31,14 @@
 namespace mxnet {
 namespace op {
 
-bool SupportMKLDNNReshape(const ReshapeParam &param,
-                          const NDArray &data) {
-  auto data_ndim = data.shape().ndim();
+bool SupportMKLDNNReshape(const NDArray &in_data,
+                          const NDArray &out_data) {
+  auto in_ndim = in_data.shape().ndim();
+  auto out_ndim = out_data.shape().ndim();
 
-  if (data_ndim > 4 ||
-      data.dtype() != mshadow::kFloat32 ||
-      param.shape.ndim() > 4)
+  if (in_ndim > 4 ||
+      in_data.dtype() != mshadow::kFloat32 ||
+      out_ndim > 4)
     return false;
 
   return true;
@@ -48,21 +49,16 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
                                    const NDArray &output) {
   auto engine = CpuEngine::Get()->get_engine();
 
-  // data_
+  // source
   auto in_mem = input.GetMKLDNNData();
-  auto in_pd = in_mem->get_primitive_desc();
-  data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);
+  auto in_md = in_mem->get_desc();
 
   // temp_
-  auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
-  auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
-  auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
-  auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
-  auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
-  temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
+  auto temp_md = GetDesc(in_md, GetDefaultFormat(in_md));
+  temp_ = std::make_shared<mkldnn::memory>(temp_md, engine, nullptr);
 
   // destination
-  out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
+  out_ = std::make_shared<mkldnn::memory>(temp_md, engine, nullptr);
 
   if (req == kWriteInplace) {
     // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
@@ -70,17 +66,17 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
     // address with input buffer.
     // If the input has default layout, then nothing need to do.
     if (input.IsMKLDNNData()) {
-      prims_.push_back(mkldnn::reorder(*data_, *temp_));   // reorder to default
-      prims_.push_back(mkldnn::reorder(*temp_, *out_));    // copy back
+      prims_.push_back(mkldnn::reorder(*in_mem, *temp_));  // reorder to default
+      prims_.push_back(mkldnn::reorder(*temp_, *out_));   // copy back
       needInvalidateInput = true;
     }
   } else if (req == kWriteTo) {
     if (input.IsMKLDNNData()) {
-      prims_.push_back(mkldnn::reorder(*data_, *temp_));   // reorder to default
-      prims_.push_back(mkldnn::reorder(*temp_, *out_));    // copy to the output buffer
+      prims_.push_back(mkldnn::reorder(*in_mem, *temp_));   // reorder to default
+      prims_.push_back(mkldnn::reorder(*temp_, *out_));     // copy to the output buffer
       needInvalidateInput = false;
     } else {
-      prims_.push_back(mkldnn::reorder(*data_, *out_));    // copy directly from input to output
+      prims_.push_back(mkldnn::reorder(*in_mem, *out_));    // copy directly from input to output
       needInvalidateInput = false;
     }
   } else {
@@ -89,42 +85,36 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
 }
 
 int MKLDNNReshapeFwd::GetWorkspaceSize() {
-  return temp_ ? temp_->get_primitive_desc().get_size() : 0;
-}
-
-void MKLDNNReshapeFwd::SetNewMem(const NDArray &input,
-                                 const NDArray &output,
-                                 void* workspace) {
-  if (input.IsMKLDNNData()) {
-    this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle());
-  } else {
-    MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, {
-      this->data_->set_data_handle(input.data().dptr<DTYPE>());
-    })
-  }
-
-  if (output.IsMKLDNNData()) {
-    this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle());
-  } else {
-    MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
-      this->out_->set_data_handle(output.data().dptr<DTYPE>());
-    })
-  }
-
-  if (workspace) {
-    this->temp_->set_data_handle(workspace);
-  }
+  return temp_ ? temp_->get_desc().get_size() : 0;
 }
 
 void MKLDNNReshapeFwd::Execute(const NDArray &input,
                                const NDArray &output,
+                               const OpReqType &req,
                                void* workspace) {
-  // set memory handles
-  SetNewMem(input, output, workspace);
-  // register primitives
   auto stream = MKLDNNStream::Get();
-  for (auto &v : this->prims_) {
-    stream->RegisterPrim(v);
+  auto in_mem = input.GetMKLDNNData();
+  // register primitives and arguments
+  std::vector<mkldnn_args_map_t> args_map;
+  size_t prims_size = prims_.size();
+  if (prims_size == 1) {
+    args_map.push_back({{MKLDNN_ARG_FROM, *in_mem},
+                        {MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
+  } else if (prims_size == 2) {
+    if (workspace) {
+      temp_->set_data_handle(workspace);
+    }
+    args_map.push_back({{MKLDNN_ARG_FROM, *in_mem},
+                        {MKLDNN_ARG_TO, *temp_}});
+    args_map.push_back({{MKLDNN_ARG_FROM, *temp_},
+                        {MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
+  } else {
+    CHECK(prims_size == 0 && req != kWriteTo)
+          << "kWriteTo should never reach here.";
+  }
+
+  for (size_t i = 0; i < prims_size; i++) {
+    stream->RegisterPrimArgs(prims_[i], args_map[i]);
   }
   stream->Submit();
   // invalidate mkldnn memory in input
@@ -133,30 +123,6 @@ void MKLDNNReshapeFwd::Execute(const NDArray &input,
   }
 }
 
-MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
-                                    const OpReqType &req,
-                                    const NDArray &input,
-                                    const NDArray &output) {
-#if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<MKLDNNReshapeSignature,
-                                         MKLDNNReshapeFwd, OpHash> fwds;
-#else
-  static MX_THREAD_LOCAL std::unordered_map<MKLDNNReshapeSignature,
-                                            MKLDNNReshapeFwd, OpHash> fwds;
-#endif
-  MKLDNNReshapeSignature key(param);
-  key.AddSign(req);
-  key.AddSign(input);
-  key.AddSign(output);
-
-  auto it = fwds.find(key);
-  if (it == fwds.end()) {
-    MKLDNNReshapeFwd fwd(req, input, output);
-    it = AddToCache(&fwds, key, fwd);
-  }
-  return it->second;
-}
-
 void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
                           const OpContext &ctx,
                           const NDArray &input,
@@ -166,7 +132,9 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
   if (req == kNullOp) return;
   CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
 
-  auto fwd = GetReshapeForward(param, req, input, output);
+  auto fwd = GetCachedForward<MKLDNNReshapeFwd, ReshapeParam,
+                              MKLDNNReshapeSignature>(param, req, input, output);
+
   auto ws_size = fwd.GetWorkspaceSize();
   void* ws_ptr = nullptr;
   if (ws_size) {
@@ -176,7 +144,7 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
     ws_ptr = reinterpret_cast<void*>(ws.dptr_);
   }
 
-  fwd.Execute(input, output, ws_ptr);
+  fwd.Execute(input, output, req, ws_ptr);
 }
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 5a2bd03..3f1a5f8 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -394,6 +394,10 @@ struct ExpandDimParam : public dmlc::Parameter<ExpandDimParam> {
               "the input `NDArray`'s dimension is `ndim`, the range of "
               "the inserted axis is `[-ndim, ndim]`");
   }
+
+  bool operator==(const ExpandDimParam &other) const {
+    return this->axis == other.axis;
+  }
 };
 
 
@@ -2936,6 +2940,16 @@ struct hash<mxnet::op::ReshapeParam> {
     return ret;
   }
 };
+
+template<>
+struct hash<mxnet::op::ExpandDimParam> {
+  size_t operator()(const mxnet::op::ExpandDimParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.axis);
+    return ret;
+  }
+};
+
 }  // namespace std
 
 #endif  // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index a4f0db0..6bf1ec0 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -25,9 +25,11 @@
 // this will be invoked by gcc and compile CPU version
 #include "./matrix_op-inl.h"
 #include "./elemwise_unary_op.h"
+#if MXNET_USE_MKLDNN == 100
 #include "../nn/mkldnn/mkldnn_ops-inl.h"
 #include "../nn/mkldnn/mkldnn_base-inl.h"
 #include "../nn/mkldnn/mkldnn_slice-inl.h"
+#endif
 
 namespace mxnet {
 namespace op {
@@ -105,19 +107,18 @@ DMLC_REGISTER_PARAMETER(SqueezeParam);
 DMLC_REGISTER_PARAMETER(DepthToSpaceParam);
 DMLC_REGISTER_PARAMETER(SplitParam);
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const OpContext& ctx,
                                 const std::vector<NDArray>& inputs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
-  const ReshapeParam& param = nnvm::get<ReshapeParam>(attrs.parsed);
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   // If inputs are supposed to be in MKLDNN format and
   // MKLDNNsupport the data type or the shape. Then convert
   // it to the output format and shape
-  if (SupportMKLDNNReshape(param, inputs[0])) {
+  if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
     MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
     return;
   }
@@ -207,7 +208,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ReshapeComputeExCPU)
 .set_attr<FInferStorageType>("FInferStorageType", ReshapeStorageType)
@@ -233,7 +234,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
                       const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   auto data_ndim = inputs[0].shape().ndim();
   if (data_ndim <= 4 && inputs[0].dtype() == mshadow::kFloat32) {
     MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]);
@@ -248,7 +249,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
 #endif
 }
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs,
                                       const int dev_mask,
                                       DispatchMode* dispatch_mode,
@@ -294,13 +295,13 @@ Example::
 .set_num_outputs(1)
 .set_attr<mxnet::FInferShape>("FInferShape", FlattenShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<FInferStorageType>("FInferStorageType", FlattenStorageType)
 #endif
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_copy" })
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", FlattenEx)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -411,6 +412,33 @@ Examples::
 .add_arguments(TransposeParam::__FIELDS__());
 
 
+#if MXNET_USE_MKLDNN == 100
+static void ExpandDimEx(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<NDArray>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  auto data_ndim = inputs[0].shape().ndim();
+  if (data_ndim <= 3 && inputs[0].dtype() == mshadow::kFloat32) {
+    MKLDNNExpandDimsForward(attrs, ctx, inputs[0], req[0], outputs[0]);
+    return;
+  }
+  FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs,
+                                        const int dev_mask,
+                                        DispatchMode* dispatch_mode,
+                                        std::vector<int>* in_attrs,
+                                        std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
+}
+#endif
+
 NNVM_REGISTER_OP(expand_dims)
 .add_alias("_npi_expand_dims")
 .describe(R"code(Inserts a new axis of size 1 into the array shape
@@ -424,6 +452,9 @@ will return a new array with shape ``(2,1,3,4)``.
 .set_attr_parser(ParamParser<ExpandDimParam>)
 .set_attr<mxnet::FInferShape>("FInferShape", ExpandDimShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+#if MXNET_USE_MKLDNN == 100
+.set_attr<FInferStorageType>("FInferStorageType", ExpandDimStorageType)
+#endif
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
     return std::vector<std::pair<int, int> >{{0, 0}};
@@ -434,6 +465,13 @@ will return a new array with shape ``(2,1,3,4)``.
   })
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
+#if MXNET_USE_MKLDNN == 100
+.set_attr<FComputeEx>("FComputeEx<cpu>", ExpandDimEx)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+#endif
 .add_argument("data", "NDArray-or-Symbol", "Source input")
 .add_arguments(ExpandDimParam::__FIELDS__());