You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/08/15 16:15:27 UTC

[14/22] incubator-singa git commit: Fixed the bug leading to wired accuracy (nan), which was caused by forgeting to average the gradient over the whole mini-batch. That is why we need a lower learning rate and could not use momentum. Update the lr in opt

Fixed the bug leading to wired accuracy (nan), which was caused by forgeting
to average the gradient over the whole mini-batch. That is why we need a lower
learning rate and could not use momentum.
Update the lr in optimzier.py to time the multiplier
Fix the bug from mis-setting the pooling type of alexnet.py (max-->avg)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/6d4539ee
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/6d4539ee
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/6d4539ee

Branch: refs/heads/dev
Commit: 6d4539eed2ae200a3a904a70cb789fc1b39d0f38
Parents: 1db2784
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Mon Aug 15 13:13:19 2016 +0800
Committer: Wei Wang <wa...@gmail.com>
Committed: Mon Aug 15 20:16:30 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/alexnet.cc   |  11 +-
 examples/cifar10/alexnet.py   |  13 +-
 examples/cifar10/train.py     |  19 ++-
 src/model/feed_forward_net.cc |   6 +-
 src/model/optimizer/sgd.cc    |   4 +-
 src/python/singa/__init__.py  | 240 -------------------------------------
 src/python/singa/layer.py     |  15 +--
 src/python/singa/net.py       |   8 +-
 src/python/singa/optimizer.py |  36 ++++--
 src/python/singa/tensor.py    |   8 +-
 10 files changed, 68 insertions(+), 292 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/examples/cifar10/alexnet.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.cc b/examples/cifar10/alexnet.cc
index e1363e4..8051d1b 100644
--- a/examples/cifar10/alexnet.cc
+++ b/examples/cifar10/alexnet.cc
@@ -134,7 +134,7 @@ FeedForwardNet CreateNet() {
   return net;
 }
 
