You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/18 17:49:45 UTC

[GitHub] eric-haibin-lin closed pull request #10780: [MXNET-375] Lp Pooling and Global Lp Pooling

eric-haibin-lin closed pull request #10780: [MXNET-375] Lp Pooling and Global Lp Pooling
URL: https://github.com/apache/incubator-mxnet/pull/10780
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/pool.cuh b/src/operator/nn/pool.cuh
index 0e9cff0c51e..9d004d295be 100644
--- a/src/operator/nn/pool.cuh
+++ b/src/operator/nn/pool.cuh
@@ -80,7 +80,9 @@
 
 #include <mxnet/base.h>
 #include <mxnet/operator.h>
+#include "./pool_utils.h"
 #include "../mxnet_op.h"
+#include "../mshadow_op.h"
 #include "../../common/cuda_utils.h"
 
 namespace mxnet {
@@ -208,27 +210,26 @@ __global__ void pool_max_3d_gpu_kernel(const int nthreads, const DType* in_data,
  * \brief avg/sum pooling gpu kernel for 1-D images.
  * Do not call this kernel directly. Use the interface pool().
  */
-template <typename DType>
+template <typename DType, int p = 1>
 __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, const int channels,
                                        const int width, const int pooled_width, const int kernel_w,
-                                       const int stride_w, const int pad_w,
-                                       DType* out_data, bool getAvg = false) {
+                                       const int stride_w, const int pad_w, DType* out_data,
+                                       const bool getAvg = false) {
   CUDA_KERNEL_LOOP(index, nthreads) {
-	  const int pw = index % pooled_width;
-	  const int c = (index / pooled_width) % channels;
-	  const int n = index / pooled_width / channels;
-	  int wstart = pw * stride_w - pad_w;
-	  int wend = min(wstart + kernel_w, width + pad_w);
-	  const int pool_size = (getAvg? (wend - wstart) : 1);
-	  wstart = max(wstart, 0);
-	  wend = min(wend, width);
-	  DType sum = 0;
-	  const DType* out_slice =
-	 		in_data + (n * channels + c) * width;
+    const int pw = index % pooled_width;
+    const int c = (index / pooled_width) % channels;
+    const int n = index / pooled_width / channels;
+    int wstart = pw * stride_w - pad_w;
+    int wend = min(wstart + kernel_w, width + pad_w);
+    const int pool_size = (getAvg? (wend - wstart) : 1);
+    wstart = max(wstart, 0);
+    wend = min(wend, width);
+    DType sum = 0;
+    const DType* out_slice = in_data + (n * channels + c) * width;
     for (int w = wstart; w < wend; ++w) {
-      sum += out_slice[w];
+      sum += a_pow_p<DType, p>::Map(out_slice[w]) / pool_size;
     }
-    out_data[index] = sum / pool_size;
+    out_data[index] = a_root_p<DType, p>::Map(sum);
   }
 }
 
@@ -236,37 +237,36 @@ __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data,
  * \brief avg/sum pooling gpu kernel for 2-D images.
  * Do not call this kernel directly. Use the interface pool().
  */
-template <typename DType>
+template <typename DType, int p = 1>
 __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, const int channels,
                                        const int height, const int width,
                                        const int pooled_height, const int pooled_width,
                                        const int kernel_h, const int kernel_w,
                                        const int stride_h, const int stride_w,
-                                       const int pad_h, const int pad_w,
-                                       DType* out_data, bool getAvg = false) {
+                                       const int pad_h, const int pad_w, DType* out_data,
+                                       const bool getAvg = false) {
   CUDA_KERNEL_LOOP(index, nthreads) {
-	  const int pw = index % pooled_width;
-	  const int ph = (index / pooled_width) % pooled_height;
-	  const int c = (index / pooled_width / pooled_height) % channels;
-	  const int n = index / pooled_width / pooled_height / channels;
-	  int hstart = ph * stride_h - pad_h;
-	  int wstart = pw * stride_w - pad_w;
-	  int hend = min(hstart + kernel_h, height + pad_h);
-	  int wend = min(wstart + kernel_w, width + pad_w);
-	  const int pool_size = (getAvg? (hend - hstart) * (wend - wstart) : 1);
-	  hstart = max(hstart, 0);
-	  wstart = max(wstart, 0);
-	  hend = min(hend, height);
-	  wend = min(wend, width);
-	  DType sum = 0;
-	  const DType* out_slice =
-	 		in_data + (n * channels + c) * height * width;
-	  for (int h = hstart; h < hend; ++h) {
-		  for (int w = wstart; w < wend; ++w) {
-		    sum += out_slice[h * width + w];
-		  }
-	  }
-    out_data[index] = sum / pool_size;
+    const int pw = index % pooled_width;
+    const int ph = (index / pooled_width) % pooled_height;
+    const int c = (index / pooled_width / pooled_height) % channels;
+    const int n = index / pooled_width / pooled_height / channels;
+    int hstart = ph * stride_h - pad_h;
+    int wstart = pw * stride_w - pad_w;
+    int hend = min(hstart + kernel_h, height + pad_h);
+    int wend = min(wstart + kernel_w, width + pad_w);
+    const int pool_size = (getAvg? (hend - hstart) * (wend - wstart) : 1);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
+    hend = min(hend, height);
+    wend = min(wend, width);
+    DType sum = 0;
+    const DType* out_slice = in_data + (n * channels + c) * height * width;
+    for (int h = hstart; h < hend; ++h) {
+      for (int w = wstart; w < wend; ++w) {
+        sum += a_pow_p<DType, p>::Map(out_slice[h * width + w]) / pool_size;
+      }
+    }
+    out_data[index] = a_root_p<DType, p>::Map(sum);
   }
 }
 
@@ -274,7 +274,7 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data,
  * \brief avg/sum pooling gpu kernel for 3-D images.
  * Do not call this kernel directly. Use the interface pool().
  */
-template <typename DType>
+template <typename DType, int p = 1>
 __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, const int channels,
                                        const int depth, const int height, const int width,
                                        const int pooled_depth, const int pooled_height,
@@ -282,37 +282,36 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
                                        const int kernel_h, const int kernel_w,
                                        const int stride_d, const int stride_h, const int stride_w,
                                        const int pad_d, const int pad_h, const int pad_w,
