You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ka...@apache.org on 2018/07/13 05:43:45 UTC

[1/3] incubator-singa git commit: SINGA-378 Implement maxpooling operation and its related functions for autograd

Repository: incubator-singa
Updated Branches:
  refs/heads/master f134a24e2 -> a36291824


SINGA-378 Implement maxpooling operation and its related functions for autograd

- implement corresponding functions for maxpooling, GPU part.

- write inferface file for maxpooling functions.

- implement maxpooling layer and maxpooling operation in python

- modified example codes.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/571818eb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/571818eb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/571818eb

Branch: refs/heads/master
Commit: 571818eb06ca1e88e51a1ad8adb61f43c349ee62
Parents: f134a24
Author: xuewanqi <xu...@outlook.com>
Authored: Wed Jul 11 14:30:02 2018 +0000
Committer: Wang Wei <wa...@gmail.com>
Committed: Thu Jul 12 17:08:21 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py |  23 +++--
 python/singa/autograd.py       | 162 ++++++++++++++++++++----------------
 src/api/model_operation.i      |  38 ++++++++-
 src/model/operation/pooling.cc | 126 ++++++++++++++++++++++++++++
 src/model/operation/pooling.h  |  63 ++++++++++++++
 5 files changed, 327 insertions(+), 85 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index b1d8dbe..2cb3cae 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -84,7 +84,7 @@ if __name__ == '__main__':
         dev = device.get_default_device()
     else:
         print('Using GPU')
-        dev = device.create_cuda_gpu_on(1)
+        dev = device.create_cuda_gpu()
 
     train, test = load_data(args.file_path)
 
@@ -92,7 +92,7 @@ if __name__ == '__main__':
     num_classes = 10
     epochs = 1
 
-    sgd = optimizer.SGD(0.01)
+    sgd = optimizer.SGD(0.001)
 
     x_train = preprocess(train[0])
     y_train = to_categorical(train[1], num_classes)
@@ -110,27 +110,32 @@ if __name__ == '__main__':
     conv2 = autograd.Conv2D(32, 32, 3, padding=1)
     bn2 = autograd.BatchNorm(32)
     linear = autograd.Linear(32 * 28 * 28, 10)
-
+    pooling1 = autograd.MaxPool2D(3, 1, padding=1)
+    pooling2 = autograd.MaxPool2D(3, 1, padding=1)
 
     def forward(x, t):
         y = conv1(x)
         y = autograd.relu(y)
         y = bn1(y)
         y = autograd.max_pool_2d(y)
+        y = pooling1(y)
+
         y = conv2(y)
-        y = bn2(y)
         y = autograd.relu(y)
-        y = autograd.max_pool_2d(y)
-        y=autograd.flatten(y)
+        y = bn2(y)
+        y = pooling2(y)
+        y = autograd.flatten(y)
         y = linear(y)
         loss = autograd.softmax_cross_entropy(y, t)
         return loss, y
 
     autograd.training = True
-    for epoch in range(epochs):
+    for epoch in range(50):
         for i in range(batch_number):
-            inputs = tensor.Tensor(device=dev, data=x_train[ i * 100:(1 + i) * 100], stores_grad=False)
-            targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100], requires_grad=False, stores_grad=False)
+            inputs = tensor.Tensor(device=dev, data=x_train[
+                                   i * 100:(1 + i) * 100], stores_grad=False)
+            targets = tensor.Tensor(device=dev, data=y_train[
+                                    i * 100:(1 + i) * 100], requires_grad=False, stores_grad=False)
 
             loss, y = forward(inputs, targets)
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index d272dcd..fcdc020 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -37,12 +37,9 @@ def infer_dependency(op):
     '''
     Infer the dependency of all operations with the
     given op as the last operation.
-
     Operation A is depending on B is A uses the output(s) of B.
-
     Args:
         op: an Operation instance, e.g. the loss operation.
-
     Return:
         a Counter instance with the operation as the key,
         and the number of operations that are depending on it as the value
@@ -74,12 +71,10 @@ def gradients(y, dy=None):
 def backward(y, dy=None):
     '''
     Run the backward propagation starting at y.
-
     Args:
         y: a Tensor instance, usually the loss
         dy: a number or a Tensor instance, for the gradient of the
             objective/loss w.r.t y, usually 1.0
-
     Return:
         a dictionary storing the gradient tensors of all tensors
         whose stores_grad is true (e.g. parameter tensors)
@@ -156,7 +151,6 @@ class Operation(object):
     '''
     An operation includes the forward and backward function of
     tensor calculation.
-
     Steps to add a specific operation Xxxx:
     1. create a subclass of Operation, name it as Xxxx
     2. override the forward() and backward(); The arguments of forward()
@@ -169,10 +163,8 @@ class Operation(object):
     def _do_forward(self, *xs):
         '''
         Do not call this function from user code. It is called by __call__().
-
         Args:
             xs, Tensor instance(s)
-
         Returns:
             Tensor instance(s)
         '''
@@ -218,10 +210,8 @@ class Operation(object):
 
     def forward(self, *xs):
         '''Forward propagation.
-
         Args:
             xs: input args consisting of only CTensors.
-
         Returns:
             CTensor instance(s)
         '''
@@ -229,10 +219,8 @@ class Operation(object):
 
     def backward(self, *dys):
         ''' Backward propagation.
-
         Args:
             dys: input args consisting of only CTensors.
-
         Returns:
             CTensor instance(s)
         '''
@@ -244,7 +232,6 @@ class Operation(object):
 
 class Dummy(Operation):
     '''Dummy operation whice serves as a placehoder for autograd
-
     Args:
         name(string): set it for debug
     '''
@@ -262,7 +249,6 @@ class ReLU(Operation):
         '''
         Args:
             x(CTensor): input tensor
-
         Returns:
             a new CTensor whose element y = x if x >= 0; otherwise 0;
         '''
@@ -274,7 +260,6 @@ class ReLU(Operation):
         '''
         Args:
             dy(CTensor): dL / dy
