You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/05/24 14:03:40 UTC

[incubator-mxnet] branch v1.x updated: OneDNN pooling: this pull-request contains: (#20202)

This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new c76987e  OneDNN pooling: this pull-request contains: (#20202)
c76987e is described below

commit c76987eb18e3d43a1817046fe38dcd261648afb9
Author: mozga <ma...@intel.com>
AuthorDate: Mon May 24 16:01:23 2021 +0200

    OneDNN pooling: this pull-request contains: (#20202)
    
    1. A few part of repeated lines was replaced by a function.
    2. Some duplication was removed from the code.
---
 src/operator/nn/mkldnn/mkldnn_pooling.cc | 139 ++++++++++---------------------
 1 file changed, 43 insertions(+), 96 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 84ccfe0..a613310 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -106,20 +106,51 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
   switch (param.pool_type) {
     case pool_enum::kMaxPooling:
       return mkldnn::algorithm::pooling_max;
-      break;
     case pool_enum::kAvgPooling:
-      if (param.count_include_pad.has_value() && !param.count_include_pad.value()) {
+      if (param.count_include_pad.has_value() &&
+          !param.count_include_pad.value()) {
         return mkldnn::algorithm::pooling_avg_exclude_padding;
       } else {
         return mkldnn::algorithm::pooling_avg_include_padding;
       }
-      break;
     default:
       LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method.";
       return mkldnn::algorithm::pooling_max;
   }
 }
 
+void PrepareKernels(mkldnn::memory::dims *kernel, mkldnn::memory::dims *strides,
+                    mkldnn::memory::dims *pad_l, mkldnn::memory::dims *pad_r,
+                    const PoolingParam &param,
+                    const mkldnn::memory::desc &data_md, int kernel_ndims) {
+  CHECK_GE(param.pad.ndim(), kernel_ndims);
+  CHECK_GE(param.stride.ndim(), kernel_ndims);
+
+  for (int idx = 0; idx < kernel_ndims; ++idx) {
+    kernel->at(idx) = param.kernel[idx];
+    pad_l->at(idx) = param.pad[idx];
+    pad_r->at(idx) = param.pad[idx];
+    strides->at(idx) = param.stride[idx];
+  }
+  if (param.pooling_convention == pool_enum::kFull) {
+    for (int idx = 0; idx < kernel_ndims; ++idx) {
+      pad_r->at(idx) =
+          GetPaddingSizeFull(data_md.data.dims[idx + 2], pad_l->at(idx),
+                             pad_r->at(idx), kernel->at(idx), strides->at(idx));
+    }
+  }
+  if (param.global_pool) {
+    for (int idx = 0; idx < kernel_ndims; ++idx) {
+      kernel->at(idx) = data_md.data.dims[idx + 2];
+      strides->at(idx) = 1;
+      pad_l->at(idx) = pad_r->at(idx) = 0;
+    }
+  }
+  for (int idx = 0; idx < kernel_ndims; ++idx) {
+    CHECK_GT(kernel->at(idx), 0) << "Filter dimensions cannot be zero.";
+  }
+}
+
 void InitPoolingPrimitiveParams(const PoolingParam &param,
                                 const mkldnn::memory::desc &data_md,
                                 const mkldnn::memory::dims &new_kernel,
@@ -127,106 +158,22 @@ void InitPoolingPrimitiveParams(const PoolingParam &param,
                                 const mkldnn::memory::dims &new_pad_l,
                                 const mkldnn::memory::dims &new_pad_r) {
   const int kernel_ndims = param.kernel.ndim();
-  mkldnn::memory::dims& kernel = const_cast<mkldnn::memory::dims&>(new_kernel);
-  mkldnn::memory::dims& strides = const_cast<mkldnn::memory::dims&>(new_strides);
-  mkldnn::memory::dims& pad_l = const_cast<mkldnn::memory::dims&>(new_pad_l);
-  mkldnn::memory::dims& pad_r = const_cast<mkldnn::memory::dims&>(new_pad_r);
-  if (kernel_ndims == 1) {
-    CHECK_GE(param.pad.ndim(), 1);
-    CHECK_GE(param.stride.ndim(), 1);
-    kernel[0] = param.kernel[0];
-    pad_l[0] = param.pad[0];
-    pad_r[0] = param.pad[0];
-    strides[0] = param.stride[0];
-
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_r[0] =
-        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], kernel[0], strides[0]);
-    }
-
-    if (param.global_pool) {
-      kernel[0] = data_md.data.dims[2];
-      strides[0] = 1;
-      pad_l[0] = pad_r[0] = 0;
-    }
+  mkldnn::memory::dims &kernel = const_cast<mkldnn::memory::dims &>(new_kernel);
+  mkldnn::memory::dims &strides =
+      const_cast<mkldnn::memory::dims &>(new_strides);
+  mkldnn::memory::dims &pad_l = const_cast<mkldnn::memory::dims &>(new_pad_l);
+  mkldnn::memory::dims &pad_r = const_cast<mkldnn::memory::dims &>(new_pad_r);
 
-    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
-  } else if (kernel_ndims == 2) {
-    CHECK_GE(param.pad.ndim(), 2);
-    CHECK_GE(param.stride.ndim(), 2);
-    kernel[0] = param.kernel[0];
-    kernel[1] = param.kernel[1];
-    pad_l[0] = param.pad[0];
-    pad_l[1] = param.pad[1];
-    pad_r[0] = param.pad[0];
-    pad_r[1] = param.pad[1];
-    strides[0] = param.stride[0];
-    strides[1] = param.stride[1];
-
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_r[0] =
-        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], kernel[0], strides[0]);
-      pad_r[1] =
-        GetPaddingSizeFull(data_md.data.dims[3], pad_l[1], pad_r[1], kernel[1], strides[1]);
-    }
-
-    if (param.global_pool) {
-      kernel[0] = data_md.data.dims[2];
-      kernel[1] = data_md.data.dims[3];
-      strides[0] = strides[1] = 1;
-      pad_l[0] = pad_l[1] = pad_r[0] = pad_r[1] = 0;
-    }
-
-    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
-    CHECK_GT(kernel[1], 0) << "Filter dimensions cannot be zero.";
-  } else {
-    CHECK_GE(param.pad.ndim(), 3);
-    CHECK_GE(param.stride.ndim(), 3);
-    kernel[0] = param.kernel[0];
-    kernel[1] = param.kernel[1];
-    kernel[2] = param.kernel[2];
-    pad_l[0] = param.pad[0];
-    pad_l[1] = param.pad[1];
-    pad_l[2] = param.pad[2];
-    pad_r[0] = param.pad[0];
-    pad_r[1] = param.pad[1];
-    pad_r[2] = param.pad[2];
-    strides[0] = param.stride[0];
-    strides[1] = param.stride[1];
-    strides[2] = param.stride[2];
-
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_r[0] =
-        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], kernel[0], strides[0]);
-      pad_r[1] =
-        GetPaddingSizeFull(data_md.data.dims[3], pad_l[1], pad_r[1], kernel[1], strides[1]);
-      pad_r[2] =
-        GetPaddingSizeFull(data_md.data.dims[4], pad_l[2], pad_r[2], kernel[2], strides[2]);
-    }
-
-    if (param.global_pool) {
-      kernel[0] = data_md.data.dims[2];
-      kernel[1] = data_md.data.dims[3];
-      kernel[2] = data_md.data.dims[4];
-      strides[0] = strides[1] = strides[2] = 1;
-      pad_l[0] = pad_l[1] = pad_l[2] = pad_r[0] = pad_r[1] = pad_r[2] = 0;
-    }
-
-    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
-    CHECK_GT(kernel[1], 0) << "Filter dimensions cannot be zero.";
-    CHECK_GT(kernel[2], 0) << "Filter dimensions cannot be zero.";
-  }
+  PrepareKernels(&kernel, &strides, &pad_l, &pad_r, param, data_md, kernel_ndims);
 
   if (pad_l[0] != 0 || (kernel_ndims == 2 && pad_l[1] != 0) ||
-     (kernel_ndims == 3 && pad_l[2] != 0)) {
+      (kernel_ndims == 3 && pad_l[2] != 0)) {
     CHECK(param.pool_type == pool_enum::kAvgPooling ||
           param.pool_type == pool_enum::kMaxPooling)
         << "Padding implemented only for average and max pooling.";
     CHECK_LT(pad_l[0], kernel[0]);
-    if (kernel_ndims > 1)
-      CHECK_LT(pad_l[1], kernel[1]);
-    if (kernel_ndims > 2)
-      CHECK_LT(pad_l[2], kernel[2]);
+    if (kernel_ndims > 1) CHECK_LT(pad_l[1], kernel[1]);
+    if (kernel_ndims > 2) CHECK_LT(pad_l[2], kernel[2]);
   }
 }