-                                       DType* out_data, bool getAvg = false) {
+                                       DType* out_data, const bool getAvg = false) {
   CUDA_KERNEL_LOOP(index, nthreads) {
-	  const int pw = index % pooled_width;
-	  const int ph = (index / pooled_width) % pooled_height;
+    const int pw = index % pooled_width;
+    const int ph = (index / pooled_width) % pooled_height;
     const int pd = (index / pooled_width / pooled_height) % pooled_depth;
-	  const int c = (index / pooled_width / pooled_height / pooled_depth) % channels;
-	  const int n = index / pooled_width / pooled_height / pooled_depth / channels;
+    const int c = (index / pooled_width / pooled_height / pooled_depth) % channels;
+    const int n = index / pooled_width / pooled_height / pooled_depth / channels;
     int dstart = pd * stride_d - pad_d;
-	  int hstart = ph * stride_h - pad_h;
-	  int wstart = pw * stride_w - pad_w;
+    int hstart = ph * stride_h - pad_h;
+    int wstart = pw * stride_w - pad_w;
     int dend = min(dstart + kernel_d, depth + pad_d);
-	  int hend = min(hstart + kernel_h, height + pad_h);
-	  int wend = min(wstart + kernel_w, width + pad_w);
-	  const int pool_size = (getAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
+    int hend = min(hstart + kernel_h, height + pad_h);
+    int wend = min(wstart + kernel_w, width + pad_w);
+    const int pool_size = (getAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
     dstart = max(dstart, 0);
-	  hstart = max(hstart, 0);
-	  wstart = max(wstart, 0);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
     dend = min(dend, depth);
-	  hend = min(hend, height);
-	  wend = min(wend, width);
-	  DType sum = 0;
-	  const DType* out_slice =
-	 		in_data + (n * channels + c) * depth * height * width;
+    hend = min(hend, height);
+    wend = min(wend, width);
+    DType sum = 0;
+    const DType* out_slice = in_data + (n * channels + c) * depth * height * width;
     for (int d = dstart; d < dend; ++d) {
       for (int h = hstart; h < hend; ++h) {
         for (int w = wstart; w < wend; ++w) {
-          sum += out_slice[(d * height + h) * width + w];
+          sum += a_pow_p<DType, p>::Map(out_slice[(d * height + h) * width + w]) / pool_size;
         }
       }
     }
-    out_data[index] = sum / pool_size;
+    out_data[index] = a_root_p<DType, p>::Map(sum);
   }
 }
 
@@ -482,34 +481,38 @@ __global__ void unpool_max_3d_gpu_kernel(const int nthreads, const DType* out_gr
  * \brief avg/sum unpooling gpu kernel for 1-D images.
  * Do not call this kernel directly. Use the interface unpool().
  */
-template<typename DType>
+template<typename DType, int p = 1>
 __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_grad,
+                                         const DType* in_data, const DType* out_data,
                                          const int channels, const int width,
                                          const int pooled_width, const int kernel_w,
-                                         const int stride_w, const int pad_w,
-                                         DType* in_grad, bool isAvg = false) {
+                                         const int stride_w, const int pad_w, DType* in_grad,
+                                         const bool isAvg = false) {
   // index is the input image index in NCW
   CUDA_KERNEL_LOOP(index, nthreads) {
-	  // find out the local index
-	  // find out the local offset
-	  const int w = index % width + pad_w;
-	  const int c = (index / width) % channels;
-	  const int n = index / width / channels;
-	  const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-	  const int pwend = min(w / stride_w + 1, pooled_width);
-	  DType gradient = 0;
-	  const DType* out_grad_slice =
+    // find out the local index
+    // find out the local offset
+    const int w = index % width + pad_w;
+    const int c = (index / width) % channels;
+    const int n = index / width / channels;
+    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
+    const int pwend = min(w / stride_w + 1, pooled_width);
+    DType gradient = 0;
+    const DType* out_grad_slice =
       out_grad + (n * channels + c) * pooled_width;
+    const DType* out_data_slice =
+      out_data + (n * channels + c) * pooled_width;
     for (int pw = pwstart; pw < pwend; ++pw) {
       // figure out the pooling size
       int wstart = pw * stride_w - pad_w;
       int wend = min(wstart + kernel_w, width + pad_w);
       int pool_size = (isAvg? (wend - wstart) : 1);
-      gradient += out_grad_slice[pw] / pool_size;
+      gradient +=
+        lp_grad<DType, p>::Map(out_grad_slice[pw], in_data[index], out_data_slice[pw]) / pool_size;
     }
     // if req=kWriteTo, in_grad has already been assigned zero values in unpool()
     // use "+=" here instead of "=" to accommodate when req=kAddTo
-	  in_grad[index] += gradient;
+    in_grad[index] += gradient;
   }
 }
 
@@ -517,43 +520,50 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr
  * \brief avg/sum unpooling gpu kernel for 2-D images.
  * Do not call this kernel directly. Use the interface unpool().
  */
-template<typename DType>
+template<typename DType, int p = 1>
 __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_grad,
+                                         const DType* in_data, const DType* out_data,
                                          const int channels, const int height, const int width,
                                          const int pooled_height, const int pooled_width,
                                          const int kernel_h, const int kernel_w,
                                          const int stride_h, const int stride_w,
-                                         const int pad_h, const int pad_w,
-                                         DType* in_grad, bool isAvg = false) {
+                                         const int pad_h, const int pad_w, DType* in_grad,
+                                         const bool isAvg = false) {
   // index is the input image index in NCHW
   CUDA_KERNEL_LOOP(index, nthreads) {
-	  // find out the local index
-	  // find out the local offset
-	  const int w = index % width + pad_w;
-	  const int h = (index / width) % height + pad_h;
-	  const int c = (index / width / height) % channels;
-	  const int n = index / width / height / channels;
-	  const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-	  const int phend = min(h / stride_h + 1, pooled_height);
-	  const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-	  const int pwend = min(w / stride_w + 1, pooled_width);
-	  DType gradient = 0;
-	  const DType* out_grad_slice =
+    // find out the local index
+    // find out the local offset
+    const int w = index % width + pad_w;
+    const int h = (index / width) % height + pad_h;
+    const int c = (index / width / height) % channels;
+    const int n = index / width / height / channels;
+    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
+    const int phend = min(h / stride_h + 1, pooled_height);
+    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
+    const int pwend = min(w / stride_w + 1, pooled_width);
+    DType gradient = 0;
+    const DType* out_grad_slice =
       out_grad + (n * channels + c) * pooled_height * pooled_width;
-	  for (int ph = phstart; ph < phend; ++ph) {
-	 	  for (int pw = pwstart; pw < pwend; ++pw) {
-		    // figure out the pooling size
-			  int hstart = ph * stride_h - pad_h;
-			  int wstart = pw * stride_w - pad_w;
-			  int hend = min(hstart + kernel_h, height + pad_h);
-			  int wend = min(wstart + kernel_w, width + pad_w);
-			  int pool_size = (isAvg? (hend - hstart) * (wend - wstart) : 1);
-			  gradient += out_grad_slice[ph * pooled_width + pw] / pool_size;
-		  }
-	  }
+    const DType* out_data_slice =
+      out_data + (n * channels + c) * pooled_height * pooled_width;
+    for (int ph = phstart; ph < phend; ++ph) {
+      for (int pw = pwstart; pw < pwend; ++pw) {
+        // figure out the pooling size
+        int hstart = ph * stride_h - pad_h;
+        int wstart = pw * stride_w - pad_w;
+        int hend = min(hstart + kernel_h, height + pad_h);
+        int wend = min(wstart + kernel_w, width + pad_w);
+        int pool_size = (isAvg? (hend - hstart) * (wend - wstart) : 1);
+        int out_index = ph * pooled_width + pw;
+        gradient +=
+          lp_grad<DType, p>::Map(out_grad_slice[out_index],
+                                 in_data[index],
+                                 out_data_slice[out_index]) / pool_size;
+      }
+    }
     // if req=kWriteTo, in_grad has already been assigned zero values in unpool()
     // use "+=" here instead of "=" to accommodate when req=kAddTo
-	  in_grad[index] += gradient;
+    in_grad[index] += gradient;
   }
 }
 
@@ -561,33 +571,36 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr
  * \brief avg/sum unpooling gpu kernel for 3-D images.
  * Do not call this kernel directly. Use the interface unpool().
  */
-template<typename DType>
+template<typename DType, int p = 1>
 __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_grad,
+                                         const DType* in_data, const DType* out_data,
                                          const int channels, const int depth, const int height,
                                          const int width, const int pooled_depth,
                                          const int pooled_height, const int pooled_width,
                                          const int kernel_d, const int kernel_h,
                                          const int kernel_w, const int stride_d, const int stride_h,
                                          const int stride_w, const int pad_d, const int pad_h,
-                                         const int pad_w, DType* in_grad, bool isAvg = false) {
+                                         const int pad_w, DType* in_grad, const bool isAvg = false) {
   // index is the input image index in NCDHW
   CUDA_KERNEL_LOOP(index, nthreads) {
-	  // find out the local index
-	  // find out the local offset
-	  const int w = index % width + pad_w;
-	  const int h = (index / width) % height + pad_h;
+    // find out the local index
+    // find out the local offset
+    const int w = index % width + pad_w;
+    const int h = (index / width) % height + pad_h;
     const int d = (index / width / height) % depth + pad_d;
-	  const int c = (index / width / height / depth) % channels;
-	  const int n = index / width / height / depth / channels;
+    const int c = (index / width / height / depth) % channels;
+    const int n = index / width / height / depth / channels;
     const int pdstart = (d < kernel_d) ? 0 : (d - kernel_d) / stride_d + 1;
     const int pdend = min(d / stride_d + 1, pooled_depth);
-	  const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-	  const int phend = min(h / stride_h + 1, pooled_height);
-	  const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-	  const int pwend = min(w / stride_w + 1, pooled_width);
-	  DType gradient = 0;
-	  const DType* out_grad_slice =
+    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
+    const int phend = min(h / stride_h + 1, pooled_height);
+    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
+    const int pwend = min(w / stride_w + 1, pooled_width);
+    DType gradient = 0;
+    const DType* out_grad_slice =
       out_grad + (n * channels + c) * pooled_depth * pooled_height * pooled_width;
+    const DType* out_data_slice =
+      out_data + (n * channels + c) * pooled_depth * pooled_height * pooled_width;
     for (int pd = pdstart; pd < pdend; ++pd) {
       for (int ph = phstart; ph < phend; ++ph) {
         for (int pw = pwstart; pw < pwend; ++pw) {
@@ -599,13 +612,16 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr
           int hend = min(hstart + kernel_h, height + pad_h);
           int wend = min(wstart + kernel_w, width + pad_w);
           int pool_size = (isAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
-          gradient += out_grad_slice[(pd * pooled_height + ph) * pooled_width + pw] / pool_size;
+          int out_index = (pd * pooled_height + ph) * pooled_width + pw;
+          gradient += lp_grad<DType, p>::Map(out_grad_slice[out_index],
+                                             in_data[index],
+                                             out_data_slice[out_index]) / pool_size;
         }
       }
     }
     // if req=kWriteTo, in_grad has already been assigned zero values in unpool()
     // use "+=" here instead of "=" to accommodate when req=kAddTo
-	  in_grad[index] += gradient;
+    in_grad[index] += gradient;
   }
 }
 
@@ -621,8 +637,9 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr
  * \param pool_type supported pooling type: max, avg, sum
  * \param req_type operator request type, only support kWriteTo for now
  * \param out_data pointer of the output tensor data in the format of NCW, NCHW, or NCDHW
+ * \param p_value value of p for Lp pooling
  */
-template<typename DType>
+template<typename DType, int p>
 inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& ishape,
                  const TShape& oshape, const TShape& kernel, const TShape& pad,
                  const TShape& stride, const int pool_type, OpReqType req_type,
@@ -651,6 +668,13 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
                                    oshape.Size(), in_data, ishape[1], ishape[2], oshape[2],
                                    kernel[0], stride[0], pad[0], out_data);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_1d_gpu_kernel);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      // NOLINT_NEXT_LINE(whitespace/operators)
