You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/06/13 13:20:15 UTC

[22/50] [abbrv] incubator-singa git commit: SINGA0-183 Add the base classes for optimizer, constraint and regularizer

SINGA0-183 Add the base classes for optimizer, constraint and regularizer

Draft base optimizer, constraint and regularizer classes. The API for local all reduce is also added  (in comments).
Test sgd with/without momentum using cpp and cuda.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2dac3808
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2dac3808
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2dac3808

Branch: refs/heads/master
Commit: 2dac380872402e72b4250981cd99c6c59d66184d
Parents: 7d149ec
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Tue May 24 22:09:24 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Mon May 30 20:48:32 2016 +0800

----------------------------------------------------------------------
 include/singa/model/optimizer.h         | 222 +++++++++++++++++++++++++++
 src/CMakeLists.txt                      |   1 +
 src/model/optimizer/local_all_reduce.cc |  25 +++
 src/model/optimizer/optimizer.cc        |  93 +++++++++++
 src/model/optimizer/sgd.cc              |  49 ++++++
 src/proto/model.proto                   |  35 ++++-
 test/singa/test_sgd.cc                  | 150 ++++++++++++++++++
 7 files changed, 574 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2dac3808/include/singa/model/optimizer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/optimizer.h b/include/singa/model/optimizer.h
