You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/09/23 14:04:13 UTC
[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add
MKL-DNN BN (#16199)
This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
new 0b8805a [mkldnn-v1.0] Add MKL-DNN BN (#16199)
0b8805a is described below
commit 0b8805afc963296ecfeb0599e6c63858ca6a85b0
Author: rongzha1 <ro...@intel.com>
AuthorDate: Mon Sep 23 22:03:04 2019 +0800
[mkldnn-v1.0] Add MKL-DNN BN (#16199)
* add mkldnn bn
* add static_cast to transform data type
* change mkldnn_args_map_t
* retrigger CI
---
src/operator/nn/batch_norm.cc | 14 +-
src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 232 ++++++++-----------------
2 files changed, 81 insertions(+), 165 deletions(-)
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 3214e3b..da042a1 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -28,7 +28,7 @@
#include <nnvm/op_attr_types.h>
#include "../elemwise_op_common.h"
#include "../operator_common.h"
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include "./mkldnn/mkldnn_batch_norm-inl.h"
#endif
@@ -379,7 +379,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
return true;
}
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) {
mxnet::TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
@@ -454,7 +454,7 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
bool dispatched = false;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
if (!dispatched) {
dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
in_attrs, out_attrs);
@@ -592,11 +592,11 @@ then set ``gamma`` to 1 and its gradient to 0.
.set_attr<nnvm::FInferType>("FInferType", BatchNormType)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
.set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -623,13 +623,13 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
.set_num_outputs(3)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr_parser(ParamParser<BatchNormParam>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormGradComputeExCPU)
#endif
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 61de08f..ef5886e 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -26,7 +26,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include <vector>
#include <utility>
#include <mkldnn.hpp>
@@ -44,54 +44,44 @@ typedef mkldnn::batch_normalization_forward::desc t_bn_f_desc;
typedef mkldnn::batch_normalization_backward::primitive_desc t_bn_b_pdesc;
typedef mkldnn::batch_normalization_backward::desc t_bn_b_desc;
-using mkldnn::use_global_stats;
-using mkldnn::use_scale_shift;
-using mkldnn::forward_training;
-using mkldnn::forward_inference;
-
-inline static unsigned _GetFlags(const std::vector<NDArray> &in_data,
+inline static mkldnn::normalization_flags _GetFlags(const std::vector<NDArray> &in_data,
const std::vector<NDArray> &aux_states,
const BatchNormParam ¶m, bool is_train_and_not_global_stats) {
- unsigned flags = 0U;
+ mkldnn::normalization_flags flags = static_cast<mkldnn::normalization_flags>(0U);
if (in_data.size() == 3U) {
- flags |= use_scale_shift;
+ flags |= mkldnn::normalization_flags::use_scale_shift;
}
// aux_states[0]: inMean
// aux_states[1]: inVariance
if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
- flags |= use_global_stats;
+ flags |= mkldnn::normalization_flags::use_global_stats;
}
return flags;
}
-template <typename DType>
inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem,
bool is_train,
- DType eps,
- unsigned flags) {
- auto data_mpd = data_mem.get_primitive_desc();
- auto data_md = data_mpd.desc();
- auto engine = CpuEngine::Get()->get_engine();
+ float eps,
+ mkldnn::normalization_flags flags) {
+ auto data_md = data_mem.get_desc();
+ auto engine = CpuEngine::Get()->get_engine();
if (is_train) {
- t_bn_f_desc bnFwd_desc(forward_training, data_md, eps, flags);
+ t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_training, data_md, eps, flags);
return t_bn_f_pdesc(bnFwd_desc, engine);
} else {
- t_bn_f_desc bnFwd_desc(forward_inference, data_md, eps, flags);
+ t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_inference, data_md, eps, flags);
return t_bn_f_pdesc(bnFwd_desc, engine);
}
}
-template <typename DType>
inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem,
const mkldnn::memory &diff_mem,
- DType eps,
- unsigned flags) {
- auto data_mpd = data_mem.get_primitive_desc();
- auto data_md = data_mpd.desc();
- auto diff_mpd = diff_mem.get_primitive_desc();
- auto diff_md = diff_mpd.desc();
+ float eps,
+ mkldnn::normalization_flags flags) {
+ auto data_md = data_mem.get_desc();
+ auto diff_md = diff_mem.get_desc();
auto engine = CpuEngine::Get()->get_engine();
t_bn_b_desc bnBwd_desc(mkldnn::prop_kind::backward, diff_md, data_md, eps, flags);
@@ -101,18 +91,15 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem,
typedef ParamOpSign<BatchNormParam> MKLDNNBNSignature;
class MKLDNNBNForward {
- std::shared_ptr<const mkldnn::memory> data_m;
std::shared_ptr<const mkldnn::memory> weight_m;
- std::shared_ptr<const mkldnn::memory> out_m;
- std::shared_ptr<const mkldnn::memory> mean_m;
- std::shared_ptr<const mkldnn::memory> var_m;
std::shared_ptr<mkldnn::batch_normalization_forward> fwd;
bool is_train_and_not_global_stats;
t_bn_f_pdesc pd;
public:
MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) {
- weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc()));
+ weight_m.reset(new mkldnn::memory(pd.weights_desc(), CpuEngine::Get()->get_engine()));
+ fwd.reset(new mkldnn::batch_normalization_forward(pd));
this->is_train_and_not_global_stats = is_train_and_not_global_stats;
}
@@ -124,59 +111,6 @@ class MKLDNNBNForward {
return pd;
}
- const mkldnn::memory &GetMean() const {
- return *mean_m;
- }
-
- const mkldnn::memory &GetVar() const {
- return *var_m;
- }
-
- void SetDataHandle(const mkldnn::memory *data, const mkldnn::memory *mean,
- const mkldnn::memory *var, const mkldnn::memory *out) {
- if (data_m) {
- data_m->set_data_handle(data->get_data_handle());
- } else {
- data_m.reset(new mkldnn::memory(data->get_primitive_desc(),
- data->get_data_handle()));
- }
- if (out_m) {
- out_m->set_data_handle(out->get_data_handle());
- } else {
- out_m.reset(new mkldnn::memory(out->get_primitive_desc(),
- out->get_data_handle()));
- }
- if (mean_m) {
- mean_m->set_data_handle(mean->get_data_handle());
- } else {
- mean_m.reset(new mkldnn::memory(mean->get_primitive_desc(),
- mean->get_data_handle()));
- }
- if (var_m) {
- var_m->set_data_handle(var->get_data_handle());
- } else {
- var_m.reset(new mkldnn::memory(var->get_primitive_desc(),
- var->get_data_handle()));
- }
-
- if (fwd == nullptr) {
- if (!is_train_and_not_global_stats)
- fwd.reset(new mkldnn::batch_normalization_forward(
- pd, *data_m, mkldnn::primitive::at(*mean_m),
- mkldnn::primitive::at(*var_m), *weight_m, *out_m));
- else
- fwd.reset(new mkldnn::batch_normalization_forward(
- pd, mkldnn::primitive::at(*data_m),
- mkldnn::primitive::at(*weight_m), *out_m,
- *mean_m, *var_m));
- }
- }
-
- void SetDataHandle(const NDArray &data, const NDArray &mean,
- const NDArray &var, const mkldnn::memory &out) {
- SetDataHandle(data.GetMKLDNNData(), mean.GetMKLDNNData(), var.GetMKLDNNData(), &out);
- }
-
const mkldnn::batch_normalization_forward &GetFwd() const {
return *fwd;
}
@@ -185,7 +119,7 @@ class MKLDNNBNForward {
template<typename DType>
static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
const OpContext &ctx, const mkldnn::memory *data_mem,
- unsigned flags) {
+ mkldnn::normalization_flags flags) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNForward, OpHash> fwds;
#else
@@ -193,13 +127,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
#endif
MKLDNNBNSignature key(param);
key.AddSign(ctx.is_train);
- key.AddSign(param.use_global_stats);
key.AddSign(*data_mem);
auto it = fwds.find(key);
if (it == fwds.end()) {
auto fwd_pd = _GetFwd(*data_mem, ctx.is_train,
- (DType) param.eps, flags);
+ param.eps, flags);
MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
it = AddToCache(&fwds, key, fwd);
}
@@ -209,7 +142,7 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
template<typename DType>
static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
const OpContext &ctx, const NDArray &in_data,
- unsigned flags) {
+ mkldnn::normalization_flags flags) {
return GetBNForward<DType>(param, ctx, in_data.GetMKLDNNData(), flags);
}
@@ -220,18 +153,20 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m,
const std::vector<NDArray> &out_data,
const std::vector<NDArray> &aux_states) {
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
- unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);
+ mkldnn::normalization_flags flags = _GetFlags(in_data,
+ aux_states,
+ param,
+ ctx.is_train && !param.use_global_stats);
const NDArray &data = in_data[batchnorm::kData];
-
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
- const NDArray &out = out_data[batchnorm::kOut];
+ const NDArray &out = out_data[batchnorm::kOut];
// for output memory
- auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc());
+ auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
// mxnet will always use scale shift.
// But if fix_gamma is true, then all scale elements will be set to 1.0f
- if (flags & use_scale_shift) {
+ if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
const NDArray &gamma = in_data[batchnorm::kGamma];
const NDArray &beta = in_data[batchnorm::kBeta];
CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
@@ -241,7 +176,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m,
DType* weight_buf = reinterpret_cast<DType *>(weight_mem.get_data_handle());
nnvm::dim_t channels_ = data.shape()[1];
- CHECK(weight_mem.get_primitive_desc().get_size() == channels_ * sizeof(DType) * 2);
+ CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(DType) * 2);
DType* weight_ptr = gamma.data().dptr<DType>();
DType* bias_ptr = beta.data().dptr<DType>();
if (!param.fix_gamma) {
@@ -249,17 +184,22 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m,
memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_);
} else if (IsBNWriting(req[batchnorm::kGamma])) {
for (int i = 0; i < channels_; i++) {
- weight_buf[i] = (DType)1.0f;
- weight_ptr[i] = (DType)1.0f;
+ weight_buf[i] = static_cast<DType>(1.0f);
+ weight_ptr[i] = static_cast<DType>(1.0f);
weight_buf[channels_ + i] = bias_ptr[i]; // bias
}
} else {
for (int i = 0; i < channels_; i++) {
- weight_buf[i] = (DType)1.0f;
+ weight_buf[i] = static_cast<DType>(1.0f);
weight_buf[channels_ + i] = bias_ptr[i]; // bias
}
}
+ mkldnn_args_map_t net_args;
+ net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData();
+ net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem;
+ net_args[MKLDNN_ARG_DST] = *out_mem;
+
if (!ctx.is_train || param.use_global_stats) {
DType* omean = out_data[batchnorm::kMean].data().dptr<DType>();
DType* ovar = out_data[batchnorm::kVar].data().dptr<DType>();
@@ -270,26 +210,21 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m,
omean[i] = inmean[i];
ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps);
}
-
- fwd.SetDataHandle(data, aux_states[batchnorm::kMovingMean],
- aux_states[batchnorm::kMovingVar],
- *out_mem);
- MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+ net_args[MKLDNN_ARG_MEAN] = *(aux_states[batchnorm::kMovingMean].GetMKLDNNData());
+ net_args[MKLDNN_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetMKLDNNData());
+ MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
MKLDNNStream::Get()->Submit();
} else { // training
const NDArray &outMean = out_data[batchnorm::kMean];
const NDArray &outVar = out_data[batchnorm::kVar];
- DType* omean = outMean.data().dptr<DType>();
- DType* ovar = outVar.data().dptr<DType>();
-
- fwd.SetDataHandle(data, outMean, outVar, *out_mem);
- MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+ net_args[MKLDNN_ARG_MEAN] = *(outMean.GetMKLDNNData());
+ net_args[MKLDNN_ARG_VARIANCE] = *(outVar.GetMKLDNNData());
+ MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
MKLDNNStream::Get()->Submit();
- DType* mean_mem_ptr = reinterpret_cast<DType*>(fwd.GetMean().get_data_handle());
- DType* var_mem_ptr = reinterpret_cast<DType*>(fwd.GetVar().get_data_handle());
+
+ DType* ovar = outVar.data().dptr<DType>();
for (int i = 0; i < channels_; i++) {
- omean[i] = mean_mem_ptr[i];
- ovar[i] = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps);
+ ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps);
}
}
} else { // no input gamma and beta
@@ -299,11 +234,6 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m,
class MKLDNNBNBackward {
std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
- std::shared_ptr<mkldnn::memory> data_m;
- std::shared_ptr<mkldnn::memory> diff_m;
- std::shared_ptr<mkldnn::memory> gradi_m;
- std::shared_ptr<mkldnn::memory> mean_m;
- std::shared_ptr<mkldnn::memory> var_m;
const std::shared_ptr<mkldnn::memory> weight_m;
const std::shared_ptr<mkldnn::memory> gradw_m;
@@ -311,41 +241,16 @@ class MKLDNNBNBackward {
const t_bn_b_pdesc pd;
explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd)
- : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())),
- gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())),
- pd(_pd) {}
+ : weight_m(new mkldnn::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())),
+ gradw_m(new mkldnn::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())),
+ pd(_pd) {
+ bwd.reset(new mkldnn::batch_normalization_backward(pd));
+ }
const mkldnn::memory &GetWeight() const { return *weight_m; }
const mkldnn::memory &GetGradw() const { return *gradw_m; }
- void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff,
- const NDArray &mean, const mkldnn::memory &var,
- const mkldnn::memory &gradi) {
- auto mean_ptr = mean.data().dptr_;
- if (bwd == nullptr) {
- data_m.reset(new mkldnn::memory(data.get_primitive_desc(),
- data.get_data_handle()));
- diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(),
- diff.get_data_handle()));
- gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(),
- gradi.get_data_handle()));
- mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr));
- var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
- var.get_data_handle()));
- bwd.reset(new mkldnn::batch_normalization_backward(
- pd, *data_m, mkldnn::primitive::at(*mean_m),
- mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m,
- *gradw_m));
- } else {
- data_m->set_data_handle(data.get_data_handle());
- diff_m->set_data_handle(diff.get_data_handle());
- gradi_m->set_data_handle(gradi.get_data_handle());
- mean_m->set_data_handle(mean_ptr);
- var_m->set_data_handle(var.get_data_handle());
- }
- }
-
const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; }
};
@@ -353,7 +258,7 @@ template <typename DType>
static MKLDNNBNBackward &GetBNBackward(
const BatchNormParam ¶m, const OpContext &ctx, const NDArray &in_data,
const mkldnn::memory &in_mem, const NDArray &diff_data,
- const mkldnn::memory &diff_mem, unsigned flags) {
+ const mkldnn::memory &diff_mem, mkldnn::normalization_flags flags) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
#else
@@ -385,7 +290,10 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m,
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);
CHECK_EQ(in_grad.size(), 3U);
- unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);
+ mkldnn::normalization_flags flags = _GetFlags(in_data,
+ aux_states,
+ param,
+ ctx.is_train && !param.use_global_stats);
const NDArray &data = in_data[batchnorm::kData];
const NDArray &diff = out_grad[batchnorm::kOut];
@@ -405,13 +313,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m,
// MKLDNN batchnorm should run on special layouts. If one of them isn't, we
// should reorder them.
if (data.IsDefaultData())
- data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc());
+ data_mem = data.GetMKLDNNDataReorder(diff_mem->get_desc());
else if (diff.IsDefaultData())
- diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc());
+ diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
- auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc());
+ auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_desc());
- if (flags & use_scale_shift) {
+ if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
const NDArray &gamma = in_data[batchnorm::kGamma];
const NDArray &beta = in_data[batchnorm::kBeta];
DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
@@ -420,20 +328,27 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m,
if (!param.fix_gamma)
weight_buf[i] = (gamma.data().dptr<DType>())[i]; // weight
else
- weight_buf[i] = (DType)1.0f;
+ weight_buf[i] = static_cast<DType>(1.0f);
}
for (int i = 0; i < channels_; i++) {
weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i]; // bias
}
+ mkldnn_args_map_t net_args;
+ net_args[MKLDNN_ARG_SRC] = *data_mem;
+ net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
+ net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight();
+ net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
+ net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem;
+
// training but no input mean and variance
if (ctx.is_train && !param.use_global_stats) {
DType* moving_mean_ptr = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
DType* moving_var_ptr = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
DType* out_mean_ptr = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
DType* out_var_ptr = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
- mkldnn::memory var_mem(bwd.pd.variance_primitive_desc());
+ mkldnn::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());
DType minus_mom = (1.0f - param.momentum);
@@ -445,13 +360,14 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m,
moving_var_ptr[i] = moving_var_ptr[i] * param.momentum +
variance * minus_mom;
}
- bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem);
- MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
+ net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData());
+ net_args[MKLDNN_ARG_VARIANCE] = var_mem;
+ MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
MKLDNNStream::Get()->Submit();
} else {
- bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean,
- *moving_var.GetMKLDNNData(), *gradi_mem);
- MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
+ net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData());
+ net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
+ MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
MKLDNNStream::Get()->Submit();
}