You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ka...@apache.org on 2016/06/27 14:11:50 UTC

[3/6] incubator-singa git commit: SINGA-204 Support the training of feed-forward neural nets

SINGA-204 Support the training of feed-forward neural nets

Implement Alexnet model for Cifar10 https://code.google.com/p/cuda-convnet/
But the test accuracy is low 0.72 (which should be 0.82).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/71eb059c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/71eb059c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/71eb059c

Branch: refs/heads/dev
Commit: 71eb059cd13ea41e74195c7c115f927aaf143490
Parents: cf1d841
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Mon Jun 27 01:21:59 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Mon Jun 27 15:29:05 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/alexnet.cc             | 123 ++++++++++++++++++---------
 examples/cifar10/cifar10.h              |   3 +-
 examples/cifar10/make.sh                |   2 +-
 include/singa/core/tensor.h             |   1 +
 include/singa/model/feed_forward_net.h  |   2 +-
 include/singa/model/initializer.h       |  26 +++++-
 include/singa/model/loss.h              |  12 ++-
 include/singa/model/metric.h            |   2 +-
 include/singa/model/optimizer.h         |  18 ++--
 include/singa/utils/string.h            |  11 +++
 src/core/tensor/math_kernel.cu          |  22 +++++
 src/core/tensor/math_kernel.h           |   2 +
 src/core/tensor/tensor.cc               |  69 ++++++++++-----
 src/core/tensor/tensor_math.h           |   5 ++
 src/core/tensor/tensor_math_cpp.h       |  14 +++
 src/core/tensor/tensor_math_cuda.h      |   9 ++
 src/model/feed_forward_net.cc           | 104 ++++++++++++++--------
 src/model/layer/cudnn_convolution.cc    |   5 +-
 src/model/layer/cudnn_dropout.cc        |   2 +-
 src/model/layer/dense.cc                |   4 +-
 src/model/loss/mse.cc                   |   5 +-
 src/model/loss/softmax_cross_entropy.cc |  11 ++-
 src/model/metric/accuracy.cc            |   1 +
 src/model/optimizer/optimizer.cc        |  30 ++++++-
 src/model/optimizer/sgd.cc              |   1 +
 src/proto/model.proto                   |   5 ++
 test/singa/test_cross_entropy.cc        |   8 +-
 test/singa/test_dense.cc                |   2 +-
 test/singa/test_mse.cc                  |   8 +-
 test/singa/test_tensor_math.cc          |   4 +-
 30 files changed, 370 insertions(+), 141 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/examples/cifar10/alexnet.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.cc b/examples/cifar10/alexnet.cc
index 45d8571..d6541a3 100644
--- a/examples/cifar10/alexnet.cc
+++ b/examples/cifar10/alexnet.cc
@@ -28,12 +28,13 @@
 #include "../../src/model/layer/cudnn_convolution.h"
 #include "../../src/model/layer/cudnn_activation.h"
 #include "../../src/model/layer/cudnn_pooling.h"
+#include "../../src/model/layer/cudnn_lrn.h"
 #include "../../src/model/layer/dense.h"
 #include "../../src/model/layer/flatten.h"
 namespace singa {
 
 LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
-                      int pad) {
+                      int pad, float std) {
   LayerConf conf;
   conf.set_name(name);
   conf.set_type("CudnnConvolution");
@@ -42,13 +43,23 @@ LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
   conv->add_kernel_size(kernel);
   conv->add_stride(stride);
   conv->add_pad(pad);
+  conv->set_bias_term(true);
 
-  FillerConf *weight = conv->mutable_weight_filler();
-  weight->set_type("Xavier");
+  ParamSpec *wspec = conf.add_param();
+  wspec->set_name(name + "_weight");
+  auto wfill = wspec->mutable_filler();
+  wfill->set_type("Gaussian");
+  wfill->set_std(std);
+
+  ParamSpec *bspec = conf.add_param();
+  bspec->set_name(name + "_bias");
+  bspec->set_lr_mult(2);
+//  bspec->set_decay_mult(0);
   return conf;
 }
 
-LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride, int pad) {
+LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
+                         int pad) {
   LayerConf conf;
   conf.set_name(name);
   conf.set_type("CudnnPooling");
@@ -56,8 +67,7 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride, int
   pool->set_kernel_size(kernel);
   pool->set_stride(stride);
   pool->set_pad(pad);
-  if (!max_pool)
-    pool->set_pool(PoolingConf_PoolMethod_AVE);
+  if (!max_pool) pool->set_pool(PoolingConf_PoolMethod_AVE);
   return conf;
 }
 
@@ -68,21 +78,38 @@ LayerConf GenReLUConf(string name) {
   return conf;
 }
 
-LayerConf GenDenseConf(string name, int num_output) {
+LayerConf GenDenseConf(string name, int num_output, float std, float wd) {
   LayerConf conf;
   conf.set_name(name);
   conf.set_type("Dense");
   DenseConf *dense = conf.mutable_dense_conf();
   dense->set_num_output(num_output);
-  FillerConf *weight = dense->mutable_weight_filler();
-  weight->set_type("Xavier");
+  FillerConf *bias = dense->mutable_bias_filler();
+
+  ParamSpec *wspec = conf.add_param();
+  wspec->set_name(name + "_weight");
+  wspec->set_decay_mult(wd);
+
+  auto wfill = wspec->mutable_filler();
+  wfill->set_type("Gaussian");
+  wfill->set_std(std);
+
+  ParamSpec *bspec = conf.add_param();
+  bspec->set_name(name + "_bias");
+  bspec->set_lr_mult(2);
+  bspec->set_decay_mult(0);
+
   return conf;
 }
 
-LayerConf GenSoftmaxConf(string name) {
+LayerConf GenLRNConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnSoftmax");
+  conf.set_type("CudnnLRN");
+  LRNConf *lrn = conf.mutable_lrn_conf();
+  lrn->set_local_size(3);
+  lrn->set_alpha(5e-05);
+  lrn->set_beta(0.75);
   return conf;
 }
 
@@ -92,25 +119,25 @@ LayerConf GenFlattenConf(string name) {
   conf.set_type("Flatten");
   return conf;
 }
-FeedForwardNet CreateNet(Optimizer* opt, Loss<Tensor>* loss, Metric<Tensor>* metric) {
+
+FeedForwardNet CreateNet() {
   FeedForwardNet net;
   Shape s{3, 32, 32};
 
-  net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2), &s);
+  net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2, 0.0001),
+          &s);
   net.Add(new CudnnActivation(), GenReLUConf("relu1"));
-  net.Add(new CudnnPooling, GenPoolingConf("pool1", true, 3, 2, 0));
-  net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2));
+  net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 3, 2, 1));
+  net.Add(new CudnnLRN(), GenLRNConf("lrn1"));
+  net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2, 0.01));
   net.Add(new CudnnActivation(), GenReLUConf("relu2"));
-  net.Add(new CudnnPooling(), GenPoolingConf("pool2", true, 3, 2, 0));
-  net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2));
+  net.Add(new CudnnPooling(), GenPoolingConf("pool2", false, 3, 2, 1));
+  net.Add(new CudnnLRN(), GenLRNConf("lrn2"));
+  net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2, 0.01));
   net.Add(new CudnnActivation(), GenReLUConf("relu3"));