+      pool_sum_1d_gpu_kernel<DType, p><<<cuda_get_num_blocks(oshape.Size()), mshadow::cuda::kBaseThreadNum,
+                               0, mshadow::Stream<gpu>::GetStream(s)>>>(
+                                   oshape.Size(), in_data, ishape[1], ishape[2], oshape[2],
+                                   kernel[0], stride[0], pad[0], out_data);
+      MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_1d_gpu_kernel);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
@@ -679,6 +703,14 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
                                    oshape[2], oshape[3], kernel[0], kernel[1],
                                    stride[0], stride[1], pad[0], pad[1], out_data);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_2d_gpu_kernel);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      // NOLINT_NEXT_LINE(whitespace/operators)
+      pool_sum_2d_gpu_kernel<DType, p><<<cuda_get_num_blocks(oshape.Size()), mshadow::cuda::kBaseThreadNum,
+                               0, mshadow::Stream<gpu>::GetStream(s)>>>(
+                                   oshape.Size(), in_data, ishape[1], ishape[2], ishape[3],
+                                   oshape[2], oshape[3], kernel[0], kernel[1],
+                                   stride[0], stride[1], pad[0], pad[1], out_data);
+      MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_2d_gpu_kernel);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
@@ -710,6 +742,15 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
                                    kernel[1], kernel[2], stride[0], stride[1], stride[2],
                                    pad[0], pad[1], pad[2], out_data);
       MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_3d_gpu_kernel);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      // NOLINT_NEXT_LINE(whitespace/operators)
+      pool_sum_3d_gpu_kernel<DType, p><<<cuda_get_num_blocks(oshape.Size()), mshadow::cuda::kBaseThreadNum,
+                               0, mshadow::Stream<gpu>::GetStream(s)>>>(
+                                   oshape.Size(), in_data, ishape[1], ishape[2], ishape[3],
+                                   ishape[4], oshape[2], oshape[3], oshape[4], kernel[0],
+                                   kernel[1], kernel[2], stride[0], stride[1], stride[2],
+                                   pad[0], pad[1], pad[2], out_data);
+      MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_3d_gpu_kernel);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
@@ -730,8 +771,9 @@ inline void pool(mshadow::Stream<gpu>* s, const DType* in_data, const TShape& is
  * \param pool_type supported pooling type: max, avg, sum
  * \param req_type operator request type: kNullOp, kNullWriteInplace, kNullWriteTo, kNullAddTo
  * \param in_grad pointer of the gradient of the operator's input tensor
+ * \param p_value value of p for Lp pooling
  */
-template<typename DType>
+template<typename DType, int p>
 inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType* in_data,
                    const DType* out_data, const TShape& ishape, const TShape& oshape,
                    const TShape& kernel, const TShape& pad, const TShape& stride,
@@ -754,7 +796,7 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
       // NOLINT_NEXT_LINE(whitespace/operators)
       unpool_sum_1d_gpu_kernel<<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
                                  0, mshadow::Stream<gpu>::GetStream(s)>>>(
-                                     ishape.Size(), out_grad,
+                                     ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], oshape[2], kernel[0],
                                      stride[0], pad[0], in_grad, true);
       MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel);
@@ -762,14 +804,22 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
       // NOLINT_NEXT_LINE(whitespace/operators)
       unpool_sum_1d_gpu_kernel<<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
                                  0, mshadow::Stream<gpu>::GetStream(s)>>>(
-                                     ishape.Size(), out_grad,
+                                     ishape.Size(), out_grad, in_data, out_data,
+                                     ishape[1], ishape[2], oshape[2], kernel[0],
+                                     stride[0], pad[0], in_grad);
+      MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      // NOLINT_NEXT_LINE(whitespace/operators)
+      unpool_sum_1d_gpu_kernel<DType, p><<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
+                                 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+                                     ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], oshape[2], kernel[0],
                                      stride[0], pad[0], in_grad);
       MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
