You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/03 20:18:09 UTC

[incubator-mxnet] branch master updated: [cpp-package] add lr scheduler (#6885)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 76dee53  [cpp-package] add lr scheduler (#6885)
76dee53 is described below

commit 76dee53dc494b35ce7dd4ac88cdce817bf9aa1ce
Author: CNevd <CN...@users.noreply.github.com>
AuthorDate: Fri Aug 4 04:18:07 2017 +0800

    [cpp-package] add lr scheduler (#6885)
    
    * add lr scheduler
    
    * Update lr_scheduler.h
    
    * Update mlp_gpu.cpp
    
    * Update test_score.cpp
    
    * update optimizer.hpp
---
 cpp-package/example/alexnet.cpp               | 11 +++-
 cpp-package/example/charRNN.cpp               | 11 +++-
 cpp-package/example/googlenet.cpp             | 19 ++++---
 cpp-package/example/inception_bn.cpp          | 12 ++++-
 cpp-package/example/lenet.cpp                 | 17 ++++--
 cpp-package/example/lenet_with_mxdataiter.cpp | 20 ++++---
 cpp-package/example/mlp_cpu.cpp               | 20 ++++---
 cpp-package/example/mlp_gpu.cpp               | 43 ++++++++++-----
 cpp-package/example/resnet.cpp                | 11 +++-
 cpp-package/example/test_score.cpp            | 22 +++++---
 cpp-package/include/mxnet-cpp/executor.h      | 12 -----
 cpp-package/include/mxnet-cpp/executor.hpp    |  7 ---
 cpp-package/include/mxnet-cpp/lr_scheduler.h  | 78 +++++++++++++++++++++++++++
 cpp-package/include/mxnet-cpp/optimizer.h     | 22 ++++----
 cpp-package/include/mxnet-cpp/optimizer.hpp   | 42 +++++++++++----
 15 files changed, 254 insertions(+), 93 deletions(-)

diff --git a/cpp-package/example/alexnet.cpp b/cpp-package/example/alexnet.cpp
index c0d8273..6a9e01a 100644
--- a/cpp-package/example/alexnet.cpp
+++ b/cpp-package/example/alexnet.cpp
@@ -199,6 +199,7 @@ int main(int argc, char const *argv[]) {
 
   /*with data and label, executor can be generated automatically*/
   auto *exec = Net.SimpleBind(ctx, args_map);
+  auto arg_names = Net.ListArguments();
   aux_map = exec->aux_dict();
   args_map = exec->arg_dict();
 
@@ -240,7 +241,9 @@ int main(int argc, char const *argv[]) {
   Optimizer* opt = OptimizerRegistry::Find("ccsgd");
   opt->SetParam("momentum", 0.9)
      ->SetParam("rescale_grad", 1.0 / batch_size)
-     ->SetParam("clip_gradient", 10);
+     ->SetParam("clip_gradient", 10)
+     ->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay);
 
   Accuracy acu_train, acu_val;
   LogLoss logloss_val;
@@ -258,7 +261,11 @@ int main(int argc, char const *argv[]) {
       batch.label.CopyTo(&args_map["label"]);
       exec->Forward(true);
       exec->Backward();
-      exec->UpdateAll(opt, learning_rate, weight_decay);
+      for (size_t i = 0; i < arg_names.size(); ++i) {
+        if (arg_names[i] == "data" || arg_names[i] == "label") continue;
+        opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+      }
+
       NDArray::WaitAll();
       acu_train.Update(batch.label, exec->outputs[0]);
     }
diff --git a/cpp-package/example/charRNN.cpp b/cpp-package/example/charRNN.cpp
index 5cb6382..d95c97d 100644
--- a/cpp-package/example/charRNN.cpp
+++ b/cpp-package/example/charRNN.cpp
@@ -451,6 +451,8 @@ void train(const string file, int batch_size, int max_epoch, int start_epoch) {
   mx_float learning_rate = 0.0002;
   mx_float weight_decay = 0.000002;
   Optimizer* opt = OptimizerRegistry::Find("ccsgd");
+  opt->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay);
 //  opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
 //  ->SetParam("clip_gradient", 10);
 
@@ -470,7 +472,10 @@ void train(const string file, int batch_size, int max_epoch, int start_epoch) {
 
       exe->Forward(true);
       exe->Backward();
-      exe->UpdateAll(opt, learning_rate, weight_decay);
+      for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
+        opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
+      }
+
       NDArray::WaitAll();
     }
     auto toc = chrono::system_clock::now();
@@ -547,7 +552,9 @@ void trainWithBuiltInRNNOp(const string file, int batch_size, int max_epoch, int
 
       exe->Forward(true);
       exe->Backward();
-      exe->UpdateAll(opt, learning_rate, weight_decay);
+      for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
+        opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
+      }
       NDArray::WaitAll();
     }
     auto toc = chrono::system_clock::now();
diff --git a/cpp-package/example/googlenet.cpp b/cpp-package/example/googlenet.cpp
index a4dcbbd..2e59fbf 100644
--- a/cpp-package/example/googlenet.cpp
+++ b/cpp-package/example/googlenet.cpp
@@ -128,7 +128,13 @@ int main(int argc, char const *argv[]) {
   Optimizer* opt = OptimizerRegistry::Find("ccsgd");
   opt->SetParam("momentum", 0.9)
      ->SetParam("rescale_grad", 1.0 / batch_size)
-     ->SetParam("clip_gradient", 10);
+     ->SetParam("clip_gradient", 10)
+     ->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay);
+
+
+  auto *exec = googlenet.SimpleBind(Context::gpu(), args_map);
+  auto arg_names = googlenet.ListArguments();
 
   for (int iter = 0; iter < max_epoch; ++iter) {
     LG << "Epoch: " << iter;
@@ -138,11 +144,12 @@ int main(int argc, char const *argv[]) {
       args_map["data"] = data_batch.data.Copy(Context::gpu());
       args_map["data_label"] = data_batch.label.Copy(Context::gpu());
       NDArray::WaitAll();
-      auto *exec = googlenet.SimpleBind(Context::gpu(), args_map);
       exec->Forward(true);
       exec->Backward();
-      exec->UpdateAll(opt, learning_rate, weight_decay);
-      delete exec;
+      for (size_t i = 0; i < arg_names.size(); ++i) {
+        if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+        opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+      }
     }
 
     Accuracy acu;
@@ -152,14 +159,14 @@ int main(int argc, char const *argv[]) {
       args_map["data"] = data_batch.data.Copy(Context::gpu());
       args_map["data_label"] = data_batch.label.Copy(Context::gpu());
       NDArray::WaitAll();
-      auto *exec = googlenet.SimpleBind(Context::gpu(), args_map);
       exec->Forward(false);
       NDArray::WaitAll();
       acu.Update(data_batch.label, exec->outputs[0]);
-      delete exec;
     }
     LG << "Accuracy: " << acu.Get();
   }
+
+  delete exec;
   MXNotifyShutdown();
   return 0;
 }