-  net.Add(new CudnnConvolution(), GenConvConf("pool3", true, 3, 2, 0));
+  net.Add(new CudnnPooling(), GenPoolingConf("pool3", false, 3, 2, 1));
   net.Add(new Flatten(), GenFlattenConf("flat"));
-  net.Add(new Dense(), GenDenseConf("ip1", 10));
-  OptimizerConf opt_conf;
-  opt_conf.set_momentum(0.9);
-  opt->Setup(opt_conf);
-  net.Compile(true, opt, loss, metric);
+  net.Add(new Dense(), GenDenseConf("ip", 10, 0.01, 250));
   return net;
 }
 
@@ -120,50 +147,62 @@ void Train(float lr, int num_epoch, string data_dir) {
   {
     auto train = data.ReadTrainData();
     size_t nsamples = train.first.shape(0);
-    auto matx = Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
+    auto matx =
+        Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
     const auto mean = Average(matx, 0);
     SubRow(mean, &matx);
     train_x = Reshape(matx, train.first.shape());
     train_y = train.second;
     auto test = data.ReadTestData();
     nsamples = test.first.shape(0);
-    auto maty = Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples});
+    auto maty =
+        Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples});
     SubRow(mean, &maty);
     test_x = Reshape(maty, test.first.shape());
     test_y = test.second;
   }
-  LOG(ERROR) << "creating net";
+  LOG(INFO) << "Training samples = " << train_y.shape(0)
+            << " Test samples =" << test_y.shape(0);
+  auto net = CreateNet();
+  SGD sgd;
+  OptimizerConf opt_conf;
+  opt_conf.set_momentum(0.9);
+  auto reg = opt_conf.mutable_regularizer();
+  reg->set_coefficient(0.004);
+  sgd.Setup(opt_conf);
+  sgd.SetLearningRateGenerator([lr](int step) {
+    if (step <= 120)
+      return 0.001;
+    else if (step <= 130)
+      return 0.0001;
+    else if (step <= 140)
+      return 0.00001;
+  });
   SoftmaxCrossEntropy loss;
   Accuracy acc;
-  SGD sgd;
-  sgd.SetLearningRateGenerator([lr](int step) {return lr;});
-  auto net = CreateNet(&sgd, &loss, &acc);
-
+  net.Compile(true, &sgd, &loss, &acc);
   auto cuda = std::make_shared<CudaGPU>();
   net.ToDevice(cuda);
 
   train_x.ToDevice(cuda);
   train_y.ToDevice(cuda);
-  net.Train(50, num_epoch, train_x, train_y); // test_x, test_y);
+  test_x.ToDevice(cuda);
+  test_y.ToDevice(cuda);
+  net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
 }
-
-
 }
 