-  } else  if (kernel.ndim() == 2) {
+  } else if (kernel.ndim() == 2) {
     if (pool_enum::kMaxPooling == pool_type) {
       // NOLINT_NEXT_LINE(whitespace/operators)
       unpool_max_2d_gpu_kernel<<<cuda_get_num_blocks(oshape.Size()), mshadow::cuda::kBaseThreadNum,
@@ -783,7 +833,7 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
       // NOLINT_NEXT_LINE(whitespace/operators)
       unpool_sum_2d_gpu_kernel<<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
                                  0, mshadow::Stream<gpu>::GetStream(s)>>>(
-                                     ishape.Size(), out_grad,
+                                     ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], ishape[3],
                                      oshape[2], oshape[3], kernel[0], kernel[1],
                                      stride[0], stride[1], pad[0], pad[1], in_grad, true);
@@ -792,7 +842,16 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
       // NOLINT_NEXT_LINE(whitespace/operators)
       unpool_sum_2d_gpu_kernel<<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
                                  0, mshadow::Stream<gpu>::GetStream(s)>>>(
-                                     ishape.Size(), out_grad,
+                                     ishape.Size(), out_grad, in_data, out_data,
+                                     ishape[1], ishape[2], ishape[3],
+                                     oshape[2], oshape[3], kernel[0], kernel[1],
+                                     stride[0], stride[1], pad[0], pad[1], in_grad);
+      MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_2d_gpu_kernel);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      // NOLINT_NEXT_LINE(whitespace/operators)
+      unpool_sum_2d_gpu_kernel<DType, p><<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
+                                 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+                                     ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], ishape[3],
                                      oshape[2], oshape[3], kernel[0], kernel[1],
                                      stride[0], stride[1], pad[0], pad[1], in_grad);
@@ -815,7 +874,7 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
       // NOLINT_NEXT_LINE(whitespace/operators)
       unpool_sum_3d_gpu_kernel<<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
                                  0, mshadow::Stream<gpu>::GetStream(s)>>>(
-                                     ishape.Size(), out_grad,
+                                     ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], ishape[3], ishape[4],
                                      oshape[2], oshape[3], oshape[4], kernel[0], kernel[1],
                                      kernel[2], stride[0], stride[1], stride[2], pad[0], pad[1],
@@ -825,7 +884,17 @@ inline void unpool(mshadow::Stream<gpu>* s, const DType* out_grad, const DType*
       // NOLINT_NEXT_LINE(whitespace/operators)
       unpool_sum_3d_gpu_kernel<<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
                                  0, mshadow::Stream<gpu>::GetStream(s)>>>(
-                                     ishape.Size(), out_grad,
+                                     ishape.Size(), out_grad, in_data, out_data,
+                                     ishape[1], ishape[2], ishape[3], ishape[4],
+                                     oshape[2], oshape[3], oshape[4], kernel[0], kernel[1],
+                                     kernel[2], stride[0], stride[1], stride[2], pad[0], pad[1],
+                                     pad[2], in_grad);
+      MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_3d_gpu_kernel);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      // NOLINT_NEXT_LINE(whitespace/operators)
+      unpool_sum_3d_gpu_kernel<DType, p><<<cuda_get_num_blocks(ishape.Size()), mshadow::cuda::kBaseThreadNum,
+                                 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+                                     ishape.Size(), out_grad, in_data, out_data,
                                      ishape[1], ishape[2], ishape[3], ishape[4],
                                      oshape[2], oshape[3], oshape[4], kernel[0], kernel[1],
                                      kernel[2], stride[0], stride[1], stride[2], pad[0], pad[1],
diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h
index 79accb5d521..9fe43b2bd46 100644
--- a/src/operator/nn/pool.h
+++ b/src/operator/nn/pool.h
@@ -62,7 +62,9 @@
 #include <mxnet/base.h>
 #include <mxnet/operator.h>
 #include <algorithm>
+#include "./pool_utils.h"
 #include "../mxnet_op.h"
+#include "../mshadow_op.h"
 
 namespace mxnet {
 namespace op {
@@ -70,7 +72,7 @@ namespace op {
 namespace pool_enum {
 enum PoolingOpInputs {kData};
 enum PoolingOpOutputs {kOut, kMask};
-enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling};
+enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling, kLpPooling};
 enum PoolingOpPadConventionType {kValid, kFull};
 }  // namespace pool_enum
 
@@ -211,10 +213,10 @@ inline void pool_max_3d_cpu(const DType* in_data, const TShape& ishape, const TS
  * \brief avg/sum pooling cpu function for 1-D images.
  * Do not call this kernel directly. Use the interface pool().
  */
-template<typename DType>
+template<typename DType, int p = 1>
 inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, bool getAvg = false) {
+                            DType* out_data, const bool getAvg = false) {
   const int width = ishape[2];
   const int pooled_width = oshape[2];
   const int kernel_w = kernel[0];
@@ -227,14 +229,14 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS
       for (int pw = 0; pw < pooled_width; ++pw) {
         int wstart = pw * stride_w - pad_w;
         int wend = std::min(wstart + kernel_w, width + pad_w);
-        int pool_size = (wend - wstart);
+        int pool_size = (getAvg ? (wend - wstart) : 1);
         wstart = std::max(wstart, 0);
         wend = std::min(wend, width);
         DType sum = 0;
         for (int w = wstart; w < wend; ++w) {
-          sum += in_data[w];
+          sum += a_pow_p<DType, p>::Map(in_data[w]) / pool_size;
         }
-        out_data[pw] = (getAvg? sum/pool_size : sum);
+        out_data[pw] = a_root_p<DType, p>::Map(sum);
       }
       in_data += in_data_offset;
       out_data += out_data_offset;
@@ -246,10 +248,10 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS
  * \brief avg/sum pooling cpu function for 2-D images.
  * Do not call this kernel directly. Use the interface pool().
  */
-template<typename DType>
+template<typename DType, int p = 1>
 inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, bool getAvg = false) {
+                            DType* out_data, const bool getAvg = false) {
   const int height = ishape[2], width = ishape[3];
   const int pooled_height = oshape[2], pooled_width = oshape[3];
   const int kernel_h = kernel[0], kernel_w = kernel[1];
@@ -265,7 +267,7 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS
           int wstart = pw * stride_w - pad_w;
           int hend = std::min(hstart + kernel_h, height + pad_h);
           int wend = std::min(wstart + kernel_w, width + pad_w);
-          int pool_size = (hend - hstart) * (wend - wstart);
+          int pool_size = (getAvg ? (hend - hstart) * (wend - wstart) : 1);
           hstart = std::max(hstart, 0);
           wstart = std::max(wstart, 0);
           hend = std::min(hend, height);
@@ -273,10 +275,10 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS
           DType sum = 0;
           for (int h = hstart; h < hend; ++h) {
             for (int w = wstart; w < wend; ++w) {
-              sum += in_data[h*width+w];
+              sum += a_pow_p<DType, p>::Map(in_data[h*width+w]) / pool_size;
             }
           }
-          out_data[ph*pooled_width+pw] = (getAvg? sum/pool_size : sum);
+          out_data[ph*pooled_width+pw] = a_root_p<DType, p>::Map(sum);
         }
       }
       in_data += in_data_offset;
@@ -289,10 +291,10 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS
  * \brief avg/sum pooling cpu function for 3-D images.
  * Do not call this kernel directly. Use the interface pool().
  */
-template<typename DType>
+template<typename DType, int p = 1>
 inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape,
                             const TShape& kernel, const TShape& pad, const TShape& stride,
-                            DType* out_data, bool getAvg = false) {
+                            DType* out_data, const bool getAvg = false) {
   const int depth = ishape[2], height = ishape[3], width = ishape[4];
   const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4];
   const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2];
@@ -311,7 +313,7 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS
             int dend = std::min(dstart + kernel_d, depth + pad_d);
             int hend = std::min(hstart + kernel_h, height + pad_h);
             int wend = std::min(wstart + kernel_w, width + pad_w);