diff --git a/cpp-package/example/inception_bn.cpp b/cpp-package/example/inception_bn.cpp
index 5db4f81..4442e00 100644
--- a/cpp-package/example/inception_bn.cpp
+++ b/cpp-package/example/inception_bn.cpp
@@ -156,9 +156,12 @@ int main(int argc, char const *argv[]) {
   Optimizer* opt = OptimizerRegistry::Find("ccsgd");
   opt->SetParam("momentum", 0.9)
      ->SetParam("rescale_grad", 1.0 / batch_size)
-     ->SetParam("clip_gradient", 10);
+     ->SetParam("clip_gradient", 10)
+     ->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay);
 
   auto *exec = inception_bn_net.SimpleBind(Context::gpu(), args_map);
+  auto arg_names = inception_bn_net.ListArguments();
 
   for (int iter = 0; iter < max_epoch; ++iter) {
     LG << "Epoch: " << iter;
@@ -171,7 +174,12 @@ int main(int argc, char const *argv[]) {
 
       exec->Forward(true);
       exec->Backward();
-      exec->UpdateAll(opt, learning_rate, weight_decay);
+      // Update parameters
+      for (size_t i = 0; i < arg_names.size(); ++i) {
+        if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+        opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+      }
+
       NDArray::WaitAll();
     }
 
diff --git a/cpp-package/example/lenet.cpp b/cpp-package/example/lenet.cpp
index 91b83a0..56f8d2c 100644
--- a/cpp-package/example/lenet.cpp
+++ b/cpp-package/example/lenet.cpp
@@ -118,7 +118,12 @@ class Lenet {
     Optimizer* opt = OptimizerRegistry::Find("ccsgd");
     opt->SetParam("momentum", 0.9)
        ->SetParam("rescale_grad", 1.0)
-       ->SetParam("clip_gradient", 10);
+       ->SetParam("clip_gradient", 10)
+       ->SetParam("lr", learning_rate)
+       ->SetParam("wd", weight_decay);
+
+    Executor *exe = lenet.SimpleBind(ctx_dev, args_map);
+    auto arg_names = lenet.ListArguments();
 
     for (int ITER = 0; ITER < max_epoch; ++ITER) {
       size_t start_index = 0;
@@ -135,17 +140,19 @@ class Lenet {
         start_index += batch_size;
         NDArray::WaitAll();
 
-        Executor *exe = lenet.SimpleBind(ctx_dev, args_map);
         exe->Forward(true);
         exe->Backward();
-        exe->UpdateAll(opt, learning_rate, weight_decay);
-
-        delete exe;
+        // Update parameters
+        for (size_t i = 0; i < arg_names.size(); ++i) {
+          if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+          opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
+        }
       }
 
       LG << "Iter " << ITER
          << ", accuracy: " << ValAccuracy(batch_size * 10, lenet);
     }
+    delete exe;
   }
 
  private:
diff --git a/cpp-package/example/lenet_with_mxdataiter.cpp b/cpp-package/example/lenet_with_mxdataiter.cpp
index 85a4b20..f6301b5 100644
--- a/cpp-package/example/lenet_with_mxdataiter.cpp
+++ b/cpp-package/example/lenet_with_mxdataiter.cpp
@@ -85,7 +85,13 @@ int main(int argc, char const *argv[]) {
   Optimizer* opt = OptimizerRegistry::Find("ccsgd");
   opt->SetParam("momentum", 0.9)
      ->SetParam("rescale_grad", 1.0)
-     ->SetParam("clip_gradient", 10);
+     ->SetParam("clip_gradient", 10)
+     ->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay);
+
+
+  auto *exec = lenet.SimpleBind(Context::gpu(), args_map);
+  auto arg_names = lenet.ListArguments();
 
   for (int iter = 0; iter < max_epoch; ++iter) {
     LG << "Epoch: " << iter;
@@ -95,11 +101,13 @@ int main(int argc, char const *argv[]) {
       args_map["data"] = data_batch.data.Copy(Context::gpu());
       args_map["data_label"] = data_batch.label.Copy(Context::gpu());
       NDArray::WaitAll();
-      auto *exec = lenet.SimpleBind(Context::gpu(), args_map);
       exec->Forward(true);
       exec->Backward();
-      exec->UpdateAll(opt, learning_rate, weight_decay);
-      delete exec;
+      // Update parameters
+      for (size_t i = 0; i < arg_names.size(); ++i) {
+        if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+        opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+      }
     }
 
     Accuracy acu;
@@ -109,14 +117,14 @@ int main(int argc, char const *argv[]) {
       args_map["data"] = data_batch.data.Copy(Context::gpu());
       args_map["data_label"] = data_batch.label.Copy(Context::gpu());
       NDArray::WaitAll();
-      auto *exec = lenet.SimpleBind(Context::gpu(), args_map);
       exec->Forward(false);
       NDArray::WaitAll();
       acu.Update(data_batch.label, exec->outputs[0]);
-      delete exec;
     }
     LG << "Accuracy: " << acu.Get();
   }
+
+  delete exec;
   MXNotifyShutdown();
   return 0;
 }
