You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2018/07/16 03:13:27 UTC

[1/4] incubator-singa git commit: SINGA-384 Implement ResNet using autograd API

Repository: incubator-singa
Updated Branches:
  refs/heads/master 76779be72 -> 870c5df0b


SINGA-384 Implement ResNet using autograd API

Add ResNet as an example of autograd.

Rename autograd operations to be consistent with torch

Pass the inference of resnet


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2b5c3f70
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2b5c3f70
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2b5c3f70

Branch: refs/heads/master
Commit: 2b5c3f709ee2c0530f4a97ea26a34f55bff36c6e
Parents: 76779be
Author: Wang Wei <wa...@gmail.com>
Authored: Fri Jul 13 16:06:32 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Mon Jul 16 10:04:13 2018 +0800

----------------------------------------------------------------------
 examples/autograd/resnet.py      | 226 ++++++++++++++++++++++++++++++++++
 python/singa/autograd.py         |  38 +++---
 src/model/operation/batchnorm.cc |   1 -
 3 files changed, 243 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2b5c3f70/examples/autograd/resnet.py
----------------------------------------------------------------------
diff --git a/examples/autograd/resnet.py b/examples/autograd/resnet.py
new file mode 100644
index 0000000..930d9e0
--- /dev/null
+++ b/examples/autograd/resnet.py
@@ -0,0 +1,226 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# the code is modified from
+# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
+
+from singa import autograd
+from singa import tensor
+from singa import device
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152']
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return autograd.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride,
+                           padding=1, bias=False)
+
+
+class BasicBlock(autograd.Layer):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = autograd.BatchNorm2d(planes)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = autograd.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def __call__(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = autograd.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = autograd.relu(out)
+
+        return out
+
+
+class Bottleneck(autograd.Layer):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = autograd.Conv2D(
+            inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = autograd.BatchNorm2d(planes)
+        self.conv2 = autograd.Conv2D(planes, planes, kernel_size=3, stride=stride,
+                                     padding=1, bias=False)
+        self.bn2 = autograd.BatchNorm2d(planes)
+        self.conv3 = autograd.Conv2D(
+            planes, planes * self.expansion, kernel_size=1, bias=False)
+        self.bn3 = autograd.BatchNorm2d(planes * self.expansion)
+
+        self.downsample = downsample
+        self.stride = stride
+
+    def __call__(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = autograd.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = autograd.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = autograd.relu(out)
+
+        return out
+
+
+class ResNet(autograd.Layer):
+
+    def __init__(self, block, layers, num_classes=1000):
+        self.inplanes = 64
+        super(ResNet, self).__init__()
+        self.conv1 = autograd.Conv2D(3, 64, kernel_size=7, stride=2, padding=3,
+                                     bias=False)
+        self.bn1 = autograd.BatchNorm2d(64)
+        self.maxpool = autograd.MaxPool2d(
+            kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.avgpool = autograd.AvgPool2d(7, stride=1)
+        self.fc = autograd.Linear(512 * block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            conv = autograd.Conv2D(self.inplanes, planes * block.expansion,
+                                   kernel_size=1, stride=stride, bias=False)
+            bn = autograd.BatchNorm2d(planes * block.expansion),
+            downsample = lambda x: bn(conv(x))
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        def forward(x):
+            for layer in layers:
+                x = layer(x)
+            return x
+        return forward
+
+    def __call__(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = autograd.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = autograd.flatten(x)
+        x = self.fc(x)
+
+        return x
+
+
+def resnet18(pretrained=False, **kwargs):
+    """Constructs a ResNet-18 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+
+    return model
+
+
+def resnet34(pretrained=False, **kwargs):
+    """Constructs a ResNet-34 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+
+    return model
+
+
+def resnet50(pretrained=False, **kwargs):
+    """Constructs a ResNet-50 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+
+    return model
+
+
+def resnet101(pretrained=False, **kwargs):
+    """Constructs a ResNet-101 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+
+    return model
+
+
+def resnet152(pretrained=False, **kwargs):
+    """Constructs a ResNet-152 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+
+    return model
+
+
+if __name__ == '__main__':
+
+    model = resnet18()
+    x = tensor.Tensor((16, 3, 224, 224), device.create_cuda_gpu())
+    x.set_value(float(0.1))
+    autograd.training = True
+    y = model(x)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2b5c3f70/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index faa9685..c77c174 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -572,12 +572,12 @@ class Concat(Operation):
         return tuple(dxs)
 
 
-def concat(xs, axis=0):
+def cat(xs, axis=0):
     # xs is a tuple of multiple Tensors
     return Concat(axis)(*xs)[0]
 
 
-class _Conv2D(Operation):
+class _Conv2d(Operation):
 
     def __init__(self, handle):
         self.handle = handle
@@ -627,10 +627,10 @@ class _Conv2D(Operation):
 
 
 def conv2d(handle, x, W, b):
-    return _Conv2D(handle)(x, W, b)[0]
+    return _Conv2d(handle)(x, W, b)[0]
 
 
-class Conv2D(Layer):
+class Conv2d(Layer):
 
     def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                  padding=0, dilation=1, groups=1, bias=True, **kwargs):
@@ -693,10 +693,6 @@ class Conv2D(Layer):
 
     def __call__(self, x):
         assert x.shape[1] == self.in_channels, 'in_channels dismatched'
-        assert (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]
-                ) % self.stride[0] == 0, 'invalid padding or strides.'
-        assert (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]
-                ) % self.stride[1] == 0, 'invalid padding or stride.'
 
         self.device_check(x, self.W, self.b)
 
@@ -720,7 +716,7 @@ class Conv2D(Layer):
         return y
 
 
-class BatchNorm(Layer):
+class BatchNorm2d(Layer):
 
     def __init__(self, num_features, momentum=0.9):
         self.channels = num_features
@@ -765,7 +761,7 @@ class BatchNorm(Layer):
         return y
 
 
-class _BatchNorm(Operation):
+class _BatchNorm2d(Operation):
 
     def __init__(self, handle, running_mean, running_var):
         self.running_mean = running_mean.data
@@ -805,11 +801,11 @@ class _BatchNorm(Operation):
             return dx, ds, db
 
 
-def batchnorm(handle, x, scale, bias, running_mean, running_var):
+def batchnorm_2d(handle, x, scale, bias, running_mean, running_var):
     return _BatchNorm(handle, running_mean, running_var)(x, scale, bias)[0]
 
 
-class _Pooling2D(Operation):
+class _Pooling2d(Operation):
 
     def __init__(self, handle):
         self.handle = handle
@@ -838,7 +834,7 @@ def pooling_2d(handle, x):
     return _Pooling2D(handle)(x)[0]
 
 
-class Pooling2D(Layer):
+class Pooling2d(Layer):
 
     def __init__(self, kernel_size, stride=None, padding=0, is_max=True):
         if isinstance(kernel_size, int):
@@ -897,31 +893,31 @@ class Pooling2D(Layer):
         return y
 
 
-class MaxPooling2D(Pooling2D):
+class MaxPool2d(Pooling2D):
 
     def __init__(self, kernel_size, stride=None, padding=0):
-        super(MaxPooling2D, self).__init__(kernel_size, stride, padding, True)
+        super(MaxPool2d, self).__init__(kernel_size, stride, padding, True)
 
 
-class AvgPooling2D(Pooling2D):
+class AvgPool2d(Pooling2D):
 
     def __init__(self, kernel_size, stride=None, padding=0):
-        super(AvgPooling2D, self).__init__(kernel_size, stride, padding, False)
+        super(AvgPool2d, self).__init__(kernel_size, stride, padding, False)
 
 
-class MaxPooling1D(Pooling2D):
+class MaxPool1d(Pooling2D):
 
     def __init__(self, kernel_size, stride=None, padding=0):
         if stride is None:
             stride = kernel_size
-        super(MaxPooling2D, self).__init__(
+        super(MaxPool2d, self).__init__(
             (1, kernel_size), (0, stride), (0, padding), True)
 
 
-class AvgPooling1D(Pooling2D):
+class AvgPool1d(Pooling2D):
 
     def __init__(self, kernel_size, stride=None, padding=0):
         if stride is None:
             stride = kernel_size
-        super(MaxPooling2D, self).__init__(
+        super(MaxPool2d, self).__init__(
             (1, kernel_size), (0, stride), (0, padding), False)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2b5c3f70/src/model/operation/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
index 29eaba9..4673919 100755
--- a/src/model/operation/batchnorm.cc
+++ b/src/model/operation/batchnorm.cc
@@ -121,7 +121,6 @@ const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
   CHECK_EQ(mean.device()->lang(), kCuda);
   CHECK_EQ(var.device()->lang(), kCuda);
 
-  vector<Tensor> out_grads;
   Tensor dx;
   dx.ResetLike(dy);
 


[2/4] incubator-singa git commit: SINGA-385 Add new python module for optimizers

Posted by zh...@apache.org.
SINGA-385 Add new python module for optimizers

Add the base optimizer and SGD (with momentum).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/117dfcfd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/117dfcfd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/117dfcfd

Branch: refs/heads/master
Commit: 117dfcfd052bb92142a30b59fc173a2ef6480332
Parents: 2b5c3f7
Author: Wang Wei <wa...@gmail.com>
Authored: Sat Jul 14 13:07:52 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Mon Jul 16 10:04:54 2018 +0800

----------------------------------------------------------------------
 examples/autograd/resnet.py | 117 ++++++++++++++++++++++++++++--
 python/singa/autograd.py    |  13 ++++
 python/singa/opt.py         | 152 +++++++++++++++++++++++++++++++++++++++
 python/singa/tensor.py      |  12 ++++
 4 files changed, 287 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/examples/autograd/resnet.py
----------------------------------------------------------------------
diff --git a/examples/autograd/resnet.py b/examples/autograd/resnet.py
index 930d9e0..f1fb9d6 100644
--- a/examples/autograd/resnet.py
+++ b/examples/autograd/resnet.py
@@ -23,6 +23,10 @@
 from singa import autograd
 from singa import tensor
 from singa import device
+from singa import utils
+from singa import optimizer
+
+import numpy as np
 
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@@ -60,7 +64,7 @@ class BasicBlock(autograd.Layer):
         if self.downsample is not None:
             residual = self.downsample(x)
 
-        out += residual
+        out = autograd.add(out, residual)
         out = autograd.relu(out)
 
         return out
@@ -101,7 +105,7 @@ class Bottleneck(autograd.Layer):
         if self.downsample is not None:
             residual = self.downsample(x)
 
-        out += residual
+        out = autograd.add(out, residual)
         out = autograd.relu(out)
 
         return out
@@ -217,10 +221,109 @@ def resnet152(pretrained=False, **kwargs):
     return model
 
 
-if __name__ == '__main__':
+def load_dataset(filepath):
+    print('Loading data file %s' % filepath)
+    with open(filepath, 'rb') as fd:
+        try:
+            cifar10 = pickle.load(fd, encoding='latin1')
+        except TypeError:
+            cifar10 = pickle.load(fd)
+    image = cifar10['data'].astype(dtype=np.uint8)
+    image = image.reshape((-1, 3, 32, 32))
+    label = np.asarray(cifar10['labels'], dtype=np.uint8)
+    label = label.reshape(label.size, 1)
+    return image, label
+
+
+def load_train_data(dir_path, num_batches=5):
+    labels = []
+    batchsize = 10000
+    images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
+    for did in range(1, num_batches + 1):
+        fname_train_data = dir_path + "/data_batch_{}".format(did)
+        image, label = load_dataset(fname_train_data)
+        images[(did - 1) * batchsize:did * batchsize] = image
+        labels.extend(label)
+    images = np.array(images, dtype=np.float32)
+    labels = np.array(labels, dtype=np.int32)
+    return images, labels
+
+
+def load_test_data(dir_path):
+    images, labels = load_dataset(dir_path + "/test_batch")
+    return np.array(images,  dtype=np.float32), np.array(labels, dtype=np.int32)
+
+
+def accuracy(pred, target):
+    y = np.argmax(pred, axis=1)
+    t = np.argmax(target, axis=1)
+    a = y == t
+    return np.array(a, 'int').sum() / float(len(t))
+
+
+def train(data, net, max_epoch, get_lr, weight_decay=1e-5, batch_size=100):
+    print('Start intialization............')
+    dev = device.create_cuda_gpu()
+
+    opt = optimizer.SGD(momentum=0.9, weight_decay=weight_decay)
+
+    tx = tensor.Tensor((batch_size, 3, 32, 32), dev)
+    ty = tensor.Tensor((batch_size,), dev, tensor.int32)
+    train_x, train_y, test_x, test_y = data
+    num_train_batch = train_x.shape[0] // batch_size
+    num_test_batch = test_x.shape[0] // batch_size
+    idx = np.arange(train_x.shape[0], dtype=np.int32)
+    for epoch in range(max_epoch):
+        np.random.shuffle(idx)
+        loss, acc = 0.0, 0.0
+        print('Epoch %d' % epoch)
+        autograd.training = True
+        for b in range(num_train_batch):
+            x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
+            y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
+            tx.copy_from_numpy(x)
+            ty.copy_from_numpy(y)
+            x = net(tx)
+            loss = autograd.softmax_cross_entropy(x, ty)
+            np_loss = tensor.to_numpy(loss)
+            acc += accuracy(tensor.to_numpy(x), y)
+
+            for p, g in autograd.backwards(loss):
+                opt.apply_with_lr(epoch, get_lr(epoch), g, p)
+            # update progress bar
+            utils.update_progress(b * 1.0 / num_train_batch,
+                                  'training loss = %f' % (np_loss[0]))
+
+        loss, acc = 0.0, 0.0
+        autograd.training = True
+        for b in range(num_test_batch):
+            x = test_x[b * batch_size: (b + 1) * batch_size]
+            y = test_y[b * batch_size: (b + 1) * batch_size]
+            tx.copy_from_numpy(x)
+            ty.copy_from_numpy(y)
+            x = net(tx)
+            l = autograd.softmax_cross_entropy(x, ty)
+            loss += tensor.to_numpy(l)[0]
+            acc += accuracy(x, y)
+
+        print('test loss = %f, test accuracy = %f' %
+              ((loss / num_test_batch), (acc / num_test_batch)))
+
+
+def resnet_lr(epoch):
+    if epoch < 81:
+        return 0.1
+    elif epoch < 122:
+        return 0.01
+    else:
+        return 0.001
 
+
+if __name__ == '__main__':
     model = resnet18()
-    x = tensor.Tensor((16, 3, 224, 224), device.create_cuda_gpu())
-    x.set_value(float(0.1))
-    autograd.training = True
-    y = model(x)
+    train_x, train_y = load_train_data()
+    test_x, test_y = load_test_data()
+    mean = np.average(train_x, axis=0)
+    train_x -= mean
+    test_x -= mean
+    train(model, (train_x, train_y, test_x, test_y), 10, resnet_lr)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index c77c174..63e3771 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -347,6 +347,19 @@ def add_bias(x, b, axis=0):
     return AddBias(axis)(x, b)[0]
 
 
+class Add(Operation):
+
+    def forward(self, a, b):
+        return a + b
+
+    def backward(self, dy):
+        return dy, dy
+
+
+def add(a, b):
+    return Add()(a, b)[0]
+
+
 class SoftMax(Operation):
     '''
     Apply SoftMax for each row of the Tensor or each column of the Tensor

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/opt.py
----------------------------------------------------------------------
diff --git a/python/singa/opt.py b/python/singa/opt.py
new file mode 100644
index 0000000..bf04b09
--- /dev/null
+++ b/python/singa/opt.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+'''This module includes a set of optimizers for updating model parameters.
+It replaces the old optimizers from optimizer.py'''
+
+from singa import tensor
+
+
+class Optimizer(object):
+    r"""Base optimizer.
+
+    Args:
+        config (Dict): specify the default values of configurable variables.
+    """
+
+    def __init__(self, config):
+        self.config = config
+        self.step = 0
+        self.param2config = {}
+
+    def update(self, param, grad):
+        r"""Update the param values with given gradients.
+
+        Args:
+            param(Tensor): param values to be updated in-place
+            grad(Tensor): param gradients; the values may be updated
+                    in this function; do not use it anymore
+        """
+        pass
+
+    def step(self):
+        r"""To increment the step counter"""
+        self.step += 1
+
+    def register(self, param_group, config):
+        for param in param_group:
+            assert param not in self.param2config, 'param is already registered'
+
+            self.param2config[param] = config
+
+    def load(self):
+        pass
+
+    def save(self):
+        pass
+
+
+class SGD(Optimizer):
+    r"""Implements stochastic gradient descent (optionally with momentum).
+
+    Nesterov momentum is based on the formula from
+    `On the importance of initialization and momentum in deep learning`__.
+
+    Args:
+        lr(float): learning rate
+        momentum(float, optional): momentum factor(default: 0)
+        weight_decay(float, optional): weight decay(L2 penalty)(default: 0)
+        dampening(float, optional): dampening for momentum(default: 0)
+        nesterov(bool, optional): enables Nesterov momentum(default: False)
+
+    Example:
+        >> > from singa import opt
+        >> > optimizer = opt.SGD(lr=0.1, momentum=0.9)
+        >> > optimizer.update()
+
+    __ http: // www.cs.toronto.edu / %7Ehinton / absps / momentum.pdf
+
+    .. note::
+        The implementation of SGD with Momentum / Nesterov subtly differs from
+        Sutskever et. al. and implementations in some other frameworks.
+
+        Considering the specific case of Momentum, the update can be written as
+
+        .. math::
+                  v = \rho * v + g \\
+                  p = p - lr * v
+
+        where p, g, v and: math: `\rho` denote the parameters, gradient,
+        velocity, and momentum respectively.
+
+        This is in contrast to Sutskever et. al. and
+        other frameworks which employ an update of the form
+
+        .. math::
+             v = \rho * v + lr * g \\
+             p = p - v
+
+        The Nesterov version is analogously modified.
+    """
+
+    def __init__(self, lr=0.1, momentum=0, dampening=0,
+                 weight_decay=0, nesterov=False):
+        if momentum < 0.0:
+            raise ValueError("Invalid momentum value: {}".format(momentum))
+        if weight_decay < 0.0:
+            raise ValueError(
+                "Invalid weight_decay value: {}".format(weight_decay))
+
+        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
+                        weight_decay=weight_decay, nesterov=nesterov)
+        if nesterov and (momentum <= 0 or dampening != 0):
+            raise ValueError(
+                "Nesterov momentum requires a momentum and zero dampening")
+        super(SGD, self).__init__(defaults)
+
+    def update(self, param, grad):
+        """Performs a single optimization step.
+
+        Arguments:
+                param(Tensor): param values to be update in-place
+                grad(Tensor): param gradients; the values may be updated
+                        in this function; cannot use it anymore
+        """
+        group = self.param2group[param]
+        weight_decay = group['weight_decay']
+        momentum = group['momentum']
+        dampening = group['dampening']
+        nesterov = group['nesterov']
+
+        if weight_decay != 0:
+            grad += param * weight_decay
+        if momentum != 0:
+            param_state = self.state[param]
+            if 'momentum_buffer' not in param_state:
+                buf = param_state[
+                    'momentum_buffer'] = tensor.zeros_like(param)
+                buf *= momentum
+                buf += grad
+            else:
+                buf = param_state['momentum_buffer']
+                buf *= momentum
+                buf += (1 - dampening) * grad
+            if nesterov:
+                grad += momentum * buf
+            else:
+                grad = buf
+        param -= grad * group['lr']

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 46a47b7..441431f 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -602,6 +602,18 @@ def from_raw_tensors(tt):
     return ret
 
 
+def zeros_like(t):
+    ret = Tensor(t.shape, t.device, t.dtype)
+    ret.set_value(float(0))
+    return ret
+
+
+def ones_like(t):
+    ret = Tensor(t.shape, t.device, t.dtype)
+    ret.set_value(float(1))
+    return ret
+
+
 def product(shape):
     return reduce(lambda x, y: x * y, shape)
 


[4/4] incubator-singa git commit: SINGA-382 Implement concat operation for autograd

Posted by zh...@apache.org.
SINGA-382 Implement concat operation for autograd

Update the cnn_mnist example to use concat operation.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/870c5df0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/870c5df0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/870c5df0

Branch: refs/heads/master
Commit: 870c5df0b9fa6eb87044b49e1013ef2f5a5298e1
Parents: e651c1a
Author: Wang Wei <wa...@gmail.com>
Authored: Mon Jul 16 10:11:07 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Mon Jul 16 10:11:07 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py | 24 +++++++++++++-----------
 tool/conda/singa/meta.yaml     |  1 +
 2 files changed, 14 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/870c5df0/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index 3ddd532..62ae5b2 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -24,7 +24,7 @@ import os
 from singa import device
 from singa import tensor
 from singa import autograd
-from singa import optimizer
+from singa import opt
 
 
 def load_data(path):
@@ -92,7 +92,7 @@ if __name__ == '__main__':
     num_classes = 10
     epochs = 1
 
-    sgd = optimizer.SGD(0.001)
+    sgd = opt.SGD(lr=0.01)
 
     x_train = preprocess(train[0])
     y_train = to_categorical(train[1], num_classes)
@@ -105,14 +105,14 @@ if __name__ == '__main__':
     print('the shape of testing label is', y_test.shape)
 
     # operations initialization
-    conv1 = autograd.Conv2D(1, 32, 3, padding=1, bias=False)
-    bn1 = autograd.BatchNorm(32)
-    conv21 = autograd.Conv2D(32, 16, 3, padding=1)
-    conv22 = autograd.Conv2D(32, 16, 3, padding=1)
-    bn2 = autograd.BatchNorm(32)
+    conv1 = autograd.Conv2d(1, 32, 3, padding=1, bias=False)
+    bn1 = autograd.BatchNorm2d(32)
+    conv21 = autograd.Conv2d(32, 16, 3, padding=1)
+    conv22 = autograd.Conv2d(32, 16, 3, padding=1)
+    bn2 = autograd.BatchNorm2d(32)
     linear = autograd.Linear(32 * 28 * 28, 10)
-    pooling1 = autograd.MaxPooling2D(3, 1, padding=1)
-    pooling2 = autograd.AvgPooling2D(3, 1, padding=1)
+    pooling1 = autograd.MaxPool2d(3, 1, padding=1)
+    pooling2 = autograd.AvgPool2d(3, 1, padding=1)
 
     def forward(x, t):
         y = conv1(x)
@@ -121,7 +121,7 @@ if __name__ == '__main__':
         y = pooling1(y)
         y1 = conv21(y)
         y2 = conv22(y)
-        y = autograd.concat((y1, y2), 1)
+        y = autograd.cat((y1, y2), 1)
         y = bn2(y)
         y = autograd.relu(y)
         y = bn2(y)
@@ -148,4 +148,6 @@ if __name__ == '__main__':
                       tensor.to_numpy(loss)[0])
 
             for p, gp in autograd.backward(loss):
-                sgd.apply(epoch, gp, p, '')
+                sgd.update(p, gp)
+
+            sgd.step()

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/870c5df0/tool/conda/singa/meta.yaml
----------------------------------------------------------------------
diff --git a/tool/conda/singa/meta.yaml b/tool/conda/singa/meta.yaml
index ee76636..424532c 100644
--- a/tool/conda/singa/meta.yaml
+++ b/tool/conda/singa/meta.yaml
@@ -55,6 +55,7 @@ requirements:
     - flask-cors >=3.0.2
     - pillow >=2.3.0
     - future >=0.16.0
+    - tqdm
 
 test:
   source_files:


[3/4] incubator-singa git commit: SINGA-384 Implement ResNet using autograd API

Posted by zh...@apache.org.
SINGA-384 Implement ResNet using autograd API

Implment a simple CNN using autograd API and train it over cifar10

Benchmark resnet training time


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e651c1ae
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e651c1ae
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e651c1ae

Branch: refs/heads/master
Commit: e651c1ae68a600e162d317f1575f2b4b57b96622
Parents: 117dfcf
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Sat Jul 14 18:20:51 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Mon Jul 16 10:06:31 2018 +0800

----------------------------------------------------------------------
 examples/autograd/resnet.py    | 145 +++++++++---------------------------
 python/singa/autograd.py       |  34 +++++----
 python/singa/opt.py            |  11 ++-
 src/api/model_operation.i      |  24 +++---
 src/model/operation/pooling.cc |  22 +++---
 src/model/operation/pooling.h  |  34 ++++-----
 6 files changed, 100 insertions(+), 170 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e651c1ae/examples/autograd/resnet.py
----------------------------------------------------------------------
diff --git a/examples/autograd/resnet.py b/examples/autograd/resnet.py
index f1fb9d6..72c33ed 100644
--- a/examples/autograd/resnet.py
+++ b/examples/autograd/resnet.py
@@ -23,10 +23,10 @@
 from singa import autograd
 from singa import tensor
 from singa import device
-from singa import utils
-from singa import optimizer
+from singa import opt
 
 import numpy as np
+from tqdm import trange
 
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@@ -35,7 +35,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
 
 def conv3x3(in_planes, out_planes, stride=1):
     """3x3 convolution with padding"""
-    return autograd.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride,
+    return autograd.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                            padding=1, bias=False)
 
 
@@ -75,13 +75,14 @@ class Bottleneck(autograd.Layer):
 
     def __init__(self, inplanes, planes, stride=1, downsample=None):
         super(Bottleneck, self).__init__()
-        self.conv1 = autograd.Conv2D(
+        self.conv1 = autograd.Conv2d(
             inplanes, planes, kernel_size=1, bias=False)
         self.bn1 = autograd.BatchNorm2d(planes)
-        self.conv2 = autograd.Conv2D(planes, planes, kernel_size=3, stride=stride,
+        self.conv2 = autograd.Conv2d(planes, planes, kernel_size=3,
+                                     stride=stride,
                                      padding=1, bias=False)
         self.bn2 = autograd.BatchNorm2d(planes)
-        self.conv3 = autograd.Conv2D(
+        self.conv3 = autograd.Conv2d(
             planes, planes * self.expansion, kernel_size=1, bias=False)
         self.bn3 = autograd.BatchNorm2d(planes * self.expansion)
 
@@ -116,7 +117,7 @@ class ResNet(autograd.Layer):
     def __init__(self, block, layers, num_classes=1000):
         self.inplanes = 64
         super(ResNet, self).__init__()
-        self.conv1 = autograd.Conv2D(3, 64, kernel_size=7, stride=2, padding=3,
+        self.conv1 = autograd.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                                      bias=False)
         self.bn1 = autograd.BatchNorm2d(64)
         self.maxpool = autograd.MaxPool2d(
@@ -131,10 +132,12 @@ class ResNet(autograd.Layer):
     def _make_layer(self, block, planes, blocks, stride=1):
         downsample = None
         if stride != 1 or self.inplanes != planes * block.expansion:
-            conv = autograd.Conv2D(self.inplanes, planes * block.expansion,
+            conv = autograd.Conv2d(self.inplanes, planes * block.expansion,
                                    kernel_size=1, stride=stride, bias=False)
-            bn = autograd.BatchNorm2d(planes * block.expansion),
-            downsample = lambda x: bn(conv(x))
+            bn = autograd.BatchNorm2d(planes * block.expansion)
+
+            def downsample(x):
+                return bn(conv(x))
 
         layers = []
         layers.append(block(self.inplanes, planes, stride, downsample))
@@ -221,109 +224,29 @@ def resnet152(pretrained=False, **kwargs):
     return model
 
 
-def load_dataset(filepath):
-    print('Loading data file %s' % filepath)
-    with open(filepath, 'rb') as fd:
-        try:
-            cifar10 = pickle.load(fd, encoding='latin1')
-        except TypeError:
-            cifar10 = pickle.load(fd)
-    image = cifar10['data'].astype(dtype=np.uint8)
-    image = image.reshape((-1, 3, 32, 32))
-    label = np.asarray(cifar10['labels'], dtype=np.uint8)
-    label = label.reshape(label.size, 1)
-    return image, label
-
-
-def load_train_data(dir_path, num_batches=5):
-    labels = []
-    batchsize = 10000
-    images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
-    for did in range(1, num_batches + 1):
-        fname_train_data = dir_path + "/data_batch_{}".format(did)
-        image, label = load_dataset(fname_train_data)
-        images[(did - 1) * batchsize:did * batchsize] = image
-        labels.extend(label)
-    images = np.array(images, dtype=np.float32)
-    labels = np.array(labels, dtype=np.int32)
-    return images, labels
-
-
-def load_test_data(dir_path):
-    images, labels = load_dataset(dir_path + "/test_batch")
-    return np.array(images,  dtype=np.float32), np.array(labels, dtype=np.int32)
-
-
-def accuracy(pred, target):
-    y = np.argmax(pred, axis=1)
-    t = np.argmax(target, axis=1)
-    a = y == t
-    return np.array(a, 'int').sum() / float(len(t))
-
-
-def train(data, net, max_epoch, get_lr, weight_decay=1e-5, batch_size=100):
+if __name__ == '__main__':
+    model = resnet18()
     print('Start intialization............')
-    dev = device.create_cuda_gpu()
+    dev = device.create_cuda_gpu_on(1)
 
-    opt = optimizer.SGD(momentum=0.9, weight_decay=weight_decay)
+    niters = 200
+    batch_size = 16
+    IMG_SIZE = 224
+    sgd = opt.SGD(lr=0.1, momentum=0.9, weight_decay=1e-5)
 
-    tx = tensor.Tensor((batch_size, 3, 32, 32), dev)
+    tx = tensor.Tensor((batch_size, 3, IMG_SIZE, IMG_SIZE), dev)
     ty = tensor.Tensor((batch_size,), dev, tensor.int32)
-    train_x, train_y, test_x, test_y = data
-    num_train_batch = train_x.shape[0] // batch_size
-    num_test_batch = test_x.shape[0] // batch_size
-    idx = np.arange(train_x.shape[0], dtype=np.int32)
-    for epoch in range(max_epoch):
-        np.random.shuffle(idx)
-        loss, acc = 0.0, 0.0
-        print('Epoch %d' % epoch)
-        autograd.training = True
-        for b in range(num_train_batch):
-            x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
-            y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
-            tx.copy_from_numpy(x)
-            ty.copy_from_numpy(y)
-            x = net(tx)
+    autograd.training = True
+    x = np.random.randn(batch_size, 3, IMG_SIZE, IMG_SIZE).astype(np.float32)
+    y = np.random.randint(0, 1000, batch_size, dtype=np.int32)
+    tx.copy_from_numpy(x)
+    ty.copy_from_numpy(y)
+
+    with trange(niters) as t:
+        for b in t:
+            x = model(tx)
             loss = autograd.softmax_cross_entropy(x, ty)
-            np_loss = tensor.to_numpy(loss)
-            acc += accuracy(tensor.to_numpy(x), y)
-
-            for p, g in autograd.backwards(loss):
-                opt.apply_with_lr(epoch, get_lr(epoch), g, p)
-            # update progress bar
-            utils.update_progress(b * 1.0 / num_train_batch,
-                                  'training loss = %f' % (np_loss[0]))
-
-        loss, acc = 0.0, 0.0
-        autograd.training = True
-        for b in range(num_test_batch):
-            x = test_x[b * batch_size: (b + 1) * batch_size]
-            y = test_y[b * batch_size: (b + 1) * batch_size]
-            tx.copy_from_numpy(x)
-            ty.copy_from_numpy(y)
-            x = net(tx)
-            l = autograd.softmax_cross_entropy(x, ty)
-            loss += tensor.to_numpy(l)[0]
-            acc += accuracy(x, y)
-
-        print('test loss = %f, test accuracy = %f' %
-              ((loss / num_test_batch), (acc / num_test_batch)))
-
-
-def resnet_lr(epoch):
-    if epoch < 81:
-        return 0.1
-    elif epoch < 122:
-        return 0.01
-    else:
-        return 0.001
-
-
-if __name__ == '__main__':
-    model = resnet18()
-    train_x, train_y = load_train_data()
-    test_x, test_y = load_test_data()
-    mean = np.average(train_x, axis=0)
-    train_x -= mean
-    test_x -= mean
-    train(model, (train_x, train_y, test_x, test_y), 10, resnet_lr)
+            for p, g in autograd.backward(loss):
+                # print(p.shape, g.shape)
+                # sgd.update(p, g)
+                pass

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e651c1ae/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 63e3771..a084764 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -350,7 +350,7 @@ def add_bias(x, b, axis=0):
 class Add(Operation):
 
     def forward(self, a, b):
-        return a + b
+        return singa.__add__(a, b)
 
     def backward(self, dy):
         return dy, dy
@@ -469,22 +469,24 @@ def cross_entropy(y, t):
 
 class SoftMaxCrossEntropy(Operation):
 
-    def forward(self, x, t):
+    def __init__(self, t):
+        self.t = t.data
+
+    def forward(self, x):
         self.p = singa.SoftMax(x)
-        self.t = t
         loss = CTensor((1,), self.p.device())
-        ret = singa.CrossEntropyFwd(self.p, t)
+        ret = singa.CrossEntropyFwd(self.p, self.t)
         loss.SetFloatValue(singa.SumAsFloat(ret) / x.shape()[0])
         return loss
 
     def backward(self, dy=1.0):
         dx = singa.SoftmaxCrossEntropyBwd(self.p, self.t)
-        return singa.DivFloat(dx, float(self.p.shape()[0])), None
+        return singa.DivFloat(dx, float(self.p.shape()[0]))
 
 
 def softmax_cross_entropy(x, t):
     # x is the logits and t is the ground truth; both are 2D.
-    return SoftMaxCrossEntropy()(x, t)[0]
+    return SoftMaxCrossEntropy(t)(x)[0]
 
 
 def ctensor2numpy(x):
@@ -769,7 +771,7 @@ class BatchNorm2d(Layer):
                     self.momentum, x.data)
         self.handle.device_id = x.device.id()
 
-        y = batchnorm(self.handle, x, self.scale, self.bias,
+        y = batchnorm_2d(self.handle, x, self.scale, self.bias,
                       self.running_mean, self.running_var)
         return y
 
@@ -794,7 +796,7 @@ class _BatchNorm2d(Operation):
             if self.handle.device_id == -1:
                 raise NotImplementedError
             else:
-                y, _, _ = singa.GpuBatchNormForwardInference(
+                y = singa.GpuBatchNormForwardInference(
                     self.handle, x, scale, bias, self.running_mean, self.running_var)
         return y
 
@@ -815,7 +817,7 @@ class _BatchNorm2d(Operation):
 
 
 def batchnorm_2d(handle, x, scale, bias, running_mean, running_var):
-    return _BatchNorm(handle, running_mean, running_var)(x, scale, bias)[0]
+    return _BatchNorm2d(handle, running_mean, running_var)(x, scale, bias)[0]
 
 
 class _Pooling2d(Operation):
@@ -844,7 +846,7 @@ class _Pooling2d(Operation):
 
 
 def pooling_2d(handle, x):
-    return _Pooling2D(handle)(x)[0]
+    return _Pooling2d(handle)(x)[0]
 
 
 class Pooling2d(Layer):
@@ -894,11 +896,11 @@ class Pooling2d(Layer):
         else:
             if not hasattr(self, 'handle'):
                 self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride,
-                                                       self.padding, self.is_max)  # False for nan_prop
+                                                       self.padding, self.is_max)
             elif x.shape[0] != self.handle.batchsize or out_shape_h != self.handle.pooled_height or \
                     out_shape_w != self.handle.pooled_width:
                 self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride,
-                                                       self.padding, self.is_max)  # False for nan_prop
+                                                       self.padding, self.is_max)
 
         self.handle.device_id = x.device.id()
 
@@ -906,19 +908,19 @@ class Pooling2d(Layer):
         return y
 
 
-class MaxPool2d(Pooling2D):
+class MaxPool2d(Pooling2d):
 
     def __init__(self, kernel_size, stride=None, padding=0):
         super(MaxPool2d, self).__init__(kernel_size, stride, padding, True)
 
 
-class AvgPool2d(Pooling2D):
+class AvgPool2d(Pooling2d):
 
     def __init__(self, kernel_size, stride=None, padding=0):
         super(AvgPool2d, self).__init__(kernel_size, stride, padding, False)
 
 
-class MaxPool1d(Pooling2D):
+class MaxPool1d(Pooling2d):
 
     def __init__(self, kernel_size, stride=None, padding=0):
         if stride is None:
@@ -927,7 +929,7 @@ class MaxPool1d(Pooling2D):
             (1, kernel_size), (0, stride), (0, padding), True)
 
 
-class AvgPool1d(Pooling2D):
+class AvgPool1d(Pooling2d):
 
     def __init__(self, kernel_size, stride=None, padding=0):
         if stride is None:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e651c1ae/python/singa/opt.py
----------------------------------------------------------------------
diff --git a/python/singa/opt.py b/python/singa/opt.py
index bf04b09..6c59f28 100644
--- a/python/singa/opt.py
+++ b/python/singa/opt.py
@@ -29,9 +29,10 @@ class Optimizer(object):
     """
 
     def __init__(self, config):
-        self.config = config
+        self.default_config = config
         self.step = 0
         self.param2config = {}
+        self.param2state = {}
 
     def update(self, param, grad):
         r"""Update the param values with given gradients.
@@ -126,7 +127,9 @@ class SGD(Optimizer):
                 grad(Tensor): param gradients; the values may be updated
                         in this function; cannot use it anymore
         """
-        group = self.param2group[param]
+        group = self.default_config
+        if param in self.param2config:
+            group = self.param2config[param]
         weight_decay = group['weight_decay']
         momentum = group['momentum']
         dampening = group['dampening']
@@ -135,7 +138,9 @@ class SGD(Optimizer):
         if weight_decay != 0:
             grad += param * weight_decay
         if momentum != 0:
-            param_state = self.state[param]
+            if param not in self.param2state:
+                self.param2state[param] = {}
+            param_state = self.param2state[param]
             if 'momentum_buffer' not in param_state:
                 buf = param_state[
                     'momentum_buffer'] = tensor.zeros_like(param)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e651c1ae/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 3d9bdbe..435ff1c 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -41,14 +41,14 @@ class BatchNormHandle{
 
 class PoolingHandle {
  public:
-  PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
-                const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+  PoolingHandle(const Tensor &input, const std::vector<int>& kernel_size,
+                const std::vector<int>& stride, const std::vector<int>& padding,
                 const bool is_max=true);
 
-  size_t batchsize;
+  int batchsize;
 
-  size_t pooled_height;
-  size_t pooled_width;
+  int pooled_height;
+  int pooled_width;
 };
 
 
@@ -92,14 +92,14 @@ const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh,
 
 class CudnnPoolingHandle : public PoolingHandle {
  public:
-  CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
-                     const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+  CudnnPoolingHandle(const Tensor &input, const std::vector<int>& kernel_size,
+                     const std::vector<int>& stride, const std::vector<int>& padding,
                      const bool is_max=true);
 
-  size_t batchsize;
-  
-  size_t pooled_height;
-  size_t pooled_width;
+  int batchsize;
+
+  int pooled_height;
+  int pooled_width;
 };
 
 Tensor GpuPoolingForward(const CudnnPoolingHandle &cph, const Tensor &x);
@@ -108,4 +108,4 @@ Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy, const
 
 #endif  // USE_CUDNN
 
-}  //namespace singa
\ No newline at end of file
+}  //namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e651c1ae/src/model/operation/pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/pooling.cc b/src/model/operation/pooling.cc
index 03ff804..efc03ff 100755
--- a/src/model/operation/pooling.cc
+++ b/src/model/operation/pooling.cc
@@ -4,8 +4,8 @@
 namespace singa {
 
 PoolingHandle::PoolingHandle(const Tensor &input,
-                             const std::vector<size_t>& kernel_size,
-                             const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+                             const std::vector<int>& kernel_size,
+                             const std::vector<int>& stride, const std::vector<int>& padding,
                              const bool is_max) {
   kernel_h = kernel_size[0];
   kernel_w = kernel_size[1];
@@ -24,18 +24,19 @@ PoolingHandle::PoolingHandle(const Tensor &input,
   pooled_height = 1;
 
   if (stride_h > 0)
-    pooled_height =
-      static_cast<size_t>((height + 2 * pad_h - kernel_h) / stride_h) + 1;
-  pooled_width =
-    static_cast<size_t>((width + 2 * pad_w - kernel_w) / stride_w) + 1;
+    pooled_height = std::floor(
+      ((height + 2 * pad_h - kernel_h) / stride_h)) + 1;
+  pooled_width = std::floor(
+    ((width + 2 * pad_w - kernel_w) / stride_w)) + 1;
   is_max_pooling = is_max;
 }
 
 #ifdef USE_CUDNN
 
 CudnnPoolingHandle::CudnnPoolingHandle(const Tensor &input,
-                                       const std::vector<size_t>& kernel_size,
-                                       const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+                                       const std::vector<int>& kernel_size,
+                                       const std::vector<int>& stride,
+                                       const std::vector<int>& padding,
                                        const bool is_max)
   : PoolingHandle(input, kernel_size, stride, padding, is_max) {
 
@@ -51,14 +52,13 @@ CudnnPoolingHandle::CudnnPoolingHandle(const Tensor &input,
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW,
                                          GetCudnnDataType(dtype), batchsize,
                                          channels, height, width));
+  // LOG(ERROR) << batchsize << " " << channels << " " << pooled_height << " " << pooled_width;
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(
                 y_desc, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, channels,
                 pooled_height, pooled_width));
-  auto pool_method = CUDNN_POOLING_MAX;
+  auto pool_method = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
   if (is_max)
     pool_method = CUDNN_POOLING_MAX;
-  else
-    pool_method = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
 
   CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc, pool_method, nan_prop,
                                           kernel_h, kernel_w, pad_h, pad_w,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e651c1ae/src/model/operation/pooling.h
----------------------------------------------------------------------
diff --git a/src/model/operation/pooling.h b/src/model/operation/pooling.h
index a4d1051..b6a4d21 100644
--- a/src/model/operation/pooling.h
+++ b/src/model/operation/pooling.h
@@ -13,24 +13,24 @@ namespace singa {
 
 class PoolingHandle {
  public:
-  PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
-                const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+  PoolingHandle(const Tensor &input, const std::vector<int>& kernel_size,
+                const std::vector<int>& stride, const std::vector<int>& padding,
                 const bool is_max = true);
 
-  size_t kernel_w;
-  size_t pad_w;
-  size_t stride_w;
-  size_t kernel_h;
-  size_t pad_h;
-  size_t stride_h;
+  int kernel_w;
+  int pad_w;
+  int stride_w;
+  int kernel_h;
+  int pad_h;
+  int stride_h;
 
-  size_t batchsize;
-  size_t channels;
-  size_t height;
-  size_t width;
+  int batchsize;
+  int channels;
+  int height;
+  int width;
 
-  size_t pooled_height;
-  size_t pooled_width;
+  int pooled_height;
+  int pooled_width;
 
   bool is_max_pooling;
 };
@@ -38,8 +38,8 @@ class PoolingHandle {
 #ifdef USE_CUDNN
 class CudnnPoolingHandle : public PoolingHandle {
  public:
-  CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
-                     const std::vector<size_t>& stride, const std::vector<size_t>& padding,
+  CudnnPoolingHandle(const Tensor &input, const std::vector<int>& kernel_size,
+                     const std::vector<int>& stride, const std::vector<int>& padding,
                      const bool is_max = true);
   ~CudnnPoolingHandle();
 
@@ -59,4 +59,4 @@ Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy,
 
 }  // namespace singa
 
-#endif  // SINGA_MODEL_OPERATION_POOLING_H_
\ No newline at end of file
+#endif  // SINGA_MODEL_OPERATION_POOLING_H_