You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/08/24 14:46:35 UTC

[incubator-mxnet] branch master updated: [operator] Integrate oneDNN layer normalization implementation (#19562)

This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 695ba2e  [operator] Integrate oneDNN layer normalization implementation (#19562)
695ba2e is described below

commit 695ba2e5bd4506050bc768d5ff5964994fc5d534
Author: bartekkuncer <ba...@intel.com>
AuthorDate: Tue Aug 24 16:44:11 2021 +0200

    [operator] Integrate oneDNN layer normalization implementation (#19562)
    
    * [operator] Integrate oneDNN layer normalization implementation
    
    * change sizeof(float) to mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype())
    
    * remove eps from key and unify layernorm_fwd_t/mkldnn::layer_normalization_forward
    
    * add author
---
 src/operator/nn/layer_norm-inl.h               |  26 ++-
 src/operator/nn/layer_norm.cc                  |  66 +++++++
 src/operator/nn/layer_norm.cu                  |   4 +
 src/operator/nn/mkldnn/mkldnn_base-inl.h       |   2 +
 src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h  |   1 +
 src/operator/nn/mkldnn/mkldnn_batch_dot.cc     |   1 +
 src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h | 103 ++++++++++
 src/operator/nn/mkldnn/mkldnn_layer_norm.cc    | 260 +++++++++++++++++++++++++
 src/operator/nn/mkldnn/mkldnn_ops-inl.h        |  12 ++
 9 files changed, 474 insertions(+), 1 deletion(-)

diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h
index 79d0906..b55ef29 100644
--- a/src/operator/nn/layer_norm-inl.h
+++ b/src/operator/nn/layer_norm-inl.h
@@ -45,7 +45,9 @@ namespace op {
 
 namespace layernorm {
 enum LayerNormOpInputs {kData, kGamma, kBeta};  // kGamma: scaling parameters, kBeta: shift biases
-enum LayerNormOpOutputs {kOut, kMean, kStd};  // req, out_data
+enum LayerNormOpOutputs {kOut, kMean, kStd};  // indices for req, out_data
+enum LayerNormOpInputsBwd {kBwdOutGrad, kBwdData, kBwdGamma, kBwdMean, kBwdStd, kBwdBeta};
+enum LayerNormOpOutputsBwd {kBwdDataGrad, kBwdGammaGrad, kBwdBetaGrad};
 }  // namespace layernorm
 
 struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
@@ -71,6 +73,11 @@ struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
     (*dict)["eps"] = eps_s.str();
     (*dict)["output_mean_var"] = output_mean_var_s.str();
   }
+
+  bool operator==(const LayerNormParam& other) const {
+    return (this->axis == other.axis && this->eps == other.eps &&
+            this->output_mean_var == other.output_mean_var);
+  }
 };
 
 inline int GetRealAxis(int axis, int ndim) {
@@ -257,7 +264,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
                                  const std::vector<TBlob>& outputs) {
   using namespace mshadow;
   using namespace mshadow::expr;
+#if MXNET_USE_ONEDNN == 1
+  CHECK_EQ(inputs.size(), 6U);  // additional beta tensor
+#else
   CHECK_EQ(inputs.size(), 5U);
+#endif
   const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
   int axis = param.axis;
   if (axis < 0) {
@@ -313,4 +324,17 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
 
 }  // namespace op
 }  // namespace mxnet
+
+namespace std {
+template <>
+struct hash<mxnet::op::LayerNormParam> {
+  size_t operator()(const mxnet::op::LayerNormParam& val) {
+    size_t ret = 0;
+    ret        = dmlc::HashCombine(ret, val.axis);
+    ret        = dmlc::HashCombine(ret, val.eps);
+    ret        = dmlc::HashCombine(ret, val.output_mean_var);
+    return ret;
+  }
+};
+}  // namespace std
 #endif  // MXNET_OPERATOR_NN_LAYER_NORM_INL_H_
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 1a040fa..4e8a80e 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -57,6 +57,10 @@
 #include "layer_norm-inl.h"
 #include <nnvm/op_attr_types.h>
 #include "../elemwise_op_common.h"
+#if MXNET_USE_ONEDNN == 1
+#include "./mkldnn/mkldnn_base-inl.h"
+#include "./mkldnn/mkldnn_ops-inl.h"
+#endif  // MXNET_USE_ONEDNN
 
 #if MSHADOW_USE_MKL == 1
 #include "../mkl_functions-inl.h"
@@ -392,6 +396,50 @@ void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
   return LayerNormGradComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
 }
 