diff --git a/cpp-package/example/mlp_cpu.cpp b/cpp-package/example/mlp_cpu.cpp
index 6948649..358e834 100644
--- a/cpp-package/example/mlp_cpu.cpp
+++ b/cpp-package/example/mlp_cpu.cpp
@@ -70,7 +70,13 @@ int main(int argc, char** argv) {
 
   // Create sgd optimizer
   Optimizer* opt = OptimizerRegistry::Find("sgd");
-  opt->SetParam("rescale_grad", 1.0/batch_size);
+  opt->SetParam("rescale_grad", 1.0/batch_size)
+     ->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay);
+
+  // Create executor by binding parameters to the model
+  auto *exec = net.SimpleBind(ctx, args);
+  auto arg_names = net.ListArguments();
 
   // Start training
   for (int iter = 0; iter < max_epoch; ++iter) {
@@ -85,15 +91,14 @@ int main(int argc, char** argv) {
       args["X"] = data_batch.data;
       args["label"] = data_batch.label;
 
-      // Create executor by binding parameters to the model
-      auto *exec = net.SimpleBind(ctx, args);
       // Compute gradients
       exec->Forward(true);
       exec->Backward();
       // Update parameters
-      exec->UpdateAll(opt, learning_rate, weight_decay);
-      // Remember to free the memory
-      delete exec;
+      for (size_t i = 0; i < arg_names.size(); ++i) {
+        if (arg_names[i] == "X" || arg_names[i] == "label") continue;
+        opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+      }
     }
     auto toc = chrono::system_clock::now();
 
@@ -103,16 +108,15 @@ int main(int argc, char** argv) {
       auto data_batch = val_iter.GetDataBatch();
       args["X"] = data_batch.data;
       args["label"] = data_batch.label;
-      auto *exec = net.SimpleBind(ctx, args);
       // Forward pass is enough as no gradient is needed when evaluating
       exec->Forward(false);
       acc.Update(data_batch.label, exec->outputs[0]);
-      delete exec;
     }
     float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
     LG << "Epoch: " << iter << " " << samples/duration << " samples/sec Accuracy: " << acc.Get();
   }
 
