You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/04/23 06:19:03 UTC

[GitHub] [incubator-mxnet] bgawrych commented on a change in pull request #20202: [v1.x] OneDNN pooling operator, additional function to prepare kernels

bgawrych commented on a change in pull request #20202:
URL: https://github.com/apache/incubator-mxnet/pull/20202#discussion_r618962643



##########
File path: src/operator/nn/mkldnn/mkldnn_pooling.cc
##########
@@ -106,127 +106,72 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
   switch (param.pool_type) {
     case pool_enum::kMaxPooling:
       return mkldnn::algorithm::pooling_max;
-      break;
     case pool_enum::kAvgPooling:
-      if (param.count_include_pad.has_value() && !param.count_include_pad.value()) {
+      if (param.count_include_pad.has_value() &&
+          !param.count_include_pad.value()) {
         return mkldnn::algorithm::pooling_avg_exclude_padding;
       } else {
         return mkldnn::algorithm::pooling_avg_include_padding;
       }
-      break;
     default:
       LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method.";
       return mkldnn::algorithm::pooling_max;
   }
 }
 
+void prepareKernels(mkldnn::memory::dims &kernel, mkldnn::memory::dims &strides,
+                    mkldnn::memory::dims &pad_l, mkldnn::memory::dims &pad_r,
+                    const PoolingParam &param, const mkldnn::memory::desc &data_md, int kernel_ndims) {
+  CHECK_GE(param.pad.ndim(), kernel_ndims);
+  CHECK_GE(param.stride.ndim(), kernel_ndims);
+
+  for (int idx = 0; idx < kernel_ndims; ++idx) {
+    kernel[idx] = param.kernel[idx];
+    pad_l[idx] = param.pad[idx];
+    pad_r[idx] = param.pad[idx];
+    strides[idx] = param.stride[idx];
+  }
+  if (param.pooling_convention == pool_enum::kFull) {
+    for (int idx = 0; idx < kernel_ndims; ++idx) {
+      pad_r[idx] = GetPaddingSizeFull(data_md.data.dims[idx + 2], pad_l[idx],
+                                      pad_r[idx], kernel[idx], strides[idx]);
+    }
+  }
+  if (param.global_pool) {
+    for (int idx = 0; idx < kernel_ndims; ++idx) {
+      kernel[idx] = data_md.data.dims[idx + 2];
+      strides[idx] = 1;
+      pad_l[idx] = pad_r[idx] = 0;
+    }
+  }
+  for (int idx = 0; idx < kernel_ndims; ++idx) {
+    CHECK_GT(kernel[idx], 0) << "Filter dimensions cannot be zero.";
+  }
+}
+
 void InitPoolingPrimitiveParams(const PoolingParam &param,
                                 const mkldnn::memory::desc &data_md,
                                 const mkldnn::memory::dims &new_kernel,
                                 const mkldnn::memory::dims &new_strides,
                                 const mkldnn::memory::dims &new_pad_l,
                                 const mkldnn::memory::dims &new_pad_r) {
   const int kernel_ndims = param.kernel.ndim();
-  mkldnn::memory::dims& kernel = const_cast<mkldnn::memory::dims&>(new_kernel);
-  mkldnn::memory::dims& strides = const_cast<mkldnn::memory::dims&>(new_strides);
-  mkldnn::memory::dims& pad_l = const_cast<mkldnn::memory::dims&>(new_pad_l);
-  mkldnn::memory::dims& pad_r = const_cast<mkldnn::memory::dims&>(new_pad_r);
-  if (kernel_ndims == 1) {
-    CHECK_GE(param.pad.ndim(), 1);
-    CHECK_GE(param.stride.ndim(), 1);
-    kernel[0] = param.kernel[0];
-    pad_l[0] = param.pad[0];
-    pad_r[0] = param.pad[0];
-    strides[0] = param.stride[0];
-
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_r[0] =
-        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], kernel[0], strides[0]);
-    }
-
-    if (param.global_pool) {
-      kernel[0] = data_md.data.dims[2];
-      strides[0] = 1;
-      pad_l[0] = pad_r[0] = 0;
-    }
+  mkldnn::memory::dims &kernel = const_cast<mkldnn::memory::dims &>(new_kernel);
+  mkldnn::memory::dims &strides =
+      const_cast<mkldnn::memory::dims &>(new_strides);
+  mkldnn::memory::dims &pad_l = const_cast<mkldnn::memory::dims &>(new_pad_l);
+  mkldnn::memory::dims &pad_r = const_cast<mkldnn::memory::dims &>(new_pad_r);
 
