You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2018/07/16 03:13:28 UTC

[2/4] incubator-singa git commit: SINGA-385 Add new python module for optimizers

SINGA-385 Add new python module for optimizers

Add the base optimizer and SGD (with momentum).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/117dfcfd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/117dfcfd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/117dfcfd

Branch: refs/heads/master
Commit: 117dfcfd052bb92142a30b59fc173a2ef6480332
Parents: 2b5c3f7
Author: Wang Wei <wa...@gmail.com>
Authored: Sat Jul 14 13:07:52 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Mon Jul 16 10:04:54 2018 +0800

----------------------------------------------------------------------
 examples/autograd/resnet.py | 117 ++++++++++++++++++++++++++++--
 python/singa/autograd.py    |  13 ++++
 python/singa/opt.py         | 152 +++++++++++++++++++++++++++++++++++++++
 python/singa/tensor.py      |  12 ++++
 4 files changed, 287 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/examples/autograd/resnet.py
----------------------------------------------------------------------
diff --git a/examples/autograd/resnet.py b/examples/autograd/resnet.py
index 930d9e0..f1fb9d6 100644
--- a/examples/autograd/resnet.py
+++ b/examples/autograd/resnet.py
@@ -23,6 +23,10 @@
 from singa import autograd
 from singa import tensor
 from singa import device
+from singa import utils
+from singa import optimizer
+
+import numpy as np
 
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@@ -60,7 +64,7 @@ class BasicBlock(autograd.Layer):
         if self.downsample is not None:
             residual = self.downsample(x)
 
-        out += residual
+        out = autograd.add(out, residual)
         out = autograd.relu(out)
 
         return out
@@ -101,7 +105,7 @@ class Bottleneck(autograd.Layer):
         if self.downsample is not None:
             residual = self.downsample(x)
 
-        out += residual
+        out = autograd.add(out, residual)
         out = autograd.relu(out)
 
         return out
@@ -217,10 +221,109 @@ def resnet152(pretrained=False, **kwargs):
     return model
 
 