-            int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+            int pool_size = (getAvg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
             dstart = std::max(dstart, 0);
             hstart = std::max(hstart, 0);
             wstart = std::max(wstart, 0);
@@ -322,11 +324,11 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS
             for (int d = dstart; d < dend; ++d) {
               for (int h = hstart; h < hend; ++h) {
                 for (int w = wstart; w < wend; ++w) {
-                  sum += in_data[(d*height+h)*width+w];
+                  sum += a_pow_p<DType, p>::Map(in_data[(d*height+h)*width+w]) / pool_size;
                 }
               }
             }
-            out_data[(pd*pooled_height+ph)*pooled_width+pw] = (getAvg? sum/pool_size : sum);
+            out_data[(pd*pooled_height+ph)*pooled_width+pw] = a_root_p<DType, p>::Map(sum);
           }
         }
       }
@@ -504,11 +506,11 @@ inline void unpool_max_3d_cpu(const DType* out_grad, const DType* in_data,
  * \brief avg/sum unpooling cpu function for 1-D images.
  * Do not call this kernel directly. Use the interface unpool().
  */
-template<typename DType>
-inline void unpool_sum_1d_cpu(const DType* out_grad, const TShape& ishape,
-                              const TShape& oshape, const TShape& kernel,
+template<typename DType, int p = 1>
+inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
+                              const TShape& ishape, const TShape& oshape, const TShape& kernel,
                               const TShape& pad, const TShape& stride,
-                              DType* in_grad, bool isAvg = false) {
+                              DType* in_grad, const bool isAvg = false) {
   const int width = ishape[2];
   const int pooled_width = oshape[2];
   const int kernel_w = kernel[0];
@@ -521,18 +523,17 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const TShape& ishape,
       for (int pw = 0; pw < pooled_width; ++pw) {
         int wstart = pw * stride_w - pad_w;
         int wend = std::min(wstart + kernel_w, width + pad_w);
-        int pool_size = 1;
-        if (isAvg) {
-          pool_size = wend - wstart;
-        }
+        int pool_size = (isAvg ? (wend - wstart) : 1);
         wstart = std::max(wstart, 0);
         wend = std::min(wend, width);
         for (int w = wstart; w < wend; ++w) {
-          in_grad[w] += out_grad[pw] / pool_size;
+          in_grad[w] += lp_grad<DType, p>::Map(out_grad[pw], in_data[w], out_data[pw]) / pool_size;
         }
       }
       in_grad += in_grad_offset;
+      in_data += in_grad_offset;
       out_grad += out_grad_offset;
+      out_data += out_grad_offset;
     }
   }
 }
@@ -541,11 +542,11 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const TShape& ishape,
  * \brief avg/sum unpooling cpu function for 2-D images.
  * Do not call this kernel directly. Use the interface unpool().
  */
-template<typename DType>
-inline void unpool_sum_2d_cpu(const DType* out_grad, const TShape& ishape,
-                              const TShape& oshape, const TShape& kernel,
+template<typename DType, int p = 1>
+inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
+                              const TShape& ishape, const TShape& oshape, const TShape& kernel,
                               const TShape& pad, const TShape& stride,
-                              DType* in_grad, bool isAvg = false) {
+                              DType* in_grad, const bool isAvg = false) {
   const int height = ishape[2], width = ishape[3];
   const int pooled_height = oshape[2], pooled_width = oshape[3];
   const int kernel_h = kernel[0], kernel_w = kernel[1];
@@ -561,10 +562,7 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const TShape& ishape,
           int wstart = pw * stride_w - pad_w;
           int hend = std::min(hstart + kernel_h, height + pad_h);
           int wend = std::min(wstart + kernel_w, width + pad_w);
-          int pool_size = 1;
-          if (isAvg) {
-            pool_size = (hend - hstart) * (wend - wstart);
-          }
+          int pool_size = (isAvg ? (hend - hstart) * (wend - wstart) : 1);
           hstart = std::max(hstart, 0);
           wstart = std::max(wstart, 0);
           hend = std::min(hend, height);
@@ -572,13 +570,18 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const TShape& ishape,
           const int pool_index = ph * pooled_width + pw;
           for (int h = hstart; h < hend; ++h) {
             for (int w = wstart; w < wend; ++w) {
-              in_grad[h*width+w] += out_grad[pool_index] / pool_size;
+              in_grad[h*width+w] +=
+                lp_grad<DType, p>::Map(out_grad[pool_index],
+                                       in_data[h*width+w],
+                                       out_data[pool_index]) / pool_size;
             }
           }
         }
       }
       in_grad += in_grad_offset;
+      in_data += in_grad_offset;
       out_grad += out_grad_offset;
+      out_data += out_grad_offset;
     }
   }
 }
@@ -587,11 +590,11 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const TShape& ishape,
  * \brief avg/sum unpooling cpu function for 3-D images.
  * Do not call this kernel directly. Use the interface unpool().
  */
-template<typename DType>
-inline void unpool_sum_3d_cpu(const DType* out_grad, const TShape& ishape,
-                              const TShape& oshape, const TShape& kernel,
+template<typename DType, int p = 1>
+inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data,
+                              const TShape& ishape, const TShape& oshape, const TShape& kernel,
                               const TShape& pad, const TShape& stride,
-                              DType* in_grad, bool isAvg = false) {
+                              DType* in_grad, const bool isAvg = false) {
   const int depth = ishape[2], height = ishape[3], width = ishape[4];
   const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4];
   const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2];
@@ -610,10 +613,7 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const TShape& ishape,
             int dend = std::min(dstart + kernel_d, depth + pad_d);
             int hend = std::min(hstart + kernel_h, height + pad_h);
             int wend = std::min(wstart + kernel_w, width + pad_w);
-            int pool_size = 1;
-            if (isAvg) {
-              pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
-            }
+            int pool_size = (isAvg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1);
             dstart = std::max(dstart, 0);
             hstart = std::max(hstart, 0);
             wstart = std::max(wstart, 0);
@@ -624,7 +624,10 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const TShape& ishape,
             for (int d = dstart; d < dend; ++d) {
               for (int h = hstart; h < hend; ++h) {
                 for (int w = wstart; w < wend; ++w) {
-                  in_grad[(d*height+h)*width+w] += out_grad[pool_index] / pool_size;
+                  in_grad[(d*height+h)*width+w] +=
+                    lp_grad<DType, p>::Map(out_grad[pool_index],
+                                           in_data[(d*height+h)*width+w],
+                                           out_data[pool_index]) / pool_size;
                 }
               }
             }
@@ -632,7 +635,9 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const TShape& ishape,
         }
       }
       in_grad += in_grad_offset;
+      in_data += in_grad_offset;
       out_grad += out_grad_offset;
+      out_data += out_grad_offset;
     }
   }
 }
@@ -649,8 +654,9 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const TShape& ishape,
  * \param pool_type supported pooling type: max, avg, sum
  * \param req_type operator request type, only support kWriteTo for now
  * \param out_data pointer of the output tensor data in the format of NCW, NCHW, or NCDHW
+ * \param p_value value of p for Lp pooling
  */
-template<typename DType>
+template<typename DType, int p>
 inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& ishape,
                  const TShape& oshape, const TShape& kernel, const TShape& pad,
                  const TShape& stride, const int pool_type, OpReqType req_type,
