You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/25 17:40:32 UTC

[GitHub] anirudh2290 closed pull request #12594: [MXNET-867] Pooling1D with "same" padding

anirudh2290 closed pull request #12594: [MXNET-867] Pooling1D with "same" padding
URL: https://github.com/apache/incubator-mxnet/pull/12594
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 3ae61298de8..55416355d8a 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -180,3 +180,4 @@ List of Contributors
 * [Per Goncalves da Silva](https://github.com/perdasilva)
 * [Zhijingcheng Yu](https://github.com/jasonyu1996)
 * [Cheng-Che Lee](https://github.com/stu1130)
+* [Chaitanya Bapat](https://github.com/ChaiBapchya)
\ No newline at end of file
diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h
index 8f7a5edc832..33005c8e5f0 100644
--- a/src/operator/nn/pool.h
+++ b/src/operator/nn/pool.h
@@ -73,7 +73,7 @@ namespace pool_enum {
 enum PoolingOpInputs {kData};
 enum PoolingOpOutputs {kOut, kMask};
 enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling, kLpPooling};
-enum PoolingOpPadConventionType {kValid, kFull};
+enum PoolingOpPadConventionType {kValid, kFull, kSame};
 }  // namespace pool_enum
 
 /*!
diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h
index ad74a8feae3..71d85da9ba5 100644
--- a/src/operator/nn/pooling-inl.h
+++ b/src/operator/nn/pooling-inl.h
@@ -74,6 +74,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
     DMLC_DECLARE_FIELD(pooling_convention).set_default(pool_enum::kValid)
     .add_enum("full", pool_enum::kFull)
     .add_enum("valid", pool_enum::kValid)
+    .add_enum("same", pool_enum::kSame)
     .describe("Pooling convention to be applied.");
 
     DMLC_DECLARE_FIELD(stride).set_default(TShape())
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 558722edb20..611568807a9 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -96,6 +96,13 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
     CHECK(param.p_value.has_value());
   }
   const TShape &dshape = (*in_shape)[0];
+  if (param.pooling_convention == pool_enum::kSame) {
+    CHECK_EQ(dshape.ndim(), 3U)
+      << "Pooling: Input data should be 3D in (batch, channel, x)"
+      << ". Currently 'same' supports Max Pooling 1-D";
+    CHECK(param.pad[0] == 0 && param.pad[1] == 0 && param.pad[2] == 0)
+      << "Same pooling convention disables the use of pad parameter.";
+  }
   CHECK_GE(dshape.ndim(), 3U)
       << "Pooling: Input data should be  3D in (batch, channel, x)"
       << " Or 4D in (batch, channel, y, x) "
@@ -126,11 +133,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
       oshape[2] = 1 +
                   (dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
                       param.stride[0];
-    } else {
+    } else if (param.pooling_convention == pool_enum::kFull) {
       oshape[2] = 1 + static_cast<int>(std::ceil(
                           static_cast<float>(dshape[2] + 2 * param.pad[0] -
                                              param.kernel[0]) /
                           param.stride[0]));
+    } else {
+      oshape[2] = static_cast<int>(std::ceil(
+                          static_cast<float>(dshape[2] + 2 * param.pad[0]) /
+                          param.stride[0]));
     }
     out_shape->clear();
     out_shape->push_back(oshape);  // save output shape
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 43c357808f1..a7f484e81b3 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6975,6 +6975,40 @@ def test_valid_kernel_size():
         mx.nd.array(np.random.rand(1, 1, 28, 28)),
         kernel_size=valid_kernel_size)
 
+@with_seed()
+def test_valid_max_pooling_pad_type_same():
+    import math
+    input_data = mx.nd.array(np.random.rand(1,1,10))
+    stride = 2
+    kernel = 2
+    output_data=mx.nd.Pooling(
+        input_data,
+        kernel=kernel,
+        stride=stride,
+        pad=(0,0,0),
+        pool_type='max',
+        name='pooling',
+        pooling_convention="same")
+    assert(math.ceil(input_data.shape[2]/stride) == output_data.shape[2])
+
+@with_seed()
+def test_invalid_max_pooling_pad_type_same():
+    import math
+    input_data = mx.nd.array(np.random.rand(1,1,10))
+    stride = 2
+    kernel = 2
+    pad = 2
+    assert_exception(
+        mx.nd.Pooling,
+        MXNetError,
+        input_data,
+        stride=stride,
+        kernel=kernel,
+        pad=pad,
+        pool_type='max',
+        name='pooling',
+        pooling_convention="same")
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services