You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/06/18 18:11:08 UTC

[incubator-mxnet] branch master updated: [MXNET-380] count_include_pad argument for Avg Pooling (#11021)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f34dafe  [MXNET-380] count_include_pad argument for Avg Pooling (#11021)
f34dafe is described below

commit f34dafe2e2e96275537b594dc1475a2995a051b6
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Mon Jun 18 11:10:59 2018 -0700

    [MXNET-380] count_include_pad argument for Avg Pooling (#11021)
    
    * add count_include_pad argument
    
    * add cound_include_pad in cudnn pooling and corresponding tests
    
    * add gluon support for the new option
    
    * switch to built-in functions for artificial NaN
    
    * add doc for gluon
    
    * change isAvg/getAvg to is_avg/get_avg
---
 cpp-package/scripts/OpWrapperGenerator.py |  1 +
 python/mxnet/gluon/nn/conv_layers.py      | 22 +++++---
 src/operator/nn/cudnn/cudnn_pooling-inl.h |  8 ++-
 src/operator/nn/mkldnn/mkldnn_pooling.cc  |  6 ++-
 src/operator/nn/pool.cuh                  | 83 +++++++++++++++++++++--------
 src/operator/nn/pool.h                    | 75 +++++++++++++++++---------
 src/operator/nn/pooling-inl.h             | 29 +++++++---
 tests/python/gpu/test_operator_gpu.py     | 88 ++++++++++++++++++++-----------
 tests/python/unittest/test_gluon.py       |  3 ++
 9 files changed, 220 insertions(+), 95 deletions(-)

diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py
index 0c000d9..8facde1 100644
--- a/cpp-package/scripts/OpWrapperGenerator.py
+++ b/cpp-package/scripts/OpWrapperGenerator.py
@@ -77,6 +77,7 @@ class EnumType:
 
 class Arg:
     typeDict = {'boolean':'bool',\
+        'boolean or None':'dmlc::optional<bool>',\
         'Shape(tuple)':'Shape',\
         'Symbol':'Symbol',\
         'NDArray':'Symbol',\
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 2fbf7d8..24f3027 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -675,7 +675,7 @@ class Conv3DTranspose(_Conv):
 class _Pooling(HybridBlock):
     """Abstract class for different pooling layers."""
     def __init__(self, pool_size, strides, padding, ceil_mode, global_pool,
-                 pool_type, **kwargs):
+                 pool_type, count_include_pad=None, **kwargs):
         super(_Pooling, self).__init__(**kwargs)
         if strides is None:
             strides = pool_size
@@ -687,6 +687,8 @@ class _Pooling(HybridBlock):
             'kernel': pool_size, 'stride': strides, 'pad': padding,
             'global_pool': global_pool, 'pool_type': pool_type,
             'pooling_convention': 'full' if ceil_mode else 'valid'}
+        if count_include_pad is not None:
+            self._kwargs['count_include_pad'] = count_include_pad
 
     def _alias(self):
         return 'pool'
@@ -863,6 +865,8 @@ class AvgPool1D(_Pooling):
         respectively. padding is applied on 'W' dimension.
     ceil_mode : bool, default False
         When `True`, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Inputs:
@@ -879,13 +883,13 @@ class AvgPool1D(_Pooling):
           equation.
     """
     def __init__(self, pool_size=2, strides=None, padding=0, layout='NCW',
-                 ceil_mode=False, **kwargs):
+                 ceil_mode=False, count_include_pad=True, **kwargs):
         assert layout == 'NCW', "Only supports 'NCW' layout for now"
         if isinstance(pool_size, numeric_types):
             pool_size = (pool_size,)
         assert len(pool_size) == 1, "pool_size must be a number or a list of 1 ints"
         super(AvgPool1D, self).__init__(
-            pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs)
+            pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs)
 
 
 class AvgPool2D(_Pooling):
@@ -907,6 +911,8 @@ class AvgPool2D(_Pooling):
         dimensions respectively. padding is applied on 'H' and 'W' dimension.
     ceil_mode : bool, default False
         When True, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Inputs:
@@ -926,13 +932,13 @@ class AvgPool2D(_Pooling):
           equation.
     """
     def __init__(self, pool_size=(2, 2), strides=None, padding=0,
-                 ceil_mode=False, layout='NCHW', **kwargs):
+                 ceil_mode=False, layout='NCHW', count_include_pad=True, **kwargs):
         assert layout == 'NCHW', "Only supports 'NCHW' layout for now"
         if isinstance(pool_size, numeric_types):
             pool_size = (pool_size,)*2
         assert len(pool_size) == 2, "pool_size must be a number or a list of 2 ints"
         super(AvgPool2D, self).__init__(
-            pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs)
+            pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs)
 
 
 class AvgPool3D(_Pooling):
@@ -955,6 +961,8 @@ class AvgPool3D(_Pooling):
         dimension.
     ceil_mode : bool, default False
         When True, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Inputs:
@@ -975,13 +983,13 @@ class AvgPool3D(_Pooling):
           equation.
     """
     def __init__(self, pool_size=(2, 2, 2), strides=None, padding=0,
-                 ceil_mode=False, layout='NCDHW', **kwargs):
+                 ceil_mode=False, layout='NCDHW', count_include_pad=True, **kwargs):
         assert layout == 'NCDHW', "Only supports 'NCDHW' layout for now"
         if isinstance(pool_size, numeric_types):
             pool_size = (pool_size,)*3
         assert len(pool_size) == 3, "pool_size must be a number or a list of 3 ints"
         super(AvgPool3D, self).__init__(
-            pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs)
+            pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs)
 
 
 class GlobalMaxPool1D(_Pooling):
diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h
index 84cf640..bc3ee36 100644
--- a/src/operator/nn/cudnn/cudnn_pooling-inl.h
+++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h
@@ -51,7 +51,11 @@ class CuDNNPoolingOp {
         mode_ = CUDNN_POOLING_MAX;
         break;
       case pool_enum::kAvgPooling:
-        mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+        if (param_.count_include_pad.has_value() && !param_.count_include_pad.value()) {
+          mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
+        } else {
+          mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+        }
         break;
       default:
         LOG(FATAL) << "Not implmented";
@@ -263,7 +267,7 @@ class CuDNNPoolingOp {
                                              &(pad_vec[0]),
                                              &(stride_vec[0])));
       #else
-      LOG(FATAL) << "3D pooling only support CUDNN v5 and abouve";
+      LOG(FATAL) << "3D pooling only support CUDNN v5 and above";
       #endif
     }
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 259af2b..9fd88a1 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -121,7 +121,11 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
       return mkldnn::algorithm::pooling_max;
       break;
     case pool_enum::kAvgPooling:
-      return mkldnn::algorithm::pooling_avg_include_padding;
+      if (param.count_include_pad.has_value() && !param.count_include_pad.value()) {
+        return mkldnn::algorithm::pooling_avg_exclude_padding;
+      } else {
+        return mkldnn::algorithm::pooling_avg_include_padding;
+      }
       break;
     default:
       LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method.";
diff --git a/src/operator/nn/pool.cuh b/src/operator/nn/pool.cuh
index 9d004d2..976aacf 100644
--- a/src/operator/nn/pool.cuh
+++ b/src/operator/nn/pool.cuh
@@ -214,16 +214,19 @@ template <typename DType, int p = 1>
 __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, const int channels,
                                        const int width, const int pooled_width, const int kernel_w,
                                        const int stride_w, const int pad_w, DType* out_data,
-                                       const bool getAvg = false) {
+                                       const bool get_avg = false, const bool count_include_pad = true) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     const int pw = index % pooled_width;
     const int c = (index / pooled_width) % channels;
     const int n = index / pooled_width / channels;
     int wstart = pw * stride_w - pad_w;
     int wend = min(wstart + kernel_w, width + pad_w);
-    const int pool_size = (getAvg? (wend - wstart) : 1);
+    int pool_size = (get_avg? (wend - wstart) : 1);
     wstart = max(wstart, 0);
     wend = min(wend, width);
+    if (get_avg && !count_include_pad) {
+      pool_size = (wend - wstart);
+    }
     DType sum = 0;
     const DType* out_slice = in_data + (n * channels + c) * width;
     for (int w = wstart; w < wend; ++w) {
@@ -244,7 +247,8 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data,
                                        const int kernel_h, const int kernel_w,
                                        const int stride_h, const int stride_w,
                                        const int pad_h, const int pad_w, DType* out_data,
-                                       const bool getAvg = false) {
+                                       const bool get_avg = false,
+                                       const bool count_include_pad = true) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     const int pw = index % pooled_width;
     const int ph = (index / pooled_width) % pooled_height;
@@ -254,11 +258,14 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data,
     int wstart = pw * stride_w - pad_w;
     int hend = min(hstart + kernel_h, height + pad_h);
     int wend = min(wstart + kernel_w, width + pad_w);
-    const int pool_size = (getAvg? (hend - hstart) * (wend - wstart) : 1);
+    int pool_size = (get_avg? (hend - hstart) * (wend - wstart) : 1);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
     hend = min(hend, height);
     wend = min(wend, width);
+    if (get_avg && !count_include_pad) {
+      pool_size = (hend - hstart) * (wend - wstart);
+    }
     DType sum = 0;
     const DType* out_slice = in_data + (n * channels + c) * height * width;
     for (int h = hstart; h < hend; ++h) {
@@ -282,7 +289,8 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
                                        const int kernel_h, const int kernel_w,
                                        const int stride_d, const int stride_h, const int stride_w,
                                        const int pad_d, const int pad_h, const int pad_w,
-                                       DType* out_data, const bool getAvg = false) {
+                                       DType* out_data, const bool get_avg = false,
+                                       const bool count_include_pad = true) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     const int pw = index % pooled_width;
     const int ph = (index / pooled_width) % pooled_height;
@@ -295,13 +303,16 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
     int dend = min(dstart + kernel_d, depth + pad_d);
     int hend = min(hstart + kernel_h, height + pad_h);
     int wend = min(wstart + kernel_w, width + pad_w);
-    const int pool_size = (getAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+    int pool_size = (get_avg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
     dstart = max(dstart, 0);
     hstart = max(hstart, 0);
     wstart = max(wstart, 0);
     dend = min(dend, depth);
     hend = min(hend, height);
     wend = min(wend, width);
+    if (get_avg && !count_include_pad) {
+      pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+    }
     DType sum = 0;
     const DType* out_slice = in_data + (n * channels + c) * depth * height * width;
     for (int d = dstart; d < dend; ++d) {
@@ -311,7 +322,9 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
         }
       }
     }
-    out_data[index] = a_root_p<DType, p>::Map(sum);
+    out_data[index] = (pool_size == 0) ?
+                      DType(nanf("")) :
+                      a_root_p<DType, p>::Map(sum);
   }
 }
 
@@ -487,7 +500,8 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr
                                          const int channels, const int width,
                                          const int pooled_width, const int kernel_w,
                                          const int stride_w, const int pad_w, DType* in_grad,
-                                         const bool isAvg = false) {
+                                         const bool is_avg = false,
+                                         const bool count_include_pad = true) {
   // index is the input image index in NCW
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -506,7 +520,12 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr
       // figure out the pooling size
       int wstart = pw * stride_w - pad_w;
       int wend = min(wstart + kernel_w, width + pad_w);
-      int pool_size = (isAvg? (wend - wstart) : 1);
+      int pool_size = (is_avg? (wend - wstart) : 1);
+      if (is_avg && !count_include_pad) {
+        wstart = max(wstart, 0);
+        wend = min(wend, width);
+        pool_size = (wend - wstart);
+      }
       gradient +=
         lp_grad<DType, p>::Map(out_grad_slice[pw], in_data[index], out_data_slice[pw]) / pool_size;
     }
@@ -528,7 +547,8 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr
                                          const int kernel_h, const int kernel_w,
                                          const int stride_h, const int stride_w,
                                          const int pad_h, const int pad_w, DType* in_grad,
-                                         const bool isAvg = false) {
+                                         const bool is_avg = false,
+                                         const bool count_include_pad = true) {
   // index is the input image index in NCHW
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -553,8 +573,15 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr
         int wstart = pw * stride_w - pad_w;
         int hend = min(hstart + kernel_h, height + pad_h);
         int wend = min(wstart + kernel_w, width + pad_w);
-        int pool_size = (isAvg? (hend - hstart) * (wend - wstart) : 1);
+        int pool_size = (is_avg? (hend - hstart) * (wend - wstart) : 1);
         int out_index = ph * pooled_width + pw;
+        if (is_avg && !count_include_pad) {
+          hstart = max(hstart, 0);
+          wstart = max(wstart, 0);
+          hend = min(hend, height);
+          wend = min(wend, width);
+          pool_size = (hend - hstart) * (wend - wstart);
+        }
         gradient +=
           lp_grad<DType, p>::Map(out_grad_slice[out_index],
                                  in_data[index],
@@ -580,7 +607,8 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr
                                          const int kernel_d, const int kernel_h,
                                          const int kernel_w, const int stride_d, const int stride_h,
                                          const int stride_w, const int pad_d, const int pad_h,
-                                         const int pad_w, DType* in_grad, const bool isAvg = false) {
+                                         const int pad_w, DType* in_grad, const bool is_avg = false,
+                                         const bool count_include_pad = true) {
   // index is the input image index in NCDHW
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
@@ -611,8 +639,17 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr
           int dend = min(dstart + kernel_d, depth + pad_d);
           int hend = min(hstart + kernel_h, height + pad_h);
           int wend = min(wstart + kernel_w, width + pad_w);
-          int pool_size = (isAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+          int pool_size = (is_avg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
           int out_index = (pd * pooled_height + ph) * pooled_width + pw;
+          if (is_avg && !count_include_pad) {
+            dstart = max(dstart, 0);
+            hstart = max(hstart, 0);
+            wstart = max(wstart, 0);
+            dend = min(dend, depth);
+            hend = min(hend, height);
+            wend = min(wend, width);
+            pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+          }
           gradient += lp_grad<DType, p>::Map(out_grad_slice[out_index],
                                              in_data[index],
                                              out_data_slice[out_index]) / pool_size;
@@ -643,7 +680,7 @@ template<typename DType, int p>
 inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& ishape,
                  const TShape& oshape, const TShape& kernel, const TShape& pad,
                  const TShape& stride, const int pool_type, OpReqType req_type,
-                 DType* out_data) {
+                 DType* out_data, const bool count_include_pad) {
   CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations";
   using namespace mxnet_op;
   if (kernel.ndim() == 1) {
@@ -659,7 +696,8 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
       pool_sum_1d_gpu_kernel<<<cuda_get_num_blocks(oshape.Size()), mshadow::cuda::kBaseThreadNum,
                                0, mshadow::Stream<gpu>::GetStream(s)>>>(
                                    oshape.Size(), in_data, ishape[1], ishape[2], oshape[2],
-                                   kernel[0], stride[0], pad[0], out_data, true);
+                                   kernel[0], stride[0], pad[0], out_data,
+                                   true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_1d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -693,7 +731,8 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
                                0, mshadow::Stream<gpu>::GetStream(s)>>>(
                                    oshape.Size(), in_data, ishape[1], ishape[2], ishape[3],
                                    oshape[2], oshape[3], kernel[0], kernel[1],
-                                   stride[0], stride[1], pad[0], pad[1], out_data, true);
+                                   stride[0], stride[1], pad[0], pad[1], out_data,
+                                   true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_2d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -731,7 +770,7 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
                                    oshape.Size(), in_data, ishape[1], ishape[2], ishape[3],
                                    ishape[4], oshape[2], oshape[3], oshape[4], kernel[0],
                                    kernel[1], kernel[2], stride[0], stride[1], stride[2],
-                                   pad[0], pad[1], pad[2], out_data, true);
+                                   pad[0], pad[1], pad[2], out_data, true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_3d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -777,7 +816,8 @@ template<typename DType, int p>
 inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType* in_data,
                    const DType* out_data, const TShape& ishape, const TShape& oshape,
                    const TShape& kernel, const TShape& pad, const TShape& stride,
-                   const int pool_type, OpReqType req_type, DType* in_grad) {
+                   const int pool_type, OpReqType req_type, DType* in_grad,
+                   const bool count_include_pad) {
   if (mxnet::kNullOp == req_type) return;
   if (mxnet::kAddTo != req_type) {
     mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, ishape.Size(), in_grad);
@@ -798,7 +838,7 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
                                  0, mshadow::Stream<gpu>::GetStream(s)>>>(
                                      ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], oshape[2], kernel[0],
-                                     stride[0], pad[0], in_grad, true);
+                                     stride[0], pad[0], in_grad, true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -836,7 +876,8 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
                                      ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], ishape[3],
                                      oshape[2], oshape[3], kernel[0], kernel[1],
-                                     stride[0], stride[1], pad[0], pad[1], in_grad, true);
+                                     stride[0], stride[1], pad[0], pad[1], in_grad,
+                                     true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_2d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
@@ -878,7 +919,7 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
                                      ishape[1], ishape[2], ishape[3], ishape[4],
                                      oshape[2], oshape[3], oshape[4], kernel[0], kernel[1],
                                      kernel[2], stride[0], stride[1], stride[2], pad[0], pad[1],
-                                     pad[2], in_grad, true);
+                                     pad[2], in_grad, true, count_include_pad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_3d_gpu_kernel);
     } else if (pool_enum::kSumPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h
index 9fe43b2..8f7a5ed 100644
--- a/src/operator/nn/pool.h
+++ b/src/operator/nn/pool.h
@@ -216,7 +216,8 @@ inline void pool_max_3d_cpu(const DType* in_data, const TShape& ishape, const TS
 template<typename DType, int p = 1>
 inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, const bool getAvg = false) {
+                            DType* out_data,
+                            const bool get_avg = false, const bool count_include_pad = true) {
   const int width = ishape[2];
   const int pooled_width = oshape[2];
   const int kernel_w = kernel[0];
@@ -229,9 +230,12 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS
       for (int pw = 0; pw < pooled_width; ++pw) {
         int wstart = pw * stride_w - pad_w;
         int wend = std::min(wstart + kernel_w, width + pad_w);
-        int pool_size = (getAvg ? (wend - wstart) : 1);
+        int pool_size = (get_avg ? (wend - wstart) : 1);
         wstart = std::max(wstart, 0);
         wend = std::min(wend, width);
+        if (get_avg && !count_include_pad) {
+          pool_size = (wend - wstart);
+        }
         DType sum = 0;
         for (int w = wstart; w < wend; ++w) {
           sum += a_pow_p<DType, p>::Map(in_data[w]) / pool_size;
@@ -251,7 +255,8 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS
 template<typename DType, int p = 1>
 inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, const bool getAvg = false) {
+                            DType* out_data,
+                            const bool get_avg = false, const bool count_include_pad = true) {
   const int height = ishape[2], width = ishape[3];
   const int pooled_height = oshape[2], pooled_width = oshape[3];
   const int kernel_h = kernel[0], kernel_w = kernel[1];
@@ -267,11 +272,14 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS
           int wstart = pw * stride_w - pad_w;
           int hend = std::min(hstart + kernel_h, height + pad_h);
           int wend = std::min(wstart + kernel_w, width + pad_w);
-          int pool_size = (getAvg ? (hend - hstart) * (wend - wstart) : 1);
+          int pool_size = (get_avg ? (hend - hstart) * (wend - wstart) : 1);
           hstart = std::max(hstart, 0);
           wstart = std::max(wstart, 0);
           hend = std::min(hend, height);
           wend = std::min(wend, width);
+          if (get_avg && !count_include_pad) {
+            pool_size = (hend - hstart) * (wend - wstart);
+          }
           DType sum = 0;
           for (int h = hstart; h < hend; ++h) {
             for (int w = wstart; w < wend; ++w) {
@@ -294,7 +302,8 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS
 template<typename DType, int p = 1>
 inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, const bool getAvg = false) {
+                            DType* out_data,
+                            const bool get_avg = false, const bool count_include_pad = true) {
   const int depth = ishape[2], height = ishape[3], width = ishape[4];
   const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4];
   const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2];
@@ -313,13 +322,16 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS
             int dend = std::min(dstart + kernel_d, depth + pad_d);
             int hend = std::min(hstart + kernel_h, height + pad_h);
             int wend = std::min(wstart + kernel_w, width + pad_w);
-            int pool_size = (getAvg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+            int pool_size = (get_avg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
             dstart = std::max(dstart, 0);
             hstart = std::max(hstart, 0);
             wstart = std::max(wstart, 0);
             dend = std::min(dend, depth);
             hend = std::min(hend, height);
             wend = std::min(wend, width);
+            if (get_avg && !count_include_pad) {
+              pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+            }
             DType sum = 0;
             for (int d = dstart; d < dend; ++d) {
               for (int h = hstart; h < hend; ++h) {
@@ -328,7 +340,9 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS
                 }
               }
             }
-            out_data[(pd*pooled_height+ph)*pooled_width+pw] = a_root_p<DType, p>::Map(sum);
+            out_data[(pd*pooled_height+ph)*pooled_width+pw] = (pool_size == 0) ?
+                                                              DType(nanf("")) :
+                                                              a_root_p<DType, p>::Map(sum);
           }
         }
       }
@@ -509,8 +523,8 @@ inline void unpool_max_3d_cpu(const DType* out_grad, const DType* in_data,
 template<typename DType, int p = 1>
 inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
                               const TShape& ishape, const TShape& oshape, const TShape& kernel,
-                              const TShape& pad, const TShape& stride,
-                              DType* in_grad, const bool isAvg = false) {
+                              const TShape& pad, const TShape& stride, DType* in_grad,
+                              const bool is_avg = false, const bool count_include_pad = true) {
   const int width = ishape[2];
   const int pooled_width = oshape[2];
   const int kernel_w = kernel[0];
@@ -523,9 +537,12 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const
       for (int pw = 0; pw < pooled_width; ++pw) {
         int wstart = pw * stride_w - pad_w;
         int wend = std::min(wstart + kernel_w, width + pad_w);
-        int pool_size = (isAvg ? (wend - wstart) : 1);
+        int pool_size = (is_avg ? (wend - wstart) : 1);
         wstart = std::max(wstart, 0);
         wend = std::min(wend, width);
+        if (is_avg && !count_include_pad) {
+          pool_size = (wend - wstart);
+        }
         for (int w = wstart; w < wend; ++w) {
           in_grad[w] += lp_grad<DType, p>::Map(out_grad[pw], in_data[w], out_data[pw]) / pool_size;
         }
@@ -545,8 +562,8 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const
 template<typename DType, int p = 1>
 inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
                               const TShape& ishape, const TShape& oshape, const TShape& kernel,
-                              const TShape& pad, const TShape& stride,
-                              DType* in_grad, const bool isAvg = false) {
+                              const TShape& pad, const TShape& stride, DType* in_grad,
+                              const bool is_avg = false, const bool count_include_pad = true) {
   const int height = ishape[2], width = ishape[3];
   const int pooled_height = oshape[2], pooled_width = oshape[3];
   const int kernel_h = kernel[0], kernel_w = kernel[1];
@@ -562,11 +579,14 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const
           int wstart = pw * stride_w - pad_w;
           int hend = std::min(hstart + kernel_h, height + pad_h);
           int wend = std::min(wstart + kernel_w, width + pad_w);
-          int pool_size = (isAvg ? (hend - hstart) * (wend - wstart) : 1);
+          int pool_size = (is_avg ? (hend - hstart) * (wend - wstart) : 1);
           hstart = std::max(hstart, 0);
           wstart = std::max(wstart, 0);
           hend = std::min(hend, height);
           wend = std::min(wend, width);
+          if (is_avg && !count_include_pad) {
+            pool_size = (hend - hstart) * (wend - wstart);
+          }
           const int pool_index = ph * pooled_width + pw;
           for (int h = hstart; h < hend; ++h) {
             for (int w = wstart; w < wend; ++w) {
@@ -593,8 +613,8 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const
 template<typename DType, int p = 1>
 inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
                               const TShape& ishape, const TShape& oshape, const TShape& kernel,
-                              const TShape& pad, const TShape& stride,
-                              DType* in_grad, const bool isAvg = false) {
+                              const TShape& pad, const TShape& stride, DType* in_grad,
+                              const bool is_avg = false, const bool count_include_pad = true) {
   const int depth = ishape[2], height = ishape[3], width = ishape[4];
   const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4];
   const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2];
@@ -613,13 +633,16 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const
             int dend = std::min(dstart + kernel_d, depth + pad_d);
             int hend = std::min(hstart + kernel_h, height + pad_h);
             int wend = std::min(wstart + kernel_w, width + pad_w);
-            int pool_size = (isAvg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+            int pool_size = (is_avg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
             dstart = std::max(dstart, 0);
             hstart = std::max(hstart, 0);
             wstart = std::max(wstart, 0);
             dend = std::min(dend, depth);
             hend = std::min(hend, height);
             wend = std::min(wend, width);
+            if (is_avg && !count_include_pad) {
+              pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+            }
             const int pool_index = (pd * pooled_height + ph) * pooled_width + pw;
             for (int d = dstart; d < dend; ++d) {
               for (int h = hstart; h < hend; ++h) {
@@ -660,13 +683,14 @@ template<typename DType, int p>
 inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& ishape,
                  const TShape& oshape, const TShape& kernel, const TShape& pad,
                  const TShape& stride, const int pool_type, OpReqType req_type,
-                 DType* out_data) {
+                 DType* out_data, const bool count_include_pad) {
   CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations";
   if (kernel.ndim() == 1) {
     if (pool_enum::kMaxPooling == pool_type) {
       pool_max_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
+      pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data,
+                      true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -678,7 +702,8 @@ inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& is
     if (pool_enum::kMaxPooling == pool_type) {
       pool_max_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
+      pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data,
+                      true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -690,7 +715,8 @@ inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& is
     if (pool_enum::kMaxPooling == pool_type) {
       pool_max_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
+      pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data,
+                      true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -723,7 +749,8 @@ template<typename DType, int p>
 inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType* in_data,
                    const DType* out_data, const TShape& ishape, const TShape& oshape,
                    const TShape& kernel, const TShape& pad, const TShape& stride,
-                   const int pool_type, OpReqType req_type, DType* in_grad, const int p_value = 2) {
+                   const int pool_type, OpReqType req_type, DType* in_grad,
+                   const bool count_include_pad) {
   if (mxnet::kNullOp == req_type) return;
   if (mxnet::kAddTo != req_type) {
     mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, ishape.Size(), in_grad);
@@ -733,7 +760,7 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
       unpool_max_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
       unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
-                        true);
+                        true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -747,7 +774,7 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
       unpool_max_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
       unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
-                        true);
+                        true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kLpPooling == pool_type) {
@@ -761,7 +788,7 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
       unpool_max_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
       unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
-                        true);
+                        true, count_include_pad);
     } else if (pool_enum::kSumPooling == pool_type) {
       unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kLpPooling == pool_type) {
diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h
index 9c7b1af..ad74a8f 100644
--- a/src/operator/nn/pooling-inl.h
+++ b/src/operator/nn/pooling-inl.h
@@ -52,6 +52,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
   bool global_pool;
   bool cudnn_off;
   dmlc::optional<int> p_value;
+  dmlc::optional<bool> count_include_pad;
   DMLC_DECLARE_PARAMETER(PoolingParam) {
     DMLC_DECLARE_FIELD(kernel).set_default(TShape())  // add default value here
     .enforce_nonzero()
@@ -83,7 +84,13 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
     .describe("Pad for pooling: (y, x) or (d, y, x). Defaults to no padding.");
 
     DMLC_DECLARE_FIELD(p_value).set_default(dmlc::optional<int>())
-    .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling");
+    .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling.");
+
+    DMLC_DECLARE_FIELD(count_include_pad).set_default(dmlc::optional<bool>())
+    .describe("Only used for AvgPool, specify whether to count padding elements for average"
+              "calculation. For example, with a 5*5 kernel on a 3*3 corner of a image,"
+              "the sum of the 9 valid elements will be divided by 25 if this is set to true,"
+              "or it will be divided by 9 if this is set to false. Defaults to true.");
   }
 
   bool operator==(const PoolingParam& other) const {
@@ -94,7 +101,8 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
            this->pooling_convention == other.pooling_convention &&
            this->global_pool        == other.global_pool &&
            this->cudnn_off          == other.cudnn_off &&
-           this->p_value            == other.p_value;
+           this->p_value            == other.p_value &&
+           this->count_include_pad  == other.count_include_pad;
   }
 };
 
@@ -114,6 +122,7 @@ struct hash<mxnet::op::PoolingParam> {
     ret = dmlc::HashCombine(ret, val.global_pool);
     ret = dmlc::HashCombine(ret, val.cudnn_off);
     ret = dmlc::HashCombine(ret, val.p_value);
+    ret = dmlc::HashCombine(ret, val.count_include_pad);
     return ret;
   }
 };
@@ -155,27 +164,29 @@ class PoolingOp {
     }
     const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ?
                         param_.p_value.value() : 1;
+    const bool count_include_pad = (param_.count_include_pad.has_value()) ?
+                                   param_.count_include_pad.value() : true;
     switch (p_value) {
       case 1:
         pool<DType, 1>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
           kernel,
           padding,
           stride,
-          param_.pool_type, req, out_data.dptr<DType>());
+          param_.pool_type, req, out_data.dptr<DType>(), count_include_pad);
         break;
       case 2:
         pool<DType, 2>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
           kernel,
           padding,
           stride,
-          param_.pool_type, req, out_data.dptr<DType>());
+          param_.pool_type, req, out_data.dptr<DType>(), count_include_pad);
         break;
       case 3:
         pool<DType, 3>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
           kernel,
           padding,
           stride,
-          param_.pool_type, req, out_data.dptr<DType>());
+          param_.pool_type, req, out_data.dptr<DType>(), count_include_pad);
         break;
       default:
         LOG(FATAL) << "p value of " << p_value << " is not supported yet...";
@@ -203,6 +214,8 @@ class PoolingOp {
 
     const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ?
                         param_.p_value.value() : 1;
+    const bool count_include_pad = (param_.count_include_pad.has_value()) ?
+                                   param_.count_include_pad.value() : true;
     switch (p_value) {
       case 1:
         unpool<DType, 1>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
@@ -210,7 +223,7 @@ class PoolingOp {
            kernel,
            padding,
            stride,
-           param_.pool_type, req, in_grad.dptr<DType>());
+           param_.pool_type, req, in_grad.dptr<DType>(), count_include_pad);
         break;
       case 2:
         unpool<DType, 2>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
@@ -218,7 +231,7 @@ class PoolingOp {
            kernel,
            padding,
            stride,
-           param_.pool_type, req, in_grad.dptr<DType>());
+           param_.pool_type, req, in_grad.dptr<DType>(), count_include_pad);
         break;
       case 3:
         unpool<DType, 3>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
@@ -226,7 +239,7 @@ class PoolingOp {
            kernel,
            padding,
            stride,
-           param_.pool_type, req, in_grad.dptr<DType>());
+           param_.pool_type, req, in_grad.dptr<DType>(), count_include_pad);
         break;
       default:
         LOG(FATAL) << "p value of " << p_value << " is not supported yet...";
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 9d33541..ed4aaa4 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -740,8 +740,8 @@ def test_pooling_with_type():
 
 @with_seed()
 def test_pooling_versions():
-    def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride,
-                                     pooling_convention='valid', global_pool=False, p_value=2):
+    def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride, pooling_convention='valid',
+                                     global_pool=False, p_value=2, count_include_pad=True, tol=None):
         ctx_list = []
         sym_list = []
         # PoolingV1 cpu
@@ -765,61 +765,69 @@ def test_pooling_versions():
             ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, name='pool', p_value=p_value))
+                                               pooling_convention=pooling_convention, name='pool',
+                                               p_value=p_value, count_include_pad=count_include_pad))
             else:
-                sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool', p_value=p_value))
+                sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool',
+                                               p_value=p_value, count_include_pad=count_include_pad))
         # Pooling gpu
         if 'pool_gpu' in pool_op_list:
             ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, cudnn_off=True, name='pool', p_value=p_value))
+                                               pooling_convention=pooling_convention, cudnn_off=True, name='pool',
+                                               p_value=p_value, count_include_pad=count_include_pad))
             else:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, cudnn_off=True,
-                                               name='pool', p_value=p_value))
+                                               name='pool', p_value=p_value, count_include_pad=count_include_pad))
         # CuDNNPooling
         if 'pool_cudnn' in pool_op_list:
             ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, p_value=p_value, cudnn_off=False, name='pool'))
+                                               pooling_convention=pooling_convention, p_value=p_value, cudnn_off=False,
+                                               name='pool', count_include_pad=count_include_pad))
             else:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, p_value=p_value,
-                                               cudnn_off=False, name='pool'))
-        check_consistency(sym_list, ctx_list)
+                                               cudnn_off=False, name='pool', count_include_pad=count_include_pad))
+        check_consistency(sym_list, ctx_list, equal_nan=(not count_include_pad), tol=tol)
 
-    def test_1d_pooling(pool_type, p_value=2):
+    def test_1d_pooling(pool_type, p_value=2, count_include_pad=True):
         data = (2, 3, 20)
         kernel = (4,)
         pad = (0,)
         stride = (1,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (2,)
         stride = (2,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (0,)
         stride = (1,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (2,)
         stride = (2,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     global_pool=True, p_value=p_value)
+                                     global_pool=True, p_value=p_value, count_include_pad=count_include_pad)
 
-    def test_2d_pooling(pool_type, p_value=2):
+    def test_2d_pooling(pool_type, p_value=2, count_include_pad=True):
         data = (2, 3, 20, 20)
         kernel = (4, 5)
         pad = (0, 0)
@@ -831,14 +839,15 @@ def test_pooling_versions():
         else:
             test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                          data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                         pooling_convention='valid', global_pool=False)
+                                         pooling_convention='valid', global_pool=False, count_include_pad=count_include_pad)
 
         # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here
         pad = (2, 3)
         stride = (2, 3)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (0, 0)
         stride = (1, 1)
@@ -847,16 +856,24 @@ def test_pooling_versions():
                                          data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
                                          pooling_convention='full', global_pool=False, p_value=p_value)
         else:
-            test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
-                                         data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                         pooling_convention='full', global_pool=False)
+            if count_include_pad:
+                test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                             data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                             pooling_convention='full', global_pool=False,
+                                             count_include_pad=count_include_pad)
+            else:
+                test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                             data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                             pooling_convention='full', global_pool=False,
+                                             count_include_pad=count_include_pad)
 
         # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here
         pad = (2, 3)
         stride = (2, 3)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         if pool_type == 'lp':
             test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
@@ -865,55 +882,62 @@ def test_pooling_versions():
         else:
             test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                          data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                         global_pool=True)
+                                         global_pool=True, count_include_pad=count_include_pad)
 
-    def test_3d_pooling(pool_type, p_value=2):
+    def test_3d_pooling(pool_type, p_value=2, count_include_pad=True):
         data = (2, 3, 20, 20, 20)
         kernel = (4, 5, 3)
         pad = (0, 0, 0)
         stride = (1, 1, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (2, 3, 3)
         stride = (2, 3, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False, p_value=p_value)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (0, 0, 0)
         stride = (1, 1, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         pad = (2, 3, 3)
         stride = (2, 3, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False, p_value=p_value)
+                                     pooling_convention='full', global_pool=False, p_value=p_value,
+                                     count_include_pad=count_include_pad)
 
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     global_pool=True, p_value=p_value)
+                                     global_pool=True, p_value=p_value, count_include_pad=count_include_pad)
 
     test_1d_pooling('max')
-    test_1d_pooling('avg')
+    test_1d_pooling('avg', count_include_pad=True)
+    test_1d_pooling('avg', count_include_pad=False)
     test_1d_pooling('sum')
     test_1d_pooling('lp', p_value=1)
     test_1d_pooling('lp', p_value=2)
     test_1d_pooling('lp', p_value=3)
 
     test_2d_pooling('max')
-    test_2d_pooling('avg')
+    test_2d_pooling('avg', count_include_pad=True)
+    test_2d_pooling('avg', count_include_pad=False)
     test_2d_pooling('sum')
     test_2d_pooling('lp', p_value=1)
     test_2d_pooling('lp', p_value=2)
     test_2d_pooling('lp', p_value=3)
 
     test_3d_pooling('max')
-    test_3d_pooling('avg')
+    test_3d_pooling('avg', count_include_pad=True)
+    test_3d_pooling('avg', count_include_pad=False)
     test_3d_pooling('sum')
     test_3d_pooling('lp', p_value=1)
     test_3d_pooling('lp', p_value=2)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 451fde2..e9259fd 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -470,6 +470,7 @@ def test_pool():
         nn.MaxPool1D(3),
         nn.MaxPool1D(3, 2),
         nn.AvgPool1D(),
+        nn.AvgPool1D(count_include_pad=False),
         nn.GlobalAvgPool1D(),
         ]
     for layer in layers1d:
@@ -481,6 +482,7 @@ def test_pool():
         nn.MaxPool2D((3, 3)),
         nn.MaxPool2D(3, 2),
         nn.AvgPool2D(),
+        nn.AvgPool2D(count_include_pad=False),
         nn.GlobalAvgPool2D(),
         ]
     for layer in layers2d:
@@ -491,6 +493,7 @@ def test_pool():
         nn.MaxPool3D((3, 3, 3)),
         nn.MaxPool3D(3, 2),
         nn.AvgPool3D(),
+        nn.AvgPool3D(count_include_pad=False),
         nn.GlobalAvgPool3D(),
         ]
     for layer in layers3d:

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.