-    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
-  } else if (kernel_ndims == 2) {
-    CHECK_GE(param.pad.ndim(), 2);
-    CHECK_GE(param.stride.ndim(), 2);
-    kernel[0] = param.kernel[0];
-    kernel[1] = param.kernel[1];
-    pad_l[0] = param.pad[0];
-    pad_l[1] = param.pad[1];
-    pad_r[0] = param.pad[0];
-    pad_r[1] = param.pad[1];
-    strides[0] = param.stride[0];
-    strides[1] = param.stride[1];
-
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_r[0] =
-        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], kernel[0], strides[0]);
-      pad_r[1] =
-        GetPaddingSizeFull(data_md.data.dims[3], pad_l[1], pad_r[1], kernel[1], strides[1]);
-    }
-
-    if (param.global_pool) {
-      kernel[0] = data_md.data.dims[2];
-      kernel[1] = data_md.data.dims[3];
-      strides[0] = strides[1] = 1;
-      pad_l[0] = pad_l[1] = pad_r[0] = pad_r[1] = 0;
-    }
-
-    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
-    CHECK_GT(kernel[1], 0) << "Filter dimensions cannot be zero.";
-  } else {
-    CHECK_GE(param.pad.ndim(), 3);
-    CHECK_GE(param.stride.ndim(), 3);
-    kernel[0] = param.kernel[0];
-    kernel[1] = param.kernel[1];
-    kernel[2] = param.kernel[2];
-    pad_l[0] = param.pad[0];
-    pad_l[1] = param.pad[1];
-    pad_l[2] = param.pad[2];
-    pad_r[0] = param.pad[0];
-    pad_r[1] = param.pad[1];
-    pad_r[2] = param.pad[2];
-    strides[0] = param.stride[0];
-    strides[1] = param.stride[1];
-    strides[2] = param.stride[2];
-
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_r[0] =
-        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], kernel[0], strides[0]);
-      pad_r[1] =
-        GetPaddingSizeFull(data_md.data.dims[3], pad_l[1], pad_r[1], kernel[1], strides[1]);
-      pad_r[2] =
-        GetPaddingSizeFull(data_md.data.dims[4], pad_l[2], pad_r[2], kernel[2], strides[2]);
-    }
-
-    if (param.global_pool) {
-      kernel[0] = data_md.data.dims[2];
-      kernel[1] = data_md.data.dims[3];
-      kernel[2] = data_md.data.dims[4];
-      strides[0] = strides[1] = strides[2] = 1;
-      pad_l[0] = pad_l[1] = pad_l[2] = pad_r[0] = pad_r[1] = pad_r[2] = 0;
-    }
-
-    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
-    CHECK_GT(kernel[1], 0) << "Filter dimensions cannot be zero.";
-    CHECK_GT(kernel[2], 0) << "Filter dimensions cannot be zero.";
-  }
+  prepareKernels(kernel, strides, pad_l, pad_r, param, data_md, kernel_ndims);

Review comment:
       Upper case 'P' ? Every other function is starting with upper-case

##########
File path: src/operator/nn/mkldnn/mkldnn_pooling.cc
##########
@@ -106,127 +106,72 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
   switch (param.pool_type) {
     case pool_enum::kMaxPooling:
       return mkldnn::algorithm::pooling_max;
-      break;
     case pool_enum::kAvgPooling:
-      if (param.count_include_pad.has_value() && !param.count_include_pad.value()) {
+      if (param.count_include_pad.has_value() &&
+          !param.count_include_pad.value()) {
         return mkldnn::algorithm::pooling_avg_exclude_padding;
       } else {
         return mkldnn::algorithm::pooling_avg_include_padding;
       }
-      break;
     default:
       LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method.";
       return mkldnn::algorithm::pooling_max;
   }
 }
 
+void prepareKernels(mkldnn::memory::dims &kernel, mkldnn::memory::dims &strides,
+                    mkldnn::memory::dims &pad_l, mkldnn::memory::dims &pad_r,
+                    const PoolingParam &param, const mkldnn::memory::desc &data_md, int kernel_ndims) {
+  CHECK_GE(param.pad.ndim(), kernel_ndims);
+  CHECK_GE(param.stride.ndim(), kernel_ndims);
+
+  for (int idx = 0; idx < kernel_ndims; ++idx) {
+    kernel[idx] = param.kernel[idx];
+    pad_l[idx] = param.pad[idx];
+    pad_r[idx] = param.pad[idx];
+    strides[idx] = param.stride[idx];
+  }
+  if (param.pooling_convention == pool_enum::kFull) {
+    for (int idx = 0; idx < kernel_ndims; ++idx) {
+      pad_r[idx] = GetPaddingSizeFull(data_md.data.dims[idx + 2], pad_l[idx],
+                                      pad_r[idx], kernel[idx], strides[idx]);
+    }
+  }
+  if (param.global_pool) {
+    for (int idx = 0; idx < kernel_ndims; ++idx) {
+      kernel[idx] = data_md.data.dims[idx + 2];
+      strides[idx] = 1;
+      pad_l[idx] = pad_r[idx] = 0;
+    }
+  }
+  for (int idx = 0; idx < kernel_ndims; ++idx) {
+    CHECK_GT(kernel[idx], 0) << "Filter dimensions cannot be zero.";
+  }
+}
+
 void InitPoolingPrimitiveParams(const PoolingParam &param,
                                 const mkldnn::memory::desc &data_md,
                                 const mkldnn::memory::dims &new_kernel,
                                 const mkldnn::memory::dims &new_strides,
                                 const mkldnn::memory::dims &new_pad_l,
                                 const mkldnn::memory::dims &new_pad_r) {
   const int kernel_ndims = param.kernel.ndim();
-  mkldnn::memory::dims& kernel = const_cast<mkldnn::memory::dims&>(new_kernel);
-  mkldnn::memory::dims& strides = const_cast<mkldnn::memory::dims&>(new_strides);
-  mkldnn::memory::dims& pad_l = const_cast<mkldnn::memory::dims&>(new_pad_l);
-  mkldnn::memory::dims& pad_r = const_cast<mkldnn::memory::dims&>(new_pad_r);
-  if (kernel_ndims == 1) {
-    CHECK_GE(param.pad.ndim(), 1);
-    CHECK_GE(param.stride.ndim(), 1);
-    kernel[0] = param.kernel[0];
-    pad_l[0] = param.pad[0];
-    pad_r[0] = param.pad[0];
-    strides[0] = param.stride[0];
-
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_r[0] =
-        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], kernel[0], strides[0]);
-    }
-
-    if (param.global_pool) {
-      kernel[0] = data_md.data.dims[2];
-      strides[0] = 1;
-      pad_l[0] = pad_r[0] = 0;
-    }
+  mkldnn::memory::dims &kernel = const_cast<mkldnn::memory::dims &>(new_kernel);

Review comment:
       Do you need these const_cast? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org