-if __name__ == '__main__':
+def load_dataset(filepath):
+    print('Loading data file %s' % filepath)
+    with open(filepath, 'rb') as fd:
+        try:
+            cifar10 = pickle.load(fd, encoding='latin1')
+        except TypeError:
+            cifar10 = pickle.load(fd)
+    image = cifar10['data'].astype(dtype=np.uint8)
+    image = image.reshape((-1, 3, 32, 32))
+    label = np.asarray(cifar10['labels'], dtype=np.uint8)
+    label = label.reshape(label.size, 1)
+    return image, label
+
+
+def load_train_data(dir_path, num_batches=5):
+    labels = []
+    batchsize = 10000
+    images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
+    for did in range(1, num_batches + 1):
+        fname_train_data = dir_path + "/data_batch_{}".format(did)
+        image, label = load_dataset(fname_train_data)
+        images[(did - 1) * batchsize:did * batchsize] = image
+        labels.extend(label)
+    images = np.array(images, dtype=np.float32)
+    labels = np.array(labels, dtype=np.int32)
+    return images, labels
+
+
+def load_test_data(dir_path):
+    images, labels = load_dataset(dir_path + "/test_batch")
+    return np.array(images,  dtype=np.float32), np.array(labels, dtype=np.int32)
+
+
+def accuracy(pred, target):
+    y = np.argmax(pred, axis=1)
+    t = np.argmax(target, axis=1)
+    a = y == t
+    return np.array(a, 'int').sum() / float(len(t))
+
+
+def train(data, net, max_epoch, get_lr, weight_decay=1e-5, batch_size=100):
+    print('Start intialization............')
+    dev = device.create_cuda_gpu()
+
+    opt = optimizer.SGD(momentum=0.9, weight_decay=weight_decay)
+
+    tx = tensor.Tensor((batch_size, 3, 32, 32), dev)
+    ty = tensor.Tensor((batch_size,), dev, tensor.int32)
+    train_x, train_y, test_x, test_y = data
+    num_train_batch = train_x.shape[0] // batch_size
+    num_test_batch = test_x.shape[0] // batch_size
+    idx = np.arange(train_x.shape[0], dtype=np.int32)
+    for epoch in range(max_epoch):
+        np.random.shuffle(idx)
+        loss, acc = 0.0, 0.0
+        print('Epoch %d' % epoch)
+        autograd.training = True
+        for b in range(num_train_batch):
+            x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
+            y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
+            tx.copy_from_numpy(x)
+            ty.copy_from_numpy(y)
+            x = net(tx)
+            loss = autograd.softmax_cross_entropy(x, ty)
+            np_loss = tensor.to_numpy(loss)
+            acc += accuracy(tensor.to_numpy(x), y)
+
+            for p, g in autograd.backwards(loss):
+                opt.apply_with_lr(epoch, get_lr(epoch), g, p)
+            # update progress bar
+            utils.update_progress(b * 1.0 / num_train_batch,
+                                  'training loss = %f' % (np_loss[0]))
+
+        loss, acc = 0.0, 0.0
+        autograd.training = True
+        for b in range(num_test_batch):
+            x = test_x[b * batch_size: (b + 1) * batch_size]
+            y = test_y[b * batch_size: (b + 1) * batch_size]
+            tx.copy_from_numpy(x)
+            ty.copy_from_numpy(y)
+            x = net(tx)
+            l = autograd.softmax_cross_entropy(x, ty)
+            loss += tensor.to_numpy(l)[0]
+            acc += accuracy(x, y)
+
+        print('test loss = %f, test accuracy = %f' %
+              ((loss / num_test_batch), (acc / num_test_batch)))
+
+
+def resnet_lr(epoch):
+    if epoch < 81:
+        return 0.1
+    elif epoch < 122:
+        return 0.01
+    else:
+        return 0.001
 
+
+if __name__ == '__main__':
     model = resnet18()
-    x = tensor.Tensor((16, 3, 224, 224), device.create_cuda_gpu())
-    x.set_value(float(0.1))
-    autograd.training = True
-    y = model(x)
+    train_x, train_y = load_train_data()
+    test_x, test_y = load_test_data()
+    mean = np.average(train_x, axis=0)
+    train_x -= mean
+    test_x -= mean
+    train(model, (train_x, train_y, test_x, test_y), 10, resnet_lr)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index c77c174..63e3771 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -347,6 +347,19 @@ def add_bias(x, b, axis=0):
     return AddBias(axis)(x, b)[0]
 
 