@@ -663,6 +669,8 @@ inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& is
       pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      pool_sum_1d_cpu<DType, p>(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
@@ -673,6 +681,8 @@ inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& is
       pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      pool_sum_2d_cpu<DType, p>(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
@@ -683,6 +693,8 @@ inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& is
       pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true);
     } else if (pool_enum::kSumPooling == pool_type) {
       pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      pool_sum_3d_cpu<DType, p>(in_data, ishape, oshape, kernel, pad, stride, out_data);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
@@ -705,12 +717,13 @@ inline void pool(mshadow::Stream<cpu>* s, const DType* in_data, const TShape& is
  * \param pool_type supported pooling type: max, avg, sum
  * \param req_type operator request type: kNullOp, kNullWriteInplace, kNullWriteTo, kNullAddTo
  * \param in_grad pointer of the gradient of the operator's input tensor
+ * \param p_value value of p for Lp pooling
  */
-template<typename DType>
+template<typename DType, int p>
 inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType* in_data,
                    const DType* out_data, const TShape& ishape, const TShape& oshape,
                    const TShape& kernel, const TShape& pad, const TShape& stride,
-                   const int pool_type, OpReqType req_type, DType* in_grad) {
+                   const int pool_type, OpReqType req_type, DType* in_grad, const int p_value = 2) {
   if (mxnet::kNullOp == req_type) return;
   if (mxnet::kAddTo != req_type) {
     mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, ishape.Size(), in_grad);
@@ -719,9 +732,13 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
     if (pool_enum::kMaxPooling == pool_type) {
       unpool_max_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      unpool_sum_1d_cpu(out_grad, ishape, oshape, kernel, pad, stride, in_grad, true);
+      unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
+                        true);
     } else if (pool_enum::kSumPooling == pool_type) {
-      unpool_sum_1d_cpu(out_grad, ishape, oshape, kernel, pad, stride, in_grad);
+      unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      unpool_sum_1d_cpu<DType, p>(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride,
+                                  in_grad);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
@@ -729,9 +746,13 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
     if (pool_enum::kMaxPooling == pool_type) {
       unpool_max_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      unpool_sum_2d_cpu(out_grad, ishape, oshape, kernel, pad, stride, in_grad, true);
+      unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
+                        true);
     } else if (pool_enum::kSumPooling == pool_type) {
-      unpool_sum_2d_cpu(out_grad, ishape, oshape, kernel, pad, stride, in_grad);
+      unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      unpool_sum_2d_cpu<DType, p>(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride,
+                                  in_grad);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
@@ -739,9 +760,13 @@ inline void unpool(mshadow::Stream<cpu>* s, const DType* out_grad, const DType*
     if (pool_enum::kMaxPooling == pool_type) {
       unpool_max_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
     } else if (pool_enum::kAvgPooling == pool_type) {
-      unpool_sum_3d_cpu(out_grad, ishape, oshape, kernel, pad, stride, in_grad, true);
+      unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad,
+                        true);
     } else if (pool_enum::kSumPooling == pool_type) {
-      unpool_sum_3d_cpu(out_grad, ishape, oshape, kernel, pad, stride, in_grad);
+      unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad);
+    } else if (pool_enum::kLpPooling == pool_type) {
+      unpool_sum_3d_cpu<DType, p>(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride,
+                                  in_grad);
     } else {
       LOG(FATAL) << "Unknown pooling type " << pool_type;
     }
diff --git a/src/operator/nn/pool_utils.h b/src/operator/nn/pool_utils.h
new file mode 100644
index 00000000000..641cc4a995a
--- /dev/null
+++ b/src/operator/nn/pool_utils.h
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_NN_POOL_UTILS_H_
+#define MXNET_OPERATOR_NN_POOL_UTILS_H_
+
+#include "../mshadow_op.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename DType, int p>
+struct a_pow_p {
+  static MSHADOW_XINLINE DType Map(const DType a) {
+    return mshadow_op::power::Map(a, DType(p));
+  }
+};
+
+template<typename DType>
+struct a_pow_p<DType, 1> {
+  static MSHADOW_XINLINE DType Map(const DType a) {
+    return a;
+  }
+};
+
+template<typename DType>
+struct a_pow_p<DType, 2> {
+  static MSHADOW_XINLINE DType Map(const DType a) {
+    return a*a;
+  }
+};
+
+template<typename DType>
+struct a_pow_p<DType, 3> {
+  static MSHADOW_XINLINE DType Map(const DType a) {
+    return a*a*a;
+  }
+};
+
+template<typename DType, int p>
+struct a_root_p {
+  static MSHADOW_XINLINE DType Map(const DType a) {
+    return mshadow_op::power::Map(a, DType(1.0 / p));
+  }
+};
+
+template<typename DType>
+struct a_root_p<DType, 1> {
+  static MSHADOW_XINLINE DType Map(const DType a) {
+    return a;
+  }
+};
+
+template<typename DType>
+struct a_root_p<DType, 2> {
+  static MSHADOW_XINLINE DType Map(const DType a) {
+    return mshadow_op::square_root::Map(a);
+  }
+};
+
+template<typename DType>
+struct a_root_p<DType, 3> {
+  static MSHADOW_XINLINE DType Map(const DType a) {
+    return mshadow_op::cube_root::Map(a);
+  }
+};
+
+template<typename DType, int p>
+struct lp_grad {
+  static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) {
+    return grad * mshadow_op::power::Map(in_data / out_data, DType(p - 1));
+  }
+};
+
+template<typename DType>
+struct lp_grad<DType, 1> {
+  static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) {
+    return grad;
+  }
+};
+
+template<typename DType>
+struct lp_grad<DType, 2> {
+  static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) {
+    return grad * in_data / out_data;
+  }
+};
+
+template<typename DType>
+struct lp_grad<DType, 3> {
+  static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) {
+    return grad * in_data * in_data / (out_data * out_data);
+  }
+};
+
+}   // namespace op
+}   // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NN_POOL_UTILS_H_
diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h
index 5993bf5149d..a4770b49e85 100644
--- a/src/operator/nn/pooling-inl.h
+++ b/src/operator/nn/pooling-inl.h
@@ -21,7 +21,7 @@
  * Copyright (c) 2017 by Contributors
  * \file pooling-inl.h
  * \brief
- * \author Bing Xu, Jun Wu, Da Zheng
+ * \author Bing Xu, Jun Wu, Da Zheng, Hao Jin
 */
 
 #ifndef MXNET_OPERATOR_NN_POOLING_INL_H_
@@ -49,6 +49,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
   int pooling_convention;
   bool global_pool;
   bool cudnn_off;
+  dmlc::optional<int> p_value;
   DMLC_DECLARE_PARAMETER(PoolingParam) {
     DMLC_DECLARE_FIELD(kernel).set_default(TShape())  // add default value here
     .enforce_nonzero()
@@ -58,6 +59,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
     .add_enum("max", pool_enum::kMaxPooling)
     .add_enum("avg", pool_enum::kAvgPooling)
     .add_enum("sum", pool_enum::kSumPooling)
+    .add_enum("lp", pool_enum::kLpPooling)
     .describe("Pooling type to be applied.");
 
     DMLC_DECLARE_FIELD(global_pool).set_default(false)
@@ -77,6 +79,9 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
 
     DMLC_DECLARE_FIELD(pad).set_default(TShape())
     .describe("Pad for pooling: (y, x) or (d, y, x). Defaults to no padding.");
+
+    DMLC_DECLARE_FIELD(p_value).set_default(dmlc::optional<int>())
+    .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling");
   }
 
   bool operator==(const PoolingParam& other) const {
@@ -86,7 +91,8 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
            this->pool_type          == other.pool_type &&
            this->pooling_convention == other.pooling_convention &&
            this->global_pool        == other.global_pool &&
-           this->cudnn_off          == other.cudnn_off;
+           this->cudnn_off          == other.cudnn_off &&
+           this->p_value            == other.p_value;
   }
 };
 