+  delete exec;
   MXNotifyShutdown();
   return 0;
 }
diff --git a/cpp-package/example/mlp_gpu.cpp b/cpp-package/example/mlp_gpu.cpp
index 23be637..a6281c3 100644
--- a/cpp-package/example/mlp_gpu.cpp
+++ b/cpp-package/example/mlp_gpu.cpp
@@ -24,7 +24,7 @@ Symbol mlp(const vector<int> &layers) {
       weights[i],
       biases[i],
       layers[i]);
-    outputs[i] = i == layers.size()-1? fc : Activation(fc, ActivationActType::kRelu);
+    outputs[i] = i == layers.size()-1 ? fc : Activation(fc, ActivationActType::kRelu);
   }
 
   return SoftmaxOutput(outputs.back(), label);
@@ -70,12 +70,24 @@ int main(int argc, char** argv) {
 
   // Create sgd optimizer
   Optimizer* opt = OptimizerRegistry::Find("sgd");
-  opt->SetParam("rescale_grad", 1.0/batch_size);
+  opt->SetParam("rescale_grad", 1.0/batch_size)
+     ->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay);
+  std::unique_ptr<LRScheduler> lr_sch(new FactorScheduler(5000, 0.1));
+  opt->SetLRScheduler(std::move(lr_sch));
+
+  // Create executor by binding parameters to the model
+  auto *exec = net.SimpleBind(ctx, args);
+  auto arg_names = net.ListArguments();
+
+  // Create metrics
+  Accuracy train_acc, val_acc;
 
   // Start training
   for (int iter = 0; iter < max_epoch; ++iter) {
     int samples = 0;
     train_iter.Reset();
+    train_acc.Reset();
 
     auto tic = chrono::system_clock::now();
     while (train_iter.Next()) {
@@ -87,35 +99,40 @@ int main(int argc, char** argv) {
       // CopyTo is imperative, need to wait for it to complete.
       NDArray::WaitAll();
 
-      // Create executor by binding parameters to the model
-      auto *exec = net.SimpleBind(ctx, args);
       // Compute gradients
       exec->Forward(true);
       exec->Backward();
+
       // Update parameters
-      exec->UpdateAll(opt, learning_rate, weight_decay);
-      // Remember to free the memory
-      delete exec;
+      for (size_t i = 0; i < arg_names.size(); ++i) {
+        if (arg_names[i] == "X" || arg_names[i] == "label") continue;
+        opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+      }
+      // Update metric
+      train_acc.Update(data_batch.label, exec->outputs[0]);
     }
+    // one epoch of training is finished
     auto toc = chrono::system_clock::now();
+    float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
+    LG << "Epoch[" << iter << "] " << samples/duration \
+       << " samples/sec " << "Train-Accuracy=" << train_acc.Get();;
 
-    Accuracy acc;
     val_iter.Reset();
+    val_acc.Reset();
     while (val_iter.Next()) {
       auto data_batch = val_iter.GetDataBatch();
       data_batch.data.CopyTo(&args["X"]);
       data_batch.label.CopyTo(&args["label"]);
       NDArray::WaitAll();
-      auto *exec = net.SimpleBind(ctx, args);
+
       // Only forward pass is enough as no gradient is needed when evaluating
       exec->Forward(false);
-      acc.Update(data_batch.label, exec->outputs[0]);
-      delete exec;
+      val_acc.Update(data_batch.label, exec->outputs[0]);
     }
-    float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
-    LG << "Epoch: " << iter << " " << samples/duration << " samples/sec Accuracy: " << acc.Get();
+    LG << "Epoch[" << iter << "] Val-Accuracy=" << val_acc.Get();
   }
 
+  delete exec;
   MXNotifyShutdown();
   return 0;
 }