+class Add(Operation):
+
+    def forward(self, a, b):
+        return a + b
+
+    def backward(self, dy):
+        return dy, dy
+
+
+def add(a, b):
+    return Add()(a, b)[0]
+
+
 class SoftMax(Operation):
     '''
     Apply SoftMax for each row of the Tensor or each column of the Tensor

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/opt.py
----------------------------------------------------------------------
diff --git a/python/singa/opt.py b/python/singa/opt.py
new file mode 100644
index 0000000..bf04b09
--- /dev/null
+++ b/python/singa/opt.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+'''This module includes a set of optimizers for updating model parameters.
+It replaces the old optimizers from optimizer.py'''
+
+from singa import tensor
+
+
+class Optimizer(object):
+    r"""Base optimizer.
+
+    Args:
+        config (Dict): specify the default values of configurable variables.
+    """
+
+    def __init__(self, config):
+        self.config = config
+        self.step = 0
+        self.param2config = {}
+
+    def update(self, param, grad):
+        r"""Update the param values with given gradients.
+
+        Args:
+            param(Tensor): param values to be updated in-place
+            grad(Tensor): param gradients; the values may be updated
+                    in this function; do not use it anymore
+        """
+        pass
+
+    def step(self):
+        r"""To increment the step counter"""
+        self.step += 1
+
+    def register(self, param_group, config):
+        for param in param_group:
+            assert param not in self.param2config, 'param is already registered'
+
+            self.param2config[param] = config
+
+    def load(self):
+        pass
+
+    def save(self):
+        pass
+
+
+class SGD(Optimizer):
+    r"""Implements stochastic gradient descent (optionally with momentum).
+
+    Nesterov momentum is based on the formula from
+    `On the importance of initialization and momentum in deep learning`__.
+
+    Args:
+        lr(float): learning rate
+        momentum(float, optional): momentum factor(default: 0)
+        weight_decay(float, optional): weight decay(L2 penalty)(default: 0)
+        dampening(float, optional): dampening for momentum(default: 0)
+        nesterov(bool, optional): enables Nesterov momentum(default: False)
+
+    Example:
+        >> > from singa import opt
+        >> > optimizer = opt.SGD(lr=0.1, momentum=0.9)
+        >> > optimizer.update()
+
+    __ http: // www.cs.toronto.edu / %7Ehinton / absps / momentum.pdf
+
+    .. note::
+        The implementation of SGD with Momentum / Nesterov subtly differs from
+        Sutskever et. al. and implementations in some other frameworks.
+
+        Considering the specific case of Momentum, the update can be written as
+
+        .. math::
+                  v = \rho * v + g \\
+                  p = p - lr * v
+
+        where p, g, v and: math: `\rho` denote the parameters, gradient,
+        velocity, and momentum respectively.
+
+        This is in contrast to Sutskever et. al. and
+        other frameworks which employ an update of the form
+
+        .. math::
+             v = \rho * v + lr * g \\
+             p = p - v
+
+        The Nesterov version is analogously modified.
+    """
+
+    def __init__(self, lr=0.1, momentum=0, dampening=0,
+                 weight_decay=0, nesterov=False):
+        if momentum < 0.0:
+            raise ValueError("Invalid momentum value: {}".format(momentum))
+        if weight_decay < 0.0:
+            raise ValueError(
+                "Invalid weight_decay value: {}".format(weight_decay))
+
+        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
+                        weight_decay=weight_decay, nesterov=nesterov)
+        if nesterov and (momentum <= 0 or dampening != 0):
+            raise ValueError(
+                "Nesterov momentum requires a momentum and zero dampening")
+        super(SGD, self).__init__(defaults)
+
+    def update(self, param, grad):
+        """Performs a single optimization step.
+
+        Arguments:
+                param(Tensor): param values to be update in-place
+                grad(Tensor): param gradients; the values may be updated
+                        in this function; cannot use it anymore
+        """
+        group = self.param2group[param]
+        weight_decay = group['weight_decay']
+        momentum = group['momentum']
+        dampening = group['dampening']
+        nesterov = group['nesterov']
+
+        if weight_decay != 0:
+            grad += param * weight_decay
+        if momentum != 0:
+            param_state = self.state[param]
+            if 'momentum_buffer' not in param_state:
+                buf = param_state[
+                    'momentum_buffer'] = tensor.zeros_like(param)
+                buf *= momentum
+                buf += grad
+            else:
+                buf = param_state['momentum_buffer']
+                buf *= momentum
+                buf += (1 - dampening) * grad
+            if nesterov:
+                grad += momentum * buf
+            else:
+                grad = buf
+        param -= grad * group['lr']

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 46a47b7..441431f 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -602,6 +602,18 @@ def from_raw_tensors(tt):
     return ret
 
 
+def zeros_like(t):
+    ret = Tensor(t.shape, t.device, t.dtype)
+    ret.set_value(float(0))
+    return ret
+
+
+def ones_like(t):
+    ret = Tensor(t.shape, t.device, t.dtype)
+    ret.set_value(float(1))
+    return ret
+
+
 def product(shape):
     return reduce(lambda x, y: x * y, shape)