@@ -105,6 +111,7 @@ struct hash<mxnet::op::PoolingParam> {
     ret = dmlc::HashCombine(ret, val.pooling_convention);
     ret = dmlc::HashCombine(ret, val.global_pool);
     ret = dmlc::HashCombine(ret, val.cudnn_off);
+    ret = dmlc::HashCombine(ret, val.p_value);
     return ret;
   }
 };
@@ -144,12 +151,33 @@ class PoolingOp {
       }
       stride = TShape(ishape.ndim() - 2);
     }
-
-    pool(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
-         kernel,
-         padding,
-         stride,
-         param_.pool_type, req, out_data.dptr<DType>());
+    const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ?
+                        param_.p_value.value() : 1;
+    switch (p_value) {
+      case 1:
+        pool<DType, 1>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
+          kernel,
+          padding,
+          stride,
+          param_.pool_type, req, out_data.dptr<DType>());
+        break;
+      case 2:
+        pool<DType, 2>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
+          kernel,
+          padding,
+          stride,
+          param_.pool_type, req, out_data.dptr<DType>());
+        break;
+      case 3:
+        pool<DType, 3>(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
+          kernel,
+          padding,
+          stride,
+          param_.pool_type, req, out_data.dptr<DType>());
+        break;
+      default:
+        LOG(FATAL) << "p value of " << p_value << " is not supported yet...";
+    }
   }
 
   void Backward(const OpContext& ctx, const TBlob& out_grad,
@@ -171,12 +199,36 @@ class PoolingOp {
       stride = TShape(ishape.ndim() - 2);
     }
 
-    unpool(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
+    const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ?
+                        param_.p_value.value() : 1;
+    switch (p_value) {
+      case 1:
+        unpool<DType, 1>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
+           in_grad.shape_, out_grad.shape_,
+           kernel,
+           padding,
+           stride,
+           param_.pool_type, req, in_grad.dptr<DType>());
+        break;
+      case 2:
+        unpool<DType, 2>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
            in_grad.shape_, out_grad.shape_,
            kernel,
            padding,
            stride,
            param_.pool_type, req, in_grad.dptr<DType>());
+        break;
+      case 3:
+        unpool<DType, 3>(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
+           in_grad.shape_, out_grad.shape_,
+           kernel,
+           padding,
+           stride,
+           param_.pool_type, req, in_grad.dptr<DType>());
+        break;
+      default:
+        LOG(FATAL) << "p value of " << p_value << " is not supported yet...";
+    }
   }
 
  private:
@@ -200,7 +252,8 @@ void PoolingCompute(const nnvm::NodeAttrs& attrs,
   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
     if (pool_enum::kMaxPooling == param.pool_type
         || pool_enum::kAvgPooling == param.pool_type
-        || pool_enum::kSumPooling == param.pool_type) {
+        || pool_enum::kSumPooling == param.pool_type
+        || pool_enum::kLpPooling == param.pool_type) {
       PoolingOp<xpu, DType> op;
       op.Init(param);
       op.Forward(ctx, inputs[0], req[0], outputs[0]);
@@ -239,7 +292,8 @@ void PoolingGradCompute(const nnvm::NodeAttrs& attrs,
   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
     if (pool_enum::kMaxPooling == param.pool_type
         || pool_enum::kAvgPooling == param.pool_type
-        || pool_enum::kSumPooling == param.pool_type) {
+        || pool_enum::kSumPooling == param.pool_type
+        || pool_enum::kLpPooling == param.pool_type) {
       PoolingOp<xpu, DType> op;
       op.Init(param);
       op.Backward(ctx, inputs[ograd_idx], inputs[in_data_idx],
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index ca472b3ca1b..3ff94da3c2d 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -92,6 +92,9 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
                          std::vector<TShape> *out_shape) {
   const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
   CHECK_EQ(in_shape->size(), 1U);
+  if (param.pool_type == pool_enum::kLpPooling) {
+    CHECK(param.p_value.has_value());
+  }
   const TShape &dshape = (*in_shape)[0];
   CHECK_GE(dshape.ndim(), 3U)
       << "Pooling: Input data should be  3D in (batch, channel, x)"
@@ -344,11 +347,23 @@ Three pooling options are supported by ``pool_type``:
 - **avg**: average pooling
 - **max**: max pooling
 - **sum**: sum pooling
+- **lp**: Lp pooling
 
 For 3-D pooling, an additional *depth* dimension is added before
 *height*. Namely the input data will have shape *(batch_size, channel, depth,
 height, width)*.
 
+Notes on Lp pooling:
+
+Lp pooling was first introduced by this paper: https://arxiv.org/pdf/1204.3968.pdf.
+L-1 pooling is simply sum pooling, while L-inf pooling is simply max pooling.
+We can see that Lp pooling stands between those two, in practice the most common value for p is 2.
+
+For each window ``X``, the mathematical expression for Lp pooling is:
+
+..math::
+  f(X) = \sqrt{p}{\sum\limits_{x \in X} x^p}
+
 )code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs([](const NodeAttrs& attrs) {
diff --git a/src/operator/nn/pooling.cu b/src/operator/nn/pooling.cu
index 2187fd87ca8..997218620c3 100644
--- a/src/operator/nn/pooling.cu
+++ b/src/operator/nn/pooling.cu
@@ -66,6 +66,9 @@ void PoolingCompute<gpu>(const nnvm::NodeAttrs& attrs,
         case pool_enum::kSumPooling:
           LOG(WARNING) << "Sum pooling is not supported by cudnn, MXNet sum pooling is applied.";
           break;
+        case pool_enum::kLpPooling:
+          LOG(WARNING) << "Lp pooling is not supported by cudnn, MXNet lp pooling is applied.";
+          break;
       }
     });
   }
@@ -74,7 +77,8 @@ void PoolingCompute<gpu>(const nnvm::NodeAttrs& attrs,
   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
     if (pool_enum::kMaxPooling == param.pool_type
         || pool_enum::kAvgPooling == param.pool_type
-        || pool_enum::kSumPooling == param.pool_type) {
+        || pool_enum::kSumPooling == param.pool_type
+        || pool_enum::kLpPooling == param.pool_type) {
       PoolingOp<gpu, DType> op;
       op.Init(param);
       op.Forward(ctx, inputs[0], req[0], outputs[0]);
@@ -119,6 +123,9 @@ void PoolingGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
         case pool_enum::kSumPooling:
           LOG(WARNING) << "Sum pooling is not supported by cudnn, MXNet sum pooling is applied.";
           break;
+        case pool_enum::kLpPooling:
+          LOG(WARNING) << "Lp pooling is not supported by cudnn, MXNet Lp pooling is applied.";
+          break;
       }
     });
   }
@@ -127,7 +134,8 @@ void PoolingGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
     if (pool_enum::kMaxPooling == param.pool_type
         || pool_enum::kAvgPooling == param.pool_type
-        || pool_enum::kSumPooling == param.pool_type) {
+        || pool_enum::kSumPooling == param.pool_type
+        || pool_enum::kLpPooling == param.pool_type) {
       PoolingOp<gpu, DType> op;
       op.Init(param);
       op.Backward(ctx, inputs[ograd_idx], inputs[in_data_idx],
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index b9f2b6791d0..d5e02626218 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -741,7 +741,7 @@ def test_pooling_with_type():
 @with_seed()
 def test_pooling_versions():
     def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride,
-                                     pooling_convention='valid', global_pool=False):
+                                     pooling_convention='valid', global_pool=False, p_value=2):
         ctx_list = []
         sym_list = []
         # PoolingV1 cpu
@@ -765,140 +765,164 @@ def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, str
             ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, name='pool'))