+#if MXNET_USE_ONEDNN == 1
+static bool LayerNormInferStorageType(const nnvm::NodeAttrs& attrs,
+                                      const int dev_mask,
+                                      DispatchMode* dispatch_mode,
+                                      std::vector<int>* in_attrs,
+                                      std::vector<int>* out_attrs) {
+  CHECK(!in_attrs->empty());
+
+  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
+}
+
+static void LayerNormComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                  const OpContext& ctx,
+                                  const std::vector<NDArray>& inputs,
+                                  const std::vector<OpReqType>& req,
+                                  const std::vector<NDArray>& outputs) {
+  const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+  if (SupportMKLDNNLayerNorm(param, inputs)) {
+    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    MKLDNNRun(MKLDNNLayerNormForward, attrs, ctx, inputs, req, outputs);
+    MKLDNN_OPCHECK_RUN(LayerNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    return;
+  } else {
+    FallBackCompute(LayerNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
+  }
+}
+
+static void LayerNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                      const OpContext& ctx,
+                                      const std::vector<NDArray>& inputs,
+                                      const std::vector<OpReqType>& req,
+                                      const std::vector<NDArray>& outputs) {
+  const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+  if (SupportMKLDNNLayerNorm(param, inputs)) {
+    MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
+    MKLDNNRun(MKLDNNLayerNormBackward, attrs, ctx, inputs, req, outputs);
+    MKLDNN_OPCHECK_RUN(LayerNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    return;
+  } else {
+    FallBackCompute(LayerNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+  }
+}
+#endif
+
 NNVM_REGISTER_OP(LayerNorm)
 .add_alias("_npx_layer_norm")
 .describe(R"code(Layer normalization.
@@ -439,6 +487,11 @@ axis to be the last item in the input shape.
 .set_attr<mxnet::FInferShape>("FInferShape", LayerNormShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 3>)
 .set_attr<FCompute>("FCompute<cpu>", LayerNormCompute<cpu>)
+#if MXNET_USE_ONEDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FInferStorageType>("FInferStorageType", LayerNormInferStorageType)
+.set_attr<FComputeEx>("FComputeEx<cpu>", LayerNormComputeExCPU)
+#endif
 .set_attr<nnvm::FGradient>("FGradient", [](const nnvm::ObjectPtr& n,
                                            const std::vector<nnvm::NodeEntry>& ograds) {
   std::vector<nnvm::NodeEntry> heads;
@@ -447,6 +500,10 @@ axis to be the last item in the input shape.
   heads.push_back(n->inputs[1]);  // gamma
   heads.emplace_back(n, 1, 0);  // mean
   heads.emplace_back(n, 2, 0);  // std
+#if MXNET_USE_ONEDNN == 1
+  heads.push_back(n->inputs[2]);  // beta - needed for MKLDNN backward propagation;
+                                  // added at the end in case of fallback to non MKLDNN version
+#endif
   return MakeGradNode("_backward_LayerNorm", n, heads, n->attrs.dict);
 })
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
@@ -464,11 +521,20 @@ axis to be the last item in the input shape.
 
 
 NNVM_REGISTER_OP(_backward_LayerNorm)
+#if MXNET_USE_ONEDNN == 1
+.set_num_inputs(6)
+#else
 .set_num_inputs(5)
+#endif
 .set_num_outputs(3)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr_parser(ParamParser<LayerNormParam>)
 .set_attr<FCompute>("FCompute<cpu>", LayerNormGradCompute<cpu>)
+#if MXNET_USE_ONEDNN == 1
+.set_attr<FInferStorageType>("FInferStorageType", LayerNormInferStorageType)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FComputeEx>("FComputeEx<cpu>", LayerNormGradComputeExCPU)
+#endif
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 });
diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu
index 9a33e06..1e719ed 100644
--- a/src/operator/nn/layer_norm.cu
+++ b/src/operator/nn/layer_norm.cu
@@ -689,7 +689,11 @@ void LayerNormGradGPUContig(const LayerNormParam param,
                             const std::vector<OpReqType>& req,
                             const std::vector<TBlob>& outputs) {
   using namespace mshadow;
+#if MXNET_USE_ONEDNN == 1
+  CHECK_EQ(inputs.size(), 6U);  // additional beta tensor
+#else
   CHECK_EQ(inputs.size(), 5U);
+#endif
   const TBlob out_grad = inputs[0];
   const TBlob in_data = inputs[1];
   const TBlob gamma = inputs[2];
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 2cef524..2ee0793 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -202,6 +202,7 @@ struct SoftmaxParam;
 struct SoftmaxOutputParam;
 struct TransposeParam;
 struct ReshapeParam;
+struct LayerNormParam;
 bool SupportMKLDNNAct(const ActivationParam& param);
 bool SupportMKLDNNAct(const ActivationParam& param, const NDArray& input);
 bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param);
@@ -216,6 +217,7 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param,
 bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam& param);
 bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray& data);
 bool SupportMKLDNNBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
+bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector<NDArray> &inputs);
 }  // namespace op
 
 static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
index 34c3eb9..2459ea1 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
@@ -19,6 +19,7 @@
 
 /*!
  * \file mkldnn_batch_dot-inl.h
+ * \author: Bartosz Kuncer, bartosz.kuncer@intel.com
  */
 
 #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
index f7c93ef..87ddb98 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
+++ b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
@@ -19,6 +19,7 @@
 
 /*!
  * \file mkldnn_batch_dot.cc
+ * \author: Bartosz Kuncer, bartosz.kuncer@intel.com
  */
 
 #if MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h
new file mode 100644
index 0000000..a14673b
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_layer_norm-inl.h
+ * \author: Bartosz Kuncer, bartosz.kuncer@intel.com
+ */
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <utility>
+#include <vector>
+
+#include "../layer_norm-inl.h"
+#include "./mkldnn_base-inl.h"
+#include "./mkldnn_ops-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using layernorm_fwd_t    = mkldnn::layer_normalization_forward;
+using layernorm_fwd_pd_t = mkldnn::layer_normalization_forward::primitive_desc;
+
+using layernorm_bwd_t    = mkldnn::layer_normalization_backward;
+using layernorm_bwd_pd_t = mkldnn::layer_normalization_backward::primitive_desc;
+
+typedef ParamOpSign<LayerNormParam> LayerNormSignature;
+
+class MKLDNNLayerNormFwd {
+ public:
+  static MKLDNNLayerNormFwd& GetCached(const LayerNormParam& param,
+                                       const OpContext& ctx,
+                                       const NDArray& data);
+
+  MKLDNNLayerNormFwd(const LayerNormParam& param, const NDArray& data);
+
+  static std::shared_ptr<layernorm_fwd_pd_t> CreatePrimitiveDesc(
+      const LayerNormParam& param,
+      const mkldnn::memory::desc& src_md);
+
+  void Execute(const LayerNormParam& param,
+               const OpContext& ctx,
+               const std::vector<NDArray>& inputs,
+               const OpReqType& req,
+               const std::vector<NDArray>& outputs) const;
+
+  ~MKLDNNLayerNormFwd() {}
+
+ private:
+  std::shared_ptr<layernorm_fwd_t> fwd;
+  std::shared_ptr<layernorm_fwd_pd_t> fwd_pd;
+};
+
+class MKLDNNLayerNormBwd {
+ public:
+  static MKLDNNLayerNormBwd& GetCached(const LayerNormParam& param,
+                                       const std::vector<NDArray>& inputs);
+
+  MKLDNNLayerNormBwd(const LayerNormParam& param,
+                     const std::vector<NDArray>& inputs,
+                     const mkldnn::memory::desc& data_md,
+                     const mkldnn::memory::desc& diff_md);
+
+  static std::shared_ptr<layernorm_bwd_pd_t> CreatePrimitiveDesc(
+      const LayerNormParam& param,
+      const mkldnn::memory::desc& data_md,
+      const mkldnn::memory::desc& diff_md,
+      const layernorm_fwd_pd_t& layernorm_fwd_pd);
+
+  void Execute(const std::vector<NDArray>& inputs,
+               const std::vector<NDArray>& outputs,
+               const std::vector<OpReqType>& req) const;
+
+  ~MKLDNNLayerNormBwd() {}
+
+ private:
+  std::shared_ptr<layernorm_bwd_t> bwd;
+  std::shared_ptr<layernorm_fwd_pd_t> fwd_pd;
+  std::shared_ptr<layernorm_bwd_pd_t> bwd_pd;
+};
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_ONEDNN == 1
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H__
diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc
new file mode 100644
index 0000000..8b8e122
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_layer_norm.cc
+ * \author: Bartosz Kuncer, bartosz.kuncer@intel.com
+ */
+
+#if MXNET_USE_ONEDNN == 1
+
+#include "./mkldnn_layer_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs) {
+  const mxnet::TShape& shape = inputs[layernorm::kData].shape();
+
+  // Native implementation (which can be found in function LayerNormCPU) is faster than oneDNN's one
+  // for small tensors. Below is the heuristic based on measurements on clx machine deciding whether
+  // the shape is better for oneDNN or native implementation.
+  auto ShapeBetterForMKLDNN = [](const mxnet::TShape& shape) {
+    constexpr size_t shapeLimit = 1024;
+    return shape.Size() / shape[0] >= shapeLimit && shape[0] >= shapeLimit;
+  };
+
+  return (ShapeBetterForMKLDNN(shape) &&
+          (GetRealAxis(param.axis, shape.ndim()) == shape.ndim() - 1) && (shape.ndim() >= 2) &&
+          (shape.ndim() <= 5) &&
+          (inputs[layernorm::kData].dtype() == mshadow::kFloat32 ||
+           inputs[layernorm::kData].dtype() == mshadow::kBfloat16) &&
+          inputs[layernorm::kGamma].dtype() == mshadow::kFloat32 &&
+          inputs[layernorm::kBeta].dtype() == mshadow::kFloat32);
+}
+
+void MKLDNNLayerNormForward(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<NDArray>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<NDArray>& outputs) {
+  const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+  const auto& fwd             = MKLDNNLayerNormFwd::GetCached(param, ctx, inputs[layernorm::kData]);
+  fwd.Execute(param, ctx, inputs, req[layernorm::kOut], outputs);
+}
+
+MKLDNNLayerNormFwd& MKLDNNLayerNormFwd::GetCached(const LayerNormParam& param,
+                                                  const OpContext& ctx,
+                                                  const NDArray& data) {
+  using layernorm_fwd_map = std::unordered_map<LayerNormSignature, MKLDNNLayerNormFwd, OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local layernorm_fwd_map layer_norm_fwds;
+#else
+  static MX_THREAD_LOCAL layernorm_fwd_map layer_norm_fwds;
+#endif
+
+  LayerNormSignature key(param);
+  key.AddSign(data);
+
+  auto it = layer_norm_fwds.find(key);
+  if (it == layer_norm_fwds.end()) {
+    MKLDNNLayerNormFwd fwd(param, data);
+    it = AddToCache(&layer_norm_fwds, key, fwd);
+  }
+  return it->second;
+}
+
+MKLDNNLayerNormFwd::MKLDNNLayerNormFwd(const LayerNormParam& param, const NDArray& data) {
+  const mkldnn::memory::desc data_md = data.GetMKLDNNData()->get_desc();
+  fwd_pd                             = CreatePrimitiveDesc(param, data_md);
+  fwd = std::make_shared<layernorm_fwd_t>(*fwd_pd);
+}
+
+std::shared_ptr<layernorm_fwd_pd_t> MKLDNNLayerNormFwd::CreatePrimitiveDesc(
+    const LayerNormParam& param,
+    const mkldnn::memory::desc& src_md) {
+  layernorm_fwd_t::desc fwd_desc(mkldnn::prop_kind::forward_training,
+                                 src_md,
+                                 param.eps,
+                                 dnnl::normalization_flags::use_scale_shift);
+  mkldnn::engine& engine = CpuEngine::Get()->get_engine();
+  return std::make_shared<layernorm_fwd_pd_t>(fwd_desc, engine);
+}
+
+inline mkldnn::memory::desc GetMeanVarDesc(const mkldnn::memory::data_type& dtype,
+                                           const mxnet::TShape& _shape) {
+  const auto ndim = _shape.ndim();
+
+  mkldnn::memory::dims shape(ndim, 1), strides(ndim, 1);
+  shape[0] = _shape[0];
+  for (int i = ndim - 1; i > 0; --i) {
+    shape[i]       = _shape[i];
+    strides[i - 1] = strides[i] * shape[i];
+  }
+
+  return mkldnn::memory::desc{shape, dtype, strides};
+}
+
+inline mkldnn::memory GetScaleShiftMem(const NDArray& gamma, const NDArray& beta) {
+  // OneDNN takes gamma and beta as one SCALE_SHIFT tensor when both scale and shift are used. In
+  // mxnet scale is called gamma and shift is called beta.
+  constexpr size_t gammaAndBeta = 2;
+  CHECK_EQ(gamma.shape()[0], beta.shape()[0]);
+  const mkldnn::memory::desc scale_shift_md(mkldnn::memory::dims{gammaAndBeta, gamma.shape()[0]},
+                                            get_mkldnn_type(gamma.dtype()),
+                                            mkldnn::memory::format_tag::nc);
+  auto scale_shift_mem = mkldnn::memory(scale_shift_md, CpuEngine::Get()->get_engine());
+  char* ptr            = reinterpret_cast<char*>(scale_shift_mem.get_data_handle());
+  const size_t bytes   = scale_shift_md.get_size() / gammaAndBeta;
+  memcpy(ptr, gamma.data().dptr_, bytes);
+  memcpy(ptr + bytes, beta.data().dptr_, bytes);
+  return scale_shift_mem;
+}
+
+void MKLDNNLayerNormFwd::Execute(const LayerNormParam& param,
+                                 const OpContext& ctx,
+                                 const std::vector<NDArray>& inputs,
+                                 const OpReqType& req,
+                                 const std::vector<NDArray>& outputs) const {
+  auto mean_var_md = GetMeanVarDesc(get_mkldnn_type(outputs[layernorm::kMean].dtype()),
+                                    outputs[layernorm::kMean].shape());
+  auto mean_mem    = mkldnn_output_t(
+      OutDataOp::Noop,
+      const_cast<NDArray&>(outputs[layernorm::kMean]).CreateMKLDNNData(mean_var_md));
+  auto variance_mem =
+      mkldnn_output_t(OutDataOp::Noop,
+                      const_cast<NDArray&>(outputs[layernorm::kStd]).CreateMKLDNNData(mean_var_md));
+
+  auto output_mem      = CreateMKLDNNMem(outputs[layernorm::kOut], fwd_pd->dst_desc(), req);
+  auto scale_shift_mem = GetScaleShiftMem(inputs[layernorm::kGamma], inputs[layernorm::kBeta]);
+
+  mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *inputs[layernorm::kData].GetMKLDNNData()},
+                            {MKLDNN_ARG_DST, *output_mem.second},
+                            {MKLDNN_ARG_MEAN, *mean_mem.second},
+                            {MKLDNN_ARG_VARIANCE, *variance_mem.second},
+                            {MKLDNN_ARG_SCALE_SHIFT, scale_shift_mem}};
+
+  MKLDNNStream::Get()->RegisterPrimArgs(*fwd, args);
+  CommitOutput(outputs[layernorm::kOut], output_mem);
+  CommitOutput(outputs[layernorm::kMean], mean_mem);
+  CommitOutput(outputs[layernorm::kStd], variance_mem);
+  MKLDNNStream::Get()->Submit();
+}
+
+MKLDNNLayerNormBwd::MKLDNNLayerNormBwd(const LayerNormParam& param,
+                                       const std::vector<NDArray>& inputs,
+                                       const mkldnn::memory::desc& data_md,
+                                       const mkldnn::memory::desc& diff_md)
+    : fwd_pd(MKLDNNLayerNormFwd::CreatePrimitiveDesc(param, data_md)),
+      bwd_pd(CreatePrimitiveDesc(param, data_md, diff_md, *fwd_pd)) {
+  bwd = std::make_shared<layernorm_bwd_t>(*bwd_pd);
+}
+
+std::shared_ptr<layernorm_bwd_pd_t> MKLDNNLayerNormBwd::CreatePrimitiveDesc(
+    const LayerNormParam& param,
+    const mkldnn::memory::desc& data_md,
+    const mkldnn::memory::desc& diff_md,
+    const layernorm_fwd_pd_t& layernorm_fwd_pd) {
+  layernorm_bwd_t::desc layernorm_bwd_desc(dnnl::prop_kind::backward,
+                                           diff_md,
+                                           data_md,
+                                           param.eps,
+                                           dnnl::normalization_flags::use_scale_shift);
+  mkldnn::engine& engine = CpuEngine::Get()->get_engine();
+  return std::make_shared<layernorm_bwd_pd_t>(layernorm_bwd_desc, engine, layernorm_fwd_pd);
+}
+
+void MKLDNNLayerNormBwd::Execute(const std::vector<NDArray>& inputs,
+                                 const std::vector<NDArray>& outputs,
+                                 const std::vector<OpReqType>& req) const {
+  auto scale_shift_mem =
+      GetScaleShiftMem(inputs[layernorm::kBwdGamma], inputs[layernorm::kBwdBeta]);
+  auto diff_weights_ndarray = NDArray(scale_shift_mem.get_desc());
+  const auto bytes          = inputs[layernorm::kBwdGamma].shape()[0] *
+                     mshadow::mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype());
+  const auto diff_weights_ndaray_data_ptr_plus_bytes = reinterpret_cast<void*>(
+      reinterpret_cast<std::uintptr_t>(diff_weights_ndarray.data().dptr_) + bytes);
+  if (req[layernorm::kBwdGammaGrad] == kAddTo) {
+    memcpy(
+        diff_weights_ndarray.data().dptr_, outputs[layernorm::kBwdGammaGrad].data().dptr_, bytes);
+    memcpy(diff_weights_ndaray_data_ptr_plus_bytes,
+           outputs[layernorm::kBwdBetaGrad].data().dptr_,
+           bytes);
+  }
+  mkldnn_output_t diff_src_mem = CreateMKLDNNMem(
+      outputs[layernorm::kBwdDataGrad], bwd_pd->diff_src_desc(), req[layernorm::kBwdDataGrad]);
+  mkldnn_output_t diff_weights_mem = CreateMKLDNNMem(
+      diff_weights_ndarray, bwd_pd->diff_weights_desc(), req[layernorm::kBwdGammaGrad]);
+  mkldnn_args_map_t args = {{MKLDNN_ARG_DIFF_DST, *inputs[layernorm::kBwdOutGrad].GetMKLDNNData()},
+                            {MKLDNN_ARG_SRC, *inputs[layernorm::kBwdData].GetMKLDNNData()},
+                            {MKLDNN_ARG_SCALE_SHIFT, scale_shift_mem},
+                            {MKLDNN_ARG_MEAN, *inputs[layernorm::kBwdMean].GetMKLDNNData()},
+                            {MKLDNN_ARG_VARIANCE, *inputs[layernorm::kBwdStd].GetMKLDNNData()},
+                            {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second},
+                            {MKLDNN_ARG_DIFF_SCALE_SHIFT, *diff_weights_mem.second}};
+  MKLDNNStream::Get()->RegisterPrimArgs(*bwd, args);
+  CommitOutput(outputs[layernorm::kBwdDataGrad], diff_src_mem);
+  CommitOutput(diff_weights_ndarray, diff_weights_mem);
+  MKLDNNStream::Get()->Submit();
+  // Commit scale_shift diff
+  memcpy(outputs[layernorm::kBwdGammaGrad].data().dptr_, diff_weights_ndarray.data().dptr_, bytes);
+  memcpy(outputs[layernorm::kBwdBetaGrad].data().dptr_,
+         diff_weights_ndaray_data_ptr_plus_bytes,
+         bytes);
+}
+
+MKLDNNLayerNormBwd& MKLDNNLayerNormBwd::GetCached(const LayerNormParam& param,
+                                                  const std::vector<NDArray>& inputs) {
+  using layernorm_bwd_map = std::unordered_map<LayerNormSignature, MKLDNNLayerNormBwd, OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local layernorm_bwd_map layer_norm_bwds;
+#else
+  static MX_THREAD_LOCAL layernorm_bwd_map layer_norm_bwds;
+#endif
+  LayerNormSignature key(param);
+  key.AddSign(inputs[layernorm::kBwdOutGrad]);
+  key.AddSign(inputs[layernorm::kBwdData]);
+  key.AddSign(inputs[layernorm::kBwdGamma]);
+  key.AddSign(inputs[layernorm::kBwdMean]);
+  key.AddSign(inputs[layernorm::kBwdStd]);
+  key.AddSign(inputs[layernorm::kBwdBeta]);
+
+  auto it = layer_norm_bwds.find(key);
+  if (it == layer_norm_bwds.end()) {
+    const mkldnn::memory::desc data_md = inputs[layernorm::kBwdData].GetMKLDNNData()->get_desc();
+    const mkldnn::memory::desc diff_md = inputs[layernorm::kBwdOutGrad].GetMKLDNNData()->get_desc();
+    MKLDNNLayerNormBwd bwd(param, inputs, data_md, diff_md);
+    it = AddToCache(&layer_norm_bwds, key, bwd);
+  }
+  return it->second;
+}
+
+void MKLDNNLayerNormBackward(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<NDArray>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<NDArray>& outputs) {
+  const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+  MKLDNNLayerNormBwd& bwd     = MKLDNNLayerNormBwd::GetCached(param, inputs);
+  bwd.Execute(inputs, outputs, req);
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 3b8c39f..44a6b8f 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -165,6 +165,18 @@ void MKLDNNBatchDotForward(const nnvm::NodeAttrs& attrs,
                            const std::vector<OpReqType>& req,
                            const std::vector<NDArray>& outputs);
 
+/* For layer normalization */
+void MKLDNNLayerNormForward(const nnvm::NodeAttrs &attrs,
+                            const OpContext &ctx,
+                            const std::vector<NDArray> &inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<NDArray> &outputs);
+void MKLDNNLayerNormBackward(const nnvm::NodeAttrs &attrs,
+                             const OpContext &ctx,
+                             const std::vector<NDArray> &inputs,
+                             const std::vector<OpReqType> &req,
+                             const std::vector<NDArray> &outputs);
+
 void MKLDNNSum(const mkldnn::memory& arr1, const mkldnn::memory& arr2, const mkldnn::memory& out);
 
 void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,