-int main(int argc, char** argv) {
+int main(int argc, char **argv) {
   singa::InitChannel(nullptr);
   int pos = singa::ArgPos(argc, argv, "-epoch");
-  int nEpoch = 5;
-  if (pos != -1)
-    nEpoch = atoi(argv[pos + 1]);
+  int nEpoch = 140;
+  if (pos != -1) nEpoch = atoi(argv[pos + 1]);
   pos = singa::ArgPos(argc, argv, "-lr");
-  float lr = 0.01;
-  if (pos != -1)
-    lr = atof(argv[pos + 1]);
+  float lr = 0.001;
+  if (pos != -1) lr = atof(argv[pos + 1]);
   pos = singa::ArgPos(argc, argv, "-data");
   string data = "cifar-10-batches-bin";
-  if (pos != -1)
-    data = argv[pos + 1];
+  if (pos != -1) data = argv[pos + 1];
 
   LOG(INFO) << "Start training";
   singa::Train(lr, nEpoch, data);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/examples/cifar10/cifar10.h
----------------------------------------------------------------------
diff --git a/examples/cifar10/cifar10.h b/examples/cifar10/cifar10.h
index 261c048..7f10153 100644
--- a/examples/cifar10/cifar10.h
+++ b/examples/cifar10/cifar10.h
@@ -40,11 +40,12 @@ class Cifar10 {
   const std::pair<Tensor, Tensor> ReadFile(string file, bool shuffle = false);
 
   void ReadImage(std::ifstream* file, int* label, char* buffer);
+
  private:
   const size_t kImageSize = 32;
   const size_t kImageVol = 3072;
   const size_t kBatchSize = 10000;
-  const size_t kTrainFiles = 1;
+  const size_t kTrainFiles = 5;
 
   string dir_path_;
   bool normalize_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/examples/cifar10/make.sh
----------------------------------------------------------------------
diff --git a/examples/cifar10/make.sh b/examples/cifar10/make.sh
index 17e4b39..5a41612 100755
--- a/examples/cifar10/make.sh
+++ b/examples/cifar10/make.sh
@@ -1 +1 @@
-g++ -g --std=c++11 alexnet.cc -o alexnet -I../../include -I../../build/include -I/home/wangwei/local/cudnn4/include -I/home/wangwei/local/include -I/usr/local/cuda/include/ -I../../lib/cnmem/include -L../../build/lib/ -lsinga_core -lsinga_model  -lsinga_utils -lcudart -lcublas -lcurand -lcudnn -L/usr/local/cuda/lib64 -L/home/wangwei/local/cudnn4/lib64 ../../build/lib/libproto.a -lprotobuf
+g++ -g --std=c++11 alexnet.cc -o alexnet -I../../include -I../../build/include -I/home/wangwei/local/cudnn5/include -I/home/wangwei/local/include -I/usr/local/cuda/include/ -I../../lib/cnmem/include -L../../build/lib/ -lsinga_core -lsinga_model  -lsinga_utils -lcudart -lcublas -lcurand -lcudnn -L/home/wangwei/local/cudnn5/lib64 -L/usr/local/cuda/lib64 ../../build/lib/libproto.a -lprotobuf

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 3b496d9..18aa7ef 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -180,6 +180,7 @@ class Tensor {
   template <typename SType>
   Tensor &operator/=(const SType x);
 
+  float L1() const;
   float L2() const;
 
  protected:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/feed_forward_net.h
----------------------------------------------------------------------
diff --git a/include/singa/model/feed_forward_net.h b/include/singa/model/feed_forward_net.h
index 9beeb7a..1ca417c 100644
--- a/include/singa/model/feed_forward_net.h
+++ b/include/singa/model/feed_forward_net.h
@@ -72,7 +72,7 @@ class FeedForwardNet {
   void Train(size_t batchsize, int nb_epoch, const Tensor& x, const Tensor& y,
              const Tensor& val_x, const Tensor& val_y);
   /// Train the neural net over one batch of training data.
-  const std::pair<float, float> TrainOnBatch(const Tensor& x, const Tensor& y);
+  const std::pair<float, float> TrainOnBatch(int epoch, const Tensor& x, const Tensor& y);
 
   /// Evaluate the neural net with given data.
   /// Returns one tensor for loss values and one tensor for metric values;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/initializer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/initializer.h b/include/singa/model/initializer.h
index 302fc97..7024f70 100644
--- a/include/singa/model/initializer.h
+++ b/include/singa/model/initializer.h
@@ -21,8 +21,8 @@
 #include <string>
 #include "singa/core/tensor.h"
 #include "singa/proto/model.pb.h"
+#include "singa/utils/string.h"
 namespace singa {
-namespace init {
 /// Base class for initializing parameter values.
 using InitializerConf = FillerConf;
 class Initializer {
@@ -40,6 +40,7 @@ class Initializer {
   virtual void Fill(Tensor* t) = 0;
 };
 
+namespace init {
 class Constant : public Initializer {
 public:
   Constant() = default;
@@ -76,7 +77,7 @@ public:
   void Fill(Tensor* t) override { singa::Gaussian(mean_, std_, t); }
 
  private:
-  float mean_ = 0, std_ = 0.01;
+  float mean_ = 0, std_ = 1;
 };
 
 /// Ref: [Bengio and Glorot 2010] Understanding the difficulty of training deep
@@ -86,6 +87,7 @@ public:
   void Fill(Tensor* t) override {
     CHECK_EQ(t->nDim(), 2u);
     float scale = sqrt(6.0f / (t->shape(0) + t->shape(1)));
+    LOG(INFO) << "xavier scale " << scale;
     singa::Uniform(-scale, scale, t);
   }
 };
@@ -100,6 +102,26 @@ class MSRA : public Initializer {
     singa::Gaussian(0.0f, std, t);
   }
 };
+
 }  // namespace init
+
+std::shared_ptr<Initializer> CreateInitializer(const InitializerConf& conf) {
+  std::shared_ptr<Initializer> init;
+  if (ToLowerCase(conf.type()) == "constant") {
+    init = std::make_shared<init::Constant>();
+  } else if (ToLowerCase(conf.type()) == "uniform") {
+    init = std::make_shared<init::Uniform>();
+  } else if (ToLowerCase(conf.type()) == "gaussian") {
+    init = std::make_shared<init::Gaussian>();
+  } else if (ToLowerCase(conf.type()) == "xavier") {
+    init = std::make_shared<init::Xavier>();
+  } else if (ToLowerCase(conf.type()) == "msra") {
+    init = std::make_shared<init::MSRA>();
+  } else {
+    LOG(FATAL) << "Unknown initialization type: " << conf.type();
+  }
+  init->Setup(conf);
+  return init;
+}
 }  // namespace singa
 #endif  // SINGA_MODEL_INITIALIZER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/loss.h
----------------------------------------------------------------------
diff --git a/include/singa/model/loss.h b/include/singa/model/loss.h
index 41ec701..f400768 100644
--- a/include/singa/model/loss.h
+++ b/include/singa/model/loss.h
@@ -43,13 +43,13 @@ class Loss {
 
   /// Compute the loss values for each sample/instance given the prediction
   /// and the target.
-  virtual Tensor Forward(const Tensor& prediction, const T& target) = 0;
+  virtual Tensor Forward(int flag, const Tensor& prediction, const T& target) = 0;
 
   /// Average loss values for all samples in the mini-batch
   /// It calls Forward() internally. The calling pattern should be
   /// [Evaluate|Forward] Backward.
-  float Evaluate(const Tensor& prediction, const T& target) {
-    Tensor loss = Forward(prediction, target);
+  float Evaluate(int flag, const Tensor& prediction, const T& target) {
+    Tensor loss = Forward(flag, prediction, target);
     loss.ToHost();
     return Sum<float>(loss) / (1.0f * loss.Size());
   }
@@ -68,7 +68,7 @@ class MSE : public Loss<Tensor> {
   /// and the target, which is 0.5/||prediction-target||^2
   /// Users can call Average(const Tensor&) to get the average
   /// loss value over all samples in the batch.
-  Tensor Forward(const Tensor& prediction, const Tensor& target) override;
+  Tensor Forward(int flag, const Tensor& prediction, const Tensor& target) override;
 
   /// Compute the gradients of the loss values w.r.t. the prediction,
   /// which is (prediction-target)/batchsize
@@ -90,7 +90,7 @@ class SoftmaxCrossEntropy : public Loss<Tensor> {
   /// from Softmax(prediction).
   /// Users can call Average(const Tensor&) to get the average
   /// loss value over all samples in the batch.
-  Tensor Forward(const Tensor& prediction, const Tensor& target) override;
+  Tensor Forward(int flag, const Tensor& prediction, const Tensor& target) override;
 
   /// Compute the gradients of the loss values w.r.t. the prediction,
   /// which is: p[idx] - 1 if idx is the truth category's index; else,
@@ -106,5 +106,3 @@ class SoftmaxCrossEntropy : public Loss<Tensor> {
 }  // namespace singa
 
 #endif  // SINGA_MODEL_LOSS_H_
-
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/metric.h
----------------------------------------------------------------------
diff --git a/include/singa/model/metric.h b/include/singa/model/metric.h
index d013fa4..b100435 100644
--- a/include/singa/model/metric.h
+++ b/include/singa/model/metric.h
@@ -48,7 +48,7 @@ class Metric {
 
   /// Comptue the metric value averaged over all samples (in a batch)
   float Evaluate(const Tensor& prediction, const T& target) {
-    const Tensor& metric = Forward(prediction, target);
+    const Tensor metric = Forward(prediction, target);
     return Sum<float>(metric) / (1.0f * metric.Size());
   }
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/optimizer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/optimizer.h b/include/singa/model/optimizer.h
index a268126..2ec68fe 100644
--- a/include/singa/model/optimizer.h
+++ b/include/singa/model/optimizer.h
@@ -41,7 +41,7 @@ class Regularizer;
 class Optimizer {
  public:
   Optimizer() = default;
-
+  virtual ~Optimizer();
   /// Setup the optimzier using configurations from serialized string (for
   /// binding languages).
   void Setup(const string& str) {
@@ -51,7 +51,7 @@ class Optimizer {
   }
 
   /// Setup the meta fields of the optimizer
-  virtual void Setup(const OptimizerConf& conf) {}
+  virtual void Setup(const OptimizerConf& conf);
   /// Register the parameter, e.g., create Constraint and Regularizers.
   /// If there is no constraint or regularizer, then no need to register the
   /// parameter.
@@ -76,15 +76,21 @@ class Optimizer {
   void SetLearningRateGenerator(function<float(int)> func) {
     learning_rate_generator_ = func;
   }
-  /// Since Optimizer base layer has pure virtual function, a virtual
-  /// deconstructor is needed.
-  virtual ~Optimizer() = default;
+  float GetLearningRate(int step) {
+    if (learning_rate_generator_)
+      return learning_rate_generator_(step);
+    else
+      return 0;
+  }
 
  protected:
   function<float(int)> learning_rate_generator_;
   std::unordered_map<std::string, float> learning_rate_multplier_;
+  std::unordered_map<std::string, float> weight_decay_multplier_;
   std::unordered_map<std::string, Constraint*> constraints_;
   std::unordered_map<std::string, Regularizer*> regularizers_;
+  Constraint* constraint_ = nullptr;
+  Regularizer* regularizer_ = nullptr;
 };
 
 /// Apply constraints for parameters (gradient).
@@ -141,7 +147,7 @@ class Regularizer {
   /// e.g., clip each gradient if it is too large w.r.t the threshold,
   /// \ref
   /// https://www.reddit.com/r/MachineLearning/comments/31b6x8/gradient_clipping_rnns/
-  void Apply(int step, Tensor* grad, Tensor* value);
+  void Apply(int step, Tensor* grad, Tensor* value, float scale = 1.0f);
   /// Apply the regularizer for multiple parameter objects together.
   /// \ref https://github.com/Lasagne/Lasagne/blob/master/lasagne/updates.py
   void Apply(int step, const vector<Tensor*>& grads,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/utils/string.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/string.h b/include/singa/utils/string.h
index cbfb28b..b4c7c24 100644
--- a/include/singa/utils/string.h
+++ b/include/singa/utils/string.h
@@ -51,6 +51,17 @@ inline int ArgPos(int argc, char** arglist, const char* arg) {
   return -1;
 }
 
+template<typename T>
+inline std::string VecToStr(const std::vector<T> & in) {
+  std::string out = "(";
+
+  for (auto x : in) {
+    out += std::to_string(x) + ", ";
+  }
+  out += ")";
+  return out;
+}
+
 /**
  * Tokenize a string.
  *

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/math_kernel.cu
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 4135ab8..13005af 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -265,6 +265,19 @@ __global__ void KernelLT(const size_t num, const float *in, const float x,
     out[idx] = in[idx] < x ? 1.0f : 0.0f;
   }
 }
+
+__global__ void KernelRowMax(const size_t nrow, const size_t ncol, const float *inPtr,
+    float *outPtr) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < nrow;
+       idx += blockDim.x * gridDim.x) {
+    int offset = idx * ncol;
+    float maxval = inPtr[offset];
+    for (size_t k = 1; k < ncol; k++) {
+      maxval = max(maxval, inPtr[offset + k]);
+    }
+    outPtr[idx] = maxval;
+  }
+}
 __global__ void KernelComputeCrossEntropy(const size_t batchsize,
                                           const size_t dim, const float *p,
                                           const int *t, float *loss) {
@@ -286,6 +299,9 @@ __global__ void KernelSoftmaxCrossEntropyBwd(const size_t batchsize,
     grad[pos] = p[pos] - 1.0f;  // TODO(wangwei) Consider p and grad are diff
   }
 }
+
+
+
 // ********************************
 // Functions call kernels
 // ********************************
@@ -421,6 +437,12 @@ void SoftmaxCrossEntropyBwd(size_t batchsize, const size_t dim, const float *p,
   KernelSoftmaxCrossEntropyBwd <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF>>>
       (batchsize, dim, p, t, grad);
 }
+
+void RowMax(const size_t nrow, const size_t ncol, const float *inPtr,
+    float *outPtr, cudaStream_t stream) {
+  KernelRowMax <<<ceil(nrow / CU1DBLOCKF), CU1DBLOCKF>>>(nrow, ncol, inPtr, outPtr);
+}
+
 /*
 void square_grad(int n, const float *in, float *out, cudaStream_t s) {
   kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/math_kernel.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index d4087e5..63b0d82 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -98,6 +98,8 @@ void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim,
                             const float *p, const int *t, float *grad,
                             cudaStream_t stream);
 
+void RowMax(const size_t nrow, const size_t ncol, const float *inPtr,
+    float *outPtr, cudaStream_t stream);
 }  // cuda
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 898cdc6..b07a23c 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -42,7 +42,8 @@ Tensor::Tensor(Shape &&shape, DataType dtype)
   device_ = defaultDevice;
   block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
 }
-Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device, DataType dtype)
+Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device,
+               DataType dtype)
     : data_type_(dtype), device_(device), shape_(shape) {
   block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
 }
@@ -68,11 +69,10 @@ Tensor::Tensor(Tensor &&in)
   in.block_ = nullptr;
 }
 
-void Tensor::SetBlock(Block* block) {
+void Tensor::SetBlock(Block *block) {
   LOG(WARNING) << "Pls avoid using this function, which may have side-effect.";
   if (block_ != nullptr)
-    if (block_->DecRefCount())
-      device_->FreeBlock(block_);
+    if (block_->DecRefCount()) device_->FreeBlock(block_);
   block_ = block;
 }
 
@@ -118,8 +118,7 @@ void Tensor::ToDevice(std::shared_ptr<Device> dst) {
   // TODO(wangwei) the comparison is very strict. May compare against device ID?
   if (device_ != dst) {
     Tensor tmp(shape_, dst, data_type_);
-    if (block_ != nullptr && Size())
-      tmp.CopyData(*this);
+    if (block_ != nullptr && Size()) tmp.CopyData(*this);
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
     block_ = tmp.block_;
@@ -132,13 +131,13 @@ void Tensor::ToHost() { ToDevice(device_->host()); }
 
 template <typename DType>
 void Tensor::CopyDataFromHostPtr(const DType *src, const size_t num,
-    const size_t offset) {
+                                 const size_t offset) {
   CHECK_EQ(sizeof(DType), SizeOf(data_type_))
       << "data_type is " << DataType_Name(data_type_)
       << " user given type is of size " << sizeof(DType);
   if (src != nullptr) {
     device_->CopyDataFromHostPtr(block(), src, sizeof(DType) * num,
-        sizeof(DType) * offset);
+                                 sizeof(DType) * offset);
   } else {
     LOG(WARNING) << "Copy data from null host ptr";
   }
@@ -161,8 +160,7 @@ void Tensor::CopyData(const Tensor &src) {
 }
 
 Tensor Tensor::Clone(std::shared_ptr<Device> device) const {
-  if (device == nullptr)
-    device = device_;
+  if (device == nullptr) device = device_;
   Tensor t(shape_, device_, data_type_);
   t.transpose_ = transpose_;
   t.CopyData(*this);
@@ -244,8 +242,6 @@ GenUnaryScalarArgMemberFn(operator+=, Add);
 GenUnaryScalarArgMemberFn(operator*=, EltwiseMult);
 GenUnaryScalarArgMemberFn(operator/=, Div);
 
-
-
 // ====================Tensor Operations=======================================
 void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
                     const size_t dst_offset, const size_t src_offset) {
@@ -336,6 +332,18 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
   } while (0)
 
 // =============Element-wise operations====================================
+float Tensor::L1() const {
+  float nrm = 0.0f;
+  TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
+    device_->Exec([&nrm, this](Context *ctx) {
+      DType ret;
+      Asum<DType, Lang>(this->Size(), this->block(), &ret, ctx);
+      nrm = TypeCast<DType, float>(ret);
+    }, {this->block()}, {});
+  });
+  return nrm / Size();
+}
+
 /// L2 norm, Do not use Nrm2 (name conflict).
 float Tensor::L2() const {
   float nrm = 0.0f;
@@ -346,8 +354,10 @@ float Tensor::L2() const {
       nrm = TypeCast<DType, float>(ret);
     }, {this->block()}, {});
   });
-  return nrm;
+  return nrm / Size();
 }
+
+
 template <typename SType>
 void Tensor::SetValue(const SType x) {
   CHECK_EQ(sizeof(SType), SizeOf(data_type_));
@@ -525,18 +535,35 @@ Tensor SoftMax(const Tensor &in) {
   return out;
 }
 
+Tensor RowMax(const Tensor &in) {
+  Tensor ret({in.shape(0)}, in.device(), in.data_type());
+  TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
+    in.device()->Exec([in, ret](Context *ctx) {
+      size_t nrow = 1;
+      if (in.nDim() > 1) nrow = in.shape(0);
+      size_t ncol = in.Size() / nrow;
+      RowMax<DType, Lang>(nrow, ncol, in.block(), ret.block(), ctx);
+    }, {in.block()}, {ret.block()});
+  });
+  return ret;
+}
+
 void SoftMax(const Tensor &in, Tensor *out) {
   CHECK_LE(in.nDim(), 2u);
-  Exp(in, out);
+  out->CopyData(in);
   size_t nrow = 1, ncol = in.Size(), size = ncol;
   if (in.nDim() == 2u) {
     nrow = in.shape(0);
     ncol = size / nrow;
     out->Reshape(Shape{nrow, ncol});
   }
-  Tensor sum(Shape{nrow}, in.device(), in.data_type());
-  SumColumns(*out, &sum);
-  DivColumn(sum, out);
+  Tensor tmp = RowMax(*out);
+  SubColumn(tmp, out);
+  Exp(*out, out);
+
+  SumColumns(*out, &tmp);
+  DivColumn(tmp, out);
+  out->Reshape(in.shape());
 }
 
 void AddColumn(const Tensor &v, Tensor *M) { AddColumn(1, 1, v, M); }
@@ -582,8 +609,8 @@ void AddRow(const SType alpha, const SType beta, const Tensor &v, Tensor *M) {
     Mult(alpha, one, vmat, beta, M);
   }
 }
-template
-void AddRow(const float alpha, const float beta, const Tensor &v, Tensor *M);
+template void AddRow(const float alpha, const float beta, const Tensor &v,
+                     Tensor *M);
 
 /// Divide column 'v' by each column of matrix M; write results into 'out'
 void DivColumn(const Tensor &v, Tensor *M) {
@@ -699,7 +726,7 @@ void MultRow(const Tensor &v, Tensor *M) {
   });
 }
 
-Tensor SliceRows(const Tensor& in, const size_t start, const size_t end) {
+Tensor SliceRows(const Tensor &in, const size_t start, const size_t end) {
   LOG(FATAL) << "Tensor::SliceRows is not implemented";
   Tensor ret;
   /*
@@ -788,6 +815,7 @@ void Gaussian(const SType mean, const SType std, Tensor *out) {
 template void Gaussian<float>(const float mean, const float std, Tensor *out);
 
 // ================Blas operations============================================
+
 template <typename SType>
 void Axpy(const SType alpha, const Tensor &in, Tensor *out) {
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
@@ -869,5 +897,4 @@ void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p) {
   });
 }
 
-
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index 57ccb88..7732dd2 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -339,6 +339,11 @@ void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim,
   LOG(FATAL) << "Not Implemented";
 }
 
+template <typename DType, typename Lang>
+void RowMax(const size_t nrow, const size_t ncol, const Block *in,
+    const Block *ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
 // **************************************
 // Matrix functions
 // **************************************

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index 4717b5f..3e0c8ad 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -549,6 +549,20 @@ void SoftmaxCrossEntropyBwd<float, lang::Cpp>(const size_t batchsize,
   }
 }
 
+template <>
+void RowMax<float, lang::Cpp>(const size_t nrow, const size_t ncol,
+                              const Block *in, const Block *out, Context *ctx) {
+  const float *inPtr = static_cast<const float *>(in->data());
+  float *outPtr = static_cast<float *>(out->mutable_data());
+  for (size_t r = 0; r < nrow; r++) {
+    int offset = r * ncol;
+    float maxval = inPtr[offset];
+    for (size_t c = 1; c < ncol; c++)
+      maxval = std::max(maxval, inPtr[offset + c]);
+    outPtr[r] = maxval;
+  }
+}
+
 // =========Matrix operations ================================================
 /*
 template <>

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index 67ee861..43bfa1b 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -421,6 +421,15 @@ void SoftmaxCrossEntropyBwd<float, lang::Cuda>(const size_t batchsize,
   cuda::SoftmaxCrossEntropyBwd(batchsize, dim, pPtr, tPtr, gradPtr,
                                ctx->stream);
 }
+
+template <>
+void RowMax<float, lang::Cuda>(const size_t nrow, const size_t ncol,
+                               const Block* in, const Block* out,
+                               Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in->data());
+  float* outPtr = static_cast<float*>(out->mutable_data());
+  cuda::RowMax(nrow, ncol, inPtr, outPtr, ctx->stream);
+}
 }  // namespace singa
 
 #endif  // USE_CUDA

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/feed_forward_net.cc
----------------------------------------------------------------------
diff --git a/src/model/feed_forward_net.cc b/src/model/feed_forward_net.cc
index a24d36a..e682918 100644
--- a/src/model/feed_forward_net.cc
+++ b/src/model/feed_forward_net.cc
@@ -1,22 +1,26 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
 
 #include "singa/model/feed_forward_net.h"
+#include "singa/model/initializer.h"
 #include "singa/utils/logging.h"
 #include "singa/utils/channel.h"
 namespace singa {
@@ -37,12 +41,15 @@ Layer* FeedForwardNet::Add(const LayerConf& conf, const Shape* sample_shape) {
   return layer;
 }
 
-Layer* FeedForwardNet::Add(Layer* layer, const LayerConf& conf, const Shape* sample_shape) {
+Layer* FeedForwardNet::Add(Layer* layer, const LayerConf& conf,
+                           const Shape* sample_shape) {
+  CHECK(conf.has_name()) << "Must set layer name";
   if (sample_shape == nullptr)
     layer->Setup(layers_.back()->GetOutputSampleShape(), conf);
   else
     layer->Setup(*sample_shape, conf);
   Add(layer);
+  LOG(INFO) << layer->name() << VecToStr(layer->GetOutputSampleShape());
   return layer;
 }
 
@@ -75,12 +82,19 @@ void FeedForwardNet::Compile(bool shuffle, Optimizer* opt, Loss<Tensor>* loss,
   opt_ = opt;
   loss_ = loss;
   metric_ = metric;
-  // init params and register them to sgd
+  const auto specs = GetParamSpecs();
+  const auto params = GetParamValues();
+  CHECK_EQ(specs.size(), params.size());
+  for (size_t k = 0; k < specs.size(); k++) {
+    opt_->Register(specs[k].name(), specs[k]);
+    auto init = CreateInitializer(specs[k].filler());
+    init->Fill(params[k]);
+    LOG(INFO) << specs[k].name() << " : " << params[k]->L1();
+  }
 }
 
 void FeedForwardNet::ToDevice(std::shared_ptr<Device> device) {
-  for (auto layer: layers_)
-    layer->ToDevice(device);
+  for (auto layer : layers_) layer->ToDevice(device);
   /*
   opt_->ToDevice(device);
   loss_->ToDevice(device);
@@ -129,7 +143,6 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x,
 void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x,
                            const Tensor& y, const Tensor& val_x,
                            const Tensor& val_y) {
-  InitNetParams();
   CHECK_EQ(x.shape(0), y.shape(0)) << "Diff num of sampels in x and y";
   int num_extra_samples = x.shape(0) % batchsize;
   if (num_extra_samples != 0)
@@ -137,13 +150,18 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x,
   Channel* train_ch = GetChannel("train_perf");
   train_ch->EnableDestStderr(true);
   Channel* val_ch = GetChannel("val_perf");
+  val_ch->EnableDestStderr(true);
+  std::vector<size_t> index;
+  for (size_t i = 0; i < x.shape(0) / batchsize; i++) index.push_back(i);
   for (int epoch = 0; epoch < nb_epoch; epoch++) {
+    if (shuffle_) std::random_shuffle(index.begin(), index.end());
     float loss = 0.0f, metric = 0.0f;
     size_t b = 0;
     for (; b < x.shape(0) / batchsize; b++) {
-      const Tensor bx = CopyRows(x, b * batchsize, b * batchsize + batchsize);
-      const Tensor by = CopyRows(y, b * batchsize, b * batchsize + batchsize);
-      const auto ret = TrainOnBatch(bx, by);
+      size_t idx = index[b];
+      const Tensor bx = CopyRows(x, idx * batchsize, (idx + 1) * batchsize);
+      const Tensor by = CopyRows(y, idx * batchsize, (idx + 1) * batchsize);
+      const auto ret = TrainOnBatch(epoch, bx, by);
       loss += ret.first;
       metric += ret.second;
     }
@@ -151,7 +169,8 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x,
     metric /= b;
     train_ch->Send("Epoch " + std::to_string(epoch) + ", training loss = " +
                    std::to_string(loss) + ", accuracy = " +
-                   std::to_string(metric));
+                   std::to_string(metric) + ", lr = " +
+                   std::to_string(opt_->GetLearningRate(epoch)));
     if (val_x.Size() && val_y.Size()) {
       const auto val_perf = Evaluate(val_x, val_y, batchsize);
       val_ch->Send("Epoch " + std::to_string(epoch) + ", val loss = " +
@@ -162,22 +181,28 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x,
   }
 }
 
-const std::pair<float, float> FeedForwardNet::TrainOnBatch(const Tensor& x,
+const std::pair<float, float> FeedForwardNet::TrainOnBatch(int epoch,
+                                                           const Tensor& x,
                                                            const Tensor& y) {
   int flag = kTrain;
   const Tensor fea = Forward(flag, x);
-  float loss = loss_->Evaluate(fea, y);
+  float loss = loss_->Evaluate(flag, fea, y);
   float metric = metric_->Evaluate(fea, y);
   const Tensor grad = loss_->Backward();
-  const auto grads = Backward(kTrain, grad);
+  auto grads = Backward(kTrain, grad / static_cast<float>(x.shape(0)));
+  auto names = GetParamNames();
+  auto values = GetParamValues();
+  for (size_t k = 0; k < grads.size(); k++) {
+    opt_->Apply(epoch, names[k], &grads[k], values.at(k));
+  }
   return std::make_pair(loss, metric);
 }
 
 const Tensor FeedForwardNet::Forward(int flag, const Tensor& data) {
   Tensor input = data, output;
   for (auto layer : layers_) {
-//    LOG(INFO) << layer->name();
     output = layer->Forward(flag, input);
+    // LOG(INFO) << layer->name() << ": " << output.L2();
     input = output;
   }
   return output;
@@ -185,13 +210,22 @@ const Tensor FeedForwardNet::Forward(int flag, const Tensor& data) {
 
 const vector<Tensor> FeedForwardNet::Backward(int flag, const Tensor& grad) {
   vector<Tensor> param_grads;
+  std::stack<Tensor> buf;
   Tensor tmp = grad;
   for (int i = layers_.size() - 1; i >= 0; i--) {
- //   LOG(INFO) << layers_.at(i)->name();
+    // LOG(INFO) << layers_.at(i)->name() << " : " << tmp.L2();
     auto ret = layers_.at(i)->Backward(flag, tmp);
     tmp = ret.first;
-    if (ret.second.size())
-      for (const auto x : ret.second) param_grads.push_back(x);
+    if (ret.second.size()) {
+      for (int k = ret.second.size() - 1; k >= 0; k--) {
+        buf.push(ret.second[k]);
+        // LOG(INFO) <<  "      " << buf.top().L2();
+      }
+    }
+  }
+  while (!buf.empty()) {
+    param_grads.push_back(buf.top());
+    buf.pop();
   }
   return param_grads;
 }
@@ -230,8 +264,8 @@ std::pair<Tensor, Tensor> FeedForwardNet::EvaluateOnBatch(const Tensor& x,
   int flag = kEval;
   const Tensor fea = Forward(flag, x);
   const Tensor m = metric_->Forward(fea, y);
-  const Tensor l = loss_->Forward(fea, y);
-  return std::make_pair(m, l);
+  const Tensor l = loss_->Forward(flag, fea, y);
+  return std::make_pair(l, m);
 }
 
 const Tensor FeedForwardNet::Predict(const Tensor& x, size_t batchsize) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index eb507b2..3dca28a 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -72,8 +72,8 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
       num_filters_, conv_height_, conv_width_));
   if (bias_term_)
     CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
-                                           GetCudnnDataType(dtype), 1, 1,
-                                           num_filters_, 1));
+                                           GetCudnnDataType(dtype), 1, 1, 1,
+                                           num_filters_));
   CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
                                               stride_h_, stride_w_, 1, 1,
                                               CUDNN_CROSS_CORRELATION));
@@ -244,6 +244,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
   }, {grad.block(), weight_.block()}, {dx.block(), workspace_.block()});
   param_grad.push_back(dw);
   param_grad.push_back(db);
+  LOG(INFO) << "bias nrm " << db.L1();
   return std::make_pair(dx, param_grad);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/layer/cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc
index f9b9dbf..ab83226 100644
--- a/src/model/layer/cudnn_dropout.cc
+++ b/src/model/layer/cudnn_dropout.cc
@@ -108,7 +108,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnDropout::Backward(
 }
 void CudnnDropout::ToDevice(std::shared_ptr<Device> device) {
   Dropout::ToDevice(device);
-  state.ToDevice(device);
+  state_.ToDevice(device);
 }
 }  // namespace singa
 #endif  // CUDNN_VERSION_MAJOR>=5

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/layer/dense.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.cc b/src/model/layer/dense.cc
index c6a9f8a..338409c 100644
--- a/src/model/layer/dense.cc
+++ b/src/model/layer/dense.cc
@@ -41,13 +41,15 @@ void Dense::Setup(const Shape& in_sample, const LayerConf &conf) {
   bias_.Reshape(Shape{hdim_});
   param_values_.push_back(&weight_);
   param_values_.push_back(&bias_);
+  for (auto specs: conf.param())
+    param_specs_.push_back(specs);
 }
 
 /// \copydoc Layer::Forward(int flag, const Tensor&)
 const Tensor Dense::Forward(int flag, const Tensor &input) {
   CHECK(buf_.empty());
   Tensor output;
-  CHECK_EQ(input.nDim(), 2);
+  CHECK_EQ(input.nDim(), 2u);
   if (transpose_)  // use the transposed version of weight_ for computing
     output = Mult(input, weight_);
   else

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/loss/mse.cc
----------------------------------------------------------------------
diff --git a/src/model/loss/mse.cc b/src/model/loss/mse.cc
index a4bbb72..6e19059 100644
--- a/src/model/loss/mse.cc
+++ b/src/model/loss/mse.cc
@@ -20,7 +20,7 @@
 
 namespace singa {
 
-Tensor MSE::Forward(const Tensor& prediction, const Tensor& target) {
+Tensor MSE::Forward(int flag, const Tensor& prediction, const Tensor& target) {
   CHECK(buf_.empty()) << "Do not call Forward successively for more than twice."
                       << " The calling pattern is [Forward|Evaluate] Backward";
   Tensor t = prediction - target;
@@ -28,7 +28,8 @@ Tensor MSE::Forward(const Tensor& prediction, const Tensor& target) {
   if (t.nDim() > 1) batchsize = t.shape().at(0);
   size_t dim = t.Size() / batchsize;
   t.Reshape(Shape{batchsize, dim});
-  buf_.push(t);
+  if (kTrain & flag)
+    buf_.push(t);
   // TODO(wangwei) use CastType for operator/
   return Sum(Square(t), 1) * 0.5f;
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/loss/softmax_cross_entropy.cc
----------------------------------------------------------------------
diff --git a/src/model/loss/softmax_cross_entropy.cc b/src/model/loss/softmax_cross_entropy.cc
index bed3348..3411fbe 100644
--- a/src/model/loss/softmax_cross_entropy.cc
+++ b/src/model/loss/softmax_cross_entropy.cc
@@ -21,7 +21,7 @@
 
 namespace singa {
 
-Tensor SoftmaxCrossEntropy::Forward(const Tensor& prediction,
+Tensor SoftmaxCrossEntropy::Forward(int flag, const Tensor& prediction,
                                     const Tensor& target) {
   CHECK(buf_.empty()) << "Do not call Forward successively for more than twice."
                       << " The calling pattern is [Forward|Evaluate] Backward";
@@ -30,13 +30,17 @@ Tensor SoftmaxCrossEntropy::Forward(const Tensor& prediction,
   size_t dim = prediction.Size() / batchsize;
   const Tensor& input = Reshape(prediction, Shape{batchsize, dim});
   Tensor prob = SoftMax(input);
+  // LOG(INFO) << "prob: " << prob.L2();
 
   // buffer intermediate data
-  buf_.push(prob);
-  buf_.push(target);
+  if (flag & kTrain) {
+    buf_.push(prob);
+    buf_.push(target);
+  }
   Tensor loss(Shape{batchsize}, prob.device(), prob.data_type());
 
   ComputeCrossEntropy(prob, target, &loss);
+
   return loss;
 }
 
@@ -50,4 +54,3 @@ Tensor SoftmaxCrossEntropy::Backward() {
 }
 }  // namespace singa
 
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/metric/accuracy.cc
----------------------------------------------------------------------
diff --git a/src/model/metric/accuracy.cc b/src/model/metric/accuracy.cc
index 1b667b1..ffda938 100644
--- a/src/model/metric/accuracy.cc
+++ b/src/model/metric/accuracy.cc
@@ -30,6 +30,7 @@ Tensor Accuracy::Match(const Tensor& predict, const vector<int>& target) {
   // TODO(wangwei) CloneToDevice(host);
   const float* prob = prediction.data<float>();
   float* score = new float[batchsize];
+  memset(score, 0, batchsize * sizeof(float));
   for (size_t b = 0; b < batchsize; b++) {
     vector<std::pair<float, int>> prob_class;
     for (size_t c = 0; c < nb_classes; c++) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/optimizer/optimizer.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/optimizer.cc b/src/model/optimizer/optimizer.cc
index c9e7a72..9be47c8 100644
--- a/src/model/optimizer/optimizer.cc
+++ b/src/model/optimizer/optimizer.cc
@@ -21,6 +21,17 @@
 
 namespace singa {
 
+Optimizer::~Optimizer() {
+  for (auto entry : regularizers_) delete entry.second;
+  for (auto entry : constraints_) delete entry.second;
+  if (constraint_ != nullptr) delete constraint_;
+  if (regularizer_ != nullptr) delete regularizer_;
+}
+void Optimizer::Setup(const OptimizerConf& conf) {
+  if (conf.has_regularizer())
+    regularizer_ = new Regularizer(conf.regularizer());
+  if (conf.has_constraint()) constraint_ = new Constraint(conf.constraint());
+}
 void Optimizer::Register(const string& name, const ParamSpec& specs) {
   if (specs.has_constraint()) {
     CHECK(constraints_.find(name) == constraints_.end())
@@ -32,6 +43,11 @@ void Optimizer::Register(const string& name, const ParamSpec& specs) {
         << "Parameter with name = " << name << " has already registered";
     regularizers_[name] = new Regularizer(specs.regularizer());
   }
+  if (specs.has_decay_mult()) {
+    CHECK(weight_decay_multplier_.find(name) == weight_decay_multplier_.end())
+        << "Parameter with name = " << name << " has already registered";
+    weight_decay_multplier_[name] = specs.decay_mult();
+  }
   if (specs.has_lr_mult()) {
     CHECK(learning_rate_multplier_.find(name) == learning_rate_multplier_.end())
         << "Parameter with name = " << name << " has already registered";
@@ -47,10 +63,18 @@ void Optimizer::Register(const string& name, const ParamSpec& specs) {
 void Optimizer::Apply(int step, const string& name, Tensor* grad,
                       Tensor* param) {
   // TODO(wangwei) need to consider the order of constraint and regularizer
-  if (regularizers_.find(name) != regularizers_.end())
+  if (regularizers_.find(name) != regularizers_.end()) {
     regularizers_.at(name)->Apply(step, param, grad);
+  } else if (regularizer_ != nullptr) {
+    float scale = 1.0f;
+    if (weight_decay_multplier_.find(name) != weight_decay_multplier_.end())
+      scale = weight_decay_multplier_.at(name);
+    regularizer_->Apply(step, param, grad, scale);
+  }
   if (constraints_.find(name) != constraints_.end())
     constraints_.at(name)->Apply(step, param, grad);
+  else if (constraint_ != nullptr)
+    constraint_->Apply(step, param, grad);
   float lr = learning_rate_generator_(step);
   if (learning_rate_multplier_.find(name) != learning_rate_multplier_.end())
     lr *= learning_rate_multplier_.at(name);
@@ -62,9 +86,9 @@ void Regularizer::Setup(const RegularizerConf& conf) {
   coefficient_ = conf.coefficient();
 }
 
-void Regularizer::Apply(int step, Tensor* value, Tensor* grad) {
+void Regularizer::Apply(int step, Tensor* value, Tensor* grad, float scale) {
   if (type_ == "L2" || type_ == "l2") {
-    (*grad) -= (*value) * coefficient_;
+    Axpy(coefficient_ * scale, *value, grad);
   } else {
     CHECK(type_ == "NotSet") << "Unknown regularizer type = " << type_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/optimizer/sgd.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/sgd.cc b/src/model/optimizer/sgd.cc
index a5c66a1..71071ff 100644
--- a/src/model/optimizer/sgd.cc
+++ b/src/model/optimizer/sgd.cc
@@ -22,6 +22,7 @@
 namespace singa {
 
 void SGD::Setup(const OptimizerConf& conf) {
+  Optimizer::Setup(conf);
   if (conf.has_momentum()) {
     float m = conf.momentum();
     SetMomentumGenerator([m](int step) { return m; });

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index c06deec..b1318d9 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -89,6 +89,11 @@ message OptimizerConf {
 
   // delta is used to avoid dividing zero
   optional float delta = 6 [default = 1e-8];
+
+  // global regularizer lower priority than ParamSpec regularizer
+  optional RegularizerConf regularizer = 10;
+  // global constraint lower priority than ParamSpec constraint
+  optional ConstraintConf constraint = 11;
 }
 
 message ConstraintConf {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/test/singa/test_cross_entropy.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cross_entropy.cc b/test/singa/test_cross_entropy.cc
index d73591f..c7fa2fb 100644
--- a/test/singa/test_cross_entropy.cc
+++ b/test/singa/test_cross_entropy.cc
@@ -44,7 +44,7 @@ TEST_F(TestSoftmaxCrossEntropy, CppForward) {
   t.CopyDataFromHostPtr(tdat, 2);
 
   singa::SoftmaxCrossEntropy cross_entropy;
-  const Tensor& loss = cross_entropy.Forward(p, t);
+  const Tensor& loss = cross_entropy.Forward(singa::kEval, p, t);
   auto ldat = loss.data<float>();
 
   const float result_test = -log(0.25);
@@ -58,7 +58,7 @@ TEST_F(TestSoftmaxCrossEntropy, CppBackward) {
   t.CopyDataFromHostPtr(tdat, 2);
 
   singa::SoftmaxCrossEntropy cross_entropy;
-  cross_entropy.Forward(p, t);
+  cross_entropy.Forward(singa::kTrain, p, t);
   const Tensor& grad = cross_entropy.Backward();
 
   auto gdat = grad.data<float>();
@@ -82,7 +82,7 @@ TEST_F(TestSoftmaxCrossEntropy, CudaForward) {
   p.CopyDataFromHostPtr(pdat, 8);
   t.CopyDataFromHostPtr(tdat, 2);
 
-  Tensor loss = cross_entropy.Forward(p, t);
+  Tensor loss = cross_entropy.Forward(singa::kEval, p, t);
   loss.ToHost();
   auto ldat = loss.data<float>();
 
@@ -99,7 +99,7 @@ TEST_F(TestSoftmaxCrossEntropy, CudaBackward) {
   p.CopyDataFromHostPtr(pdat, 8);
   t.CopyDataFromHostPtr(tdat, 2);
 
-  cross_entropy.Forward(p, t);
+  cross_entropy.Forward(singa::kTrain, p, t);
   Tensor grad = cross_entropy.Backward();
 
   grad.ToHost();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/test/singa/test_dense.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_dense.cc b/test/singa/test_dense.cc
index 363fb6e..e80384f 100644
--- a/test/singa/test_dense.cc
+++ b/test/singa/test_dense.cc
@@ -207,7 +207,7 @@ TEST(Dense, BackwardCuda) {
   singa::Tensor grad(singa::Shape{batchsize, hdim}, cuda);
   grad.CopyDataFromHostPtr(dy, batchsize * hdim);
 
-  const auto ret = dense.Backward(singa::kTrain, grad);
+  auto ret = dense.Backward(singa::kTrain, grad);
   singa::Tensor in_grad = ret.first;
   singa::Tensor dweight = ret.second.at(0);
   singa::Tensor dbias = ret.second.at(1);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/test/singa/test_mse.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_mse.cc b/test/singa/test_mse.cc
index 788652f..640caf4 100644
--- a/test/singa/test_mse.cc
+++ b/test/singa/test_mse.cc
@@ -42,7 +42,7 @@ class TestMSE : public ::testing::Test {
 #ifdef USE_CBLAS
 TEST_F(TestMSE, CppForward) {
   singa::MSE mse;
-  const Tensor& loss = mse.Forward(p, t);
+  const Tensor& loss = mse.Forward(singa::kEval, p, t);
   auto ldat = loss.data<float>();
 
   for (size_t i = 0, k = 0; i < loss.Size(); i++) {
@@ -57,7 +57,7 @@ TEST_F(TestMSE, CppForward) {
 
 TEST_F(TestMSE, CppBackward) {
   singa::MSE mse;
-  mse.Forward(p, t);
+  mse.Forward(singa::kTrain, p, t);
   const Tensor& grad = mse.Backward();
 
   auto gdat = grad.data<float>();
@@ -72,7 +72,7 @@ TEST_F(TestMSE, CudaForward) {
   auto dev = std::make_shared<singa::CudaGPU>();
   p.ToDevice(dev);
   t.ToDevice(dev);
-  Tensor loss = mse->Forward(p, t);
+  Tensor loss = mse->Forward(singa::kEval, p, t);
 
   loss.ToHost();
   auto ldat = loss.data<float>();
@@ -94,7 +94,7 @@ TEST_F(TestMSE, CudaBackward) {
   auto dev = std::make_shared<singa::CudaGPU>();
   p.ToDevice(dev);
   t.ToDevice(dev);
-  mse.Forward(p, t);
+  mse.Forward(singa::kTrain, p, t);
   Tensor grad = mse.Backward();
   grad.ToHost();
   auto gdat = grad.data<float>();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/test/singa/test_tensor_math.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc
index f8d0351..2a0df0d 100644
--- a/test/singa/test_tensor_math.cc
+++ b/test/singa/test_tensor_math.cc
@@ -346,7 +346,7 @@ TEST_F(TestTensorMath, L2Cpp) {
   float l2 = a.L2();
   float target = 0.0f;
   for (size_t i = 0; i < a.Size(); i++) target += dat1[i] * dat1[i];
-  EXPECT_FLOAT_EQ(l2, sqrt(target));
+  EXPECT_FLOAT_EQ(l2, sqrt(target) / a.Size());
 }
 TEST_F(TestTensorMath, MultCpp) {
   const float x[4] = {1.0f, 2.0f, 3.0f, 4.0f};
@@ -514,7 +514,7 @@ TEST_F(TestTensorMath, L2Cuda) {
   float l2 = t.L2();
   float target = 0.0f;
   for (size_t i = 0; i < t.Size(); i++) target += dat1[i] * dat1[i];
-  EXPECT_FLOAT_EQ(l2, sqrt(target));
+  EXPECT_FLOAT_EQ(l2, sqrt(target) / t.Size());
 }
 TEST_F(TestTensorMath, MultCuda) {
   const float x[4] = {1.0f, 2.0f, 3.0f, 4.0f};