new file mode 100644
index 0000000..7ca9f53
--- /dev/null
+++ b/include/singa/model/optimizer.h
@@ -0,0 +1,222 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SINGA_MODEL_OPTIMIZER_H_
+#define SINGA_MODEL_OPTIMIZER_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "singa/core/tensor.h"
+#include "singa/proto/model.pb.h"
+
+using std::string;
+using std::vector;
+using std::unordered_map;
+namespace singa {
+class Constraint;
+class Regularizer;
+/// The base class for gradient descent algorithms used to update the model
+/// parameters in order to optimize the objective (loss) function.
+/// It updates parameters based on the gradients of the loss w.r.t each
+/// parameter. Most sub-classes uses first order gradients.
+/// An overview of gradient descent algorithms,
+/// http://sebastianruder.com/optimizing-gradient-descent/
+class Optimizer {
+ public:
+  Optimizer() = default;
+
+  /// Setup the optimzier using configurations from serialized string (for
+  /// binding languages).
+  void Setup(const string& str) {
+    OptimizerConf conf;
+    conf.ParseFromString(str);
+    this->Setup(conf);
+  }
+
+  /// Setup the meta fields of the optimizer
+  virtual void Setup(const OptimizerConf& conf) {}
+  /// Register the parameter, e.g., create Constraint and Regularizers.
+  /// If there is no constraint or regularizer, then no need to register the
+  /// parameter.
+  virtual void Register(const string& name, const ParamSpec& specs);
+
+  /// Apply the updating algorithm.
+  /// No learning rate scaling, gradient constraints/regularization will be
+  /// conducted. It assumes all these operations are done either by users or
+  /// by Apply(int, const string&, Tensor*, Tensor*).
+  /// All sub-classes should override this function.
+  virtual void Apply(int step, float lr, const string& name, Tensor* grad,
+                     Tensor* value) = 0;
+
+  /// Apply the updating algorithm.
+  /// It will apply regularization and constraint to the parameters if
+  /// configured during Register(). If will also scale the learning rate if
+  /// configured in ParamSpecs (see Register).
+  void Apply(int step, const string& name, Tensor* grad, Tensor* value);
+
+  /// The argument is a function that returns the learning rate given the
+  /// current step (i.e., curren running iteration).
+  void SetLearningRateGenerator(function<float(int)> func) {
+    learning_rate_generator_ = func;
+  }
+
+ protected:
+  function<float(int)> learning_rate_generator_;
+  std::unordered_map<std::string, float> learning_rate_multplier_;
+  std::unordered_map<std::string, Constraint*> constraints_;
+  std::unordered_map<std::string, Regularizer*> regularizers_;
+};
+
+/// Apply constraints for parameters (gradient).
+/// E.g., restrict the norm of parmeter gradients to be within a threshold.
+/// \ref http://keras.io/constraints/
+/// TODO(wangwei) implement a sub-class for each type of constraint
+class Constraint {
+ public:
+  Constraint() = default;
+  explicit Constraint(const ConstraintConf& conf) { Setup(conf); }
+  Constraint(const string& type, float threshold)
+      : type_(type), threshold_(threshold) {}
+  void Setup(const ConstraintConf& conf);
+  void Setup(const string& conf_str) {
+    ConstraintConf conf;
+    conf.ParseFromString(conf_str);
+    Setup(conf);
+  }
+  /// Apply the constraint to a single parmeter object, e.g., W, or b
+  /// e.g., clip each gradient if it is too large w.r.t the threshold,
+  /// \ref
+  /// https://www.reddit.com/r/MachineLearning/comments/31b6x8/gradient_clipping_rnns/
+  void Apply(int step, Tensor* grad, Tensor* value);
+  /// Apply the constraint for multiple parameter objects together.
+  /// \ref https://github.com/Lasagne/Lasagne/blob/master/lasagne/updates.py
+  void Apply(int step, const vector<Tensor*>& grads,
+             const vector<Tensor*>& values);
+ private:
+  /// currently only support "L2" norm constraint, i.e., the norm should be less
+  /// than the configured threshold_, otherwise, the parameters would be clipped
+  /// to make the norm within that threshold.
+  /// TODO(wangwei) consider other constraint, e.g., hard clip and unitnorm.
+  string type_ = "Unknown";
+  float threshold_;
+};
+
+/// Apply regularization for parameters (gradient), e.g., L1 norm and L2 norm.
+/// TODO(wangwei) implement a sub-class for each type of regularizer
+class Regularizer {
+ public:
+  Regularizer() = default;
+  explicit Regularizer(const RegularizerConf& conf) { Setup(conf); }
+  Regularizer(const string& type, float coefficient)
+      : type_(type), coefficient_(coefficient) {}
+  void Setup(const RegularizerConf& conf);
+  void Setup(const string& conf_str) {
+    RegularizerConf conf;
+    conf.ParseFromString(conf_str);
+    Setup(conf);
+  }
+
+  /// Apply the regularizer to a single parmeter object, e.g., W, or b
+  /// e.g., clip each gradient if it is too large w.r.t the threshold,
+  /// \ref
+  /// https://www.reddit.com/r/MachineLearning/comments/31b6x8/gradient_clipping_rnns/
+  void Apply(int step, Tensor* grad, Tensor* value);
+  /// Apply the regularizer for multiple parameter objects together.
+  /// \ref https://github.com/Lasagne/Lasagne/blob/master/lasagne/updates.py
+  void Apply(int step, const vector<Tensor*>& grads,
+             const vector<Tensor*>& values);
+ private:
+  /// currently only support "L2" regularizer. type_ is case insensitive.
+  /// TODO(wangwei) add more regularizer, e.g., L1.
+  string type_ = "NotSet";
+  float coefficient_;
+};
+
+// =============Vallina SGD with Momentum=====================================
+class SGD : Optimizer {
+ public:
+  void Setup(const OptimizerConf& conf);
+  /// Apply the updating algorithm.
+  void Apply(int step, float lr, const string& name, Tensor* grad,
+             Tensor* value) override;
+
+  /// The argument function returns the momentum value given the current running
+  /// step (i.e., iterations/mini-batches).
+  void SetMomentumGenerator(std::function<float(int)> func) {
+    momentum_generator_ = func;
+  }
+
+ private:
+  std::unordered_map<string, Tensor> history_gradient_;
+  std::function<float(int)> momentum_generator_;
+};
+
+// ============LocalAllReduce for single node multiple workers ==============
+/// Updater for training models on a single node with multiple devices (workers)
+/// All model parameters are partitioned such that each parameter is updated on
+/// one device. In specific, each worker has a model replica. All workers share
+/// the same LocalAllReduce instance. Parameters are registered at first, and
+/// then after every iteration, the gradients are aggregated by one worker (or
+/// device) for parameter updating.
+/*
+class LocalAllReduce : public Optimizer{
+ pulbic:
+  LocalAllReduce(Optimizer* opt);
+  void Setup(const string& str) {
+    AllReduce conf;
+    conf.ParseFromString(str);
+    this->Setup(conf);
+  }
+  void Setup(const AllReduce& conf) {}
+
+  /// Register all model parameters.
+  /// Instructions include:
+  /// 1. Copy parameters from the master worker (who initialized the parameters)
+  /// to others.
+  /// 2. Partition parameters onto worker devices. For example, model parameter
+  /// set is {A, B, C}, nb_workers = 3, then worker 0/1/2 would be in charge of
+  /// updating A/B/C respectively. A gradient Tensor for A/B/C would be created
+  /// on device 0/1/2, dentoed as GA/GB/GC. 0/1/2 would call the internal opt to register the specs
+  /// for A/B/C.
+  void Register(const vector<string>& names,
+                const vector<Tensor>& values,
+                const vector<ParamSpecs>& specs) override;
+
+  /// Aggregate parameter gradients and call internal opt to do the update.
+  /// Continue with the example for Register(), worker 0 would copy B's gradient
+  /// to device 1 and add it with GB.  A callback func is added to
+  /// 1. check UpdateNow() and call opt to do the real update.
+  /// 2. broadcast the new parameters back to worker 0 and 2.
+  void Update(int step, float lr, const string& name, const Tensor& grad,
+              Tensor* param) override;
+
+  /// Decide when to call the internal Optimizer for real update.
+  /// One simple implementation would return true until all workers has
+  /// aggregated their gradients. We can also add a user configuration field
+  /// to control this, e.g., if do it when 80% workers has aggregated.
+  boo UpdateNow();
+
+ private:
+  int nb_workers_;
+  vector<Tensor> aggregated_gradients_;
+};
+*/
+}
+#endif  // SINGA_MODEL_OPTIMIZER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2dac3808/src/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index df8b22b..28066de 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -35,6 +35,7 @@ LIST(APPEND SINGA_LINKER_LIBS singa_core)
 #FILE(GLOB_RECURSE model_source ${CMAKE_CURRENT_SOURCE_DIR}/model/ "*.cc")
 AUX_SOURCE_DIRECTORY(model model_source)
 AUX_SOURCE_DIRECTORY(model/layer model_source)