diff --git a/cpp-package/example/resnet.cpp b/cpp-package/example/resnet.cpp
index 5521567..b9766c7 100644
--- a/cpp-package/example/resnet.cpp
+++ b/cpp-package/example/resnet.cpp
@@ -165,11 +165,14 @@ int main(int argc, char const *argv[]) {
       .CreateDataIter();
 
   Optimizer* opt = OptimizerRegistry::Find("ccsgd");
-  opt->SetParam("momentum", 0.9)
+  opt->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay)
+     ->SetParam("momentum", 0.9)
      ->SetParam("rescale_grad", 1.0 / batch_size)
      ->SetParam("clip_gradient", 10);
 
   auto *exec = resnet.SimpleBind(Context::gpu(), args_map);
+  auto arg_names = resnet.ListArguments();
 
   for (int iter = 0; iter < max_epoch; ++iter) {
     LG << "Epoch: " << iter;
@@ -182,7 +185,11 @@ int main(int argc, char const *argv[]) {
 
       exec->Forward(true);
       exec->Backward();
-      exec->UpdateAll(opt, learning_rate, weight_decay);
+
+      for (size_t i = 0; i < arg_names.size(); ++i) {
+        if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+        opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+      }
       NDArray::WaitAll();
     }
 
diff --git a/cpp-package/example/test_score.cpp b/cpp-package/example/test_score.cpp
index 7dccd30..3534269 100644
--- a/cpp-package/example/test_score.cpp
+++ b/cpp-package/example/test_score.cpp
@@ -72,7 +72,15 @@ int main(int argc, char** argv) {
 
   // Create sgd optimizer
   Optimizer* opt = OptimizerRegistry::Find("sgd");
-  opt->SetParam("rescale_grad", 1.0/batch_size);
+  opt->SetParam("rescale_grad", 1.0/batch_size)
+     ->SetParam("lr", learning_rate)
+     ->SetParam("wd", weight_decay);
+  std::unique_ptr<LRScheduler> lr_sch(new FactorScheduler(5000, 0.1));
+  opt->SetLRScheduler(std::move(lr_sch));
+
+  // Create executor by binding parameters to the model
+  auto *exec = net.SimpleBind(ctx, args);
+  auto arg_names = net.ListArguments();
 
   float score = 0;
   // Start training
@@ -90,15 +98,14 @@ int main(int argc, char** argv) {
       // CopyTo is imperative, need to wait for it to complete.
       NDArray::WaitAll();
 
-      // Create executor by binding parameters to the model
-      auto *exec = net.SimpleBind(ctx, args);
       // Compute gradients
       exec->Forward(true);
       exec->Backward();
       // Update parameters
-      exec->UpdateAll(opt, learning_rate, weight_decay);
-      // Remember to free the memory
-      delete exec;
+      for (size_t i = 0; i < arg_names.size(); ++i) {
+        if (arg_names[i] == "X" || arg_names[i] == "label") continue;
+        opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+      }
     }
     auto toc = chrono::system_clock::now();
 
@@ -109,17 +116,16 @@ int main(int argc, char** argv) {
       data_batch.data.CopyTo(&args["X"]);
       data_batch.label.CopyTo(&args["label"]);
       NDArray::WaitAll();
-      auto *exec = net.SimpleBind(ctx, args);
       // Only forward pass is enough as no gradient is needed when evaluating
       exec->Forward(false);
       acc.Update(data_batch.label, exec->outputs[0]);
-      delete exec;
     }
     float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
     LG << "Epoch: " << iter << " " << samples/duration << " samples/sec Accuracy: " << acc.Get();
     score = acc.Get();
   }
 
+  delete exec;
   MXNotifyShutdown();
   return score >= MIN_SCORE ? 0 : 1;
 }