-void Train(float lr, int num_epoch, string data_dir) {
+void Train(int num_epoch, string data_dir) {
   Cifar10 data(data_dir);
   Tensor train_x, train_y, test_x, test_y;
   {
@@ -161,11 +161,11 @@ void Train(float lr, int num_epoch, string data_dir) {
   auto net = CreateNet();
   SGD sgd;
   OptimizerConf opt_conf;
-  opt_conf.set_momentum(0.9);
+  // opt_conf.set_momentum(0.9);
   auto reg = opt_conf.mutable_regularizer();
   reg->set_coefficient(0.004);
   sgd.Setup(opt_conf);
-  sgd.SetLearningRateGenerator([lr](int step) {
+  sgd.SetLearningRateGenerator([](int step) {
     if (step <= 120)
       return 0.001;
     else if (step <= 130)
@@ -193,14 +193,11 @@ int main(int argc, char **argv) {
   int pos = singa::ArgPos(argc, argv, "-epoch");
   int nEpoch = 1;
   if (pos != -1) nEpoch = atoi(argv[pos + 1]);
-  pos = singa::ArgPos(argc, argv, "-lr");
-  float lr = 0.001;
-  if (pos != -1) lr = atof(argv[pos + 1]);
   pos = singa::ArgPos(argc, argv, "-data");
   string data = "cifar-10-batches-bin";
   if (pos != -1) data = argv[pos + 1];
 
   LOG(INFO) << "Start training";
-  singa::Train(lr, nEpoch, data);
+  singa::Train(nEpoch, data);
   LOG(INFO) << "End training";
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/examples/cifar10/alexnet.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.py b/examples/cifar10/alexnet.py
index ddad1d5..dae129f 100644
--- a/examples/cifar10/alexnet.py
+++ b/examples/cifar10/alexnet.py
@@ -20,9 +20,6 @@ Following the same setting for hyper-parameters and data pre-processing, the fin
 validation accuracy would be about 82%.
 '''
 
-import sys
-import os
-
 # sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
 from singa import layer
 from singa import initializer
@@ -39,18 +36,18 @@ def create_net(use_cpu=False):
     W0_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.0001}
     W1_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.01}
     W2_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.01, 'decay_mult': 250}
-    b_specs = {'init': 'constant', 'value': 0, 'lt_mult': 2}
+    b_specs = {'init': 'constant', 'value': 0, 'lr_mult': 2, 'decay_mult': 0}
     net.add(layer.Conv2D('conv1', 32, 5, 1, W_specs=W0_specs.copy(), b_specs=b_specs.copy(), pad=2, input_sample_shape=(3,32,32,)))
     net.add(layer.MaxPooling2D('pool1', 3, 2, pad=1))
     net.add(layer.Activation('relu1'))
-    net.add(layer.LRN(name='lrn1'))
+    net.add(layer.LRN(name='lrn1', size=3, alpha=5e-5))
     net.add(layer.Conv2D('conv2', 32, 5, 1, W_specs=W1_specs.copy(), b_specs=b_specs.copy(), pad=2))
     net.add(layer.Activation('relu2'))
-    net.add(layer.MaxPooling2D('pool2', 3, 2,  pad=1))
-    net.add(layer.LRN('lrn2'))
+    net.add(layer.AvgPooling2D('pool2', 3, 2,  pad=1))
+    net.add(layer.LRN('lrn2', size=3, alpha=5e-5))
     net.add(layer.Conv2D('conv3', 64, 5, 1, W_specs=W1_specs.copy(), b_specs=b_specs.copy(), pad=2))
     net.add(layer.Activation('relu3'))
-    net.add(layer.MaxPooling2D('pool3', 3, 2, pad=1))
+    net.add(layer.AvgPooling2D('pool3', 3, 2, pad=1))
     net.add(layer.Flatten('flat'))
     net.add(layer.Dense('dense', 10, W_specs=W2_specs.copy(), b_specs=b_specs.copy()))
     for (p, specs) in zip(net.param_values(), net.param_specs()):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index de03750..2091ee5 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -22,7 +22,6 @@ includes 1 label & 3072 pixels.  3072 pixels are 3 channels of a 32x32 image
 import cPickle
 import numpy as np
 import os
-import sys
 import argparse
 
 # sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
@@ -84,7 +83,7 @@ def normalize_for_alexnet(train_x, test_x):
 
 
 def vgg_lr(epoch):
-    return 0.01 / float(1 << ((epoch / 30)))
+    return 0.1 / float(1 << ((epoch / 25)))
 
 
 def alexnet_lr(epoch):
@@ -92,7 +91,7 @@ def alexnet_lr(epoch):
         return 0.001
     elif epoch < 130:
         return 0.0001
-    elif epoch < 140:
+    else:
         return 0.00001
 
 
@@ -107,8 +106,8 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
         dev = device.create_cuda_gpu()
 
     net.to_device(dev)
-    opt = optimizer.SGD(momentum=0.9, weight_decay=weight_decay)
-    for (p, specs) in zip(net.param_values(), net.param_specs()):
+    opt = optimizer.SGD(momentum=0.9, decay=weight_decay)
+    for (p, specs) in zip(net.param_names(), net.param_specs()):
         opt.register(p, specs)
 
     tx = tensor.Tensor((batch_size, 3, 32, 32), dev)
@@ -129,13 +128,13 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
             grads, (l, a) = net.train(tx, ty)
             loss += l
             acc += a
-            for (s, p, g) in zip(net.param_specs(), net.param_values(), grads):
-                opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s.name))
+            for (s, p, g) in zip(net.param_names(), net.param_values(), grads):
+                opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s))
             # update progress bar
             utils.update_progress(b * 1.0 / num_train_batch,
                                   'training loss = %f, accuracy = %f' % (l, a))
-        info = '\ntraining loss = %f, training accuracy = %f' \
-            % (loss / num_train_batch, acc / num_train_batch)
+        info = '\ntraining loss = %f, training accuracy = %f, lr = %f' \
+            % (loss / num_train_batch, acc / num_train_batch, get_lr(epoch))
         print info
 
         loss, acc = 0.0, 0.0
@@ -167,7 +166,7 @@ if __name__ == '__main__':
     if args.model == 'alexnet':
         train_x, test_x = normalize_for_alexnet(train_x, test_x)
         net = alexnet.create_net(args.use_cpu)
-        train((train_x, train_y, test_x, test_y), net, 140, alexnet_lr, 0.004,
+        train((train_x, train_y, test_x, test_y), net, 160, alexnet_lr, 0.004,
               use_cpu=args.use_cpu)
     else:
         train_x, test_x = normalize_for_vgg(train_x, test_x)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/src/model/feed_forward_net.cc
----------------------------------------------------------------------
diff --git a/src/model/feed_forward_net.cc b/src/model/feed_forward_net.cc
index 514d6e2..3875430 100644
--- a/src/model/feed_forward_net.cc
+++ b/src/model/feed_forward_net.cc
@@ -206,8 +206,8 @@ const std::pair<float, float> FeedForwardNet::TrainOnBatch(int epoch,
 
 const Tensor FeedForwardNet::Forward(int flag, const Tensor& data) {
   Tensor input = data, output;
+  // LOG(INFO) << data.L1();
   for (auto layer : layers_) {
-    //    LOG(INFO) << layer->name() << ": " << input.L1();
     output = layer->Forward(flag, input);
     // LOG(INFO) << layer->name() << ": " << output.L2();
     input = output;
@@ -220,13 +220,13 @@ const vector<Tensor> FeedForwardNet::Backward(int flag, const Tensor& grad) {
   std::stack<Tensor> buf;
   Tensor tmp = grad;
   for (int i = layers_.size() - 1; i >= 0; i--) {
-    //   LOG(INFO) << layers_.at(i)->name() << " : " << tmp.L1();
+    // LOG(INFO) << layers_.at(i)->name() << " : " << tmp.L1();
     auto ret = layers_.at(i)->Backward(flag, tmp);
     tmp = ret.first;
     if (ret.second.size()) {
       for (int k = ret.second.size() - 1; k >= 0; k--) {
         buf.push(ret.second[k]);
-        //       LOG(INFO) <<  "      " << buf.top().L1();
+        // LOG(INFO) <<  "      " << buf.top().L1();
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/src/model/optimizer/sgd.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/sgd.cc b/src/model/optimizer/sgd.cc
index d78d5b8..ac453cd 100644
--- a/src/model/optimizer/sgd.cc
+++ b/src/model/optimizer/sgd.cc
@@ -33,6 +33,7 @@ void SGD::Setup(const OptimizerConf& conf) {
 // value = value - history
 void SGD::Apply(int step, float lr, const string& name, const Tensor& grad,
                 Tensor& value) {
+  // LOG(INFO) << "param " << name  << " lr = " << lr << " grad = " << grad.L1() << " value = " << value.L1();
   if (momentum_generator_) {
     float mom = momentum_generator_(step);
     if (mom != 0) {
@@ -46,9 +47,8 @@ void SGD::Apply(int step, float lr, const string& name, const Tensor& grad,
       value -= history;
       return;
     }
-  } else {
-    Axpy(-lr, grad, &value);
   }
+  Axpy(-lr, grad, &value);
 }
 }  // namespace singa
 #endif  // SRC_MODEL_OPTIMIZER_SGD_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/src/python/singa/__init__.py
----------------------------------------------------------------------
diff --git a/src/python/singa/__init__.py b/src/python/singa/__init__.py
index f14c8c5..e69de29 100644
--- a/src/python/singa/__init__.py
+++ b/src/python/singa/__init__.py
@@ -1,240 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# =============================================================================
-
-'''
-This script is the main entrance for user to run singa inside a model workspace
-
-To use this script, user sudo install these dependencies: flask pillow and protobuf
-'''
-
-import sys, glob, os, random, shutil, time
-from flask import Flask, request, redirect, url_for
-import numpy as np
-import ConfigParser
-import urllib, traceback
-
-
-from argparse import ArgumentParser
-from argparse import RawDescriptionHelpFormatter
-sys.path.append(os.getcwd())
-
-__all__ = []
-__version__ = 0.1
-__date__ = '2016-07-20'
-__updated__ = '2016-07-20'
-__shortdesc__ = '''
-welcome to singa
-'''
-
-app = Flask(__name__)
-config = ConfigParser.RawConfigParser()
-service = {}
-data_path = "data_"
-parameter_path = "parameter_"
-
-debug = False
-
-class CLIError(Exception):
-    '''Generic exception to raise and log different fatal errors.'''
-    def __init__(self, msg):
-        super(CLIError).__init__(type(self))
-        self.msg = "E: %s" % msg
-    def __str__(self):
-        return self.msg
-    def __unicode__(self):
-        return self.msg
-
-def main(argv=None): # IGNORE:C0111
-    '''Command line options.'''
-
-    from . import device
-
-    if argv is None:
-        argv = sys.argv
-    else:
-        sys.argv.extend(argv)
-
-    program_name = os.path.basename(sys.argv[0])
-    program_version = "v%s" % __version__
-    program_build_date = str(__updated__)
-    program_version_message = '%%(prog)s %s (%s)' % (program_version, program_build_date)
-    program_shortdesc = __shortdesc__
-    program_license = '''%s
-
-  Created by dbsystem group on %s.
-  Copyright 2016 NUS School of Computing. All rights reserved.
-
-  Licensed under the Apache License 2.0
-  http://www.apache.org/licenses/LICENSE-2.0
-
-  Distributed on an "AS IS" basis without warranties
-  or conditions of any kind, either express or implied.
-
-USAGE
-''' % (program_shortdesc, str(__date__))
-
-    global debug
-
-    try:
-        # Setup argument parser
-        parser = ArgumentParser(description=program_license, formatter_class=RawDescriptionHelpFormatter)
-        parser.add_argument("-p", "--port", dest="port", default=5000, help="the port to listen to, default is 5000")
-        parser.add_argument("-param", "--parameter", dest="parameter",  help="the parameter file path to be loaded")
-        parser.add_argument("-D", "--debug", dest="debug", action="store_true", help="whether need to debug")
-        parser.add_argument("-R", "--reload", dest="reload_data", action="store_true", help="whether need to reload data")
-        parser.add_argument("-C", "--cpu", dest="use_cpu", action="store_true", help="Using cpu or not, default is using gpu")
-        parser.add_argument("-m", "--mode", dest="mode", choices=['train','test','serve'], default='serve', help="On Which mode (train,test,serve) to run singa")
-        parser.add_argument('-V', '--version', action='version', version=program_version_message)
-
-        # Process arguments
-        args = parser.parse_args()
-
-        port = args.port
-        parameter_file = args.parameter
-        mode = args.mode
-        need_reload = args.reload_data
-        use_cpu = args.use_cpu
-        debug = args.debug
-
-        #prepare data files
-        config.read('file.cfg')
-        file_prepare(need_reload)
-
-
-        import network as net
-        model = net.create()
-
-        #load parameter
-        parameter_file=get_parameter(parameter_file)
-
-        if parameter_file:
-            print "load parameter file: %s" % parameter_file
-            model.load(parameter_file)
-
-        if use_cpu:
-            raise CLIError("Currently cpu is not support!")
-        else:
-            print "runing with gpu"
-            d = device.create_cuda_gpu()
-
-        model.to_device(d)
-
-        if mode == "serve":
-            print "runing singa in serve mode, listen to  port: %s " % port
-            global service
-            from serve import Service
-            service =Service(model,d)
-
-            app.debug = debug
-            app.run(host='0.0.0.0', port= port)
-        elif mode == "train":
-            print "runing singa in train mode"
-            global trainer
-            from train import Trainer
-            trainer= Trainer(model,d)
-            if not parameter_file:
-                trainer.initialize()
-            trainer.train()
-        else:
-            raise CLIError("Currently only serve mode is surpported!")
-        return 0
-    except KeyboardInterrupt:
-        ### handle keyboard interrupt ###
-        return 0
-    except Exception, e:
-        if debug:
-            traceback.print_exc()
-            raise(e)
-        indent = len(program_name) * " "
-        sys.stderr.write(program_name + ": " + str(e) + "\n")
-        sys.stderr.write(indent + "  for help use --help \n\n")
-        return 2
-
-def file_prepare(reload_data=False):
-    '''
-        download all files and generate data.py
-    '''
-    if not reload_data and os.path.exists("data_.py"):
-        return
-
-    print "download file"
-    #clean data
-    shutil.rmtree("data_.py",ignore_errors=True)
-    shutil.rmtree("data_",ignore_errors=True)
-
-    data_py=open("data_.py",'w')
-    data_py.write("#%s" % "This file is Generated by SINGA, please don't edit\n\n")
-    if config.has_section("data"):
-        file_list = config.items("data")
-        #download files
-        for f in file_list:
-            name,path=download_file(f[0],f[1],data_path)
-            data_py.write("%s=\"%s\"\n" % (name,path))
-
-    data_py.flush()
-    data_py.close()
-
-    if config.has_section("parameter"):
-        parameter_list = config.items("parameter")
-        for p in parameter_list:
-            download_file(p[0],p[1],parameter_path)
-
-def download_file(name,path,dest):
-    '''
-    download one file to dest
-    '''
-    if not os.path.exists(dest):
-        os.makedirs(dest)
-    if (path.startswith('http')):
-        file_name = path.split('/')[-1]
-        target = os.path.join(dest,file_name)
-        urllib.urlretrieve(path,target)
-    return name,target
-
-
-def get_parameter(file_name=None):
-    '''
-    get the paticular file name or get the last parameter file
-    '''
-    if not os.path.exists(parameter_path):
-        os.makedirs(parameter_path)
-        return
-
-    if file_name:
-	return os.path.join(parameter_path,file_name)
-
-    parameter_list = [ os.path.join(parameter_path,f) for f in os.listdir(parameter_path)]
-    if len(parameter_list)==0:
-        return
-    parameter_list.sort()
-
-    return parameter_list[-1]
-
-@app.route("/")
-def index():
-    return "Hello SINGA User!"
-
-@app.route('/predict', methods=['POST'])
-def predict():
-    if request.method == 'POST':
-        try:
-            response=service.serve(request)
-        except Exception as e:
-            return e
-        return response
-    return "error, should be post request"

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/src/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/src/python/singa/layer.py b/src/python/singa/layer.py
index c8c8c05..1e9caeb 100644
--- a/src/python/singa/layer.py
+++ b/src/python/singa/layer.py
@@ -362,8 +362,8 @@ class BatchNormalization(Layer):
 
 
 class LRN(Layer):
-    def __init__(self, name, size=5, alpha=1, beta=0.75, mode='cross_channel',
-                 k=1, input_sample_shape=None):
+    def __init__(self, name, size=5, alpha=1e-4, beta=0.75,
+                 mode='cross_channel', k=1, input_sample_shape=None):
         """Local response normalization.
 
         Args:
@@ -391,7 +391,7 @@ class Dense(Layer):
 
     def __init__(self, name, num_output, use_bias=True,
                  W_specs=None, b_specs=None,
-                 W_transpose=True, input_sample_shape=None):
+                 W_transpose=False, input_sample_shape=None):
         """Apply linear/affine transformation, also called inner-product or
         fully connected layer.
 
@@ -424,10 +424,10 @@ class Dense(Layer):
             W_specs['name'] = name + '_weight'
         if 'name' not in b_specs:
             b_specs['name'] = name + '_bias'
-        self.conf.param.extend([_construct_param_specs_from_dict(W_specs)])
-        self.param_specs.append(_construct_param_specs_from_dict(W_specs))
-        self.conf.param.extend([_construct_param_specs_from_dict(b_specs)])
-        self.param_specs.append(_construct_param_specs_from_dict(b_specs))
+        wspecs = _construct_param_specs_from_dict(W_specs)
+        bspecs = _construct_param_specs_from_dict(b_specs)
+        self.conf.param.extend([wspecs, bspecs])
+        self.param_specs.extend([wspecs, bspecs])
         # dense layer is transparent to engine.
         self.layer = _create_layer('singa', 'Dense')
         if input_sample_shape is not None:
@@ -712,6 +712,7 @@ def _construct_param_specs_from_dict(specs):
         a ParamSpec object
     """
     conf = model_pb2.ParamSpec()
+    print 'convert', specs
     if 'name' in specs:
         conf.name = specs['name']
     if 'lr_mult' in specs:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/src/python/singa/net.py
----------------------------------------------------------------------
diff --git a/src/python/singa/net.py b/src/python/singa/net.py
index f040378..3a1732c 100644
--- a/src/python/singa/net.py
+++ b/src/python/singa/net.py
@@ -95,16 +95,22 @@ class FeedForwardNet(object):
         # print x.l1()
         for lyr in self.layers:
             x = lyr.forward(flag, x)
-        #    print lyr.name, x.l1()
+            # print lyr.name, x.l1()
         return x
 
     def backward(self):
         grad = self.loss.backward()
+        if len(grad.shape) > 1:
+            grad /= grad.shape[0]  # average across the batch
+        # print 'grad', grad.l1()
         pgrads = []
         for lyr in reversed(self.layers):
             grad, _pgrads = lyr.backward(kTrain, grad)
+            # disp = '%f ' % grad.l1()
             for g in reversed(_pgrads):
                 pgrads.append(g)
+                # disp = disp + ', %f ' % g.l1()
+            # print disp
         return reversed(pgrads)
 
     def save(self, f):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/src/python/singa/optimizer.py
----------------------------------------------------------------------
diff --git a/src/python/singa/optimizer.py b/src/python/singa/optimizer.py
index aa6bdd1..32f03d4 100644
--- a/src/python/singa/optimizer.py
+++ b/src/python/singa/optimizer.py
@@ -102,16 +102,19 @@ class Optimizer(object):
             name (str): parameter name
             specs (ParamSpec): protobuf obj
         """
-	assert type(specs) == model_pb2.ParamSpec, \
-		'specs should be model_pb2.ParamSpec instance'
+        assert type(specs) == model_pb2.ParamSpec, \
+            'specs should be model_pb2.ParamSpec instance'
         if specs.HasField('regularizer'):
             self.regularizers[name] = CppRegularizer(specs.regularizer)
+        elif specs.decay_mult != 1:
+            self.regularizers[name] = L2Regularizer(
+                specs.decay_mult * self.regularizer.coefficient)
+
         if specs.HasField('constraint'):
             self.constraints[name] = CppConstraint(specs.constraint)
+
         if specs.lr_mult != 1:
             self.learning_rate_multiplier[name] = specs.lr_mult
-        if specs.decay_mult != 1:
-            self.decay_multiplier[name] = specs.decay_mult
 
     def apply_regularizer_constraint(self, value, grad, name=None, step=None):
         """Apply regularization and constraint if available.
@@ -129,12 +132,12 @@ class Optimizer(object):
             the updated gradient Tensor
         """
         if name is not None and name in self.constraints:
-            self.constraints[name].apply(value, grad, step)
+            self.constraints[name].apply(step, value, grad)
         elif self.constraint is not None:
             self.constraint.apply(step, value, grad)
 
         if name is not None and name in self.regularizers:
-            self.regularizers[name].apply(value, grad, step)
+            self.regularizers[name].apply(step, value, grad)
         elif self.regularizer is not None:
             self.regularizer.apply(step, value, grad)
         return grad
@@ -175,24 +178,29 @@ class Optimizer(object):
         assert self.lr_gen is not None, 'Learning rate generator is not set.'\
             'Either set the lr_gen in constructor or call apply_with_lr'
         lr = self.lr_gen(step)
+        if name is not None and name in self.learning_rate_multiplier:
+            lr = lr * self.learning_rate_multiplier[name]
         return self.apply_with_lr(step, lr, grad, value, name)
 
 
 class SGD(Optimizer):
 
-    def __init__(self, lr=None, momentum=None, decay=None, **kwargs):
+    def __init__(self, lr=None, momentum=None, decay=None):
         """The vallina Stochasitc Gradient Descent algorithm.
 
         See the base Optimizer for all arguments.
         """
         super(SGD, self).__init__(lr, momentum, decay)
         conf = model_pb2.OptimizerConf()
-        conf.momentum = momentum
+        if momentum is not None:
+            conf.momentum = momentum
         self.opt = singa.CreateOptimizer('SGD')
         self.opt.Setup(conf.SerializeToString())
 
     def apply_with_lr(self, step, lr, grad, value, name):
-        self.apply_regularizer_constraint(step, value, grad, name)
+        self.apply_regularizer_constraint(value, grad, name, step)
+        if name is not None and name in self.learning_rate_multiplier:
+            lr = lr * self.learning_rate_multiplier[name]
         self.opt.Apply(step, lr, name, grad.singa_tensor, value.singa_tensor)
         return value
 
@@ -206,6 +214,8 @@ class Nesterov(Optimizer):
         """
         super(Nesterov, self).__init__(lr, momentum, decay, kwargs)
         conf = model_pb2.OptimizerConf()
+        if momentum is not None:
+            conf.momentum = momentum
         self.opt = singa.CreateOptimizer('Nesterov')
         self.opt.Setup(conf.SerializeToString())
 
@@ -232,6 +242,8 @@ class AdaGrad(Optimizer):
 
     def apply_with_lr(self, step, lr, grad, value, name):
         grad = self.apply_regularizer_constraint(step, value, grad, name)
+        if name is not None and name in self.learning_rate_multiplier:
+            lr = lr * self.learning_rate_multiplier[name]
         self.opt.Apply(step, lr,  name, grad.singa_tensor, value.singa_tensor)
         return value
 
@@ -255,6 +267,8 @@ class RMSProp(Optimizer):
 
     def apply_with_lr(self, step, lr, grad, value, name):
         grad = self.apply_regularizer_constraint(step, value, grad, name)
+        if name is not None and name in self.learning_rate_multiplier:
+            lr = lr * self.learning_rate_multiplier[name]
         self.opt.Apply(step, lr,  name, grad.singa_tensor, value.singa_tensor)
         return value
 
@@ -300,7 +314,9 @@ class L2Regularizer(Regularizer):
         if coefficient is None:
             assert self.coefficient is not None, 'Must set the coefficient'
             coefficient = self.coefficient
-        tensor.axpy(coefficient, value, grad)
+        # print coefficient, value.l1(), grad.l1()
+        if coefficient != 0:
+            tensor.axpy(coefficient, value, grad)
         return grad
 
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6d4539ee/src/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/src/python/singa/tensor.py b/src/python/singa/tensor.py
index ed651e9..1d04cdf 100644
--- a/src/python/singa/tensor.py
+++ b/src/python/singa/tensor.py
@@ -177,28 +177,28 @@ class Tensor(object):
         if isinstance(x, Tensor):
             self.singa_tensor += x.singa_tensor
         else:
-            self.singa_tensor += x
+            self.singa_tensor += float(x)
         return self
 
     def __isub__(self, x):
         if isinstance(x, Tensor):
             self.singa_tensor -= x.singa_tensor
         else:
-            self.singa_tensor -= x
+            self.singa_tensor -= float(x)
         return self
 
     def __imul__(self, x):
         if isinstance(x, Tensor):
             self.singa_tensor *= x.singa_tensor
         else:
-            self.singa_tensor *= x
+            self.singa_tensor *= float(x)
         return self
 
     def __idiv__(self, x):
         if isinstance(x, Tensor):
             self.singa_tensor /= x.singa_tensor
         else:
-            self.singa_tensor /= x
+            self.singa_tensor /= float(x)
         return self
 
     '''