You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/04/08 04:58:17 UTC

[incubator-mxnet] branch master updated: [mkldnn] optimize for mkldnn batchnorm backward (#17902)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 13841dd  [mkldnn] optimize for mkldnn batchnorm backward (#17902)
13841dd is described below

commit 13841ddd2d6a53f9f0c22f527a0363d818489bd0
Author: rongzha1 <ro...@intel.com>
AuthorDate: Wed Apr 8 12:57:29 2020 +0800

    [mkldnn] optimize for mkldnn batchnorm backward (#17902)
    
    * optimize for backward batchnorm
    
    * using memcpy instead of 'for' loop
    
    * rm unnecessary pointer cast and add const for some variable
---
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 50 ++++++++++++++------------
 1 file changed, 27 insertions(+), 23 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 4de0bb3..d407d94 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -180,9 +180,10 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
     CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(float) * 2);
     float* weight_ptr = gamma.data().dptr<float>();
     float* bias_ptr = beta.data().dptr<float>();
+    const size_t copy_size = sizeof(weight_buf[0]) * channels_;
     if (!param.fix_gamma) {
-      memcpy(weight_buf, weight_ptr, sizeof(weight_buf[0]) * channels_);
-      memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_);
+      memcpy(weight_buf, weight_ptr, copy_size);
+      memcpy(&weight_buf[channels_], bias_ptr, copy_size);
     } else if (IsBNWriting(req[batchnorm::kGamma])) {
       for (int i = 0; i < channels_; i++) {
         weight_buf[i] = 1.0f;
@@ -332,17 +333,18 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
     const NDArray &beta     = in_data[batchnorm::kBeta];
     DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
     nnvm::dim_t channels_ = data.shape()[1];
-    for (int i = 0; i < channels_; i++) {
-      if (!param.fix_gamma)
-        weight_buf[i] = (gamma.data().dptr<DType>())[i];   // weight
-      else
+    DType *weight_ptr = gamma.data().dptr<DType>();
+    DType* bias_ptr = beta.data().dptr<DType>();
+    const size_t copy_size = sizeof(DType) * channels_;
+    if (!param.fix_gamma) {
+      memcpy(weight_buf, weight_ptr, copy_size);
+      memcpy(&weight_buf[channels_], bias_ptr, copy_size);
+    } else {
+      for (int i = 0; i < channels_; i++) {
         weight_buf[i] = static_cast<DType>(1.0f);
+      }
+      memcpy(&weight_buf[channels_], bias_ptr, copy_size);
     }
-
-    for (int i = 0; i < channels_; i++) {
-      weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i];  // bias
-    }
-
     mkldnn_args_map_t net_args;
     net_args[MKLDNN_ARG_SRC] = *data_mem;
     net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
@@ -352,10 +354,10 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
 
     // training but no input mean and variance
     if (ctx.is_train && !param.use_global_stats) {
-      DType* moving_mean_ptr  = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
-      DType* moving_var_ptr   = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
-      DType* out_mean_ptr     = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
-      DType* out_var_ptr      = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
+      DType* moving_mean_ptr  = moving_mean.data().dptr<DType>();
+      DType* moving_var_ptr   = moving_var.data().dptr<DType>();
+      DType* out_mean_ptr     = out_mean.data().dptr<DType>();
+      DType* out_var_ptr      = out_var.data().dptr<DType>();
       mkldnn::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
       DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());
 
@@ -381,15 +383,17 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
 
     // copy data from gradw_mem to in_grad[1] and in_grad[2]
     DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
-    for (int i = 0; i < channels_; i++) {
-      if (!param.fix_gamma)
-        (in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
-      else
-        (in_grad[1].data().dptr<DType>())[i] = 0.0f;
-    }
+    DType *w_grad_1 = in_grad[1].data().dptr<DType>();
+    DType *w_grad_2 = in_grad[2].data().dptr<DType>();
 
-    for (int i = 0; i < channels_; i++) {
-      (in_grad[2].data().dptr<DType>())[i] = gw_buf[i + channels_];
+    if (!param.fix_gamma) {
+      memcpy(w_grad_1, gw_buf, copy_size);
+      memcpy(w_grad_2, &gw_buf[channels_], copy_size);
+    } else {
+      for (int i = 0; i < channels_; i++) {
+        (in_grad[1].data().dptr<DType>())[i] = 0.0f;
+      }
+      memcpy(w_grad_2, &gw_buf[channels_], copy_size);
     }
   } else {
     LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";