diff --git a/cpp-package/include/mxnet-cpp/executor.h b/cpp-package/include/mxnet-cpp/executor.h
index 822344b..67eec01 100644
--- a/cpp-package/include/mxnet-cpp/executor.h
+++ b/cpp-package/include/mxnet-cpp/executor.h
@@ -79,18 +79,6 @@ class Executor {
   */
   std::string DebugStr();
   /*!
-  * \brief update the arguments with given learning rate and optimizer
-  * \param opt the pointer to the optimizer
-  * \param lr learning rate
-  * \param wd weight decay
-  * \param arg_update_begin begin index of the arguments to be updated, it
-  * starts after the input data by default
-  * \param arg_update_end end index of the arguments to be updated, it ends
-  * before the label data by default
-  */
-  void UpdateAll(Optimizer *opt, float lr, float wd, int arg_update_begin = 1,
-                 int arg_update_end = -1);
-  /*!
   * \brief destructor, free the handle
   */
   ~Executor() { MXExecutorFree(handle_); }
diff --git a/cpp-package/include/mxnet-cpp/executor.hpp b/cpp-package/include/mxnet-cpp/executor.hpp
index 1a452a1..6887956 100644
--- a/cpp-package/include/mxnet-cpp/executor.hpp
+++ b/cpp-package/include/mxnet-cpp/executor.hpp
@@ -79,13 +79,6 @@ inline std::string Executor::DebugStr() {
   return std::string(output);
 }
 
-inline void Executor::UpdateAll(Optimizer *opt, float lr, float wd,
-                                int arg_update_begin, int arg_update_end) {
-  arg_update_end = arg_update_end < 0 ? arg_arrays.size() - 1 : arg_update_end;
-  for (int i = arg_update_begin; i < arg_update_end; ++i) {
-    opt->Update(i, arg_arrays[i], grad_arrays[i], lr, wd);
-  }
-}
 }  // namespace cpp
 }  // namespace mxnet
 
