You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/10/06 06:31:24 UTC

[1/2] incubator-singa git commit: SINGA-254 Implement Adam for V1

Repository: incubator-singa
Updated Branches:
  refs/heads/master 17ac16025 -> 3a64342d0


SINGA-254 Implement Adam for V1

Implemented Adam for pysinga.

Tested Adam for alexnet on cifar10, and the accuracy was 0.8 (SGD is
0.82); It was also not as good as SGD on VGG net (0.92); May need to
turn the learning rate.

Add one more argument 'step' into most functions in optimizer module for
the iteration ID within one epoch. Some optimization algorithms would
use it, e.g. Adam.

Update the batchnorm layer's backward function to return empty tensors
for mean and variance variables. The optimizers have to skip the
updating procedure if the grad tensor is empty.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/5716105b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/5716105b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/5716105b

Branch: refs/heads/master
Commit: 5716105be00b260b57fd6a35c72d065eccea5544
Parents: deb187b
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Thu Sep 29 15:06:37 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Wed Oct 5 21:40:47 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/train.py          |  18 +--
 include/singa/core/tensor.h        |   2 +
 include/singa/model/optimizer.h    |  46 ++++---
 python/singa/layer.py              |   2 +-
 python/singa/optimizer.py          | 211 ++++++++++++++++++++------------
 python/singa/tensor.py             |  24 +++-
 src/api/model_optimizer.i          |   8 +-
 src/model/layer/batchnorm.cc       |   2 -
 src/model/layer/cudnn_batchnorm.cc |   2 -
 src/model/optimizer/adagrad.cc     |  10 +-
 src/model/optimizer/nesterov.cc    |   9 +-
 src/model/optimizer/optimizer.cc   |  47 +++----
 src/model/optimizer/rmsprop.cc     |  10 +-
 src/model/optimizer/sgd.cc         |  10 +-
 14 files changed, 254 insertions(+), 147 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index 671c861..7494b8b 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -35,6 +35,7 @@ import alexnet
 import vgg
 import resnet
 
+
 def load_dataset(filepath):
     print 'Loading data file %s' % filepath
     with open(filepath, 'rb') as fd:
@@ -94,14 +95,16 @@ def alexnet_lr(epoch):
     else:
         return 0.00001
 
+
 def resnet_lr(epoch):
-    if epoch < 80:
-        return 0.02
-    elif epoch < 120:
-        return 0.005
+    if epoch < 81:
+        return 0.1
+    elif epoch < 122:
+        return 0.01
     else:
         return 0.001
 
+
 def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
           use_cpu=False):
     print 'Start intialization............'
@@ -136,7 +139,7 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
             loss += l
             acc += a
             for (s, p, g) in zip(net.param_names(), net.param_values(), grads):
-                opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s))
+                opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s), b)
             # update progress bar
             utils.update_progress(b * 1.0 / num_train_batch,
                                   'training loss = %f, accuracy = %f' % (l, a))