-
         Returns:
             dx(CTensor): dL / dx = dy if x >= 0; otherwise 0;
         '''
@@ -291,13 +276,10 @@ class Matmul(Operation):
 
     def forward(self, x, w):
         '''Do forward propgation.
-
         Store the x(or w) if w(or x) requires gradient.
-
         Args:
             x (CTensor): matrix
             w (CTensor): matrix
-
         Returns:
             a CTensor for the result
         '''
@@ -309,7 +291,6 @@ class Matmul(Operation):
         '''
         Args:
             dy (CTensor): data for the dL / dy, L is the loss
-
         Returns:
             a tuple for (dx, dw)
         '''
@@ -329,7 +310,6 @@ class AddBias(Operation):
     def __init__(self, axis=0):
         '''
         To indicate the calculation axis, 0 for row, 1 for column.
-
         Args:
             axis: 0 or 1, default is 0.
         '''
@@ -340,7 +320,6 @@ class AddBias(Operation):
         Args:
             x: matrix.
             b: bias to be added.
-
         Return:
             the result Tensor
         '''
@@ -354,7 +333,6 @@ class AddBias(Operation):
         '''
         Args:
             dy (CTensor): data for the dL / dy, L is the loss.
-
         Return:
             a tuple for (db, dx), db is data for dL / db, dx is data
             for dL / dx.
@@ -382,7 +360,6 @@ class SoftMax(Operation):
         '''
         Args:
             x(data): the input 1d or 2d tensor
-
         Returns:
             the result Tensor
         '''
@@ -398,7 +375,6 @@ class SoftMax(Operation):
         '''
         Args:
             dy (CTensor): data for the dL / dy, L is the loss
-
         Returns:
             dx (Ctensor): data for the dL / dx, L is the loss,
             x is the input of current Opertion
@@ -435,7 +411,6 @@ def soft_max(x, axis=0):
 class CrossEntropy(Operation):
     '''
     Calculte negative log likelihood loss for a batch of training data.
-
     '''
 
     def forward(self, x, t):
@@ -444,7 +419,6 @@ class CrossEntropy(Operation):
             x (CTensor): 1d or 2d tensor, the prediction data(output)
                          of current network.
             t (CTensor): 1d or 2d tensor, the target data for training.