+AUX_SOURCE_DIRECTORY(model/optimizer model_source)
 #MESSAGE(STATUS "MODEL ${model_source}")
 ADD_LIBRARY(singa_model SHARED ${model_source})
 TARGET_LINK_LIBRARIES(singa_model ${SINGA_LINKER_LIBS})

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2dac3808/src/model/optimizer/local_all_reduce.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/local_all_reduce.cc b/src/model/optimizer/local_all_reduce.cc
new file mode 100644
index 0000000..ea03e39
--- /dev/null
+++ b/src/model/optimizer/local_all_reduce.cc
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SRC_MODEL_OPTIMIZER_LOCAL_ALL_REDUCE_H_
+#define SRC_MODEL_OPTIMIZER_LOCAL_ALL_REDUCE_H_
+#include "singa/model/optimizer.h"
+
+namespace singa {
+}
+
+#endif  // SRC_MODEL_OPTIMIZER_LOCAL_ALL_REDUCE_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2dac3808/src/model/optimizer/optimizer.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/optimizer.cc b/src/model/optimizer/optimizer.cc
new file mode 100644
index 0000000..92b6b3d
--- /dev/null
+++ b/src/model/optimizer/optimizer.cc
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/optimizer.h"
+#include "singa/utils/logging.h"
+
+namespace singa {
+
+void Optimizer::Register(const string& name, const ParamSpec& specs) {
+  if (specs.has_constraint()) {
+    CHECK(constraints_.find(name) == constraints_.end())
+        << "Parameter with name = " << name << " has already registered";
+    constraints_[name] = new Constraint(specs.constraint());
+  }
+  if (specs.has_regularizer()) {
+    CHECK(regularizers_.find(name) == regularizers_.end())
+        << "Parameter with name = " << name << " has already registered";
+    regularizers_[name] = new Regularizer(specs.regularizer());
+  }
+  if (specs.has_lr_mult()) {
+    CHECK(learning_rate_multplier_.find(name) == learning_rate_multplier_.end())
+        << "Parameter with name = " << name << " has already registered";
+    learning_rate_multplier_[name] = specs.lr_mult();
+  }
+  /*
+  if (specs.has_lr_generator()) {
+    LOG(FATAL) << "Not implemented yet";
+  }
+  */
+}
+
+void Optimizer::Apply(int step, const string& name, Tensor* grad,
+                      Tensor* param) {
+  // TODO(wangwei) need to consider the order of constraint and regularizer
+  if (regularizers_.find(name) != regularizers_.end())
+    regularizers_.at(name)->Apply(step, param, grad);
+  if (constraints_.find(name) != constraints_.end())
+    constraints_.at(name)->Apply(step, param, grad);
+  float lr = learning_rate_generator_(step);
+  if (learning_rate_multplier_.find(name) != learning_rate_multplier_.end())
+    lr *= learning_rate_multplier_.at(name);
+  Apply(step, lr, name, grad, param);
+}
+
+void Regularizer::Setup(const RegularizerConf& conf) {
+  type_ = conf.type();
+  coefficient_ = conf.coefficient();
+}
+
+void Regularizer::Apply(int step, Tensor* value, Tensor* grad) {
+  if (type_ == "L2" || type_ == "l2") {
+    (*grad) -= (*value) * coefficient_;
+  } else {
+    CHECK(type_ == "NotSet") << "Unknown regularizer type = " << type_;
+  }
+}
+
+void Regularizer::Apply(int step, const vector<Tensor*>& values,
+                        const vector<Tensor*>& grads) {
+  LOG(FATAL) << "Not implemented yet";
+}
+
+void Constraint::Setup(const ConstraintConf& conf) {
+  type_ = conf.type();
+  threshold_ = conf.threshold();
+}
+
+void Constraint::Apply(int step, Tensor* value, Tensor* grad) {
+  // TODO(wangwei) implement L2 and hard constraint
+  CHECK(type_ == "NotSet") << "Unknown regularizer type = " << type_;
+}
+
+void Constraint::Apply(int step, const vector<Tensor*>& values,
+                       const vector<Tensor*>& grads) {
+  LOG(FATAL) << "Not implemented yet";
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2dac3808/src/model/optimizer/sgd.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/sgd.cc b/src/model/optimizer/sgd.cc
new file mode 100644
index 0000000..49c17c9
--- /dev/null
+++ b/src/model/optimizer/sgd.cc
@@ -0,0 +1,49 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SRC_MODEL_OPTIMIZER_SGD_H_
+#define SRC_MODEL_OPTIMIZER_SGD_H_
+#include "singa/model/optimizer.h"
+#include <functional>
+namespace singa {
+
+void SGD::Setup(const OptimizerConf& conf) {
+  if (conf.has_momentum()) {
+    float m = conf.momentum();
+    SetMomentumGenerator([m](int step) { return m; });
+  }
+}
+
+void SGD::Apply(int step, float lr, const string& name, Tensor* grad,
+                Tensor* value) {
+  (*grad) *= lr;
+  if (momentum_generator_) {
+    float mom = momentum_generator_(step);
+    if (mom != 0) {
+      if (history_gradient_.find(name) == history_gradient_.end())
+        history_gradient_[name].ResetLike(*value);
+      Tensor& history = history_gradient_[name];
+      history *= mom;
+      history += *grad;
+      (*value) -= history;
+      return;
+    }
+  }
+  (*value) -= *grad;
+}
+}  // namespace singa
+#endif  // SRC_MODEL_OPTIMIZER_SGD_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2dac3808/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index 66296d5..1b18703 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -52,7 +52,7 @@ message BlobProto {
 }
 
 message FillerConf {
-  // The filler type.
+  // The filler type, case insensitive
   optional string type = 1 [default = 'constant'];
   optional float value = 2 [default = 0]; // the value in constant filler
   optional float min = 3 [default = 0]; // the min value in uniform filler
@@ -72,6 +72,37 @@ message FillerConf {
   optional VarianceNorm variance_norm = 8 [default = FAN_IN];
 }
 
+/// SINGA message
+message OptimizerConf {
+  // case insensitive
+  optional string type = 1 [default = "sgd"];
+
+  // used by RMSprop and Adadelta
+  optional float rho = 2 [default = 0.001];
+
+  // used by Adam and AdamMax
+  optional float beta_1 = 3 [default = 0.9];
+  optional float beta_2 = 4 [default = 0.999];
+
+  // used by vanilla sgd and nesterov
+  optional float momentum = 5 [default = 0.9];
+}
+
+message ConstraintConf {
+  // case insensitive to limit the parameter value/gradient scale
+  optional string type = 1 [default = "l2"];
+  // e.g., the threshold for limiting the parameter scale.
+  optional float threshold = 2;
+}
+
+/// SINGA message
+message RegularizerConf {
+  // case insensitive to regularize the parameters, e.g., L2.
+  optional string type = 1 [default = "l2"];
+  // e.g., the weight decay for L2 regularizer
+  optional float coefficient = 2;
+}
+
 // Specifies training parameters (multipliers on global learning constants,
 // and the name and other settings used for weight sharing).
 message ParamSpec {
@@ -101,6 +132,8 @@ message ParamSpec {
   // SINGA uses this filed internally. Users just configure the fillers in
   // Layer specific conf message as caffe (style).
   optional FillerConf filler = 20;
+  optional ConstraintConf constraint = 21;
+  optional RegularizerConf regularizer = 22;
 }
 
 enum Phase {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2dac3808/test/singa/test_sgd.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_sgd.cc b/test/singa/test_sgd.cc
new file mode 100644
index 0000000..a660556
--- /dev/null
+++ b/test/singa/test_sgd.cc
@@ -0,0 +1,150 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "gtest/gtest.h"
+#include "singa/model/optimizer.h"
+#include "singa_config.h"
+
+TEST(SGD, ApplyWithoutMomentum) {
+  singa::SGD sgd;
+  const float v[4] = {0.1, 0.2, 0.3, 0.4};
+  const float g[4] = {0.1, 0.1, 0.1, 0.1};
+
+  singa::Tensor value(singa::Shape{4}), grad(singa::Shape{4});
+  value.CopyDataFromHostPtr(v, 4);
+  grad.CopyDataFromHostPtr(g, 4);
+
+  float lr = 0.1f;
+  sgd.Apply(0, lr, "xx", &grad, &value);
+
+  singa::Tensor v1 = value.Clone();
+  const float* newv1 = v1.data<const float*>();
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FLOAT_EQ(newv1[i], v[i] - g[i] * lr);
+  }
+
+
+  lr /= 2;
+  grad.CopyDataFromHostPtr(g, 4);
+  sgd.Apply(1, lr, "xx", &grad, &value);
+  singa::Tensor v2 = value.Clone();
+  const float* newv2 = v2.data<const float*>();
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FLOAT_EQ(newv2[i], newv1[i] - g[i] * lr);
+  }
+}
+
+
+TEST(SGD, ApplyWithMomentum) {
+  singa::SGD sgd;
+  float lr = 0.1f;
+  auto func = [](int step) { return step <=5 ? 0.5f: 0.9f;};
+  sgd.SetMomentumGenerator(func);
+  const float v[4] = {0.1, 0.2, 0.3, 0.4};
+  const float g[4] = {0.01, 0.02, 0.03, 0.04};
+
+  singa::Tensor value(singa::Shape{4}), grad(singa::Shape{4});
+  value.CopyDataFromHostPtr(v, 4);
+  grad.CopyDataFromHostPtr(g, 4);
+
+  sgd.Apply(0, lr, "xx", &grad, &value);
+
+  singa::Tensor v1 = value.Clone();
+  const float* newv1 = v1.data<const float*>();
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FLOAT_EQ(newv1[i], v[i] - g[i] * lr);
+  }
+
+  grad.CopyDataFromHostPtr(g, 4);
+  sgd.Apply(1, lr, "xx", &grad, &value);
+  singa::Tensor v2 = value.Clone();
+  const float* newv2 = v2.data<const float*>();
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FLOAT_EQ(newv2[i], newv1[i] - (g[i] * lr + g[i] * lr * func(1)));
+  }
+}
+
+#ifndef USE_CUDA
+TEST(SGD, ApplyWithoutMomentumCuda) {
+  singa::SGD sgd;
+  const float v[4] = {0.1, 0.2, 0.3, 0.4};
+  const float g[4] = {0.1, 0.1, 0.1, 0.1};
+
+  singa::CudaGPU dev;
+  singa::Tensor value(singa::Shape{4}, &dev), grad(singa::Shape{4}, &dev);
+  value.CopyDataFromHostPtr(v, 4);
+  grad.CopyDataFromHostPtr(g, 4);
+
+  float lr = 0.1f;
+  sgd.Apply(0, lr, "xx", &grad, &value);
+
+  singa::Tensor v1 = value.Clone();
+  v1.ToHost();
+  const float* newv1 = v1.data<const float*>();
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FLOAT_EQ(newv1[i], v[i] - g[i] * lr);
+  }
+
+
+  lr /= 2;
+  grad.CopyDataFromHostPtr(g, 4);
+  sgd.Apply(1, lr, "xx", &grad, &value);
+  singa::Tensor v2 = value.Clone();
+  v2.ToHost();
+  const float* newv2 = v2.data<const float*>();
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FLOAT_EQ(newv2[i], newv1[i] - g[i] * lr);
+  }
+}
+
+
+TEST(SGD, ApplyWithMomentumCuda) {
+  singa::SGD sgd;
+  float lr = 0.1f;
+  auto func = [](int step) { return step <=5 ? 0.5f: 0.9f;};
+  sgd.SetMomentumGenerator(func);
+  const float v[4] = {0.1, 0.2, 0.3, 0.4};
+  const float g[4] = {0.01, 0.02, 0.03, 0.04};
+
+  singa::CudaGPU dev;
+  singa::Tensor value(singa::Shape{4}, &dev), grad(singa::Shape{4}, &dev);
+  value.CopyDataFromHostPtr(v, 4);
+  grad.CopyDataFromHostPtr(g, 4);
+
+  sgd.Apply(0, lr, "xx", &grad, &value);
+
+  singa::Tensor v1 = value.Clone();
+  v1.ToHost();
+  const float* newv1 = v1.data<const float*>();
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FLOAT_EQ(newv1[i], v[i] - g[i] * lr);
+  }
+
+  grad.CopyDataFromHostPtr(g, 4);
+  sgd.Apply(1, lr, "xx", &grad, &value);
+  singa::Tensor v2 = value.Clone();
+  v2.ToHost();
+  const float* newv2 = v2.data<const float*>();
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FLOAT_EQ(newv2[i], newv1[i] - (g[i] * lr + g[i] * lr * func(1)));
+  }
+}
+#endif