@@ -159,8 +162,9 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
     net.save('model', 20)  # save model params into checkpoint file
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='Train vgg/alexnet for cifar10')
-    parser.add_argument('model', choices=['vgg', 'alexnet', 'resnet'], default='alexnet')
+    parser = argparse.ArgumentParser(description='Train dcnn for cifar10')
+    parser.add_argument('model', choices=['vgg', 'alexnet', 'resnet'],
+                        default='alexnet')
     parser.add_argument('data', default='cifar-10-batches-py')
     parser.add_argument('--use_cpu', action='store_true')
     args = parser.parse_args()

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 2075b5d..a41afbc 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -102,6 +102,8 @@ class Tensor {
 
   size_t nDim() const { return shape_.size(); }
 
+  bool empty() const { return nDim() == 0; }
+
   bool transpose() const { return transpose_; }
 
   /// return number of total elements

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/include/singa/model/optimizer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/optimizer.h b/include/singa/model/optimizer.h
index e6e6d1c..d5260cb 100644
--- a/include/singa/model/optimizer.h
+++ b/include/singa/model/optimizer.h
@@ -57,19 +57,24 @@ class Optimizer {
   /// parameter.
   virtual void Register(const string& name, const ParamSpec& specs);
 
-  /// Apply the updating algorithm.
+
+  virtual void ApplyRegularizerConstraint(int epoch, const string& name,
+      const Tensor& value, Tensor& grad, int step = -1);
+
+  /// Apply the updating algorithm if the gradient is not empty.
   /// No learning rate scaling, gradient constraints/regularization will be
   /// conducted. It assumes all these operations are done either by users or
   /// by Apply(int, const string&, Tensor*, Tensor*).
   /// All sub-classes should override this function.
-  virtual void Apply(int step, float lr, const string& name, const Tensor& grad,
-                     Tensor& value) = 0;
+  virtual void Apply(int epoch, float lr, const string& name,
+                     Tensor& grad, Tensor& value, int step = -1) = 0;
 
-  /// Apply the updating algorithm.
+  /// Apply the updating algorithm if the gradient is not empty.
   /// It will apply regularization and constraint to the parameters if
   /// configured during Register(). If will also scale the learning rate if
   /// configured in ParamSpecs (see Register).
-  void Apply(int step, const string& name, Tensor& grad, Tensor& value);
+  void Apply(int epoch, const string& name, Tensor& grad, Tensor& value,
+      int step = -1);
 
   /// The argument is a function that returns the learning rate given the
   /// current step (i.e., curren running iteration).
@@ -86,11 +91,12 @@ class Optimizer {
  protected:
   function<float(int)> learning_rate_generator_;
   std::unordered_map<std::string, float> learning_rate_multplier_;
-  std::unordered_map<std::string, float> weight_decay_multplier_;
   std::unordered_map<std::string, Constraint*> constraints_;
   std::unordered_map<std::string, Regularizer*> regularizers_;
   Constraint* constraint_ = nullptr;
   Regularizer* regularizer_ = nullptr;
+
+  OptimizerConf conf_;
 };
 
 /// Apply constraints for parameters (gradient).
@@ -113,11 +119,11 @@ class Constraint {
   /// e.g., clip each gradient if it is too large w.r.t the threshold,
   /// \ref
   /// https://www.reddit.com/r/MachineLearning/comments/31b6x8/gradient_clipping_rnns/
-  void Apply(int step, Tensor& grad, Tensor& value);
+  void Apply(int epoch, const Tensor& value, Tensor& grad, int step = -1);
   /// Apply the constraint for multiple parameter objects together.
   /// \ref https://github.com/Lasagne/Lasagne/blob/master/lasagne/updates.py
-  void Apply(int step, const vector<Tensor>& grads,
-             const vector<Tensor>& values);
+  void Apply(int epoch, const vector<Tensor>& values,
+             const vector<Tensor>& grads, int step = -1);
 
  private:
   /// currently only support "L2" norm constraint, i.e., the norm should be less
@@ -150,11 +156,11 @@ class Regularizer {
   /// e.g., clip each gradient if it is too large w.r.t the threshold,
   /// \ref
   /// https://www.reddit.com/r/MachineLearning/comments/31b6x8/gradient_clipping_rnns/
-  void Apply(int step, Tensor& grad, Tensor& value, float scale = 1.0f);
+  void Apply(int epoch, const Tensor& value, Tensor& grad, int step = -1);
   /// Apply the regularizer for multiple parameter objects together.
   /// \ref https://github.com/Lasagne/Lasagne/blob/master/lasagne/updates.py
-  void Apply(int step, const vector<Tensor>& grads,
-             const vector<Tensor>& values);
+  void Apply(int epoch, const vector<Tensor>& values,
+             const vector<Tensor>& grads, int step = -1);
 
  private:
   /// currently only support "L2" regularizer. type_ is case insensitive.
@@ -173,8 +179,8 @@ class SGD : public Optimizer {
  public:
   void Setup(const OptimizerConf& conf);
   /// Apply the updating algorithm.
-  void Apply(int step, float lr, const string& name, const Tensor& grad,
-             Tensor& value) override;
+  void Apply(int epoch, float lr, const string& name, Tensor& grad,
+             Tensor& value, int step = -1) override;
 
   /// The argument function returns the momentum value given the current running
   /// step (i.e., iterations/mini-batches).
@@ -192,8 +198,8 @@ class Nesterov : public Optimizer {
  public:
   void Setup(const OptimizerConf& conf);
   /// Apply the updating algorithm.
-  void Apply(int step, float lr, const string& name, const Tensor& grad,
-             Tensor& value) override;
+  void Apply(int epoch, float lr, const string& name, Tensor& grad,
+             Tensor& value, int step = -1) override;
 
   /// The argument function returns the momentum value given the current running
   /// step (i.e., iterations/mini-batches).
@@ -211,8 +217,8 @@ class AdaGrad : public Optimizer {
  public:
   void Setup(const OptimizerConf& conf);
   /// Apply the updating algorithm.
-  void Apply(int step, float lr, const string& name, const Tensor& grad,
-             Tensor& value) override;
+  void Apply(int epoch, float lr, const string& name, Tensor& grad,
+             Tensor& value, int step = -1) override;
 
  private:
   std::unordered_map<string, Tensor> history_gradient_;
@@ -223,8 +229,8 @@ class RMSProp : public Optimizer {
  public:
   void Setup(const OptimizerConf& conf);
   /// Apply the updating algorithm.
-  void Apply(int step, float lr, const string& name, const Tensor& grad,
-             Tensor& value) override;
+  void Apply(int epoch, float lr, const string& name, Tensor& grad,
+             Tensor& value, int step = -1) override;
   virtual ~RMSProp() = default;
 
  private:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/python/singa/layer.py b/python/singa/layer.py
index 8caf2bb..950f26d 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -659,7 +659,7 @@ class Merge(Layer):
 
     def backward(self, flag, grad):
         assert isinstance(grad, tensor.Tensor), 'The input must be Tensor'
-        return [grad], []  # * self.num_input
+        return [grad] * self.num_input, []  # * self.num_input
 
 
 class Split(Layer):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/python/singa/optimizer.py
----------------------------------------------------------------------
diff --git a/python/singa/optimizer.py b/python/singa/optimizer.py
index 00380e0..164921f 100644
--- a/python/singa/optimizer.py
+++ b/python/singa/optimizer.py
@@ -32,6 +32,7 @@ Example usage::
   sgd.apply_with_lr(2, 0.03, g, p, 'param')  # use lr=0.03 for epoch 2
 '''
 
+import math
 from . import singa_wrap as singa
 import tensor
 from proto import model_pb2
@@ -51,36 +52,26 @@ class Optimizer(object):
     parameter udpate.
 
     Args:
-        lr (float): a constant for the learning rate, mutually exclusive with
-            'lr_gen'.
-        momentum (float): a constant for the momentum value
+        lr (float): a constant value for the learning rate
+        momentum (float): a constant value for the momentum value
         weight_decay (float): the coefficent for L2 regularizer, which is
             mutually exclusive with 'regularizer'.
-        lr_gen (function): a function returns the learning rate given
-            the current training step/epoch. It is mutually exclusive with lr.
-            If both are not set, the apply_with_lr function should be used for
-            param updating.
         regularizer: an instance of Regularizer or RegularizerConf; If set,
             regularization would be applied in apply_with_lr().
             Users can also do regularization outside.
         constraint: an instance of Constraint or ConstraintConf; If set,
             constraint would be applied inside apply_with_lr(). Users can
-            also do regularization outside.
+            also apply constraint outside.
     '''
-
-    def __init__(self, lr=None, momentum=None, weight_decay=None, lr_gen=None,
+    def __init__(self, lr=None, momentum=None, weight_decay=None,
                  regularizer=None, constraint=None):
-        if lr is not None:
-            assert lr_gen is None, 'Cannot set lr and lr_gen at the same time'
-
-            def lr_gen(epoch):
-                return lr
-        self.lr_gen = lr_gen
+        self.lr = lr
         self.momentum = momentum
         if weight_decay is not None:
             assert regularizer is None, \
                 'Cannot set weight_decay and regularizer at the same time'
             regularizer = L2Regularizer(weight_decay)
+
         if regularizer is not None:
             if isinstance(regularizer, model_pb2.RegularizerConf):
                 self.regularizer = CppRegularizer(regularizer)
@@ -103,7 +94,9 @@ class Optimizer(object):
     def register(self, name, specs):
         '''Register the param specs, including creating regularizer and
         constraint per param object. Param specific regularizer and constraint
-        have higher priority than the global ones.
+        have higher priority than the global ones. If all parameters share the
+        same setting for learning rate, regularizer and constraint, then there
+        is no need to call this function.
 
         Args:
             name (str): parameter name
@@ -124,44 +117,49 @@ class Optimizer(object):
         if specs.lr_mult != 1:
             self.learning_rate_multiplier[name] = specs.lr_mult
 
-    def apply_regularizer_constraint(self, epoch, value, grad, name=None):
+    def apply_regularizer_constraint(self, epoch, value, grad, name=None,
+                                     step=-1):
         '''Apply regularization and constraint if available.
 
         If there are both global regularizer (constraint) and param specific
         regularizer (constraint), it would use the param specific one.
 
         Args:
+            epoch (int): training epoch ID
             value (Tensor): parameter value Tensor
             grad (Tensor): parameter gradient Tensor
             name (string): to get parameter specific regularizer or constraint
-            epoch (int): some regularizer or constraint would use epoch
+            step (int): iteration ID within one epoch
 
         Returns:
             the updated gradient Tensor
         '''
         if name is not None and name in self.constraints:
-            self.constraints[name].apply(epoch, value, grad)
+            grad = self.constraints[name].apply(epoch, value, grad, step)
         elif self.constraint is not None:
-            self.constraint.apply(epoch, value, grad)
+            grad = self.constraint.apply(epoch, value, grad, step)
 
         if name is not None and name in self.regularizers:
-            self.regularizers[name].apply(epoch, value, grad)
+            grad = self.regularizers[name].apply(epoch, value, grad, step)
         elif self.regularizer is not None:
-            self.regularizer.apply(epoch, value, grad)
+            grad = self.regularizer.apply(epoch, value, grad, step)
         return grad
 
-    def apply_with_lr(self, epoch, lr, grad, value, name=None):
-        '''Do update with given learning rate.
+    def apply_with_lr(self, epoch, lr, grad, value, name=None, step=-1):
+        '''Do update of parameters with given learning rate if the grad is not
+        empty.
 
         The subclass optimizer must override this function.
+        This function do nothing if the grad is empty.
 
         Args:
-            epoch (int): training epoch (could be iteration or epoch)
+            epoch (int): training epoch ID
             lr (float): learning rate
             grad (Tensor): parameter gradient
             value (Tesnor): parameter value
-            name (string): paramter name to retrieval parameter specific
+            name (string): paramter name to index parameter specific
                 updating rules (including regularizer and constraint)
+            step (int): iteration ID within one epoch
 
         Returns:
             updated parameter value
@@ -169,25 +167,24 @@ class Optimizer(object):
         assert False, 'This is the base function, pls call the subclass func'
         return value
 
-    def apply(self, epoch, grad, value, name=None):
+    def apply(self, epoch, grad, value, name=None, step=-1):
         '''Do update assuming the learning rate generator is set.
 
         The subclass optimizer does not need to override this function.
 
         Args:
-            epoch (int): training epoch (could be iteration or epoch)
+            epoch (int): training epoch ID
             grad (Tensor): parameter gradient
             value (Tesnor): parameter value
             name (string): paramter name to retrieval parameter specific
                 updating rules (including regularizer and constraint)
+            step (int): training iteration ID within one epoch
 
         Return:
             updated parameter value
         '''
-        assert self.lr_gen is not None, 'Learning rate generator is not set.'\
-            'Either set the lr_gen in constructor or call apply_with_lr'
-        lr = self.lr_gen(epoch)
-        return self.apply_with_lr(epoch, lr, grad, value, name)
+        assert self.lr is not None, 'Must set the learning rate, i.e. "lr"'
+        return self.apply_with_lr(epoch, self.lr, grad, value, name, step)
 
 
 class SGD(Optimizer):
@@ -196,10 +193,10 @@ class SGD(Optimizer):
     See the base Optimizer for all arguments.
     '''
 
-    def __init__(self, lr=None, momentum=None, weight_decay=None, lr_gen=None,
+    def __init__(self, lr=None, momentum=None, weight_decay=None,
                  regularizer=None, constraint=None):
-        super(SGD, self).__init__(lr, momentum, weight_decay, lr_gen,
-                                  regularizer, constraint)
+        super(SGD, self).__init__(lr, momentum, weight_decay, regularizer,
+                                  constraint)
         conf = model_pb2.OptimizerConf()
         if self.momentum is not None:
             conf.momentum = self.momentum
@@ -207,8 +204,10 @@ class SGD(Optimizer):
         self.opt = singa.CreateOptimizer('SGD')
         self.opt.Setup(conf.SerializeToString())
 
-    def apply_with_lr(self, epoch, lr, grad, value, name):
-        self.apply_regularizer_constraint(epoch, value, grad, name)
+    def apply_with_lr(self, epoch, lr, grad, value, name, step=-1):
+        if grad.is_empty():
+            return value
+        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
         if name is not None and name in self.learning_rate_multiplier:
             lr = lr * self.learning_rate_multiplier[name]
         self.opt.Apply(epoch, lr, name, grad.singa_tensor, value.singa_tensor)
@@ -221,9 +220,9 @@ class Nesterov(Optimizer):
     See the base Optimizer for all arguments.
     '''
 
-    def __init__(self, lr=None, momentum=0.9, weight_decay=None, lr_gen=None,
+    def __init__(self, lr=None, momentum=0.9, weight_decay=None,
                  regularizer=None, constraint=None):
-        super(Nesterov, self).__init__(lr, momentum, weight_decay, lr_gen,
+        super(Nesterov, self).__init__(lr, momentum, weight_decay,
                                        regularizer, constraint)
         conf = model_pb2.OptimizerConf()
         if self.momentum is not None:
@@ -232,14 +231,48 @@ class Nesterov(Optimizer):
         self.opt = singa.CreateOptimizer('Nesterov')
         self.opt.Setup(conf.SerializeToString())
 
-    def apply_with_lr(self, epoch, lr, grad, value, name):
-        self.apply_regularizer_constraint(epoch, value, grad, name)
+    def apply_with_lr(self, epoch, lr, grad, value, name, step=-1):
+        if grad.is_empty():
+            return value
+
+        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
         if name is not None and name in self.learning_rate_multiplier:
             lr = lr * self.learning_rate_multiplier[name]
         self.opt.Apply(epoch, lr, name, grad.singa_tensor, value.singa_tensor)
         return value
 
 
+class RMSProp(Optimizer):
+    '''RMSProp optimizer.
+
+    See the base Optimizer for all constructor args.
+
+    Args:
+        rho (float): float within [0, 1]
+        epsilon (float): small value for preventing numeric error
+    '''
+
+    def __init__(self, rho=0.9, epsilon=1e-8, lr=None, weight_decay=None,
+                 regularizer=None, constraint=None):
+        super(RMSProp, self).__init__(lr, None, weight_decay, regularizer,
+                                      constraint)
+        conf = model_pb2.OptimizerConf()
+        conf.rho = rho
+        conf.delta = epsilon
+        self.opt = singa.CreateOptimizer('RMSProp')
+        self.opt.Setup(conf.SerializeToString())
+
+    def apply_with_lr(self, epoch, lr, grad, value, name, step=-1):
+        if grad.is_empty():
+            return value
+
+        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
+        if name is not None and name in self.learning_rate_multiplier:
+            lr = lr * self.learning_rate_multiplier[name]
+        self.opt.Apply(step, lr,  name, grad.singa_tensor, value.singa_tensor)
+        return value
+
+
 class AdaGrad(Optimizer):
     '''AdaGrad optimizer.
 
@@ -251,7 +284,7 @@ class AdaGrad(Optimizer):
 
     def __init__(self, epsilon=1e-8, lr=None, weight_decay=None, lr_gen=None,
                  regularizer=None, constraint=None):
-        super(RMSProp, self).__init__(lr, weight_decay, lr_gen, regularizer,
+        super(AdaGrad, self).__init__(lr, None, weight_decay, regularizer,
                                       constraint)
         conf = model_pb2.OptimizerConf()
         conf.delta = epsilon
@@ -259,46 +292,77 @@ class AdaGrad(Optimizer):
         self.opt = singa.CreateOptimizer('AdaGrad')
         self.opt.Setup(conf.SerializeToString())
 
-    def apply_with_lr(self, epoch, lr, grad, value, name):
-        grad = self.apply_regularizer_constraint(epoch, value, grad, name)
+    def apply_with_lr(self, epoch, lr, grad, value, name, step=-1):
+        if grad.is_empty():
+            return value
+
+        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
         if name is not None and name in self.learning_rate_multiplier:
             lr = lr * self.learning_rate_multiplier[name]
         self.opt.Apply(epoch, lr,  name, grad.singa_tensor, value.singa_tensor)
         return value
 
 
-class RMSProp(Optimizer):
-    '''RMSProp optimizer.
+class Adam(Optimizer):
+    '''Adam optimizer.
 
     See the base Optimizer for all constructor args.
 
     Args:
-        rho (float): float within [0, 1]
+        beta_1(float): coefficient of momentum
+        beta_2(float): coefficient of aggregated squared gradient
         epsilon (float): small value for preventing numeric error
     '''
 
-    def __init__(self, rho=0.9, epsilon=1e-8, lr=None, weight_decay=None,
-                 lr_gen=None, regularizer=None, constraint=None):
-        super(RMSProp, self).__init__(lr, weight_decay, lr_gen, regularizer,
-                                      constraint)
-        conf = model_pb2.OptimizerConf()
-        conf.rho = rho
-        conf.delta = epsilon
-        self.opt = singa.CreateOptimizer('RMSProp')
-        self.opt.Setup(conf.SerializeToString())
+    def __init__(self, beta_1=0.9, beta_2=0.999, epsilon=1e-8, lr=None,
+                 weight_decay=None, regularizer=None, constraint=None):
+        super(Adam, self).__init__(lr, None, weight_decay, regularizer,
+                                   constraint)
+        self.beta_1 = beta_1
+        self.beta_2 = beta_2
+        self.epsilon = epsilon
+        self.m = {}
+        self.v = {}
+        self.t = 1
+        self.last_epoch = 0
+        self.last_step = 0
+
+    def apply_with_lr(self, epoch, lr, grad, value, name, step):
+        '''Update one parameter object.
 
-    def apply_with_lr(self, epoch, lr, grad, value, name):
-        grad = self.apply_regularizer_constraint(epoch, value, grad, name)
+        Args:
+            step(int): the accumulated training iterations, not the iteration ID
+        '''
+        if grad.is_empty():
+            return value
+
+        assert step != -1, 'step should >= 0'
+        if epoch != self.last_epoch or step != self.last_step:
+            self.t += 1
+        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
         if name is not None and name in self.learning_rate_multiplier:
             lr = lr * self.learning_rate_multiplier[name]
-        self.opt.Apply(epoch, lr,  name, grad.singa_tensor, value.singa_tensor)
+        if name not in self.m or name not in self.v:
+            self.m[name] = tensor.Tensor(grad.shape, grad.device, grad.dtype)
+            self.m[name].set_value(0)
+            self.v[name] = tensor.Tensor(grad.shape, grad.device, grad.dtype)
+            self.v[name].set_value(0)
+
+        self.m[name] *= self.beta_1
+        tensor.axpy(1 - self.beta_1, grad, self.m[name])
+        self.v[name] *= self.beta_2
+        tensor.axpy(1 - self.beta_2, tensor.square(grad), self.v[name])
+        alpha = lr * math.sqrt(1 - math.pow(self.beta_2, self.t)) \
+            / (1 - math.pow(self.beta_1, self.t))
+        value -= alpha * self.m[name] / (tensor.sqrt(self.v[name]) +
+                                         self.epsilon)
         return value
 
 
 class Regularizer(object):
     '''Base Python regularizer for parameter gradients.'''
 
-    def apply(self, value, grad):
+    def apply(self, epoch, value, grad, step=-1):
         assert False, 'Not Implemented. Call the subclass function.'
         return grad
 
@@ -314,7 +378,7 @@ class CppRegularizer(Regularizer):
         self.reg = singa.CreateRegularizer(conf.type)
         self.reg.Setup(conf.SerializeToString())
 
-    def apply(self, epoch, value, grad):
+    def apply(self, epoch, value, grad, step=-1):
         self.reg.Apply(epoch, value.singa_tensor, grad.singa_tensor)
         return grad
 
@@ -329,20 +393,17 @@ class L2Regularizer(Regularizer):
     def __init__(self, coefficient):
         self.coefficient = coefficient
 
-    def apply(self, epoch, value, grad, coefficient=None):
-        if coefficient is None:
-            assert self.coefficient is not None, 'Must set the coefficient'
-            coefficient = self.coefficient
+    def apply(self, epoch, value, grad, step=-1):
         # print coefficient, value.l1(), grad.l1()
-        if coefficient != 0:
-            tensor.axpy(coefficient, value, grad)
+        if self.coefficient != 0:
+            tensor.axpy(self.coefficient, value, grad)
         return grad
 
 
 class Constraint(object):
     '''Base Python constraint class for paramter gradients'''
 
-    def apply(self, epoch, value, grad):
+    def apply(self, epoch, value, grad, step=-1):
         return grad
 
 
@@ -357,8 +418,9 @@ class CppConstraint(Constraint):
         self.constraint = singa.CreateConstraint(conf.type)
         self.constraint.Setup(conf.SerializeToString())
 
-    def apply(self, epoch, value, grad):
-        self.constraint.Apply(epoch, value.singa_tensor, grad.singa_tensor)
+    def apply(self, epoch, value, grad, step=-1):
+        self.constraint.Apply(epoch, value.singa_tensor, grad.singa_tensor,
+                              step)
         return grad
 
 
@@ -368,10 +430,7 @@ class L2Constraint(Constraint):
     def __init__(self, threshold=None):
         self.threshold = threshold
 
-    def apply(self, epoch, value, grad, threshold=None):
-        if threshold is None:
-            assert self.threshold is not None, 'Must set the threshold'
-            threshold = self.threshold
+    def apply(self, epoch, value, grad, step=-1):
         nrm = grad.l2()
-        grad *= threshold / nrm
+        grad *= self.threshold / nrm
         return grad

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 1024483..d08b6cb 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -95,6 +95,13 @@ class Tensor(object):
         '''
         return self.singa_tensor.nDim()
 
+    def is_empty(self):
+        '''
+        Returns:
+            True if the tensor is empty according to its shape
+        '''
+        return self.ndim() == 0
+
     def is_transpose(self):
         '''
         Returns:
@@ -441,7 +448,6 @@ class Tensor(object):
         else:
             return _call_singa_func(singa.GE_Tf, self.singa_tensor, rhs)
 
-
     def __radd__(self, lhs):
         lhs = float(lhs)
         return _call_singa_func(singa.Add_Tf, self.singa_tensor, lhs)
@@ -461,9 +467,8 @@ class Tensor(object):
         one = Tensor(self.shape, self.device, self.dtype)
         one.set_value(1)
         one *= lhs
-        return _call_singa_func(singa.Div_TT, one.singa_tensor,\
-                self.singa_tensor)
-
+        return _call_singa_func(singa.Div_TT, one.singa_tensor,
+                                self.singa_tensor)
 
 ''' python functions for global functions in Tensor.h
 '''
@@ -618,6 +623,17 @@ def sigmoid(t):
     return _call_singa_func(singa.Sigmoid, t.singa_tensor)
 
 
+def sqrt(t):
+    '''
+    Args:
+        t (Tensor): input Tensor
+
+    Returns:
+        a new Tensor whose element y = sqrt(x), x is an element of t
+    '''
+    return _call_singa_func(singa.Sqrt, t.singa_tensor)
+
+
 def square(t):
     '''
     Args:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/src/api/model_optimizer.i
----------------------------------------------------------------------
diff --git a/src/api/model_optimizer.i b/src/api/model_optimizer.i
index 78b30b8..793df28 100644
--- a/src/api/model_optimizer.i
+++ b/src/api/model_optimizer.i
@@ -46,8 +46,8 @@ class Optimizer {
   // Optimizer() = default;
   virtual ~Optimizer() = default;
   void Setup(const std::string& str);
-  virtual void Apply(int step, float lr, const std::string& name,
-    const Tensor& grad, Tensor& value) = 0;
+  virtual void Apply(int epoch, float lr, const std::string& name,
+      Tensor& grad, Tensor& value, int step = -1) = 0;
 };
 inline std::shared_ptr<Optimizer> CreateOptimizer(const std::string& type);
 
@@ -55,7 +55,7 @@ class Constraint {
  public:
   Constraint() = default;
   void Setup(const std::string& conf_str);
-  void Apply(int step, Tensor& grad, Tensor& value);
+  void Apply(int epoch, const Tensor& value, Tensor& grad, int step = -1);
 };
 
 inline std::shared_ptr<Constraint> CreateConstraint(const std::string& type);
@@ -64,7 +64,7 @@ class Regularizer {
  public:
   Regularizer() = default;
   void Setup(const std::string& conf_str);
-  void Apply(int step, Tensor& grad, Tensor& value);
+  void Apply(int epoch, const Tensor& value, Tensor& grad, int step = -1);
 };
 inline std::shared_ptr<Regularizer> CreateRegularizer(const std::string& type);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
index b345c6b..2ca7742 100644
--- a/src/model/layer/batchnorm.cc
+++ b/src/model/layer/batchnorm.cc
@@ -185,8 +185,6 @@ const std::pair<Tensor, vector<Tensor>> BatchNorm::Backward(
     param_grad.push_back(dbnScale_);
     param_grad.push_back(dbnBias_);
     Tensor dummy;
-    dummy.ResetLike(runningMean_);
-    dummy.SetValue(.0f);
     param_grad.push_back(dummy);
     param_grad.push_back(dummy);
   } else {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
index f29679c..a7f80be 100644
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -217,8 +217,6 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
   param_grad.push_back(dbnScale_);
   param_grad.push_back(dbnBias_);
   Tensor dummy;
-  dummy.ResetLike(dbnScale_);
-  dummy.SetValue(.0f);
   param_grad.push_back(dummy);
   param_grad.push_back(dummy);
   if (is_2d_)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/src/model/optimizer/adagrad.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/adagrad.cc b/src/model/optimizer/adagrad.cc
index cdb3fac..c12a85b 100644
--- a/src/model/optimizer/adagrad.cc
+++ b/src/model/optimizer/adagrad.cc
@@ -25,8 +25,14 @@ void AdaGrad::Setup(const OptimizerConf& conf) { delta_ = conf.delta(); }
 
 // history += grad*grad;
 // value = value - lr*grad/sqrt(history+delta)
-void AdaGrad::Apply(int step, float lr, const string& name, const Tensor& grad,
-                    Tensor& value) {
+void AdaGrad::Apply(int epoch, float lr, const string& name,
+    Tensor& grad, Tensor& value, int step) {
+  if (grad.empty())
+    return;
+  ApplyRegularizerConstraint(epoch, name, value, grad, step);
+  if (learning_rate_multplier_.find(name) != learning_rate_multplier_.end())
+    lr *= learning_rate_multplier_.at(name);
+
   if (history_gradient_.find(name) == history_gradient_.end()) {
     history_gradient_[name].ResetLike(value);
     history_gradient_[name].SetValue(0.0f);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/src/model/optimizer/nesterov.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/nesterov.cc b/src/model/optimizer/nesterov.cc
index 051499b..a95ddb0 100644
--- a/src/model/optimizer/nesterov.cc
+++ b/src/model/optimizer/nesterov.cc
@@ -30,8 +30,13 @@ void Nesterov::Setup(const OptimizerConf& conf) {
 // history = lr * grad + history * mom
 // tmp = (1+mom) * history - tmp * mom;
 // value = value - tmp;
-void Nesterov::Apply(int step, float lr, const string& name, const Tensor& grad,
-                     Tensor& value) {
+void Nesterov::Apply(int epoch, float lr, const string& name, Tensor& grad,
+                     Tensor& value, int step) {
+  if (grad.empty())
+    return;
+  ApplyRegularizerConstraint(epoch, name, value, grad, step);
+  if (learning_rate_multplier_.find(name) != learning_rate_multplier_.end())
+    lr *= learning_rate_multplier_.at(name);
   if (momentum_generator_) {
     float mom = momentum_generator_(step);
     if (history_gradient_.find(name) == history_gradient_.end()) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/src/model/optimizer/optimizer.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/optimizer.cc b/src/model/optimizer/optimizer.cc
index d098249..1340a35 100644
--- a/src/model/optimizer/optimizer.cc
+++ b/src/model/optimizer/optimizer.cc
@@ -31,6 +31,7 @@ void Optimizer::Setup(const OptimizerConf& conf) {
   if (conf.has_regularizer())
     regularizer_ = new Regularizer(conf.regularizer());
   if (conf.has_constraint()) constraint_ = new Constraint(conf.constraint());
+  conf_ = conf;
 }
 void Optimizer::Register(const string& name, const ParamSpec& specs) {
   if (specs.has_constraint()) {
@@ -44,9 +45,9 @@ void Optimizer::Register(const string& name, const ParamSpec& specs) {
     regularizers_[name] = new Regularizer(specs.regularizer());
   }
   if (specs.has_decay_mult()) {
-    CHECK(weight_decay_multplier_.find(name) == weight_decay_multplier_.end())
-        << "Parameter with name = " << name << " has already registered";
-    weight_decay_multplier_[name] = specs.decay_mult();
+    auto reg = specs.regularizer();
+    reg.set_coefficient(reg.coefficient() * conf_.regularizer().coefficient());
+    regularizers_[name] = new Regularizer(reg);
   }
   if (specs.has_lr_mult()) {
     CHECK(learning_rate_multplier_.find(name) == learning_rate_multplier_.end())
@@ -59,26 +60,25 @@ void Optimizer::Register(const string& name, const ParamSpec& specs) {
   }
   */
 }
-
-void Optimizer::Apply(int step, const string& name, Tensor& grad,
-                      Tensor& param) {
+void Optimizer::ApplyRegularizerConstraint(int epoch, const string& name,
+      const Tensor& value, Tensor& grad, int step) {
   // TODO(wangwei) need to consider the order of constraint and regularizer
   if (regularizers_.find(name) != regularizers_.end()) {
-    regularizers_.at(name)->Apply(step, param, grad);
+    regularizers_.at(name)->Apply(epoch, value, grad, step);
   } else if (regularizer_ != nullptr) {
-    float scale = 1.0f;
-    if (weight_decay_multplier_.find(name) != weight_decay_multplier_.end())
-      scale = weight_decay_multplier_.at(name);
-    regularizer_->Apply(step, param, grad, scale);
+    regularizer_->Apply(epoch, value, grad, step);
   }
   if (constraints_.find(name) != constraints_.end())
-    constraints_.at(name)->Apply(step, param, grad);
+    constraints_.at(name)->Apply(epoch, value, grad, step);
   else if (constraint_ != nullptr)
-    constraint_->Apply(step, param, grad);
+    constraint_->Apply(epoch, value, grad, step);
+}
+
+
+void Optimizer::Apply(int epoch, const string& name, Tensor& grad,
+                      Tensor& value, int step) {
   float lr = learning_rate_generator_(step);
-  if (learning_rate_multplier_.find(name) != learning_rate_multplier_.end())
-    lr *= learning_rate_multplier_.at(name);
-  Apply(step, lr, name, grad, param);
+  Apply(epoch, lr, name, grad, value, step);
 }
 
 void Regularizer::Setup(const RegularizerConf& conf) {
@@ -89,16 +89,17 @@ void Regularizer::Setup(const RegularizerConf& conf) {
   }
 }
 
-void Regularizer::Apply(int step, Tensor& value, Tensor& grad, float scale) {
+void Regularizer::Apply(int epoch, const Tensor& value, Tensor& grad, int step)
+{
   if (type_ == "L2" || type_ == "l2") {
-    Axpy(coefficient_ * scale, value, &grad);
+    Axpy(coefficient_, value, &grad);
   } else {
     CHECK(type_ == "NotSet") << "Unknown regularizer type = " << type_;
   }
 }
 
-void Regularizer::Apply(int step, const vector<Tensor>& values,
-                        const vector<Tensor>& grads) {
+void Regularizer::Apply(int epoch, const vector<Tensor>& values,
+                        const vector<Tensor>& grads, int step) {
   LOG(FATAL) << "Not implemented yet";
 }
 
@@ -107,13 +108,13 @@ void Constraint::Setup(const ConstraintConf& conf) {
   threshold_ = conf.threshold();
 }
 
-void Constraint::Apply(int step, Tensor& value, Tensor& grad) {
+void Constraint::Apply(int epoch, const Tensor& value, Tensor& grad, int step) {
   // TODO(wangwei) implement L2 and hard constraint
   CHECK(type_ == "NotSet") << "Unknown regularizer type = " << type_;
 }
 
-void Constraint::Apply(int step, const vector<Tensor>& values,
-                       const vector<Tensor>& grads) {
+void Constraint::Apply(int epoch, const vector<Tensor>& values,
+                       const vector<Tensor>& grads, int step) {
   LOG(FATAL) << "Not implemented yet";
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/src/model/optimizer/rmsprop.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/rmsprop.cc b/src/model/optimizer/rmsprop.cc
index 13e2a75..575bdb7 100644
--- a/src/model/optimizer/rmsprop.cc
+++ b/src/model/optimizer/rmsprop.cc
@@ -28,8 +28,14 @@ void RMSProp::Setup(const OptimizerConf& conf) {
 
 // history = history * rho + grad * grad * (1 - rho)
 // value = value - lr * grad / sqrt(history + delta)
-void RMSProp::Apply(int step, float lr, const string& name, const Tensor& grad,
-                    Tensor& value) {
+void RMSProp::Apply(int epoch, float lr, const string& name, Tensor& grad,
+                    Tensor& value, int step) {
+  if (grad.empty())
+    return;
+  ApplyRegularizerConstraint(epoch, name, value, grad, step);
+  if (learning_rate_multplier_.find(name) != learning_rate_multplier_.end())
+    lr *= learning_rate_multplier_.at(name);
+
   if (history_gradient_.find(name) == history_gradient_.end()) {
     history_gradient_[name].ResetLike(value);
     history_gradient_[name].SetValue(0.0f);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5716105b/src/model/optimizer/sgd.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/sgd.cc b/src/model/optimizer/sgd.cc
index ac453cd..79b2136 100644
--- a/src/model/optimizer/sgd.cc
+++ b/src/model/optimizer/sgd.cc
@@ -31,8 +31,14 @@ void SGD::Setup(const OptimizerConf& conf) {
 
 // history = history * momentum + grad * lr
 // value = value - history
-void SGD::Apply(int step, float lr, const string& name, const Tensor& grad,
-                Tensor& value) {
+void SGD::Apply(int epoch, float lr, const string& name, Tensor& grad,
+                Tensor& value, int step) {
+  if (grad.empty())
+    return;
+  ApplyRegularizerConstraint(epoch, name, value, grad, step);
+  if (learning_rate_multplier_.find(name) != learning_rate_multplier_.end())
+    lr *= learning_rate_multplier_.at(name);
+
   // LOG(INFO) << "param " << name  << " lr = " << lr << " grad = " << grad.L1() << " value = " << value.L1();
   if (momentum_generator_) {
     float mom = momentum_generator_(step);



[2/2] incubator-singa git commit: Merge pull request 262 into master

Posted by wa...@apache.org.
Merge pull request 262 into master


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3a64342d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3a64342d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3a64342d

Branch: refs/heads/master
Commit: 3a64342d0044fb62c727dc5fc177cca5b7b9ad35
Parents: 5716105 17ac160
Author: WANG Sheng <wa...@gmail.com>
Authored: Thu Oct 6 14:29:57 2016 +0800
Committer: WANG Sheng <wa...@gmail.com>
Committed: Thu Oct 6 14:29:57 2016 +0800

----------------------------------------------------------------------
 python/singa/data.py       | 134 ++++++++++++++++++++++++++++++++++++++++
 python/singa/image_tool.py |  11 ++--
 python/singa/layer.py      |   1 -
 3 files changed, 138 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3a64342d/python/singa/layer.py
----------------------------------------------------------------------
diff --cc python/singa/layer.py
index 950f26d,6eaf329..a22af55
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@@ -659,9 -659,9 +659,8 @@@ class Merge(Layer)
  
      def backward(self, flag, grad):
          assert isinstance(grad, tensor.Tensor), 'The input must be Tensor'
 -        return [grad] *self.num_input, []  # * self.num_input
 -
 +        return [grad] * self.num_input, []  # * self.num_input
  
- 
  class Split(Layer):
      '''Replicate the input tensor.