You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/04/08 04:58:17 UTC
[incubator-mxnet] branch master updated: [mkldnn] optimize for
mkldnn batchnorm backward (#17902)
This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 13841dd [mkldnn] optimize for mkldnn batchnorm backward (#17902)
13841dd is described below
commit 13841ddd2d6a53f9f0c22f527a0363d818489bd0
Author: rongzha1 <ro...@intel.com>
AuthorDate: Wed Apr 8 12:57:29 2020 +0800
[mkldnn] optimize for mkldnn batchnorm backward (#17902)
* optimize for backward batchnorm
* using memcpy instead of 'for' loop
* rm unnecessary pointer cast and add const for some variable
---
src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 50 ++++++++++++++------------
1 file changed, 27 insertions(+), 23 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 4de0bb3..d407d94 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -180,9 +180,10 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(float) * 2);
float* weight_ptr = gamma.data().dptr<float>();
float* bias_ptr = beta.data().dptr<float>();
+ const size_t copy_size = sizeof(weight_buf[0]) * channels_;
if (!param.fix_gamma) {
- memcpy(weight_buf, weight_ptr, sizeof(weight_buf[0]) * channels_);
- memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_);
+ memcpy(weight_buf, weight_ptr, copy_size);
+ memcpy(&weight_buf[channels_], bias_ptr, copy_size);
} else if (IsBNWriting(req[batchnorm::kGamma])) {
for (int i = 0; i < channels_; i++) {
weight_buf[i] = 1.0f;
@@ -332,17 +333,18 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const NDArray &beta = in_data[batchnorm::kBeta];
DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
nnvm::dim_t channels_ = data.shape()[1];
- for (int i = 0; i < channels_; i++) {
- if (!param.fix_gamma)
- weight_buf[i] = (gamma.data().dptr<DType>())[i]; // weight
- else
+ DType *weight_ptr = gamma.data().dptr<DType>();
+ DType* bias_ptr = beta.data().dptr<DType>();
+ const size_t copy_size = sizeof(DType) * channels_;
+ if (!param.fix_gamma) {
+ memcpy(weight_buf, weight_ptr, copy_size);
+ memcpy(&weight_buf[channels_], bias_ptr, copy_size);
+ } else {
+ for (int i = 0; i < channels_; i++) {
weight_buf[i] = static_cast<DType>(1.0f);
+ }
+ memcpy(&weight_buf[channels_], bias_ptr, copy_size);
}
-
- for (int i = 0; i < channels_; i++) {
- weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i]; // bias
- }
-
mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data_mem;
net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
@@ -352,10 +354,10 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
// training but no input mean and variance
if (ctx.is_train && !param.use_global_stats) {
- DType* moving_mean_ptr = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
- DType* moving_var_ptr = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
- DType* out_mean_ptr = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
- DType* out_var_ptr = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
+ DType* moving_mean_ptr = moving_mean.data().dptr<DType>();
+ DType* moving_var_ptr = moving_var.data().dptr<DType>();
+ DType* out_mean_ptr = out_mean.data().dptr<DType>();
+ DType* out_var_ptr = out_var.data().dptr<DType>();
mkldnn::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());
@@ -381,15 +383,17 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
// copy data from gradw_mem to in_grad[1] and in_grad[2]
DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
- for (int i = 0; i < channels_; i++) {
- if (!param.fix_gamma)
- (in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
- else
- (in_grad[1].data().dptr<DType>())[i] = 0.0f;
- }
+ DType *w_grad_1 = in_grad[1].data().dptr<DType>();
+ DType *w_grad_2 = in_grad[2].data().dptr<DType>();
- for (int i = 0; i < channels_; i++) {
- (in_grad[2].data().dptr<DType>())[i] = gw_buf[i + channels_];
+ if (!param.fix_gamma) {
+ memcpy(w_grad_1, gw_buf, copy_size);
+ memcpy(w_grad_2, &gw_buf[channels_], copy_size);
+ } else {
+ for (int i = 0; i < channels_; i++) {
+ (in_grad[1].data().dptr<DType>())[i] = 0.0f;
+ }
+ memcpy(w_grad_2, &gw_buf[channels_], copy_size);
}
} else {
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";