diff --git a/cpp-package/include/mxnet-cpp/lr_scheduler.h b/cpp-package/include/mxnet-cpp/lr_scheduler.h
new file mode 100644
index 0000000..91f9b3c
--- /dev/null
+++ b/cpp-package/include/mxnet-cpp/lr_scheduler.h
@@ -0,0 +1,78 @@
+/*!
+*  Copyright (c) 2017 by Contributors
+* \file lr_scheduler.h
+* \brief Scheduling learning rate
+*/
+
+#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_LR_SCHEDULER_H_
+#define CPP_PACKAGE_INCLUDE_MXNET_CPP_LR_SCHEDULER_H_
+
+#include "dmlc/logging.h"
+
+namespace mxnet {
+namespace cpp {
+
+/*!
+* \brief lr scheduler interface
+*/
+class LRScheduler {
+ public:
+  /*!
+  * \brief constructor
+  * \param base_lr the initial learning rate.
+  */
+  explicit LRScheduler(float base_lr = 0.01)
+      : base_lr_(base_lr) {}
+  /*!
+  * \brief set base lr
+  * \param lr learning rate from optimizer
+  */
+  void SetLR(const float lr) { base_lr_ = lr; }
+  /*!
+  * \brief get a new learning rate
+  */
+  virtual float GetLR(unsigned num_update) = 0;
+  /*!
+  * \brief destructor
+  */
+  virtual ~LRScheduler() {}
+
+ protected:
+  float base_lr_;
+};
+
+class FactorScheduler : public LRScheduler {
+ public:
+  explicit FactorScheduler(int step, float factor = 1, float stop_factor_lr = 1e-8)
+      : LRScheduler() {
+    step_ = step;
+    factor_ = factor;
+    stop_factor_lr_ = stop_factor_lr;
+  }
+
+  float GetLR(unsigned num_update) override {
+    while (num_update > unsigned(count_ + step_)) {
+      count_ += step_;
+      base_lr_ *= factor_;
+      if (base_lr_ < stop_factor_lr_) {
+        base_lr_ = stop_factor_lr_;
+        LG << "Update[" << num_update << "]: now learning rate arrived at " \
+           << base_lr_ << ", will not change in the future";
+      } else {
+        LG << "Update[" << num_update << "]: Change learning rate to " << base_lr_;
+      }
+    }
+    return base_lr_;
+  }
+
+ private:
+  int count_ = 0;
+  int step_;
+  float factor_;
+  float stop_factor_lr_;
+};
+
+}  // namespace cpp
+}  // namespace mxnet
+
+#endif  // CPP_PACKAGE_INCLUDE_MXNET_CPP_LR_SCHEDULER_H_
diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h
index 76f8a35..1bc36d5 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.h
+++ b/cpp-package/include/mxnet-cpp/optimizer.h
@@ -17,6 +17,7 @@
 #include "dmlc/logging.h"
 #include "mxnet-cpp/ndarray.h"
 #include "mxnet-cpp/op_map.h"