+                                               pooling_convention=pooling_convention, name='pool', p_value=p_value))
             else:
-                sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool'))
+                sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool', p_value=p_value))
         # Pooling gpu
         if 'pool_gpu' in pool_op_list:
             ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, cudnn_off=True, name='pool'))
+                                               pooling_convention=pooling_convention, cudnn_off=True, name='pool', p_value=p_value))
             else:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, cudnn_off=True,
-                                               name='pool'))
+                                               name='pool', p_value=p_value))
         # CuDNNPooling
         if 'pool_cudnn' in pool_op_list:
             ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
             if not global_pool:
                 sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                               pooling_convention=pooling_convention, cudnn_off=False, name='pool'))
+                                               pooling_convention=pooling_convention, p_value=p_value, cudnn_off=False, name='pool'))
             else:
-                sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, cudnn_off=False,
-                                               name='pool'))
+                sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, p_value=p_value,
+                                               cudnn_off=False, name='pool'))
         check_consistency(sym_list, ctx_list)
 
-    def test_1d_pooling(pool_type):
+    def test_1d_pooling(pool_type, p_value=2):
         data = (2, 3, 20)
         kernel = (4,)
         pad = (0,)
         stride = (1,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value)
 
         pad = (2,)
         stride = (2,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value)
 
         pad = (0,)
         stride = (1,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False)
+                                     pooling_convention='full', global_pool=False, p_value=p_value)
 
         pad = (2,)
         stride = (2,)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False)
+                                     pooling_convention='full', global_pool=False, p_value=p_value)
 
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     global_pool=True)
+                                     global_pool=True, p_value=p_value)
 
-    def test_2d_pooling(pool_type):
+    def test_2d_pooling(pool_type, p_value=2):
         data = (2, 3, 20, 20)
         kernel = (4, 5)
         pad = (0, 0)
         stride = (1, 1)
-        test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
-                                     data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False)
+        if pool_type == 'lp':
+            test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'],
+                                         data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                         pooling_convention='valid', global_pool=False, p_value=p_value)
+        else:
+            test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                         data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                         pooling_convention='valid', global_pool=False)
 
         # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here
         pad = (2, 3)
         stride = (2, 3)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value)
 
         pad = (0, 0)
         stride = (1, 1)
-        test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
-                                     data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False)
+        if pool_type == 'lp':
+            test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                         data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                         pooling_convention='full', global_pool=False, p_value=p_value)
+        else:
+            test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                         data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                         pooling_convention='full', global_pool=False)
 
         # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here
         pad = (2, 3)
         stride = (2, 3)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False)
-
-        test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
-                                     data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     global_pool=True)
-
-    def test_3d_pooling(pool_type):
+                                     pooling_convention='full', global_pool=False, p_value=p_value)
+
+        if pool_type == 'lp':
+            test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                         data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                         global_pool=True, p_value=p_value)
+        else:
+            test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'],
+                                         data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                         global_pool=True)
+
+    def test_3d_pooling(pool_type, p_value=2):
         data = (2, 3, 20, 20, 20)
         kernel = (4, 5, 3)
         pad = (0, 0, 0)
         stride = (1, 1, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value)
 
         pad = (2, 3, 3)
         stride = (2, 3, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='valid', global_pool=False)
+                                     pooling_convention='valid', global_pool=False, p_value=p_value)
 
         pad = (0, 0, 0)
         stride = (1, 1, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False)
+                                     pooling_convention='full', global_pool=False, p_value=p_value)
 
         pad = (2, 3, 3)
         stride = (2, 3, 1)
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     pooling_convention='full', global_pool=False)
+                                     pooling_convention='full', global_pool=False, p_value=p_value)
 
         test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'],
                                      data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                     global_pool=True)
+                                     global_pool=True, p_value=p_value)
 
     test_1d_pooling('max')
     test_1d_pooling('avg')
     test_1d_pooling('sum')
+    test_1d_pooling('lp', p_value=1)
+    test_1d_pooling('lp', p_value=2)
+    test_1d_pooling('lp', p_value=3)
 
     test_2d_pooling('max')
     test_2d_pooling('avg')
     test_2d_pooling('sum')
+    test_2d_pooling('lp', p_value=1)
+    test_2d_pooling('lp', p_value=2)
+    test_2d_pooling('lp', p_value=3)
 
     test_3d_pooling('max')
     test_3d_pooling('avg')
     test_3d_pooling('sum')
+    test_3d_pooling('lp', p_value=1)
+    test_3d_pooling('lp', p_value=2)
+    test_3d_pooling('lp', p_value=3)
 
 
 @with_seed()
 def test_global_pooling():
-    def test_1d_pooling(pool_type):
+    def test_1d_pooling(pool_type, p_value=2):
         data = (2, 3, 20)
         kernel = (4,)
         pad = (2,)
@@ -911,43 +935,43 @@ def test_1d_pooling(pool_type):
 
         ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, name='pool', p_value=p_value))
 
         ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, name='pool', p_value=p_value))
 
         ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, name='pool', p_value=p_value))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=False, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=False, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=False, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=False, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=False, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=False, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=True, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=True, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=True, name='pool'))
 
         check_consistency(sym_list, ctx_list)
 
-    def test_2d_pooling(pool_type):
+    def test_2d_pooling(pool_type, p_value=2):
         data = (2, 3, 20, 20)
         kernel = (4, 4)
         pad = (2, 2)
@@ -958,53 +982,54 @@ def test_2d_pooling(pool_type):
 
         pooling_convention = 'valid'
 
-        ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
-        sym_list.append(mx.sym.Pooling_v1(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+        if pool_type != 'lp':
+            ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
+            sym_list.append(mx.sym.Pooling_v1(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
+                                              pooling_convention=pooling_convention, global_pool=True, name='pool'))
 
-        ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
-        sym_list.append(mx.sym.Pooling_v1(kernel=kernel, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+            ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
+            sym_list.append(mx.sym.Pooling_v1(kernel=kernel, pool_type=pool_type,
+                                              pooling_convention=pooling_convention, global_pool=True, name='pool'))
 
-        ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
-        sym_list.append(mx.sym.Pooling_v1(pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+            ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
+            sym_list.append(mx.sym.Pooling_v1(pool_type=pool_type,
+                                              pooling_convention=pooling_convention, global_pool=True, name='pool'))
 
         ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, name='pool'))
 
         ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, name='pool'))
 
         ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=False, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=False, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=False, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=False, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=False, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=False, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=True, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=True, name='pool'))
 
         ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
         sym_list.append(mx.sym.Pooling(pool_type=pool_type,
-                                       pooling_convention=pooling_convention, global_pool=True, cudnn_off=True, name='pool'))
+                                       pooling_convention=pooling_convention, global_pool=True, p_value=p_value, cudnn_off=True, name='pool'))
 
 
         check_consistency(sym_list, ctx_list)
@@ -1012,10 +1037,16 @@ def test_2d_pooling(pool_type):
     test_1d_pooling('max')
     test_1d_pooling('avg')
     test_1d_pooling('sum')
+    test_1d_pooling('lp', p_value=1)
+    test_1d_pooling('lp', p_value=2)
+    test_1d_pooling('lp', p_value=3)
 
     test_2d_pooling('max')
     test_2d_pooling('avg')
     test_2d_pooling('sum')
+    test_2d_pooling('lp', p_value=1)
+    test_2d_pooling('lp', p_value=2)
+    test_2d_pooling('lp', p_value=3)
 
 
 @with_seed()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services