-
         Returns:
             loss (CTensor): scalar.
         '''
@@ -461,7 +435,6 @@ class CrossEntropy(Operation):
         Args:
             dy (float or CTensor): scalar, accumulate gradient from outside
                                 of current network, usually equal to 1.0
-
         Returns:
             dx (CTensor): data for the dL /dx, L is the loss, x is the output
                           of current network. note that this is true for
@@ -510,60 +483,33 @@ def ctensor2numpy(x):
     return np_array.reshape(x.shape())
 
 
-class MaxPool2d(Operation):
-
-    def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1,
-                 return_indices=False, ceil_mode=False, **kwargs):
-
-        inner_params = {'name': 'MaxPool2d',
-                        'border_mode': 'same',
-                        'data_format': 'NCHW',
-                        'input_sample_shape': None
-                        }
+class _MaxPool2D(Operation):
 
-        for kwarg in kwargs:
-            if kwarg not in inner_params:
-                raise TypeError('Keyword argument not understood:', kwarg)
-            else:
-                inner_params[kwarg] = kwargs[kwarg]
+    def __init__(self, handle):
+        self.handle = handle
 
-        if padding == 0:
-            pad = None
+    def forward(self, x):
+        if self.handle.device_id == -1:
+            raise NotImplementedError
         else:
-            pad = padding
-
-        if dilation != 1 or return_indices or ceil_mode:
-            raise ValueError('Not implemented yet')
-
-        self.PyLayer = layer.Pooling2D(inner_params['name'],
-                                       model_pb2.PoolingConf.MAX,
-                                       kernel_size, stride, inner_params[
-                                           'border_mode'],
-                                       pad, inner_params['data_format'],
-                                       inner_params['input_sample_shape'])
+            y = singa.GpuPoolingForward(x, self.handle)
 
-    def __call__(self, x):
         if training:
-            self.flag = model_pb2.kTrain
-        else:
-            self.flag = model_pb2.kEval
-
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
+            self.cache = (x, y)
 
-        return self._do_forward(x)
-
-    def forward(self, *xs):
-        return self.PyLayer.layer.Forward(self.flag, xs[0])
+        return y
 
     def backward(self, dy):
-        return self.PyLayer.layer.Backward(0, dy)[0]
+        if self.handle.device_id == -1:
+            raise NotImplementedError
+        else:
+            dx = singa.GpuPoolingBackward(
+                dy, self.cache[0], self.cache[1], self.handle)
+        return dx
 
 
-def max_pool_2d(x, kernel_size=3, stride=1, padding=0, dilation=1,
-                return_indices=False, ceil_mode=False, **kwargs):
-    return MaxPool2d(kernel_size, stride, padding, dilation, return_indices,
-                     ceil_mode, **kwargs)(x)[0]
+def max_pool_2d(x, handle):
+    return _MaxPool2D(handle)(x)[0]
 
 
 class Flatten(Operation):
@@ -771,6 +717,9 @@ class Conv2D(Layer):
         return y
 
 
+<< << << < HEAD
+
+
 class BatchNorm(Layer):
 
     def __init__(self, num_features, momentum=0.9):
@@ -811,7 +760,7 @@ class BatchNorm(Layer):
         self.handle.device_id = x.device.id()
 
         y = batchnorm(x, self.scale, self.bias,
-                        self.running_mean, self.running_var, self.handle)
+                      self.running_mean, self.running_var, self.handle)
         return y
 
 
@@ -857,3 +806,72 @@ class _BatchNorm(Operation):
 
 def batchnorm(x, scale, bias, running_mean, running_var, handle):
     return _BatchNorm(running_mean, running_var, handle)(x, scale, bias)[0]
+
+
+class MaxPool2D(Layer):
+
+    def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
+                 return_indices=False, ceil_mode=False):
+        if isinstance(kernel_size, int):
+            self.kernel_size = (kernel_size, kernel_size)
+        elif isinstance(kernel_size, tuple):
+            self.kernel_size = kernel_size
+        else:
+            raise TypeError('Wrong kernel_size type.')
+
+        if stride is None:
+            self.stride = self.kernel_size
+        elif isinstance(stride, int):
+            self.stride = (stride, stride)
+        elif isinstance(stride, tuple):
+            self.stride = stride
+        else:
+            raise TypeError('Wrong stride type.')
+
+        if isinstance(padding, int):
+            self.padding = (padding, padding)
+        elif isinstance(padding, tuple):
+            self.padding = padding
+        else:
+            raise TypeError('Wrong padding type.')
+
+        if dilation != 1:
+            raise ValueError('Not implemented yet')
+
+        if return_indices is not False:
+            raise ValueError('Not implemented yet')
+
+        self.ceil_mode = ceil_mode
+
+    def __call__(self, x):
+        if self.ceil_mode:
+            out_shape_h = int(math.ceil(
+                (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]) / self.stride[0])) + 1
+            out_shape_w = int(math.ceil(
+                (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]) / self.stride[1])) + 1
+        else:
+            out_shape_h = int(
+                (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]) // self.stride[0]) + 1
+            out_shape_w = int(
+                (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]) // self.stride[1]) + 1
+        if x.device.id() == -1:
+            if not hasattr(self, 'handle'):
+                self.handle = singa.PoolingHandle(x.data, self.kernel_size, self.stride,
+                                                  self.padding, self.ceil_mode, 'MAX')
+            elif x.shape[0] != self.handle.batchsize or out_shape_h != self.handle.pooled_height or \
+                    out_shape_w != self.handle.pooled_width:
+                self.handle = singa.PoolingHandle(x.data, self.kernel_size, self.stride,
+                                                  self.padding, self.ceil_mode, 'MAX')
+        else:
+            if not hasattr(self, 'handle'):
+                self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride,
+                                                       self.padding, self.ceil_mode, 'MAX', False)  # False for nan_prop
+            elif x.shape[0] != self.handle.batchsize or out_shape_h != self.handle.pooled_height or \
+                    out_shape_w != self.handle.pooled_width:
+                self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride,
+                                                       self.padding, self.ceil_mode, 'MAX', False)  # False for nan_prop
+
+        self.handle.device_id = x.device.id()
+
+        y = max_pool_2d(x, self.handle)
+        return y

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index eb41fd0..4800ff1 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -6,6 +6,8 @@
 %{
 #include "../src/model/operation/convolution.h"
 #include "../src/model/operation/batchnorm.h"
+#include "../src/model/operation/pooling.h"
+
 %}
 
 namespace singa {
@@ -29,7 +31,6 @@ Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch);
 
 
-
 class BatchNormHandle{
   public:
     BatchNormHandle(const float momentum, const Tensor& input);
@@ -38,6 +39,18 @@ class BatchNormHandle{
 };
 
 
+class PoolingHandle {
+ public:
+  PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+                const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+                const bool ceil_mode = false, const std::string pooling_method = "MAX");
+
+  size_t batchsize;
+
+  size_t pooled_height;
+  size_t pooled_width;
+};
+
 
 #if USE_CUDNN
 class CudnnConvHandle: public ConvHandle {
@@ -60,8 +73,6 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch);
 
 
-
-
 class CudnnBatchNormHandle: public BatchNormHandle{
     public:
       CudnnBatchNormHandle(const float momentum, const Tensor& input);
@@ -78,6 +89,25 @@ Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, const Tens
 const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
   const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean, const Tensor& var);
 
+
+class CudnnPoolingHandle : public PoolingHandle {
+ public:
+  CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+                     const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+                     const bool ceil_mode = false, const std::string pooling_method = "MAX",
+                     const bool NaN_prop = false);
+
+  size_t batchsize;
+  
+  size_t pooled_height;
+  size_t pooled_width;
+};
+
+Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph);
+
+Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y,
+                          const CudnnPoolingHandle &cph);
+
 #endif  // USE_CUDNN
-}
 
+}  //namespace singa
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/src/model/operation/pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/pooling.cc b/src/model/operation/pooling.cc
new file mode 100644
index 0000000..0abda35
--- /dev/null
+++ b/src/model/operation/pooling.cc
@@ -0,0 +1,126 @@
+#include "./pooling.h"
+#include <cmath>
+
+namespace singa {
+
+PoolingHandle::PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+                             const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+                             const bool ceil_mode, const std::string pooling_method) {
+  kernel_h = kernel_size[0];
+  kernel_w = kernel_size[1];
+
+  pad_h = padding[0];
+  pad_w = padding[1];
+
+  stride_h = stride[0];
+  stride_w = stride[1];
+
+  batchsize = input.shape(0);
+  channels = input.shape(1);
+  height = input.shape(2);
+  width = input.shape(3);
+
+  pooled_height = 1;
+  if (ceil_mode) {
+    if (stride_h > 0)
+      pooled_height = static_cast<int>(ceil(static_cast<float>(height + 2 * pad_h - kernel_h) / stride_h)) + 1;
+    pooled_width = static_cast<int>(ceil(static_cast<float>(width + 2 * pad_w - kernel_w) / stride_w)) + 1;
+  }
+  else {
+    if (stride_h > 0)
+      pooled_height =
+        static_cast<size_t>((height + 2 * pad_h - kernel_h) / stride_h) + 1;
+    pooled_width =
+      static_cast<size_t>((width + 2 * pad_w - kernel_w) / stride_w) + 1;
+  }
+
+  method = pooling_method;
+  CHECK(method == "MAX" || method == "AVERAGE")
+      << "Padding implemented only for average and max pooling.";
+}
+
+#ifdef USE_CUDNN
+
+CudnnPoolingHandle::CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+                                       const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+                                       const bool ceil_mode, const std::string pooling_method, const bool NaN_prop)
+  : PoolingHandle(input, kernel_size, stride, padding, ceil_mode, pooling_method) {
+  if (NaN_prop)
+    nan_prop = CUDNN_PROPAGATE_NAN;
+  else
+    nan_prop = CUDNN_NOT_PROPAGATE_NAN;
+
+  DataType dtype = input.data_type();
+
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc));
+  CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
+
+
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW,
+                                         GetCudnnDataType(dtype), batchsize,
+                                         channels, height, width));
+  CUDNN_CHECK(cudnnSetTensor4dDescriptor(
+                y_desc, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, channels,
+                pooled_height, pooled_width));
+  auto pool_method = CUDNN_POOLING_MAX;
+  if (method == "MAX")
+    pool_method = CUDNN_POOLING_MAX;
+  else if (method == "AVERAGE")
+    pool_method = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
+  else
+    LOG(FATAL) << "Not implemented!";
+
+  CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc, pool_method, nan_prop,
+                                          kernel_h, kernel_w, pad_h, pad_w,
+                                          stride_h, stride_w));
+};
+
+CudnnPoolingHandle::~CudnnPoolingHandle() {
+  if (pool_desc != nullptr)
+    CUDNN_CHECK(cudnnDestroyPoolingDescriptor(pool_desc));
+  if (x_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc));
+  if (y_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc));
+};
+
+Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph) {
+  CHECK_EQ(x.device()->lang(), kCuda);
+  CHECK_EQ(x.nDim(), 4u);
+
+  DataType dtype = x.data_type();
+  auto dev = x.device();
+  Shape shape{cph.batchsize, cph.channels, cph.pooled_height, cph.pooled_width};
+  Tensor output = Tensor(shape, dev, dtype);
+
+  output.device()->Exec([&x, &output, &cph](Context * ctx) {
+    Block *inblock = x.block(), *outblock = output.block();
+    float alpha = 1.0f, beta = 0.0f;
+    cudnnPoolingForward(ctx->cudnn_handle, cph.pool_desc, &alpha,
+                        cph.x_desc, inblock->data(), &beta, cph.y_desc,
+                        outblock->mutable_data());
+  }, {x.block()}, {output.block()});
+  return output;
+};
+
+Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y,
+                          const CudnnPoolingHandle &cph) {
+  CHECK_EQ(dy.device()->lang(), kCuda);
+  CHECK_EQ(dy.nDim(), 4u);
+
+  Tensor dx;
+  dx.ResetLike(x);
+
+  dx.device()->Exec([&dx, &dy, &x, &y, &cph](Context * ctx) {
+    Block *dyblock = dy.block(), *dxblock = dx.block(), *yblock = y.block(),
+           *xblock = x.block();
+    float alpha = 1.0f, beta = 0.0f;
+    cudnnPoolingBackward(ctx->cudnn_handle, cph.pool_desc, &alpha,
+                         cph.y_desc, yblock->data(), cph.y_desc,
+                         dyblock->data(), cph.x_desc, xblock->data(), &beta,
+                         cph.x_desc, dxblock->mutable_data());
+  }, {dy.block(), y.block(), x.block()}, {dx.block()});
+  return dx;
+};
+#endif  //USE_CUDNN
+
+}  //namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/src/model/operation/pooling.h
----------------------------------------------------------------------
diff --git a/src/model/operation/pooling.h b/src/model/operation/pooling.h
new file mode 100644
index 0000000..9ed7e33
--- /dev/null
+++ b/src/model/operation/pooling.h
@@ -0,0 +1,63 @@
+#ifndef SINGA_MODEL_OPERATION_POOLING_H_
+#define SINGA_MODEL_OPERATION_POOLING_H_
+
+#include <string>
+#include "singa/core/tensor.h"
+
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#include "../layer/cudnn_utils.h"
+#endif
+
+namespace singa {
+
+class PoolingHandle {
+public:
+  PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+                const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+                const bool ceil_mode = false, const std::string pooling_method = "MAX");
+
+  size_t kernel_w;
+  size_t pad_w;
+  size_t stride_w;
+  size_t kernel_h;
+  size_t pad_h;
+  size_t stride_h;
+
+  size_t batchsize;
+  size_t channels;
+  size_t height;
+  size_t width;
+
+  size_t pooled_height;
+  size_t pooled_width;
+
+  std::string method;
+};
+
+#ifdef USE_CUDNN
+class CudnnPoolingHandle : public PoolingHandle {
+public:
+  CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+                     const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+                     const bool ceil_mode = false, const std::string pooling_method = "MAX",
+                     const bool NaN_prop = false);
+  ~CudnnPoolingHandle();
+
+  cudnnTensorDescriptor_t x_desc = nullptr;
+  cudnnTensorDescriptor_t y_desc = nullptr;
+  cudnnPoolingDescriptor_t pool_desc = nullptr;
+  cudnnNanPropagation_t nan_prop;
+
+};
+
+Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph);
+
+Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y,
+                          const CudnnPoolingHandle &cph);
+
+#endif  //USE_CUDNN
+
+}  // namespace singa
+
+#endif  // SINGA_MODEL_OPERATION_POOLING_H_
\ No newline at end of file


[2/3] incubator-singa git commit: SINGA-378 Implement maxpooling operation and its related functions for autograd

Posted by ka...@apache.org.
SINGA-378 Implement maxpooling operation and its related functions for autograd

Update API for pooling functions

Add MaxPooling2D, AvgPooling2D, MaxPooling1D and AvgPooling1D.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/fb5cb9ab
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/fb5cb9ab
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/fb5cb9ab

Branch: refs/heads/master
Commit: fb5cb9ab000d776eed11a5f4fd3b0e7285a109c0
Parents: 571818e
Author: Wang Wei <wa...@gmail.com>
Authored: Thu Jul 12 17:53:22 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Thu Jul 12 17:53:22 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py |   6 +-
 python/singa/autograd.py       | 227 +++++++++++++++++++-----------------
 src/api/model_operation.i      |  10 +-
 src/model/operation/pooling.cc |  76 +++++-------
 src/model/operation/pooling.h  |  19 ++-
 5 files changed, 169 insertions(+), 169 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/fb5cb9ab/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index 2cb3cae..d42dc76 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -110,8 +110,8 @@ if __name__ == '__main__':
     conv2 = autograd.Conv2D(32, 32, 3, padding=1)
     bn2 = autograd.BatchNorm(32)
     linear = autograd.Linear(32 * 28 * 28, 10)
-    pooling1 = autograd.MaxPool2D(3, 1, padding=1)
-    pooling2 = autograd.MaxPool2D(3, 1, padding=1)
+    pooling1 = autograd.MaxPooling2D(3, 1, padding=1)
+    pooling2 = autograd.AvgPooling2D(3, 1, padding=1)
 
     def forward(x, t):
         y = conv1(x)
@@ -130,7 +130,7 @@ if __name__ == '__main__':
         return loss, y
 
     autograd.training = True
-    for epoch in range(50):
+    for epoch in range(epochs):
         for i in range(batch_number):
             inputs = tensor.Tensor(device=dev, data=x_train[
                                    i * 100:(1 + i) * 100], stores_grad=False)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/fb5cb9ab/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index fcdc020..7b4d18d 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -483,35 +483,6 @@ def ctensor2numpy(x):
     return np_array.reshape(x.shape())
 
 
-class _MaxPool2D(Operation):
-
-    def __init__(self, handle):
-        self.handle = handle
-
-    def forward(self, x):
-        if self.handle.device_id == -1:
-            raise NotImplementedError
-        else:
-            y = singa.GpuPoolingForward(x, self.handle)
-
-        if training:
-            self.cache = (x, y)
-
-        return y
-
-    def backward(self, dy):
-        if self.handle.device_id == -1:
-            raise NotImplementedError
-        else:
-            dx = singa.GpuPoolingBackward(
-                dy, self.cache[0], self.cache[1], self.handle)
-        return dx
-
-
-def max_pool_2d(x, handle):
-    return _MaxPool2D(handle)(x)[0]
-
-
 class Flatten(Operation):
 
     def __init(self, start_axis=1):
@@ -534,6 +505,46 @@ def flatten(x):
     return Flatten()(x)[0]
 
 
+class Layer(object):
+
+    def __init__(self):
+        pass
+
+    def device_check(self, *inputs):
+        x_device = inputs[0].device
+        for var in inputs:
+            if var.device.id() != x_device:
+                var.to_device(x_device)
+
+
+class Linear(Layer):
+
+    def __init__(self, in_features, out_features, bias=True):
+        w_shape = (in_features, out_features)
+        b_shape = (1, out_features)
+        self.bias = bias
+
+        self.W = Tensor(shape=w_shape,
+                        requires_grad=True, stores_grad=True)
+        std = math.sqrt(2.0 / (in_features + out_features))
+        self.W.gaussian(0.0, std)
+
+        if self.bias:
+            self.b = Tensor(shape=b_shape,
+                            requires_grad=True, stores_grad=True)
+            self.b.set_value(0.0)
+
+    def __call__(self, x):
+        if self.bias:
+            self.device_check(x, self.W, self.b)
+        else:
+            self.device_check(x, self.W)
+        y = matmul(x, self.W)
+        if self.bias:
+            y = add_bias(y, self.b, axis=0)
+        return y
+
+
 class _Conv2D(Operation):
 
     def __init__(self, handle):
@@ -583,50 +594,10 @@ class _Conv2D(Operation):
                 return dx, dW, None
 
 
-def conv2d(x, W, b, handle):
+def conv2d(handle, x, W, b):
     return _Conv2D(handle)(x, W, b)[0]
 
 
-class Layer(object):
-
-    def __init__(self):
-        pass
-
-    def device_check(self, *inputs):
-        x_device = inputs[0].device
-        for var in inputs:
-            if var.device.id() != x_device:
-                var.to_device(x_device)
-
-
-class Linear(Layer):
-
-    def __init__(self, in_features, out_features, bias=True):
-        w_shape = (in_features, out_features)
-        b_shape = (1, out_features)
-        self.bias = bias
-
-        self.W = Tensor(shape=w_shape,
-                        requires_grad=True, stores_grad=True)
-        std = math.sqrt(2.0 / (in_features + out_features))
-        self.W.gaussian(0.0, std)
-
-        if self.bias:
-            self.b = Tensor(shape=b_shape,
-                            requires_grad=True, stores_grad=True)
-            self.b.set_value(0.0)
-
-    def __call__(self, x):
-        if self.bias:
-            self.device_check(x, self.W, self.b)
-        else:
-            self.device_check(x, self.W)
-        y = matmul(x, self.W)
-        if self.bias:
-            y = add_bias(y, self.b, axis=0)
-        return y
-
-
 class Conv2D(Layer):
 
     def __init__(self, in_channels, out_channels, kernel_size, stride=1,
@@ -713,13 +684,10 @@ class Conv2D(Layer):
                                                     self.padding, self.in_channels, self.out_channels, self.bias)
         self.handle.device_id = x.device.id()
 
-        y = conv2d(x, self.W, self.b, self.handle)
+        y = conv2d(self.handle, x, self.W, self.b)
         return y
 
 
-<< << << < HEAD
-
-
 class BatchNorm(Layer):
 
     def __init__(self, num_features, momentum=0.9):
@@ -759,14 +727,14 @@ class BatchNorm(Layer):
                     self.momentum, x.data)
         self.handle.device_id = x.device.id()
 
-        y = batchnorm(x, self.scale, self.bias,
-                      self.running_mean, self.running_var, self.handle)
+        y = batchnorm(self.handle, x, self.scale, self.bias,
+                      self.running_mean, self.running_var)
         return y
 
 
 class _BatchNorm(Operation):
 
-    def __init__(self, running_mean, running_var, handle):
+    def __init__(self, handle, running_mean, running_var):
         self.running_mean = running_mean.data
         self.running_var = running_var.data
         self.handle = handle
@@ -804,14 +772,42 @@ class _BatchNorm(Operation):
             return dx, ds, db
 
 
-def batchnorm(x, scale, bias, running_mean, running_var, handle):
-    return _BatchNorm(running_mean, running_var, handle)(x, scale, bias)[0]
+def batchnorm(handle, x, scale, bias, running_mean, running_var):
+    return _BatchNorm(handle, running_mean, running_var, handle)(x, scale, bias)[0]
+
+
+class _Pooling2D(Operation):
+
+    def __init__(self, handle):
+        self.handle = handle
+
+    def forward(self, x):
+        if self.handle.device_id == -1:
+            raise NotImplementedError
+        else:
+            y = singa.GpuPoolingForward(self.handle, x)
+
+        if training:
+            self.cache = (x, y)
 
+        return y
+
+    def backward(self, dy):
+        if self.handle.device_id == -1:
+            raise NotImplementedError
+        else:
+            dx = singa.GpuPoolingBackward(self.handle,
+                                          dy, self.cache[0], self.cache[1])
+        return dx
+
+
+def pooling_2d(handle, x):
+    return _Pooling2D(handle)(x)[0]
 
-class MaxPool2D(Layer):
 
-    def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
-                 return_indices=False, ceil_mode=False):
+class Pooling2D(Layer):
+
+    def __init__(self, kernel_size, stride=None, padding=0, is_max=True):
         if isinstance(kernel_size, int):
             self.kernel_size = (kernel_size, kernel_size)
         elif isinstance(kernel_size, tuple):
@@ -825,6 +821,8 @@ class MaxPool2D(Layer):
             self.stride = (stride, stride)
         elif isinstance(stride, tuple):
             self.stride = stride
+            assert stride[0] > 0 or (kernel_size[0] == 1 and padding[
+                0] == 0), 'stride[0]=0, but kernel_size[0]=%d, padding[0]=%d' % (kernel_size[0], padding[0])
         else:
             raise TypeError('Wrong stride type.')
 
@@ -835,43 +833,62 @@ class MaxPool2D(Layer):
         else:
             raise TypeError('Wrong padding type.')
 
-        if dilation != 1:
-            raise ValueError('Not implemented yet')
-
-        if return_indices is not False:
-            raise ValueError('Not implemented yet')
-
-        self.ceil_mode = ceil_mode
+        self.is_max = is_max
 
     def __call__(self, x):
-        if self.ceil_mode:
-            out_shape_h = int(math.ceil(
-                (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]) / self.stride[0])) + 1
-            out_shape_w = int(math.ceil(
-                (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]) / self.stride[1])) + 1
-        else:
-            out_shape_h = int(
-                (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]) // self.stride[0]) + 1
-            out_shape_w = int(
-                (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]) // self.stride[1]) + 1
+
+        out_shape_h = int(
+            (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]) // self.stride[0]) + 1
+        out_shape_w = int(
+            (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]) // self.stride[1]) + 1
         if x.device.id() == -1:
             if not hasattr(self, 'handle'):
-                self.handle = singa.PoolingHandle(x.data, self.kernel_size, self.stride,
-                                                  self.padding, self.ceil_mode, 'MAX')
+                self.handle = singa.PoolingHandle(
+                    x.data, self.kernel_size, self.stride, self.padding, self.is_max)
             elif x.shape[0] != self.handle.batchsize or out_shape_h != self.handle.pooled_height or \
                     out_shape_w != self.handle.pooled_width:
                 self.handle = singa.PoolingHandle(x.data, self.kernel_size, self.stride,
-                                                  self.padding, self.ceil_mode, 'MAX')
+                                                  self.padding, self.is_max)
         else:
             if not hasattr(self, 'handle'):
                 self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride,
-                                                       self.padding, self.ceil_mode, 'MAX', False)  # False for nan_prop
+                                                       self.padding, self.is_max)  # False for nan_prop
             elif x.shape[0] != self.handle.batchsize or out_shape_h != self.handle.pooled_height or \
                     out_shape_w != self.handle.pooled_width:
                 self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride,
-                                                       self.padding, self.ceil_mode, 'MAX', False)  # False for nan_prop
+                                                       self.padding, self.is_max)  # False for nan_prop
 
         self.handle.device_id = x.device.id()
 
-        y = max_pool_2d(x, self.handle)
+        y = pooling_2d(self.handle, x)
         return y
+
+
+class MaxPooling2D(Pooling2D):
+
+    def __init__(self, kernel_size, stride=None, padding=0):
+        super(MaxPooling2D, self).__init__(kernel_size, stride, padding, True)
+
+
+class AvgPooling2D(Pooling2D):
+
+    def __init__(self, kernel_size, stride=None, padding=0):
+        super(AvgPooling2D, self).__init__(kernel_size, stride, padding, False)
+
+
+class MaxPooling1D(Pooling2D):
+
+    def __init__(self, kernel_size, stride=None, padding=0):
+        if stride is None:
+            stride = kernel_size
+        super(MaxPooling2D, self).__init__(
+            (1, kernel_size), (0, stride), (0, padding), True)
+
+
+class AvgPooling1D(Pooling2D):
+
+    def __init__(self, kernel_size, stride=None, padding=0):
+        if stride is None:
+            stride = kernel_size
+        super(MaxPooling2D, self).__init__(
+            (1, kernel_size), (0, stride), (0, padding), False)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/fb5cb9ab/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 4800ff1..3d9bdbe 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -43,7 +43,7 @@ class PoolingHandle {
  public:
   PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
                 const std::vector<size_t>& stride, const std::vector<size_t>& padding,
-                const bool ceil_mode = false, const std::string pooling_method = "MAX");
+                const bool is_max=true);
 
   size_t batchsize;
 
@@ -94,8 +94,7 @@ class CudnnPoolingHandle : public PoolingHandle {
  public:
   CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
                      const std::vector<size_t>& stride, const std::vector<size_t>& padding,
-                     const bool ceil_mode = false, const std::string pooling_method = "MAX",
-                     const bool NaN_prop = false);
+                     const bool is_max=true);
 
   size_t batchsize;
   
@@ -103,10 +102,9 @@ class CudnnPoolingHandle : public PoolingHandle {
   size_t pooled_width;
 };
 
-Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph);
+Tensor GpuPoolingForward(const CudnnPoolingHandle &cph, const Tensor &x);
 
-Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y,
-                          const CudnnPoolingHandle &cph);
+Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy, const Tensor& x, const Tensor& y);
 
 #endif  // USE_CUDNN
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/fb5cb9ab/src/model/operation/pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/pooling.cc b/src/model/operation/pooling.cc
index 0abda35..0072671 100644
--- a/src/model/operation/pooling.cc
+++ b/src/model/operation/pooling.cc
@@ -3,9 +3,10 @@
 
 namespace singa {
 
-PoolingHandle::PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+PoolingHandle::PoolingHandle(const Tensor &input,
+                             const std::vector<size_t>& kernel_size,
                              const std::vector<size_t>& stride, const std::vector<size_t>& padding,
-                             const bool ceil_mode, const std::string pooling_method) {
+                             const bool is_max) {
   kernel_h = kernel_size[0];
   kernel_w = kernel_size[1];
 
@@ -21,34 +22,24 @@ PoolingHandle::PoolingHandle(const Tensor &input, const std::vector<size_t>& ker
   width = input.shape(3);
 
   pooled_height = 1;
-  if (ceil_mode) {
-    if (stride_h > 0)
-      pooled_height = static_cast<int>(ceil(static_cast<float>(height + 2 * pad_h - kernel_h) / stride_h)) + 1;
-    pooled_width = static_cast<int>(ceil(static_cast<float>(width + 2 * pad_w - kernel_w) / stride_w)) + 1;
-  }
-  else {
-    if (stride_h > 0)
-      pooled_height =
-        static_cast<size_t>((height + 2 * pad_h - kernel_h) / stride_h) + 1;
-    pooled_width =
-      static_cast<size_t>((width + 2 * pad_w - kernel_w) / stride_w) + 1;
-  }
-
-  method = pooling_method;
-  CHECK(method == "MAX" || method == "AVERAGE")
-      << "Padding implemented only for average and max pooling.";
+
+  if (stride_h > 0)
+    pooled_height =
+      static_cast<size_t>((height + 2 * pad_h - kernel_h) / stride_h) + 1;
+  pooled_width =
+    static_cast<size_t>((width + 2 * pad_w - kernel_w) / stride_w) + 1;
+  is_max_pooling = is_max;
 }
 
 #ifdef USE_CUDNN
 
-CudnnPoolingHandle::CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+CudnnPoolingHandle::CudnnPoolingHandle(const Tensor &input,
+                                       const std::vector<size_t>& kernel_size,
                                        const std::vector<size_t>& stride, const std::vector<size_t>& padding,
-                                       const bool ceil_mode, const std::string pooling_method, const bool NaN_prop)
-  : PoolingHandle(input, kernel_size, stride, padding, ceil_mode, pooling_method) {
-  if (NaN_prop)
-    nan_prop = CUDNN_PROPAGATE_NAN;
-  else
-    nan_prop = CUDNN_NOT_PROPAGATE_NAN;
+                                       const bool is_max)
+  : PoolingHandle(input, kernel_size, stride, padding, is_max) {
+
+#nan_prop = CUDNN_NOT_PROPAGATE_NAN;
 
   DataType dtype = input.data_type();
 
@@ -64,12 +55,10 @@ CudnnPoolingHandle::CudnnPoolingHandle(const Tensor &input, const std::vector<si
                 y_desc, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, channels,
                 pooled_height, pooled_width));
   auto pool_method = CUDNN_POOLING_MAX;
-  if (method == "MAX")
+  if (is_max)
     pool_method = CUDNN_POOLING_MAX;
-  else if (method == "AVERAGE")
-    pool_method = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
   else
-    LOG(FATAL) << "Not implemented!";
+    pool_method = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
 
   CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc, pool_method, nan_prop,
                                           kernel_h, kernel_w, pad_h, pad_w,
@@ -81,26 +70,24 @@ CudnnPoolingHandle::~CudnnPoolingHandle() {
     CUDNN_CHECK(cudnnDestroyPoolingDescriptor(pool_desc));
   if (x_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc));
   if (y_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc));
-};
+}
+
 
 Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph) {
   CHECK_EQ(x.device()->lang(), kCuda);
   CHECK_EQ(x.nDim(), 4u);
 
-  DataType dtype = x.data_type();
-  auto dev = x.device();
-  Shape shape{cph.batchsize, cph.channels, cph.pooled_height, cph.pooled_width};
-  Tensor output = Tensor(shape, dev, dtype);
+  Tensor output = Tensor({cph.batchsize, cph.channels, cph.pooled_height, cph.pooled_width},
+                         x.device(), x.data_type());
 
-  output.device()->Exec([&x, &output, &cph](Context * ctx) {
-    Block *inblock = x.block(), *outblock = output.block();
+  output.device()->Exec([&](Context * ctx) {
     float alpha = 1.0f, beta = 0.0f;
     cudnnPoolingForward(ctx->cudnn_handle, cph.pool_desc, &alpha,
-                        cph.x_desc, inblock->data(), &beta, cph.y_desc,
-                        outblock->mutable_data());
+                        cph.x_desc, x.block()->data(), &beta, cph.y_desc,
+                        output.block()->mutable_data());
   }, {x.block()}, {output.block()});
   return output;
-};
+}
 
 Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y,
                           const CudnnPoolingHandle &cph) {
@@ -110,14 +97,13 @@ Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y,
   Tensor dx;
   dx.ResetLike(x);
 
-  dx.device()->Exec([&dx, &dy, &x, &y, &cph](Context * ctx) {
-    Block *dyblock = dy.block(), *dxblock = dx.block(), *yblock = y.block(),
-           *xblock = x.block();
+  dx.device()->Exec([&](Context * ctx) {
+
     float alpha = 1.0f, beta = 0.0f;
     cudnnPoolingBackward(ctx->cudnn_handle, cph.pool_desc, &alpha,
-                         cph.y_desc, yblock->data(), cph.y_desc,
-                         dyblock->data(), cph.x_desc, xblock->data(), &beta,
-                         cph.x_desc, dxblock->mutable_data());
+                         cph.y_desc, y.block()->data(), cph.y_desc,
+                         dy.block()->data(), cph.x_desc, x.block()->data(), &beta,
+                         cph.x_desc, dx.block()->mutable_data());
   }, {dy.block(), y.block(), x.block()}, {dx.block()});
   return dx;
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/fb5cb9ab/src/model/operation/pooling.h
----------------------------------------------------------------------
diff --git a/src/model/operation/pooling.h b/src/model/operation/pooling.h
index 9ed7e33..a4d1051 100644
--- a/src/model/operation/pooling.h
+++ b/src/model/operation/pooling.h
@@ -12,10 +12,10 @@
 namespace singa {
 
 class PoolingHandle {
-public:
+ public:
   PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
                 const std::vector<size_t>& stride, const std::vector<size_t>& padding,
-                const bool ceil_mode = false, const std::string pooling_method = "MAX");
+                const bool is_max = true);
 
   size_t kernel_w;
   size_t pad_w;
@@ -32,29 +32,28 @@ public:
   size_t pooled_height;
   size_t pooled_width;
 
-  std::string method;
+  bool is_max_pooling;
 };
 
 #ifdef USE_CUDNN
 class CudnnPoolingHandle : public PoolingHandle {
-public:
+ public:
   CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
                      const std::vector<size_t>& stride, const std::vector<size_t>& padding,
-                     const bool ceil_mode = false, const std::string pooling_method = "MAX",
-                     const bool NaN_prop = false);
+                     const bool is_max = true);
   ~CudnnPoolingHandle();
 
   cudnnTensorDescriptor_t x_desc = nullptr;
   cudnnTensorDescriptor_t y_desc = nullptr;
   cudnnPoolingDescriptor_t pool_desc = nullptr;
-  cudnnNanPropagation_t nan_prop;
+  cudnnNanPropagation_t nan_prop = CUDNN_PROPAGATE_NAN;
 
 };
 
-Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph);
+Tensor GpuPoolingForward(const CudnnPoolingHandle &cph, const Tensor &x);
 
-Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y,
-                          const CudnnPoolingHandle &cph);
+Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy,
+                          const Tensor& x, const Tensor& y);
 
 #endif  //USE_CUDNN
 


[3/3] incubator-singa git commit: SINGA-378 Implement maxpooling operation and its related functions for autograd

Posted by ka...@apache.org.
SINGA-378 Implement maxpooling operation and its related functions for autograd

- fix some bugs and test the codes (mnist_cnn.py example runs well)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/a3629182
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/a3629182
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/a3629182

Branch: refs/heads/master
Commit: a36291824bdfd99b907adc68b5fc206c9053bdc8
Parents: fb5cb9a
Author: xuewanqi <xu...@outlook.com>
Authored: Thu Jul 12 11:10:49 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Thu Jul 12 11:10:49 2018 +0000

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py | 1 -
 python/singa/autograd.py       | 2 +-
 src/model/operation/pooling.cc | 8 ++++----
 3 files changed, 5 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a3629182/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index d42dc76..92fc43a 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -117,7 +117,6 @@ if __name__ == '__main__':
         y = conv1(x)
         y = autograd.relu(y)
         y = bn1(y)
-        y = autograd.max_pool_2d(y)
         y = pooling1(y)
 
         y = conv2(y)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a3629182/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 7b4d18d..16d7f82 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -773,7 +773,7 @@ class _BatchNorm(Operation):
 
 
 def batchnorm(handle, x, scale, bias, running_mean, running_var):
-    return _BatchNorm(handle, running_mean, running_var, handle)(x, scale, bias)[0]
+    return _BatchNorm(handle, running_mean, running_var)(x, scale, bias)[0]
 
 
 class _Pooling2D(Operation):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a3629182/src/model/operation/pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/pooling.cc b/src/model/operation/pooling.cc
old mode 100644
new mode 100755
index 0072671..03ff804
--- a/src/model/operation/pooling.cc
+++ b/src/model/operation/pooling.cc
@@ -39,7 +39,7 @@ CudnnPoolingHandle::CudnnPoolingHandle(const Tensor &input,
                                        const bool is_max)
   : PoolingHandle(input, kernel_size, stride, padding, is_max) {
 
-#nan_prop = CUDNN_NOT_PROPAGATE_NAN;
+//nan_prop = CUDNN_NOT_PROPAGATE_NAN;
 
   DataType dtype = input.data_type();
 
@@ -73,7 +73,7 @@ CudnnPoolingHandle::~CudnnPoolingHandle() {
 }
 
 
-Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph) {
+Tensor GpuPoolingForward(const CudnnPoolingHandle &cph, const Tensor &x) {
   CHECK_EQ(x.device()->lang(), kCuda);
   CHECK_EQ(x.nDim(), 4u);
 
@@ -89,8 +89,8 @@ Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph) {
   return output;
 }
 
-Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y,
-                          const CudnnPoolingHandle &cph) {
+Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy,
+                          const Tensor& x, const Tensor& y) {
   CHECK_EQ(dy.device()->lang(), kCuda);
   CHECK_EQ(dy.nDim(), 4u);