You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/18 18:11:01 UTC

[GitHub] eric-haibin-lin closed pull request #11021: [MXNET-380] count_include_pad argument for Avg Pooling

eric-haibin-lin closed pull request #11021: [MXNET-380] count_include_pad argument for Avg Pooling
URL: https://github.com/apache/incubator-mxnet/pull/11021
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py
index 0c000d9955f..8facde16840 100644
--- a/cpp-package/scripts/OpWrapperGenerator.py
+++ b/cpp-package/scripts/OpWrapperGenerator.py
@@ -77,6 +77,7 @@ def GetConvertEnumVariableToString(self, variable=''):
 
 class Arg:
     typeDict = {'boolean':'bool',\
+        'boolean or None':'dmlc::optional<bool>',\
         'Shape(tuple)':'Shape',\
         'Symbol':'Symbol',\
         'NDArray':'Symbol',\
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 2fbf7d8786d..24f30270ad6 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -675,7 +675,7 @@ def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0),
 class _Pooling(HybridBlock):
     """Abstract class for different pooling layers."""
     def __init__(self, pool_size, strides, padding, ceil_mode, global_pool,
-                 pool_type, **kwargs):
+                 pool_type, count_include_pad=None, **kwargs):
         super(_Pooling, self).__init__(**kwargs)
         if strides is None:
             strides = pool_size
@@ -687,6 +687,8 @@ def __init__(self, pool_size, strides, padding, ceil_mode, global_pool,
             'kernel': pool_size, 'stride': strides, 'pad': padding,
             'global_pool': global_pool, 'pool_type': pool_type,
             'pooling_convention': 'full' if ceil_mode else 'valid'}
+        if count_include_pad is not None:
+            self._kwargs['count_include_pad'] = count_include_pad
 
     def _alias(self):
         return 'pool'
@@ -863,6 +865,8 @@ class AvgPool1D(_Pooling):
         respectively. padding is applied on 'W' dimension.
     ceil_mode : bool, default False
         When `True`, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Inputs:
@@ -879,13 +883,13 @@ class AvgPool1D(_Pooling):
           equation.
     """
     def __init__(self, pool_size=2, strides=None, padding=0, layout='NCW',
-                 ceil_mode=False, **kwargs):
+                 ceil_mode=False, count_include_pad=True, **kwargs):
         assert layout == 'NCW', "Only supports 'NCW' layout for now"
         if isinstance(pool_size, numeric_types):
             pool_size = (pool_size,)
         assert len(pool_size) == 1, "pool_size must be a number or a list of 1 ints"
         super(AvgPool1D, self).__init__(
-            pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs)
+            pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs)
 
 
 class AvgPool2D(_Pooling):
@@ -907,6 +911,8 @@ class AvgPool2D(_Pooling):
         dimensions respectively. padding is applied on 'H' and 'W' dimension.
     ceil_mode : bool, default False
         When True, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Inputs:
@@ -926,13 +932,13 @@ class AvgPool2D(_Pooling):
           equation.
     """
     def __init__(self, pool_size=(2, 2), strides=None, padding=0,
-                 ceil_mode=False, layout='NCHW', **kwargs):
+                 ceil_mode=False, layout='NCHW', count_include_pad=True, **kwargs):
         assert layout == 'NCHW', "Only supports 'NCHW' layout for now"
         if isinstance(pool_size, numeric_types):
             pool_size = (pool_size,)*2
         assert len(pool_size) == 2, "pool_size must be a number or a list of 2 ints"
         super(AvgPool2D, self).__init__(
-            pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs)
+            pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs)
 
 
 class AvgPool3D(_Pooling):
@@ -955,6 +961,8 @@ class AvgPool3D(_Pooling):
         dimension.
     ceil_mode : bool, default False
         When True, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Inputs:
@@ -975,13 +983,13 @@ class AvgPool3D(_Pooling):
           equation.
     """
     def __init__(self, pool_size=(2, 2, 2), strides=None, padding=0,
-                 ceil_mode=False, layout='NCDHW', **kwargs):
+                 ceil_mode=False, layout='NCDHW', count_include_pad=True, **kwargs):
         assert layout == 'NCDHW', "Only supports 'NCDHW' layout for now"
         if isinstance(pool_size, numeric_types):
             pool_size = (pool_size,)*3
         assert len(pool_size) == 3, "pool_size must be a number or a list of 3 ints"
         super(AvgPool3D, self).__init__(
-            pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs)
+            pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs)
 
 
 class GlobalMaxPool1D(_Pooling):
diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h
index 84cf6403043..bc3ee366007 100644
--- a/src/operator/nn/cudnn/cudnn_pooling-inl.h
+++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h
@@ -51,7 +51,11 @@ class CuDNNPoolingOp {
         mode_ = CUDNN_POOLING_MAX;
         break;
       case pool_enum::kAvgPooling:
-        mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+        if (param_.count_include_pad.has_value() && !param_.count_include_pad.value()) {
+          mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
+        } else {
+          mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+        }
         break;
       default:
         LOG(FATAL) << "Not implmented";
@@ -263,7 +267,7 @@ class CuDNNPoolingOp {
                                              &(pad_vec[0]),
                                              &(stride_vec[0])));
       #else
-      LOG(FATAL) << "3D pooling only support CUDNN v5 and abouve";
+      LOG(FATAL) << "3D pooling only support CUDNN v5 and above";
       #endif
     }
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 259af2b9402..9fd88a13c46 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -121,7 +121,11 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
       return mkldnn::algorithm::pooling_max;
       break;
     case pool_enum::kAvgPooling:
-      return mkldnn::algorithm::pooling_avg_include_padding;
+      if (param.count_include_pad.has_value() && !param.count_include_pad.value()) {
+        return mkldnn::algorithm::pooling_avg_exclude_padding;
+      } else {
+        return mkldnn::algorithm::pooling_avg_include_padding;
+      }
       break;
     default:
       LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method.";
diff --git a/src/operator/nn/pool.cuh b/src/operator/nn/pool.cuh
index 9d004d295be..976aacf63a5 100644
--- a/src/operator/nn/pool.cuh
+++ b/src/operator/nn/pool.cuh
@@ -214,16 +214,19 @@ template <typename DType, int p = 1>
 __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, const int channels,
                                        const int width, const int pooled_width, const int kernel_w,
                                        const int stride_w, const int pad_w, DType* out_data,
-                                       const bool getAvg = false) {
+                                       const bool get_avg = false, const bool count_include_pad = true) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     const int pw = index % pooled_width;
     const int c = (index / pooled_width) % channels;
     const int n = index / pooled_width / channels;
     int wstart = pw * stride_w - pad_w;
     int wend = min(wstart + kernel_w, width + pad_w);
-    const int pool_size = (getAvg? (wend - wstart) : 1);
+    int pool_size = (get_avg? (wend - wstart) : 1);
     wstart = max(wstart, 0);
     wend = min(wend, width);
+    if (get_avg && !count_include_pad) {
+      pool_size = (wend - wstart);
+    }
     DType sum = 0;
     const DType* out_slice = in_data + (n * channels + c) * width;
     for (int w = wstart; w < wend; ++w) {
@@ -244,7 +247,8 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data,
                                        const int kernel_h, const int kernel_w,
                                        const int stride_h, const int stride_w,
                                        const int pad_h, const int pad_w, DType* out_data,
-                                       const bool getAvg = false) {
+                                       const bool get_avg = false,
+                                       const bool count_include_pad = true) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     const int pw = index % pooled_width;
     const int ph = (index / pooled_width) % pooled_height;
@@ -254,11 +258,14 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data,
     int wstart = pw * stride_w - pad_w;
     int hend = min(hstart + kernel_h, height + pad_h);
     int wend = min(wstart + kernel_w, width + pad_w);
-    const int pool_size = (getAvg? (hend - hstart) * (wend - wstart) : 1);
+    int pool_size = (get_avg? (hend - hstart) * (wend - wstart) : 1);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
     hend = min(hend, height);
     wend = min(wend, width);
+    if (get_avg && !count_include_pad) {
+      pool_size = (hend - hstart) * (wend - wstart);
+    }
     DType sum = 0;
     const DType* out_slice = in_data + (n * channels + c) * height * width;
     for (int h = hstart; h < hend; ++h) {
@@ -282,7 +289,8 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
                                        const int kernel_h, const int kernel_w,
                                        const int stride_d, const int stride_h, const int stride_w,
                                        const int pad_d, const int pad_h, const int pad_w,
-                                       DType* out_data, const bool getAvg = false) {
+                                       DType* out_data, const bool get_avg = false,
+                                       const bool count_include_pad = true) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     const int pw = index % pooled_width;
     const int ph = (index / pooled_width) % pooled_height;
@@ -295,13 +303,16 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
     int dend = min(dstart + kernel_d, depth + pad_d);
     int hend = min(hstart + kernel_h, height + pad_h);
     int wend = min(wstart + kernel_w, width + pad_w);
-    const int pool_size = (getAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+    int pool_size = (get_avg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
     dstart = max(dstart, 0);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
     dend = min(dend, depth);
     hend = min(hend, height);
     wend = min(wend, width);
+    if (get_avg && !count_include_pad) {
+      pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+    }
     DType sum = 0;
     const DType* out_slice = in_data + (n * channels + c) * depth * height * width;
     for (int d = dstart; d < dend; ++d) {
@@ -311,7 +322,9 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
         }
       }
     }
-    out_data[index] = a_root_p<DType, p>::Map(sum);
+    out_data[index] = (pool_size == 0) ?
+                      DType(nanf("")) :
+                      a_root_p<DType, p>::Map(sum);
   }
 }
 
@@ -487,7 +500,8 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr
                                          const int channels, const int width,
                                          const int pooled_width, const int kernel_w,
                                          const int stride_w, const int pad_w, DType* in_grad,
-                                         const bool isAvg = false) {
+                                         const bool is_avg = false,
+                                         const bool count_include_pad = true) {
   // index is the input image index in NCW
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -506,7 +520,12 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr
       // figure out the pooling size
       int wstart = pw * stride_w - pad_w;
       int wend = min(wstart + kernel_w, width + pad_w);
-      int pool_size = (isAvg? (wend - wstart) : 1);
+      int pool_size = (is_avg? (wend - wstart) : 1);
+      if (is_avg && !count_include_pad) {
+        wstart = max(wstart, 0);
+        wend = min(wend, width);
+        pool_size = (wend - wstart);
+      }
       gradient +=
         lp_grad<DType, p>::Map(out_grad_slice[pw], in_data[index], out_data_slice[pw]) / pool_size;
     }
@@ -528,7 +547,8 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr
                                          const int kernel_h, const int kernel_w,
                                          const int stride_h, const int stride_w,
                                          const int pad_h, const int pad_w, DType* in_grad,
-                                         const bool isAvg = false) {
+                                         const bool is_avg = false,
+                                         const bool count_include_pad = true) {
   // index is the input image index in NCHW
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -553,8 +573,15 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr
         int wstart = pw * stride_w - pad_w;
         int hend = min(hstart + kernel_h, height + pad_h);
         int wend = min(wstart + kernel_w, width + pad_w);
-        int pool_size = (isAvg? (hend - hstart) * (wend - wstart) : 1);
+        int pool_size = (is_avg? (hend - hstart) * (wend - wstart) : 1);
         int out_index = ph * pooled_width + pw;
+        if (is_avg && !count_include_pad) {
+          hstart = max(hstart, 0);
+          wstart = max(wstart, 0);
+          hend = min(hend, height);
+          wend = min(wend, width);
+          pool_size = (hend - hstart) * (wend - wstart);
+        }
         gradient +=
           lp_grad<DType, p>::Map(out_grad_slice[out_index],
                                  in_data[index],
@@ -580,7 +607,8 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr
                                          const int kernel_d, const int kernel_h,
                                          const int kernel_w, const int stride_d, const int stride_h,
                                          const int stride_w, const int pad_d, const int pad_h,
-                                         const int pad_w, DType* in_grad, const bool isAvg = false) {
+                                         const int pad_w, DType* in_grad, const bool is_avg = false,
+                                         const bool count_include_pad = true) {
   // index is the input image index in NCDHW
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -611,8 +639,17 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr
           int dend = min(dstart + kernel_d, depth + pad_d);
           int hend = min(hstart + kernel_h, height + pad_h);
           int wend = min(wstart + kernel_w, width + pad_w);
-          int pool_size = (isAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+          int pool_size = (is_avg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
           int out_index = (pd * pooled_height + ph) * pooled_width + pw;
+          if (is_avg && !count_include_pad) {
+            dstart = max(dstart, 0);
+            hstart = max(hstart, 0);
+            wstart = max(wstart, 0);
+            dend = min(dend, depth);
+            hend = min(hend, height);
+            wend = min(wend, width);
+            pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+          }
           gradient += lp_grad<DType, p>::Map(out_grad_slice[out_index],
                                              in_data[index],
                                              out_data_slice[out_index]) / pool_size;
@@ -643,7 +680,7 @@ template<typename DType, int p>
 inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& ishape,
                  const TShape& oshape, const TShape& kernel, const TShape& pad,
                  const TShape& stride, const int pool_type, OpReqType req_type,
-                 DType* out_data) {
+                 DType* out_data, const bool count_include_pad) {
   CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations";
   using namespace mxnet_op;
   if (kernel.ndim() == 1) {
@@ -659,7 +696,8 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
       pool_sum_1d_gpu_kernel<<<cuda_get_num_blocks(oshape.Size()), mshadow::cuda::kBaseThreadNum,
                                0, mshadow::Stream<gpu>::GetStream(s)>>>(
                                    oshape.Size(), in_data, ishape[1], ishape[2], oshape[2],
-                                   kernel[0], stride[0], pad[0], out_data, true);
+                                   kernel[0], stride[0], pad[0], out_data,
+                                   true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_1d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -693,7 +731,8 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
                                0, mshadow::Stream<gpu>::GetStream(s)>>>(
                                    oshape.Size(), in_data, ishape[1], ishape[2], ishape[3],
                                    oshape[2], oshape[3], kernel[0], kernel[1],
-                                   stride[0], stride[1], pad[0], pad[1], out_data, true);
+                                   stride[0], stride[1], pad[0], pad[1], out_data,
+                                   true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_2d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -731,7 +770,7 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
                                    oshape.Size(), in_data, ishape[1], ishape[2], ishape[3],
                                    ishape[4], oshape[2], oshape[3], oshape[4], kernel[0],
                                    kernel[1], kernel[2], stride[0], stride[1], stride[2],
-                                   pad[0], pad[1], pad[2], out_data, true);
+                                   pad[0], pad[1], pad[2], out_data, true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_3d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -777,7 +816,8 @@ template<typename DType, int p>
 inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType* in_data,
                    const DType* out_data, const TShape& ishape, const TShape& oshape,
                    const TShape& kernel, const TShape& pad, const TShape& stride,
-                   const int pool_type, OpReqType req_type, DType* in_grad) {
+                   const int pool_type, OpReqType req_type, DType* in_grad,
+                   const bool count_include_pad) {
   if (mxnet::kNullOp == req_type) return;
   if (mxnet::kAddTo != req_type) {
     mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, ishape.Size(), in_grad);
@@ -798,7 +838,7 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
                                  0, mshadow::Stream<gpu>::GetStream(s)>>>(
                                      ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], oshape[2], kernel[0],
-                                     stride[0], pad[0], in_grad, true);
+                                     stride[0], pad[0], in_grad, true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -836,7 +876,8 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
                                      ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], ishape[3],
                                      oshape[2], oshape[3], kernel[0], kernel[1],
-                                     stride[0], stride[1], pad[0], pad[1], in_grad, true);
+                                     stride[0], stride[1], pad[0], pad[1], in_grad,
+                                     true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_2d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -878,7 +919,7 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
                                      ishape[1], ishape[2], ishape[3], ishape[4],
                                      oshape[2], oshape[3], oshape[4], kernel[0], kernel[1],
                                      kernel[2], stride[0], stride[1], stride[2], pad[0], pad[1],
-                                     pad[2], in_grad, true);
+                                     pad[2], in_grad, true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_3d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h
index 9fe43b2bd46..8f7a5edc832 100644
--- a/src/operator/nn/pool.h
+++ b/src/operator/nn/pool.h
@@ -216,7 +216,8 @@ inline void pool_max_3d_cpu(const DType* in_data, const TShape& ishape, const TS
 template<typename DType, int p = 1>
 inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, const bool getAvg = false) {
+                            DType* out_data,
+                            const bool get_avg = false, const bool count_include_pad = true) {
   const int width = ishape[2];
   const int pooled_width = oshape[2];
   const int kernel_w = kernel[0];
@@ -229,9 +230,12 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS
       for (int pw = 0; pw < pooled_width; ++pw) {
         int wstart = pw * stride_w - pad_w;
         int wend = std::min(wstart + kernel_w, width + pad_w);
-        int pool_size = (getAvg ? (wend - wstart) : 1);
+        int pool_size = (get_avg ? (wend - wstart) : 1);
         wstart = std::max(wstart, 0);
         wend = std::min(wend, width);
+        if (get_avg && !count_include_pad) {
+          pool_size = (wend - wstart);
+        }
         DType sum = 0;
         for (int w = wstart; w < wend; ++w) {
           sum += a_pow_p<DType, p>::Map(in_data[w]) / pool_size;
@@ -251,7 +255,8 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS
 template<typename DType, int p = 1>
 inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, const bool getAvg = false) {
+                            DType* out_data,
+                            const bool get_avg = false, const bool count_include_pad = true) {
   const int height = ishape[2], width = ishape[3];
   const int pooled_height = oshape[2], pooled_width = oshape[3];
   const int kernel_h = kernel[0], kernel_w = kernel[1];
@@ -267,11 +272,14 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS
           int wstart = pw * stride_w - pad_w;
           int hend = std::min(hstart + kernel_h, height + pad_h);
           int wend = std::min(wstart + kernel_w, width + pad_w);
-          int pool_size = (getAvg ? (hend - hstart) * (wend - wstart) : 1);
+          int pool_size = (get_avg ? (hend - hstart) * (wend - wstart) : 1);
           hstart = std::max(hstart, 0);
           wstart = std::max(wstart, 0);
           hend = std::min(hend, height);
           wend = std::min(wend, width);
+          if (get_avg && !count_include_pad) {
+            pool_size = (hend - hstart) * (wend - wstart);
+          }
           DType sum = 0;
           for (int h = hstart; h < hend; ++h) {
             for (int w = wstart; w < wend; ++w) {
@@ -294,7 +302,8 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS
 template<typename DType, int p = 1>
 inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, const bool getAvg = false) {
+                            DType* out_data,
+                            const bool get_avg = false, const bool count_include_pad = true) {
   const int depth = ishape[2], height = ishape[3], width = ishape[4];
   const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4];
   const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2];
@@ -313,13 +322,16 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS
             int dend = std::min(dstart + kernel_d, depth + pad_d);
             int hend = std::min(hstart + kernel_h, height + pad_h);
             int wend = std::min(wstart + kernel_w, width + pad_w);
-            int pool_size = (getAvg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+            int pool_size = (get_avg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
             dstart = std::max(dstart, 0);
             hstart = std::max(hstart, 0);
             wstart = std::max(wstart, 0);
             dend = std::min(dend, depth);
             hend = std::min(hend, height);
             wend = std::min(wend, width);
+            if (get_avg && !count_include_pad) {
+              pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+            }
             DType sum = 0;
             for (int d = dstart; d < dend; ++d) {
               for (int h = hstart; h < hend; ++h) {
@@ -328,7 +340,9 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS
                 }
               }
             }
-            out_data[(pd*pooled_height+ph)*pooled_width+pw] = a_root_p<DType, p>::Map(sum);
+            out_data[(pd*pooled_height+ph)*pooled_width+pw] = (pool_size == 0) ?
+                                                              DType(nanf("")) :
+                                                              a_root_p<DType, p>::Map(sum);
           }
         }
       }
@@ -509,8 +523,8 @@ inline void unpool_max_3d_cpu(const DType* out_grad, const DType* in_data,
 template<typename DType, int p = 1>
 inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
                               const TShape& ishape, const TShape& oshape, const TShape& kernel,
-                              const TShape& pad, const TShape& stride,
-                              DType* in_grad, const bool isAvg = false) {
+                              const TShape& pad, const TShape& stride, DType* in_grad,
+                              const bool is_avg = false, const bool count_include_pad = true) {
   const int width = ishape[2];
   const int pooled_width = oshape[2];
   const int kernel_w = kernel[0];
@@ -523,9 +537,12 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const
       for (int pw = 0; pw < pooled_width; ++pw) {
         int wstart = pw * stride_w - pad_w;
         int wend = std::min(wstart + kernel_w, width + pad_w);
-        int pool_size = (isAvg ? (wend - wstart) : 1);
+        int pool_size = (is_avg ? (wend - wstart) : 1);
         wstart = std::max(wstart, 0);
         wend = std::min(wend, width);
+        if (is_avg && !count_include_pad) {
+          pool_size = (wend - wstart);
+        }
         for (int w = wstart; w < wend; ++w) {
           in_grad[w] += lp_grad<DType, p>::Map(out_grad[pw], in_data[w], out_data[pw]) / pool_size;
         }
@@ -545,8 +562,8 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const
 template<typename DType, int p = 1>
 inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
                               const TShape& ishape, const TShape& oshape, const TShape& kernel,
-                              const TShape& pad, const TShape& stride,
-                              DType* in_grad, const bool isAvg = false) {
+                              const TShape& pad, const TShape& stride, DType* in_grad,
+                              const bool is_avg = false, const bool count_include_pad = true) {
   const int height = ishape[2], width = ishape[3];
   const int pooled_height = oshape[2], pooled_width = oshape[3];
   const int kernel_h = kernel[0], kernel_w = kernel[1];
@@ -562,11 +579,14 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const
           int wstart = pw * stride_w - pad_w;
           int hend = std::min(hstart + kernel_h, height + pad_h);
           int wend = std::min(wstart + kernel_w, width + pad_w);
-          int pool_size = (isAvg ? (hend - hstart) * (wend - wstart) : 1);
+          int pool_size = (is_avg ? (hend - hstart) * (wend - wstart) : 1);
           hstart = std::max(hstart, 0);
           wstart = std::max(wstart, 0);
           hend = std::min(hend, height);
           wend = std::min(wend, width);
+          if (is_avg && !count_include_pad) {
+            pool_size = (hend - hstart) * (wend - wstart);
+          }
           const int pool_index = ph * pooled_width + pw;
           for (int h = hstart; h < hend; ++h) {
             for (int w = wstart; w < wend; ++w) {
@@ -593,8 +613,8 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const
 template<typename DType, int p = 1>
 inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
                               const TShape& ishape, const TShape& oshape, const TShape& kernel,
-                              const TShape& pad, const TShape& stride,
-                              DType* in_grad, const bool isAvg = false) {
+                              const TShape& pad, const TShape& stride, DType* in_grad,
+                              const bool is_avg = false, const bool count_include_pad = true) {
   const int depth = ishape[2], height = ishape[3], width = ishape[4];
   const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4];
   const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2];
@@ -613,13 +633,16 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const
             int dend = std::min(dstart + kernel_d, depth + pad_d);
             int hend = std::min(hstart + kernel_h, height + pad_h);
             int wend = std::min(wstart + kernel_w, width + pad_w);
-            int pool_size = (isAvg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+            int pool_size = (is_avg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
             dstart = std::max(dstart, 0);
             hstart = std::max(hstart, 0);
             wstart = std::max(wstart, 0);
             dend = std::min(dend, depth);
             hend = std::min(hend, height);
             wend = std::min(wend, width);
+            if (is_avg && !count_include_pad) {
+              pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+            }
             const int pool_index = (pd * pooled_height + ph) * pooled_width + pw;
             for (int d = dstart; d < dend; ++d) {
               for (int h = hstart; h < hend; ++h) {
@@ -660,13 +683,14 @@ template<typename DType, int p>
 inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& ishape,
                  const TShape& oshape, const TShape& kernel, const TShape& pad,
                  const TShape& stride, const int pool_type, OpReqType req_type,
-                 DType* out_data) {
+                 DType* out_data, const bool count_include_pad) {
   CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations";
   if (kernel.ndim() == 1) {
     if (pool_enum::kMaxPooling == pool_type) {
       pool_max_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
+      pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data,
+                      true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -678,7 +702,8 @@ inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& is
     if (pool_enum::kMaxPooling == pool_type) {
       pool_max_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
+      pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data,
+                      true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -690,7 +715,8 @@ inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& is
     if (pool_enum::kMaxPooling == pool_type) {
       pool_max_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
+      pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data,
+                      true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -723,7 +749,8 @@ template<typename DType, int p>
 inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType* in_data,
                    const DType* out_data, const TShape& ishape, const TShape& oshape,
                    const TShape& kernel, const TShape& pad, const TShape& stride,
-                   const int pool_type, OpReqType req_type, DType* in_grad, const int p_value = 2) {
+                   const int pool_type, OpReqType req_type, DType* in_grad,
+                   const bool count_include_pad) {
   if (mxnet::kNullOp == req_type) return;
   if (mxnet::kAddTo != req_type) {
     mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, ishape.Size(), in_grad);
@@ -733,7 +760,7 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
       unpool_max_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
       unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
-                        true);
+                        true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -747,7 +774,7 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
       unpool_max_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
       unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
-                        true);
+                        true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -761,7 +788,7 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
       unpool_max_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
       unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
-                        true);
+                        true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kLpPooling == pool_type) {
diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h
index a4770b49e85..0c4acf9d318 100644
--- a/src/operator/nn/pooling-inl.h
+++ b/src/operator/nn/pooling-inl.h
@@ -50,6 +50,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
   bool global_pool;
   bool cudnn_off;
   dmlc::optional<int> p_value;
+  dmlc::optional<bool> count_include_pad;
   DMLC_DECLARE_PARAMETER(PoolingParam) {
     DMLC_DECLARE_FIELD(kernel).set_default(TShape())  // add default value here
     .enforce_nonzero()
@@ -81,7 +82,13 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
     .describe("Pad for pooling: (y, x) or (d, y, x). Defaults to no padding.");
 
     DMLC_DECLARE_FIELD(p_value).set_default(dmlc::optional<int>())
-    .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling");
+    .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling.");
+
+    DMLC_DECLARE_FIELD(count_include_pad).set_default(dmlc::optional<bool>())
+    .describe("Only used for AvgPool, specify whether to count padding elements for average"
+              "calculation. For example, with a 5*5 kernel on a 3*3 corner of a image,"
+              "the sum of the 9 valid elements will be divided by 25 if this is set to true,"
+              "or it will be divided by 9 if this is set to false. Defaults to true.");
   }
 
   bool operator==(const PoolingParam& other) const {
@@ -92,7 +99,8 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
            this->pooling_convention == other.pooling_convention &&
            this->global_pool        == other.global_pool &&
            this->cudnn_off          == other.cudnn_off &&
-           this->p_value            == other.p_value;
+           this->p_value            == other.p_value &&
+           this->count_include_pad  == other.count_include_pad;
   }
 };
 
@@ -112,6 +120,7 @@ struct hash<mxnet::op::PoolingParam> {
     ret = dmlc::HashCombine(ret, val.global_pool);
     ret = dmlc::HashCombine(ret, val.cudnn_off);
     ret = dmlc::HashCombine(ret, val.p_value);
+    ret = dmlc::HashCombine(ret, val.count_include_pad);
     return ret;
   }
 };
@@ -153,27 +162,29 @@ class PoolingOp {
     }
     const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ?
                         param_.p_value.value() : 1;
+    const bool count_include_pad = (param_.count_include_pad.has_value()) ?
+                                   param_.count_include_pad.value() : true;
     switch (p_value) {
       case 1:
         pool<DType, 1>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
           kernel,
           padding,
           stride,
-          param_.pool_type, req, out_data.dptr<DType>());
+          param_.pool_type, req, out_data.dptr<DType>(), count_include_pad);
         break;
       case 2:
         pool<DType, 2>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
           kernel,
           padding,
           stride,
-          param_.pool_type, req, out_data.dptr<DType>());
+          param_.pool_type, req, out_data.dptr<DType>(), count_include_pad);
         break;
       case 3:
         pool<DType, 3>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
           kernel,
           padding,
           stride,
-          param_.pool_type, req, out_data.dptr<DType>());
+          param_.pool_type, req, out_data.dptr<DType>(), count_include_pad);
         break;
       default:
         LOG(FATAL) << "p value of " << p_value << " is not supported yet...";
@@ -201,6 +212,8 @@ class PoolingOp {
 
     const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ?
                         param_.p_value.value() : 1;
+    const bool count_include_pad = (param_.count_include_pad.has_value()) ?
+                                   param_.count_include_pad.value() : true;
     switch (p_value) {
       case 1:
         unpool<DType, 1>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
@@ -208,7 +221,7 @@ class PoolingOp {
            kernel,
            padding,
            stride,
-           param_.pool_type, req, in_grad.dptr<DType>());
+           param_.pool_type, req, in_grad.dptr<DType>(), count_include_pad);
         break;
       case 2:
         unpool<DType, 2>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
@@ -216,7 +229,7 @@ class PoolingOp {
            kernel,
            padding,
            stride,
-           param_.pool_type, req, in_grad.dptr<DType>());
+           param_.pool_type, req, in_grad.dptr<DType>(), count_include_pad);
         break;
       case 3:
         unpool<DType, 3>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
@@ -224,7 +237,7 @@ class PoolingOp {
            kernel,
            padding,
            stride,
-           param_.pool_type, req, in_grad.dptr<DType>());
+           param_.pool_type, req, in_grad.dptr<DType>(), count_include_pad);
         break;
       default:
         LOG(FATAL) << "p value of " << p_value << " is not supported yet...";
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 7c3d670ba22..1c6785a5702 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -740,8 +740,8 @@ def test_pooling_with_type():
 
 @with_seed()
 def test_pooling_versions():
-    def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride,
-                                     pooling_convention='valid', global_pool=False, p_value=2):
+    def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride, pooling_convention='valid',
+                                     global_pool=False, p_value=2, count_include_pad=True, tol=None):
         ctx_list = []
         sym_list = []
         # PoolingV1 cpu
@@ -765,61 +765,69 @@ def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, str
             ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, name='pool', p_value=p_value))
+                                               pooling_convention=pooling_convention, name='pool',
+                                               p_value=p_value, count_include_pad=count_include_pad))
             else:
-                sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool', p_value=p_value))
+                sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool',
+                                               p_value=p_value, count_include_pad=count_include_pad))
         # Pooling gpu
         if 'pool_gpu' in pool_op_list:
             ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, cudnn_off=True, name='pool', p_value=p_value))
+                                               pooling_convention=pooling_convention, cudnn_off=True, name='pool',
+                                               p_value=p_value, count_include_pad=count_include_pad))
             else:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, cudnn_off=True,
-                                               name='pool', p_value=p_value))
+                                               name='pool', p_value=p_value, count_include_pad=count_include_pad))
         # CuDNNPooling
         if 'pool_cudnn' in pool_op_list:
             ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, p_value=p_value, cudnn_off=False, name='pool'))
+                                               pooling_convention=pooling_convention, p_value=p_value, cudnn_off=False,
+                                               name='pool', count_include_pad=count_include_pad))
             else:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, p_value=p_value,
-                                               cudnn_off=False, name='pool'))
-        check_consistency(sym_list, ctx_list)
+                                               cudnn_off=False, name='pool', count_include_pad=count_include_pad))
+        check_consistency(sym_list, ctx_list, equal_nan=(not count_include_pad), tol=tol)
 
-    def test_1d_pooling(pool_type, p_value=2):
+    def test_1d_pooling(pool_type, p_value=2, count_include_pad=True):
         data = (2, 3, 20)
         kernel = (4,)
         pad = (0,)
         stride = (1,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (2,)
         stride = (2,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (0,)
         stride = (1,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (2,)
         stride = (2,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     global_pool=True, p_value=p_value)
+                                     global_pool=True, p_value=p_value, count_include_pad=count_include_pad)
 
-    def test_2d_pooling(pool_type, p_value=2):
+    def test_2d_pooling(pool_type, p_value=2, count_include_pad=True):
         data = (2, 3, 20, 20)
         kernel = (4, 5)
         pad = (0, 0)
@@ -831,14 +839,15 @@ def test_2d_pooling(pool_type, p_value=2):
         else:
             test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                          data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                         pooling_convention='valid', global_pool=False)
+                                         pooling_convention='valid', global_pool=False, count_include_pad=count_include_pad)
 
         # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here
         pad = (2, 3)
         stride = (2, 3)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (0, 0)
         stride = (1, 1)
@@ -847,16 +856,24 @@ def test_2d_pooling(pool_type, p_value=2):
                                          data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
                                          pooling_convention='full', global_pool=False, p_value=p_value)
         else:
-            test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
-                                         data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                         pooling_convention='full', global_pool=False)
+            if count_include_pad:
+                test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                             data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                             pooling_convention='full', global_pool=False,
+                                             count_include_pad=count_include_pad)
+            else:
+                test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                             data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                             pooling_convention='full', global_pool=False,
+                                             count_include_pad=count_include_pad)
 
         # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here
         pad = (2, 3)
         stride = (2, 3)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         if pool_type == 'lp':
             test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
@@ -865,55 +882,62 @@ def test_2d_pooling(pool_type, p_value=2):
         else:
             test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                          data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                         global_pool=True)
+                                         global_pool=True, count_include_pad=count_include_pad)
 
-    def test_3d_pooling(pool_type, p_value=2):
+    def test_3d_pooling(pool_type, p_value=2, count_include_pad=True):
         data = (2, 3, 20, 20, 20)
         kernel = (4, 5, 3)
         pad = (0, 0, 0)
         stride = (1, 1, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (2, 3, 3)
         stride = (2, 3, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (0, 0, 0)
         stride = (1, 1, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (2, 3, 3)
         stride = (2, 3, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     global_pool=True, p_value=p_value)
+                                     global_pool=True, p_value=p_value, count_include_pad=count_include_pad)
 
     test_1d_pooling('max')
-    test_1d_pooling('avg')
+    test_1d_pooling('avg', count_include_pad=True)
+    test_1d_pooling('avg', count_include_pad=False)
     test_1d_pooling('sum')
     test_1d_pooling('lp', p_value=1)
     test_1d_pooling('lp', p_value=2)
     test_1d_pooling('lp', p_value=3)
 
     test_2d_pooling('max')
-    test_2d_pooling('avg')
+    test_2d_pooling('avg', count_include_pad=True)
+    test_2d_pooling('avg', count_include_pad=False)
     test_2d_pooling('sum')
     test_2d_pooling('lp', p_value=1)
     test_2d_pooling('lp', p_value=2)
     test_2d_pooling('lp', p_value=3)
 
     test_3d_pooling('max')
-    test_3d_pooling('avg')
+    test_3d_pooling('avg', count_include_pad=True)
+    test_3d_pooling('avg', count_include_pad=False)
     test_3d_pooling('sum')
     test_3d_pooling('lp', p_value=1)
     test_3d_pooling('lp', p_value=2)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index bf1e0deb200..50ecd5c8809 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -470,6 +470,7 @@ def test_pool():
         nn.MaxPool1D(3),
         nn.MaxPool1D(3, 2),
         nn.AvgPool1D(),
+        nn.AvgPool1D(count_include_pad=False),
         nn.GlobalAvgPool1D(),
         ]
     for layer in layers1d:
@@ -481,6 +482,7 @@ def test_pool():
         nn.MaxPool2D((3, 3)),
         nn.MaxPool2D(3, 2),
         nn.AvgPool2D(),
+        nn.AvgPool2D(count_include_pad=False),
         nn.GlobalAvgPool2D(),
         ]
     for layer in layers2d:
@@ -491,6 +493,7 @@ def test_pool():
         nn.MaxPool3D((3, 3, 3)),
         nn.MaxPool3D(3, 2),
         nn.AvgPool3D(),
+        nn.AvgPool3D(count_include_pad=False),
         nn.GlobalAvgPool3D(),
         ]
     for layer in layers3d:


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services