You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/08/24 14:46:35 UTC
[incubator-mxnet] branch master updated: [operator] Integrate
oneDNN layer normalization implementation (#19562)
This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 695ba2e [operator] Integrate oneDNN layer normalization implementation (#19562)
695ba2e is described below
commit 695ba2e5bd4506050bc768d5ff5964994fc5d534
Author: bartekkuncer <ba...@intel.com>
AuthorDate: Tue Aug 24 16:44:11 2021 +0200
[operator] Integrate oneDNN layer normalization implementation (#19562)
* [operator] Integrate oneDNN layer normalization implementation
* change sizeof(float) to mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype())
* remove eps from key and unify layernorm_fwd_t/mkldnn::layer_normalization_forward
* add author
---
src/operator/nn/layer_norm-inl.h | 26 ++-
src/operator/nn/layer_norm.cc | 66 +++++++
src/operator/nn/layer_norm.cu | 4 +
src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +
src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h | 1 +
src/operator/nn/mkldnn/mkldnn_batch_dot.cc | 1 +
src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h | 103 ++++++++++
src/operator/nn/mkldnn/mkldnn_layer_norm.cc | 260 +++++++++++++++++++++++++
src/operator/nn/mkldnn/mkldnn_ops-inl.h | 12 ++
9 files changed, 474 insertions(+), 1 deletion(-)
diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h
index 79d0906..b55ef29 100644
--- a/src/operator/nn/layer_norm-inl.h
+++ b/src/operator/nn/layer_norm-inl.h
@@ -45,7 +45,9 @@ namespace op {
namespace layernorm {
enum LayerNormOpInputs {kData, kGamma, kBeta}; // kGamma: scaling parameters, kBeta: shift biases
-enum LayerNormOpOutputs {kOut, kMean, kStd}; // req, out_data
+enum LayerNormOpOutputs {kOut, kMean, kStd}; // indices for req, out_data
+enum LayerNormOpInputsBwd {kBwdOutGrad, kBwdData, kBwdGamma, kBwdMean, kBwdStd, kBwdBeta};
+enum LayerNormOpOutputsBwd {kBwdDataGrad, kBwdGammaGrad, kBwdBetaGrad};
} // namespace layernorm
struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
@@ -71,6 +73,11 @@ struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
(*dict)["eps"] = eps_s.str();
(*dict)["output_mean_var"] = output_mean_var_s.str();
}
+
+ bool operator==(const LayerNormParam& other) const {
+ return (this->axis == other.axis && this->eps == other.eps &&
+ this->output_mean_var == other.output_mean_var);
+ }
};
inline int GetRealAxis(int axis, int ndim) {
@@ -257,7 +264,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
+#if MXNET_USE_ONEDNN == 1
+ CHECK_EQ(inputs.size(), 6U); // additional beta tensor
+#else
CHECK_EQ(inputs.size(), 5U);
+#endif
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
int axis = param.axis;
if (axis < 0) {
@@ -313,4 +324,17 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
} // namespace op
} // namespace mxnet
+
+namespace std {
+template <>
+struct hash<mxnet::op::LayerNormParam> {
+ size_t operator()(const mxnet::op::LayerNormParam& val) {
+ size_t ret = 0;
+ ret = dmlc::HashCombine(ret, val.axis);
+ ret = dmlc::HashCombine(ret, val.eps);
+ ret = dmlc::HashCombine(ret, val.output_mean_var);
+ return ret;
+ }
+};
+} // namespace std
#endif // MXNET_OPERATOR_NN_LAYER_NORM_INL_H_
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 1a040fa..4e8a80e 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -57,6 +57,10 @@
#include "layer_norm-inl.h"
#include <nnvm/op_attr_types.h>
#include "../elemwise_op_common.h"
+#if MXNET_USE_ONEDNN == 1
+#include "./mkldnn/mkldnn_base-inl.h"
+#include "./mkldnn/mkldnn_ops-inl.h"
+#endif // MXNET_USE_ONEDNN
#if MSHADOW_USE_MKL == 1
#include "../mkl_functions-inl.h"
@@ -392,6 +396,50 @@ void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
return LayerNormGradComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
}
+#if MXNET_USE_ONEDNN == 1
+static bool LayerNormInferStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK(!in_attrs->empty());
+
+ return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
+}
+
+static void LayerNormComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+ if (SupportMKLDNNLayerNorm(param, inputs)) {
+ MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ MKLDNNRun(MKLDNNLayerNormForward, attrs, ctx, inputs, req, outputs);
+ MKLDNN_OPCHECK_RUN(LayerNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ return;
+ } else {
+ FallBackCompute(LayerNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ }
+}
+
+static void LayerNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+ if (SupportMKLDNNLayerNorm(param, inputs)) {
+ MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
+ MKLDNNRun(MKLDNNLayerNormBackward, attrs, ctx, inputs, req, outputs);
+ MKLDNN_OPCHECK_RUN(LayerNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ return;
+ } else {
+ FallBackCompute(LayerNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ }
+}
+#endif
+
NNVM_REGISTER_OP(LayerNorm)
.add_alias("_npx_layer_norm")
.describe(R"code(Layer normalization.
@@ -439,6 +487,11 @@ axis to be the last item in the input shape.
.set_attr<mxnet::FInferShape>("FInferShape", LayerNormShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 3>)
.set_attr<FCompute>("FCompute<cpu>", LayerNormCompute<cpu>)
+#if MXNET_USE_ONEDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FInferStorageType>("FInferStorageType", LayerNormInferStorageType)
+.set_attr<FComputeEx>("FComputeEx<cpu>", LayerNormComputeExCPU)
+#endif
.set_attr<nnvm::FGradient>("FGradient", [](const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads;
@@ -447,6 +500,10 @@ axis to be the last item in the input shape.
heads.push_back(n->inputs[1]); // gamma
heads.emplace_back(n, 1, 0); // mean
heads.emplace_back(n, 2, 0); // std
+#if MXNET_USE_ONEDNN == 1
+ heads.push_back(n->inputs[2]); // beta - needed for MKLDNN backward propagation;
+ // added at the end in case of fallback to non MKLDNN version
+#endif
return MakeGradNode("_backward_LayerNorm", n, heads, n->attrs.dict);
})
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
@@ -464,11 +521,20 @@ axis to be the last item in the input shape.
NNVM_REGISTER_OP(_backward_LayerNorm)
+#if MXNET_USE_ONEDNN == 1
+.set_num_inputs(6)
+#else
.set_num_inputs(5)
+#endif
.set_num_outputs(3)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<LayerNormParam>)
.set_attr<FCompute>("FCompute<cpu>", LayerNormGradCompute<cpu>)
+#if MXNET_USE_ONEDNN == 1
+.set_attr<FInferStorageType>("FInferStorageType", LayerNormInferStorageType)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FComputeEx>("FComputeEx<cpu>", LayerNormGradComputeExCPU)
+#endif
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
});
diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu
index 9a33e06..1e719ed 100644
--- a/src/operator/nn/layer_norm.cu
+++ b/src/operator/nn/layer_norm.cu
@@ -689,7 +689,11 @@ void LayerNormGradGPUContig(const LayerNormParam param,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
+#if MXNET_USE_ONEDNN == 1
+ CHECK_EQ(inputs.size(), 6U); // additional beta tensor
+#else
CHECK_EQ(inputs.size(), 5U);
+#endif
const TBlob out_grad = inputs[0];
const TBlob in_data = inputs[1];
const TBlob gamma = inputs[2];
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 2cef524..2ee0793 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -202,6 +202,7 @@ struct SoftmaxParam;
struct SoftmaxOutputParam;
struct TransposeParam;
struct ReshapeParam;
+struct LayerNormParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray& input);
bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param);
@@ -216,6 +217,7 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param,
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam& param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray& data);
bool SupportMKLDNNBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
+bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector<NDArray> &inputs);
} // namespace op
static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
index 34c3eb9..2459ea1 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
@@ -19,6 +19,7 @@
/*!
* \file mkldnn_batch_dot-inl.h
+ * \author: Bartosz Kuncer, bartosz.kuncer@intel.com
*/
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
index f7c93ef..87ddb98 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
+++ b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
@@ -19,6 +19,7 @@
/*!
* \file mkldnn_batch_dot.cc
+ * \author: Bartosz Kuncer, bartosz.kuncer@intel.com
*/
#if MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h
new file mode 100644
index 0000000..a14673b
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_layer_norm-inl.h
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_layer_norm-inl.h
+ * \author: Bartosz Kuncer, bartosz.kuncer@intel.com
+ */
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <utility>
+#include <vector>
+
+#include "../layer_norm-inl.h"
+#include "./mkldnn_base-inl.h"
+#include "./mkldnn_ops-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using layernorm_fwd_t = mkldnn::layer_normalization_forward;
+using layernorm_fwd_pd_t = mkldnn::layer_normalization_forward::primitive_desc;
+
+using layernorm_bwd_t = mkldnn::layer_normalization_backward;
+using layernorm_bwd_pd_t = mkldnn::layer_normalization_backward::primitive_desc;
+
+typedef ParamOpSign<LayerNormParam> LayerNormSignature;
+
+class MKLDNNLayerNormFwd {
+ public:
+ static MKLDNNLayerNormFwd& GetCached(const LayerNormParam& param,
+ const OpContext& ctx,
+ const NDArray& data);
+
+ MKLDNNLayerNormFwd(const LayerNormParam& param, const NDArray& data);
+
+ static std::shared_ptr<layernorm_fwd_pd_t> CreatePrimitiveDesc(
+ const LayerNormParam& param,
+ const mkldnn::memory::desc& src_md);
+
+ void Execute(const LayerNormParam& param,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const OpReqType& req,
+ const std::vector<NDArray>& outputs) const;
+
+ ~MKLDNNLayerNormFwd() {}
+
+ private:
+ std::shared_ptr<layernorm_fwd_t> fwd;
+ std::shared_ptr<layernorm_fwd_pd_t> fwd_pd;
+};
+
+class MKLDNNLayerNormBwd {
+ public:
+ static MKLDNNLayerNormBwd& GetCached(const LayerNormParam& param,
+ const std::vector<NDArray>& inputs);
+
+ MKLDNNLayerNormBwd(const LayerNormParam& param,
+ const std::vector<NDArray>& inputs,
+ const mkldnn::memory::desc& data_md,
+ const mkldnn::memory::desc& diff_md);
+
+ static std::shared_ptr<layernorm_bwd_pd_t> CreatePrimitiveDesc(
+ const LayerNormParam& param,
+ const mkldnn::memory::desc& data_md,
+ const mkldnn::memory::desc& diff_md,
+ const layernorm_fwd_pd_t& layernorm_fwd_pd);
+
+ void Execute(const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs,
+ const std::vector<OpReqType>& req) const;
+
+ ~MKLDNNLayerNormBwd() {}
+
+ private:
+ std::shared_ptr<layernorm_bwd_t> bwd;
+ std::shared_ptr<layernorm_fwd_pd_t> fwd_pd;
+ std::shared_ptr<layernorm_bwd_pd_t> bwd_pd;
+};
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LAYER_NORM_INL_H__
diff --git a/src/operator/nn/mkldnn/mkldnn_layer_norm.cc b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc
new file mode 100644
index 0000000..8b8e122
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_layer_norm.cc
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_layer_norm.cc
+ * \author: Bartosz Kuncer, bartosz.kuncer@intel.com
+ */
+
+#if MXNET_USE_ONEDNN == 1
+
+#include "./mkldnn_layer_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs) {
+ const mxnet::TShape& shape = inputs[layernorm::kData].shape();
+
+ // Native implementation (which can be found in function LayerNormCPU) is faster than oneDNN's one
+ // for small tensors. Below is the heuristic based on measurements on clx machine deciding whether
+ // the shape is better for oneDNN or native implementation.
+ auto ShapeBetterForMKLDNN = [](const mxnet::TShape& shape) {
+ constexpr size_t shapeLimit = 1024;
+ return shape.Size() / shape[0] >= shapeLimit && shape[0] >= shapeLimit;
+ };
+
+ return (ShapeBetterForMKLDNN(shape) &&
+ (GetRealAxis(param.axis, shape.ndim()) == shape.ndim() - 1) && (shape.ndim() >= 2) &&
+ (shape.ndim() <= 5) &&
+ (inputs[layernorm::kData].dtype() == mshadow::kFloat32 ||
+ inputs[layernorm::kData].dtype() == mshadow::kBfloat16) &&
+ inputs[layernorm::kGamma].dtype() == mshadow::kFloat32 &&
+ inputs[layernorm::kBeta].dtype() == mshadow::kFloat32);
+}
+
+void MKLDNNLayerNormForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+ const auto& fwd = MKLDNNLayerNormFwd::GetCached(param, ctx, inputs[layernorm::kData]);
+ fwd.Execute(param, ctx, inputs, req[layernorm::kOut], outputs);
+}
+
+MKLDNNLayerNormFwd& MKLDNNLayerNormFwd::GetCached(const LayerNormParam& param,
+ const OpContext& ctx,
+ const NDArray& data) {
+ using layernorm_fwd_map = std::unordered_map<LayerNormSignature, MKLDNNLayerNormFwd, OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local layernorm_fwd_map layer_norm_fwds;
+#else
+ static MX_THREAD_LOCAL layernorm_fwd_map layer_norm_fwds;
+#endif
+
+ LayerNormSignature key(param);
+ key.AddSign(data);
+
+ auto it = layer_norm_fwds.find(key);
+ if (it == layer_norm_fwds.end()) {
+ MKLDNNLayerNormFwd fwd(param, data);
+ it = AddToCache(&layer_norm_fwds, key, fwd);
+ }
+ return it->second;
+}
+
+MKLDNNLayerNormFwd::MKLDNNLayerNormFwd(const LayerNormParam& param, const NDArray& data) {
+ const mkldnn::memory::desc data_md = data.GetMKLDNNData()->get_desc();
+ fwd_pd = CreatePrimitiveDesc(param, data_md);
+ fwd = std::make_shared<layernorm_fwd_t>(*fwd_pd);
+}
+
+std::shared_ptr<layernorm_fwd_pd_t> MKLDNNLayerNormFwd::CreatePrimitiveDesc(
+ const LayerNormParam& param,
+ const mkldnn::memory::desc& src_md) {
+ layernorm_fwd_t::desc fwd_desc(mkldnn::prop_kind::forward_training,
+ src_md,
+ param.eps,
+ dnnl::normalization_flags::use_scale_shift);
+ mkldnn::engine& engine = CpuEngine::Get()->get_engine();
+ return std::make_shared<layernorm_fwd_pd_t>(fwd_desc, engine);
+}
+
+inline mkldnn::memory::desc GetMeanVarDesc(const mkldnn::memory::data_type& dtype,
+ const mxnet::TShape& _shape) {
+ const auto ndim = _shape.ndim();
+
+ mkldnn::memory::dims shape(ndim, 1), strides(ndim, 1);
+ shape[0] = _shape[0];
+ for (int i = ndim - 1; i > 0; --i) {
+ shape[i] = _shape[i];
+ strides[i - 1] = strides[i] * shape[i];
+ }
+
+ return mkldnn::memory::desc{shape, dtype, strides};
+}
+
+inline mkldnn::memory GetScaleShiftMem(const NDArray& gamma, const NDArray& beta) {
+ // OneDNN takes gamma and beta as one SCALE_SHIFT tensor when both scale and shift are used. In
+ // mxnet scale is called gamma and shift is called beta.
+ constexpr size_t gammaAndBeta = 2;
+ CHECK_EQ(gamma.shape()[0], beta.shape()[0]);
+ const mkldnn::memory::desc scale_shift_md(mkldnn::memory::dims{gammaAndBeta, gamma.shape()[0]},
+ get_mkldnn_type(gamma.dtype()),
+ mkldnn::memory::format_tag::nc);
+ auto scale_shift_mem = mkldnn::memory(scale_shift_md, CpuEngine::Get()->get_engine());
+ char* ptr = reinterpret_cast<char*>(scale_shift_mem.get_data_handle());
+ const size_t bytes = scale_shift_md.get_size() / gammaAndBeta;
+ memcpy(ptr, gamma.data().dptr_, bytes);
+ memcpy(ptr + bytes, beta.data().dptr_, bytes);
+ return scale_shift_mem;
+}
+
+void MKLDNNLayerNormFwd::Execute(const LayerNormParam& param,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const OpReqType& req,
+ const std::vector<NDArray>& outputs) const {
+ auto mean_var_md = GetMeanVarDesc(get_mkldnn_type(outputs[layernorm::kMean].dtype()),
+ outputs[layernorm::kMean].shape());
+ auto mean_mem = mkldnn_output_t(
+ OutDataOp::Noop,
+ const_cast<NDArray&>(outputs[layernorm::kMean]).CreateMKLDNNData(mean_var_md));
+ auto variance_mem =
+ mkldnn_output_t(OutDataOp::Noop,
+ const_cast<NDArray&>(outputs[layernorm::kStd]).CreateMKLDNNData(mean_var_md));
+
+ auto output_mem = CreateMKLDNNMem(outputs[layernorm::kOut], fwd_pd->dst_desc(), req);
+ auto scale_shift_mem = GetScaleShiftMem(inputs[layernorm::kGamma], inputs[layernorm::kBeta]);
+
+ mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *inputs[layernorm::kData].GetMKLDNNData()},
+ {MKLDNN_ARG_DST, *output_mem.second},
+ {MKLDNN_ARG_MEAN, *mean_mem.second},
+ {MKLDNN_ARG_VARIANCE, *variance_mem.second},
+ {MKLDNN_ARG_SCALE_SHIFT, scale_shift_mem}};
+
+ MKLDNNStream::Get()->RegisterPrimArgs(*fwd, args);
+ CommitOutput(outputs[layernorm::kOut], output_mem);
+ CommitOutput(outputs[layernorm::kMean], mean_mem);
+ CommitOutput(outputs[layernorm::kStd], variance_mem);
+ MKLDNNStream::Get()->Submit();
+}
+
+MKLDNNLayerNormBwd::MKLDNNLayerNormBwd(const LayerNormParam& param,
+ const std::vector<NDArray>& inputs,
+ const mkldnn::memory::desc& data_md,
+ const mkldnn::memory::desc& diff_md)
+ : fwd_pd(MKLDNNLayerNormFwd::CreatePrimitiveDesc(param, data_md)),
+ bwd_pd(CreatePrimitiveDesc(param, data_md, diff_md, *fwd_pd)) {
+ bwd = std::make_shared<layernorm_bwd_t>(*bwd_pd);
+}
+
+std::shared_ptr<layernorm_bwd_pd_t> MKLDNNLayerNormBwd::CreatePrimitiveDesc(
+ const LayerNormParam& param,
+ const mkldnn::memory::desc& data_md,
+ const mkldnn::memory::desc& diff_md,
+ const layernorm_fwd_pd_t& layernorm_fwd_pd) {
+ layernorm_bwd_t::desc layernorm_bwd_desc(dnnl::prop_kind::backward,
+ diff_md,
+ data_md,
+ param.eps,
+ dnnl::normalization_flags::use_scale_shift);
+ mkldnn::engine& engine = CpuEngine::Get()->get_engine();
+ return std::make_shared<layernorm_bwd_pd_t>(layernorm_bwd_desc, engine, layernorm_fwd_pd);
+}
+
+void MKLDNNLayerNormBwd::Execute(const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs,
+ const std::vector<OpReqType>& req) const {
+ auto scale_shift_mem =
+ GetScaleShiftMem(inputs[layernorm::kBwdGamma], inputs[layernorm::kBwdBeta]);
+ auto diff_weights_ndarray = NDArray(scale_shift_mem.get_desc());
+ const auto bytes = inputs[layernorm::kBwdGamma].shape()[0] *
+ mshadow::mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype());
+ const auto diff_weights_ndaray_data_ptr_plus_bytes = reinterpret_cast<void*>(
+ reinterpret_cast<std::uintptr_t>(diff_weights_ndarray.data().dptr_) + bytes);
+ if (req[layernorm::kBwdGammaGrad] == kAddTo) {
+ memcpy(
+ diff_weights_ndarray.data().dptr_, outputs[layernorm::kBwdGammaGrad].data().dptr_, bytes);
+ memcpy(diff_weights_ndaray_data_ptr_plus_bytes,
+ outputs[layernorm::kBwdBetaGrad].data().dptr_,
+ bytes);
+ }
+ mkldnn_output_t diff_src_mem = CreateMKLDNNMem(
+ outputs[layernorm::kBwdDataGrad], bwd_pd->diff_src_desc(), req[layernorm::kBwdDataGrad]);
+ mkldnn_output_t diff_weights_mem = CreateMKLDNNMem(
+ diff_weights_ndarray, bwd_pd->diff_weights_desc(), req[layernorm::kBwdGammaGrad]);
+ mkldnn_args_map_t args = {{MKLDNN_ARG_DIFF_DST, *inputs[layernorm::kBwdOutGrad].GetMKLDNNData()},
+ {MKLDNN_ARG_SRC, *inputs[layernorm::kBwdData].GetMKLDNNData()},
+ {MKLDNN_ARG_SCALE_SHIFT, scale_shift_mem},
+ {MKLDNN_ARG_MEAN, *inputs[layernorm::kBwdMean].GetMKLDNNData()},
+ {MKLDNN_ARG_VARIANCE, *inputs[layernorm::kBwdStd].GetMKLDNNData()},
+ {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second},
+ {MKLDNN_ARG_DIFF_SCALE_SHIFT, *diff_weights_mem.second}};
+ MKLDNNStream::Get()->RegisterPrimArgs(*bwd, args);
+ CommitOutput(outputs[layernorm::kBwdDataGrad], diff_src_mem);
+ CommitOutput(diff_weights_ndarray, diff_weights_mem);
+ MKLDNNStream::Get()->Submit();
+ // Commit scale_shift diff
+ memcpy(outputs[layernorm::kBwdGammaGrad].data().dptr_, diff_weights_ndarray.data().dptr_, bytes);
+ memcpy(outputs[layernorm::kBwdBetaGrad].data().dptr_,
+ diff_weights_ndaray_data_ptr_plus_bytes,
+ bytes);
+}
+
+MKLDNNLayerNormBwd& MKLDNNLayerNormBwd::GetCached(const LayerNormParam& param,
+ const std::vector<NDArray>& inputs) {
+ using layernorm_bwd_map = std::unordered_map<LayerNormSignature, MKLDNNLayerNormBwd, OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local layernorm_bwd_map layer_norm_bwds;
+#else
+ static MX_THREAD_LOCAL layernorm_bwd_map layer_norm_bwds;
+#endif
+ LayerNormSignature key(param);
+ key.AddSign(inputs[layernorm::kBwdOutGrad]);
+ key.AddSign(inputs[layernorm::kBwdData]);
+ key.AddSign(inputs[layernorm::kBwdGamma]);
+ key.AddSign(inputs[layernorm::kBwdMean]);
+ key.AddSign(inputs[layernorm::kBwdStd]);
+ key.AddSign(inputs[layernorm::kBwdBeta]);
+
+ auto it = layer_norm_bwds.find(key);
+ if (it == layer_norm_bwds.end()) {
+ const mkldnn::memory::desc data_md = inputs[layernorm::kBwdData].GetMKLDNNData()->get_desc();
+ const mkldnn::memory::desc diff_md = inputs[layernorm::kBwdOutGrad].GetMKLDNNData()->get_desc();
+ MKLDNNLayerNormBwd bwd(param, inputs, data_md, diff_md);
+ it = AddToCache(&layer_norm_bwds, key, bwd);
+ }
+ return it->second;
+}
+
+void MKLDNNLayerNormBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
+ MKLDNNLayerNormBwd& bwd = MKLDNNLayerNormBwd::GetCached(param, inputs);
+ bwd.Execute(inputs, outputs, req);
+}
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 3b8c39f..44a6b8f 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -165,6 +165,18 @@ void MKLDNNBatchDotForward(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
+/* For layer normalization */
+void MKLDNNLayerNormForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray> &outputs);
+void MKLDNNLayerNormBackward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
void MKLDNNSum(const mkldnn::memory& arr1, const mkldnn::memory& arr2, const mkldnn::memory& out);
void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,