+#include "mxnet-cpp/lr_scheduler.h"
 
 namespace mxnet {
 namespace cpp {
@@ -57,15 +58,16 @@ class Optimizer {
     return this;
   }
   /*!
-  *  \brief Update a weight with gradient.
-  *  \param index the unique index for the weight.
-  *  \param weight the weight to update.
-  *  \param grad gradient for the weight.
-  *  \param lr learning rate.
-  *  \param wd weight decay.
+  * \bried set the lr scheduler
+  * \param lrScheduler lr scheduler used for this optimizer
+  * \return reference if self
   */
-  void Update(int index, NDArray weight, NDArray grad, mx_float lr,
-              mx_float wd);
+  Optimizer *SetLRScheduler(std::unique_ptr<LRScheduler> lrScheduler) {
+    CHECK(lrScheduler);
+    lrScheduler_ = std::move(lrScheduler);
+    lrScheduler_->SetLR(std::stof(params_["lr"]));
+    return this;
+  }
   /*!
   *  \brief Update a weight with gradient.
   *  \param index the unique index for the weight.
@@ -92,7 +94,10 @@ class Optimizer {
   std::map<int, unsigned> count_;
   unsigned begin_num_update_, num_update_;
   unsigned UpdateCount_(int index);
+  float GetLR_(int index);
+  float GetWD_(int index);
   virtual void CreateState_(int index, NDArray weight);
+  std::unique_ptr<LRScheduler> lrScheduler_ = nullptr;
 };
 
 typedef std::function<Optimizer*()> OptimizerCreator;
@@ -172,7 +177,6 @@ class AdaDeltaOptimizer : public Optimizer {
   std::map<int, NDArray*> acc_g_, acc_delta_;
 };
 
-
 }  // namespace cpp
 }  // namespace mxnet
 
diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp
index 9dcb158..0d6a7be 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.hpp
+++ b/cpp-package/include/mxnet-cpp/optimizer.hpp
@@ -42,6 +42,8 @@ namespace cpp {
 inline Optimizer::Optimizer(unsigned begin_num_update)
   : begin_num_update_(begin_num_update),
     num_update_(begin_num_update_) {
+  params_["lr"] = "0.01f";
+  params_["wd"] = "0.f";
 }
 
 inline std::map<std::string, OptimizerCreator>& OptimizerRegistry::cmap() {
@@ -56,14 +58,6 @@ inline OpMap*& Optimizer::op_map() {
 
 inline Optimizer::~Optimizer() {}
 
-inline void Optimizer::Update(int index, NDArray weight, NDArray grad, mx_float lr,
-                       mx_float wd) {
-  params_["lr"] = std::to_string(lr);
-  params_["wd"] = std::to_string(wd);
-  UpdateCount_(index);
-  Update(index, weight, grad);
-}
-
 inline void Optimizer::CreateState_(int index, NDArray weight) {
 }
 
@@ -100,6 +94,18 @@ inline unsigned Optimizer::UpdateCount_(int index) {
   return new_count;
 }
 
+inline float Optimizer::GetLR_(int index) {
+  if (nullptr != lrScheduler_) {
+    return lrScheduler_->GetLR(num_update_);
+  }
+  return std::stof(params_["lr"]);
+}
+
+inline float Optimizer::GetWD_(int index) {
+  float wd = std::stof(params_["wd"]);
+  return wd;
+}
+
 inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
   MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer);
   MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer);  // For backward compatibility
@@ -140,6 +146,9 @@ inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) {
     CreateState_(index, weight);
   }
 
+  params_["lr"] = std::to_string(GetLR_(index));
+  params_["wd"] = std::to_string(GetWD_(index));
+  UpdateCount_(index);
   auto keys = GetParamKeys_();
   auto values = GetParamValues_();
   CHECK_EQ(keys.size(), values.size());
@@ -203,6 +212,9 @@ inline void RMSPropOptimizer::Update(int index, NDArray weight, NDArray grad) {
     CreateState_(index, weight);
   }
 
+  params_["lr"] = std::to_string(GetLR_(index));
+  params_["wd"] = std::to_string(GetWD_(index));
+  UpdateCount_(index);
   auto keys = GetParamKeys_();
   auto values = GetParamValues_();
   CHECK_EQ(keys.size(), values.size());
@@ -257,6 +269,10 @@ inline void AdamOptimizer::Update(int index, NDArray weight, NDArray grad) {
   if (mean_.count(index) == 0) {
     CreateState_(index, weight);
   }
+
+  params_["lr"] = std::to_string(GetLR_(index));
+  params_["wd"] = std::to_string(GetWD_(index));
+  UpdateCount_(index);
   auto keys = GetParamKeys_();
   auto values = GetParamValues_();
   CHECK_EQ(keys.size(), values.size());
@@ -306,9 +322,11 @@ inline void AdaGradOptimizer::Update(int index, NDArray weight, NDArray grad) {
   if (history_.count(index) == 0) {
     CreateState_(index, weight);
   }
-  float lr = std::stof(params_["lr"]);
-  float wd = std::stof(params_["wd"]);
+
   float eps = std::stof(params_["eps"]);
+  float lr = GetLR_(index);
+  float wd = GetWD_(index);
+  UpdateCount_(index);
   if (params_.count("rescale_grad") > 0) {
     grad *= std::stof(params_["rescale_grad"]);
   }
@@ -345,9 +363,11 @@ inline void AdaDeltaOptimizer::Update(int index, NDArray weight, NDArray grad) {
   if (acc_g_.count(index) == 0) {
     CreateState_(index, weight);
   }
-  float wd = std::stof(params_["wd"]);
+
   float rho = std::stof(params_["rho"]);
   float epsilon = std::stof(params_["epsilon"]);
+  float wd = GetWD_(index);
+  UpdateCount_(index);
 
   if (params_.count("rescale_grad") > 0) {
     grad *= std::stof(params_["rescale_grad"]);

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].