You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/03 20:18:09 UTC
[incubator-mxnet] branch master updated: [cpp-package] add lr
scheduler (#6885)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 76dee53 [cpp-package] add lr scheduler (#6885)
76dee53 is described below
commit 76dee53dc494b35ce7dd4ac88cdce817bf9aa1ce
Author: CNevd <CN...@users.noreply.github.com>
AuthorDate: Fri Aug 4 04:18:07 2017 +0800
[cpp-package] add lr scheduler (#6885)
* add lr scheduler
* Update lr_scheduler.h
* Update mlp_gpu.cpp
* Update test_score.cpp
* update optimizer.hpp
---
cpp-package/example/alexnet.cpp | 11 +++-
cpp-package/example/charRNN.cpp | 11 +++-
cpp-package/example/googlenet.cpp | 19 ++++---
cpp-package/example/inception_bn.cpp | 12 ++++-
cpp-package/example/lenet.cpp | 17 ++++--
cpp-package/example/lenet_with_mxdataiter.cpp | 20 ++++---
cpp-package/example/mlp_cpu.cpp | 20 ++++---
cpp-package/example/mlp_gpu.cpp | 43 ++++++++++-----
cpp-package/example/resnet.cpp | 11 +++-
cpp-package/example/test_score.cpp | 22 +++++---
cpp-package/include/mxnet-cpp/executor.h | 12 -----
cpp-package/include/mxnet-cpp/executor.hpp | 7 ---
cpp-package/include/mxnet-cpp/lr_scheduler.h | 78 +++++++++++++++++++++++++++
cpp-package/include/mxnet-cpp/optimizer.h | 22 ++++----
cpp-package/include/mxnet-cpp/optimizer.hpp | 42 +++++++++++----
15 files changed, 254 insertions(+), 93 deletions(-)
diff --git a/cpp-package/example/alexnet.cpp b/cpp-package/example/alexnet.cpp
index c0d8273..6a9e01a 100644
--- a/cpp-package/example/alexnet.cpp
+++ b/cpp-package/example/alexnet.cpp
@@ -199,6 +199,7 @@ int main(int argc, char const *argv[]) {
/*with data and label, executor can be generated automatically*/
auto *exec = Net.SimpleBind(ctx, args_map);
+ auto arg_names = Net.ListArguments();
aux_map = exec->aux_dict();
args_map = exec->arg_dict();
@@ -240,7 +241,9 @@ int main(int argc, char const *argv[]) {
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0 / batch_size)
- ->SetParam("clip_gradient", 10);
+ ->SetParam("clip_gradient", 10)
+ ->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
Accuracy acu_train, acu_val;
LogLoss logloss_val;
@@ -258,7 +261,11 @@ int main(int argc, char const *argv[]) {
batch.label.CopyTo(&args_map["label"]);
exec->Forward(true);
exec->Backward();
- exec->UpdateAll(opt, learning_rate, weight_decay);
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "data" || arg_names[i] == "label") continue;
+ opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+ }
+
NDArray::WaitAll();
acu_train.Update(batch.label, exec->outputs[0]);
}
diff --git a/cpp-package/example/charRNN.cpp b/cpp-package/example/charRNN.cpp
index 5cb6382..d95c97d 100644
--- a/cpp-package/example/charRNN.cpp
+++ b/cpp-package/example/charRNN.cpp
@@ -451,6 +451,8 @@ void train(const string file, int batch_size, int max_epoch, int start_epoch) {
mx_float learning_rate = 0.0002;
mx_float weight_decay = 0.000002;
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
+ opt->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
// opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
// ->SetParam("clip_gradient", 10);
@@ -470,7 +472,10 @@ void train(const string file, int batch_size, int max_epoch, int start_epoch) {
exe->Forward(true);
exe->Backward();
- exe->UpdateAll(opt, learning_rate, weight_decay);
+ for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
+ opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
+ }
+
NDArray::WaitAll();
}
auto toc = chrono::system_clock::now();
@@ -547,7 +552,9 @@ void trainWithBuiltInRNNOp(const string file, int batch_size, int max_epoch, int
exe->Forward(true);
exe->Backward();
- exe->UpdateAll(opt, learning_rate, weight_decay);
+ for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
+ opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
+ }
NDArray::WaitAll();
}
auto toc = chrono::system_clock::now();
diff --git a/cpp-package/example/googlenet.cpp b/cpp-package/example/googlenet.cpp
index a4dcbbd..2e59fbf 100644
--- a/cpp-package/example/googlenet.cpp
+++ b/cpp-package/example/googlenet.cpp
@@ -128,7 +128,13 @@ int main(int argc, char const *argv[]) {
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0 / batch_size)
- ->SetParam("clip_gradient", 10);
+ ->SetParam("clip_gradient", 10)
+ ->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
+
+
+ auto *exec = googlenet.SimpleBind(Context::gpu(), args_map);
+ auto arg_names = googlenet.ListArguments();
for (int iter = 0; iter < max_epoch; ++iter) {
LG << "Epoch: " << iter;
@@ -138,11 +144,12 @@ int main(int argc, char const *argv[]) {
args_map["data"] = data_batch.data.Copy(Context::gpu());
args_map["data_label"] = data_batch.label.Copy(Context::gpu());
NDArray::WaitAll();
- auto *exec = googlenet.SimpleBind(Context::gpu(), args_map);
exec->Forward(true);
exec->Backward();
- exec->UpdateAll(opt, learning_rate, weight_decay);
- delete exec;
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+ opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+ }
}
Accuracy acu;
@@ -152,14 +159,14 @@ int main(int argc, char const *argv[]) {
args_map["data"] = data_batch.data.Copy(Context::gpu());
args_map["data_label"] = data_batch.label.Copy(Context::gpu());
NDArray::WaitAll();
- auto *exec = googlenet.SimpleBind(Context::gpu(), args_map);
exec->Forward(false);
NDArray::WaitAll();
acu.Update(data_batch.label, exec->outputs[0]);
- delete exec;
}
LG << "Accuracy: " << acu.Get();
}
+
+ delete exec;
MXNotifyShutdown();
return 0;
}
diff --git a/cpp-package/example/inception_bn.cpp b/cpp-package/example/inception_bn.cpp
index 5db4f81..4442e00 100644
--- a/cpp-package/example/inception_bn.cpp
+++ b/cpp-package/example/inception_bn.cpp
@@ -156,9 +156,12 @@ int main(int argc, char const *argv[]) {
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0 / batch_size)
- ->SetParam("clip_gradient", 10);
+ ->SetParam("clip_gradient", 10)
+ ->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
auto *exec = inception_bn_net.SimpleBind(Context::gpu(), args_map);
+ auto arg_names = inception_bn_net.ListArguments();
for (int iter = 0; iter < max_epoch; ++iter) {
LG << "Epoch: " << iter;
@@ -171,7 +174,12 @@ int main(int argc, char const *argv[]) {
exec->Forward(true);
exec->Backward();
- exec->UpdateAll(opt, learning_rate, weight_decay);
+ // Update parameters
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+ opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+ }
+
NDArray::WaitAll();
}
diff --git a/cpp-package/example/lenet.cpp b/cpp-package/example/lenet.cpp
index 91b83a0..56f8d2c 100644
--- a/cpp-package/example/lenet.cpp
+++ b/cpp-package/example/lenet.cpp
@@ -118,7 +118,12 @@ class Lenet {
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0)
- ->SetParam("clip_gradient", 10);
+ ->SetParam("clip_gradient", 10)
+ ->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
+
+ Executor *exe = lenet.SimpleBind(ctx_dev, args_map);
+ auto arg_names = lenet.ListArguments();
for (int ITER = 0; ITER < max_epoch; ++ITER) {
size_t start_index = 0;
@@ -135,17 +140,19 @@ class Lenet {
start_index += batch_size;
NDArray::WaitAll();
- Executor *exe = lenet.SimpleBind(ctx_dev, args_map);
exe->Forward(true);
exe->Backward();
- exe->UpdateAll(opt, learning_rate, weight_decay);
-
- delete exe;
+ // Update parameters
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+ opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
+ }
}
LG << "Iter " << ITER
<< ", accuracy: " << ValAccuracy(batch_size * 10, lenet);
}
+ delete exe;
}
private:
diff --git a/cpp-package/example/lenet_with_mxdataiter.cpp b/cpp-package/example/lenet_with_mxdataiter.cpp
index 85a4b20..f6301b5 100644
--- a/cpp-package/example/lenet_with_mxdataiter.cpp
+++ b/cpp-package/example/lenet_with_mxdataiter.cpp
@@ -85,7 +85,13 @@ int main(int argc, char const *argv[]) {
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0)
- ->SetParam("clip_gradient", 10);
+ ->SetParam("clip_gradient", 10)
+ ->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
+
+
+ auto *exec = lenet.SimpleBind(Context::gpu(), args_map);
+ auto arg_names = lenet.ListArguments();
for (int iter = 0; iter < max_epoch; ++iter) {
LG << "Epoch: " << iter;
@@ -95,11 +101,13 @@ int main(int argc, char const *argv[]) {
args_map["data"] = data_batch.data.Copy(Context::gpu());
args_map["data_label"] = data_batch.label.Copy(Context::gpu());
NDArray::WaitAll();
- auto *exec = lenet.SimpleBind(Context::gpu(), args_map);
exec->Forward(true);
exec->Backward();
- exec->UpdateAll(opt, learning_rate, weight_decay);
- delete exec;
+ // Update parameters
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+ opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+ }
}
Accuracy acu;
@@ -109,14 +117,14 @@ int main(int argc, char const *argv[]) {
args_map["data"] = data_batch.data.Copy(Context::gpu());
args_map["data_label"] = data_batch.label.Copy(Context::gpu());
NDArray::WaitAll();
- auto *exec = lenet.SimpleBind(Context::gpu(), args_map);
exec->Forward(false);
NDArray::WaitAll();
acu.Update(data_batch.label, exec->outputs[0]);
- delete exec;
}
LG << "Accuracy: " << acu.Get();
}
+
+ delete exec;
MXNotifyShutdown();
return 0;
}
diff --git a/cpp-package/example/mlp_cpu.cpp b/cpp-package/example/mlp_cpu.cpp
index 6948649..358e834 100644
--- a/cpp-package/example/mlp_cpu.cpp
+++ b/cpp-package/example/mlp_cpu.cpp
@@ -70,7 +70,13 @@ int main(int argc, char** argv) {
// Create sgd optimizer
Optimizer* opt = OptimizerRegistry::Find("sgd");
- opt->SetParam("rescale_grad", 1.0/batch_size);
+ opt->SetParam("rescale_grad", 1.0/batch_size)
+ ->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
+
+ // Create executor by binding parameters to the model
+ auto *exec = net.SimpleBind(ctx, args);
+ auto arg_names = net.ListArguments();
// Start training
for (int iter = 0; iter < max_epoch; ++iter) {
@@ -85,15 +91,14 @@ int main(int argc, char** argv) {
args["X"] = data_batch.data;
args["label"] = data_batch.label;
- // Create executor by binding parameters to the model
- auto *exec = net.SimpleBind(ctx, args);
// Compute gradients
exec->Forward(true);
exec->Backward();
// Update parameters
- exec->UpdateAll(opt, learning_rate, weight_decay);
- // Remember to free the memory
- delete exec;
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "X" || arg_names[i] == "label") continue;
+ opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+ }
}
auto toc = chrono::system_clock::now();
@@ -103,16 +108,15 @@ int main(int argc, char** argv) {
auto data_batch = val_iter.GetDataBatch();
args["X"] = data_batch.data;
args["label"] = data_batch.label;
- auto *exec = net.SimpleBind(ctx, args);
// Forward pass is enough as no gradient is needed when evaluating
exec->Forward(false);
acc.Update(data_batch.label, exec->outputs[0]);
- delete exec;
}
float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
LG << "Epoch: " << iter << " " << samples/duration << " samples/sec Accuracy: " << acc.Get();
}
+ delete exec;
MXNotifyShutdown();
return 0;
}
diff --git a/cpp-package/example/mlp_gpu.cpp b/cpp-package/example/mlp_gpu.cpp
index 23be637..a6281c3 100644
--- a/cpp-package/example/mlp_gpu.cpp
+++ b/cpp-package/example/mlp_gpu.cpp
@@ -24,7 +24,7 @@ Symbol mlp(const vector<int> &layers) {
weights[i],
biases[i],
layers[i]);
- outputs[i] = i == layers.size()-1? fc : Activation(fc, ActivationActType::kRelu);
+ outputs[i] = i == layers.size()-1 ? fc : Activation(fc, ActivationActType::kRelu);
}
return SoftmaxOutput(outputs.back(), label);
@@ -70,12 +70,24 @@ int main(int argc, char** argv) {
// Create sgd optimizer
Optimizer* opt = OptimizerRegistry::Find("sgd");
- opt->SetParam("rescale_grad", 1.0/batch_size);
+ opt->SetParam("rescale_grad", 1.0/batch_size)
+ ->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
+ std::unique_ptr<LRScheduler> lr_sch(new FactorScheduler(5000, 0.1));
+ opt->SetLRScheduler(std::move(lr_sch));
+
+ // Create executor by binding parameters to the model
+ auto *exec = net.SimpleBind(ctx, args);
+ auto arg_names = net.ListArguments();
+
+ // Create metrics
+ Accuracy train_acc, val_acc;
// Start training
for (int iter = 0; iter < max_epoch; ++iter) {
int samples = 0;
train_iter.Reset();
+ train_acc.Reset();
auto tic = chrono::system_clock::now();
while (train_iter.Next()) {
@@ -87,35 +99,40 @@ int main(int argc, char** argv) {
// CopyTo is imperative, need to wait for it to complete.
NDArray::WaitAll();
- // Create executor by binding parameters to the model
- auto *exec = net.SimpleBind(ctx, args);
// Compute gradients
exec->Forward(true);
exec->Backward();
+
// Update parameters
- exec->UpdateAll(opt, learning_rate, weight_decay);
- // Remember to free the memory
- delete exec;
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "X" || arg_names[i] == "label") continue;
+ opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+ }
+ // Update metric
+ train_acc.Update(data_batch.label, exec->outputs[0]);
}
+ // one epoch of training is finished
auto toc = chrono::system_clock::now();
+ float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
+ LG << "Epoch[" << iter << "] " << samples/duration \
+ << " samples/sec " << "Train-Accuracy=" << train_acc.Get();;
- Accuracy acc;
val_iter.Reset();
+ val_acc.Reset();
while (val_iter.Next()) {
auto data_batch = val_iter.GetDataBatch();
data_batch.data.CopyTo(&args["X"]);
data_batch.label.CopyTo(&args["label"]);
NDArray::WaitAll();
- auto *exec = net.SimpleBind(ctx, args);
+
// Only forward pass is enough as no gradient is needed when evaluating
exec->Forward(false);
- acc.Update(data_batch.label, exec->outputs[0]);
- delete exec;
+ val_acc.Update(data_batch.label, exec->outputs[0]);
}
- float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
- LG << "Epoch: " << iter << " " << samples/duration << " samples/sec Accuracy: " << acc.Get();
+ LG << "Epoch[" << iter << "] Val-Accuracy=" << val_acc.Get();
}
+ delete exec;
MXNotifyShutdown();
return 0;
}
diff --git a/cpp-package/example/resnet.cpp b/cpp-package/example/resnet.cpp
index 5521567..b9766c7 100644
--- a/cpp-package/example/resnet.cpp
+++ b/cpp-package/example/resnet.cpp
@@ -165,11 +165,14 @@ int main(int argc, char const *argv[]) {
.CreateDataIter();
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
- opt->SetParam("momentum", 0.9)
+ opt->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay)
+ ->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0 / batch_size)
->SetParam("clip_gradient", 10);
auto *exec = resnet.SimpleBind(Context::gpu(), args_map);
+ auto arg_names = resnet.ListArguments();
for (int iter = 0; iter < max_epoch; ++iter) {
LG << "Epoch: " << iter;
@@ -182,7 +185,11 @@ int main(int argc, char const *argv[]) {
exec->Forward(true);
exec->Backward();
- exec->UpdateAll(opt, learning_rate, weight_decay);
+
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
+ opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+ }
NDArray::WaitAll();
}
diff --git a/cpp-package/example/test_score.cpp b/cpp-package/example/test_score.cpp
index 7dccd30..3534269 100644
--- a/cpp-package/example/test_score.cpp
+++ b/cpp-package/example/test_score.cpp
@@ -72,7 +72,15 @@ int main(int argc, char** argv) {
// Create sgd optimizer
Optimizer* opt = OptimizerRegistry::Find("sgd");
- opt->SetParam("rescale_grad", 1.0/batch_size);
+ opt->SetParam("rescale_grad", 1.0/batch_size)
+ ->SetParam("lr", learning_rate)
+ ->SetParam("wd", weight_decay);
+ std::unique_ptr<LRScheduler> lr_sch(new FactorScheduler(5000, 0.1));
+ opt->SetLRScheduler(std::move(lr_sch));
+
+ // Create executor by binding parameters to the model
+ auto *exec = net.SimpleBind(ctx, args);
+ auto arg_names = net.ListArguments();
float score = 0;
// Start training
@@ -90,15 +98,14 @@ int main(int argc, char** argv) {
// CopyTo is imperative, need to wait for it to complete.
NDArray::WaitAll();
- // Create executor by binding parameters to the model
- auto *exec = net.SimpleBind(ctx, args);
// Compute gradients
exec->Forward(true);
exec->Backward();
// Update parameters
- exec->UpdateAll(opt, learning_rate, weight_decay);
- // Remember to free the memory
- delete exec;
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (arg_names[i] == "X" || arg_names[i] == "label") continue;
+ opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
+ }
}
auto toc = chrono::system_clock::now();
@@ -109,17 +116,16 @@ int main(int argc, char** argv) {
data_batch.data.CopyTo(&args["X"]);
data_batch.label.CopyTo(&args["label"]);
NDArray::WaitAll();
- auto *exec = net.SimpleBind(ctx, args);
// Only forward pass is enough as no gradient is needed when evaluating
exec->Forward(false);
acc.Update(data_batch.label, exec->outputs[0]);
- delete exec;
}
float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
LG << "Epoch: " << iter << " " << samples/duration << " samples/sec Accuracy: " << acc.Get();
score = acc.Get();
}
+ delete exec;
MXNotifyShutdown();
return score >= MIN_SCORE ? 0 : 1;
}
diff --git a/cpp-package/include/mxnet-cpp/executor.h b/cpp-package/include/mxnet-cpp/executor.h
index 822344b..67eec01 100644
--- a/cpp-package/include/mxnet-cpp/executor.h
+++ b/cpp-package/include/mxnet-cpp/executor.h
@@ -79,18 +79,6 @@ class Executor {
*/
std::string DebugStr();
/*!
- * \brief update the arguments with given learning rate and optimizer
- * \param opt the pointer to the optimizer
- * \param lr learning rate
- * \param wd weight decay
- * \param arg_update_begin begin index of the arguments to be updated, it
- * starts after the input data by default
- * \param arg_update_end end index of the arguments to be updated, it ends
- * before the label data by default
- */
- void UpdateAll(Optimizer *opt, float lr, float wd, int arg_update_begin = 1,
- int arg_update_end = -1);
- /*!
* \brief destructor, free the handle
*/
~Executor() { MXExecutorFree(handle_); }
diff --git a/cpp-package/include/mxnet-cpp/executor.hpp b/cpp-package/include/mxnet-cpp/executor.hpp
index 1a452a1..6887956 100644
--- a/cpp-package/include/mxnet-cpp/executor.hpp
+++ b/cpp-package/include/mxnet-cpp/executor.hpp
@@ -79,13 +79,6 @@ inline std::string Executor::DebugStr() {
return std::string(output);
}
-inline void Executor::UpdateAll(Optimizer *opt, float lr, float wd,
- int arg_update_begin, int arg_update_end) {
- arg_update_end = arg_update_end < 0 ? arg_arrays.size() - 1 : arg_update_end;
- for (int i = arg_update_begin; i < arg_update_end; ++i) {
- opt->Update(i, arg_arrays[i], grad_arrays[i], lr, wd);
- }
-}
} // namespace cpp
} // namespace mxnet
diff --git a/cpp-package/include/mxnet-cpp/lr_scheduler.h b/cpp-package/include/mxnet-cpp/lr_scheduler.h
new file mode 100644
index 0000000..91f9b3c
--- /dev/null
+++ b/cpp-package/include/mxnet-cpp/lr_scheduler.h
@@ -0,0 +1,78 @@
+/*!
+* Copyright (c) 2017 by Contributors
+* \file lr_scheduler.h
+* \brief Scheduling learning rate
+*/
+
+#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_LR_SCHEDULER_H_
+#define CPP_PACKAGE_INCLUDE_MXNET_CPP_LR_SCHEDULER_H_
+
+#include "dmlc/logging.h"
+
+namespace mxnet {
+namespace cpp {
+
+/*!
+* \brief lr scheduler interface
+*/
+class LRScheduler {
+ public:
+ /*!
+ * \brief constructor
+ * \param base_lr the initial learning rate.
+ */
+ explicit LRScheduler(float base_lr = 0.01)
+ : base_lr_(base_lr) {}
+ /*!
+ * \brief set base lr
+ * \param lr learning rate from optimizer
+ */
+ void SetLR(const float lr) { base_lr_ = lr; }
+ /*!
+ * \brief get a new learning rate
+ */
+ virtual float GetLR(unsigned num_update) = 0;
+ /*!
+ * \brief destructor
+ */
+ virtual ~LRScheduler() {}
+
+ protected:
+ float base_lr_;
+};
+
+class FactorScheduler : public LRScheduler {
+ public:
+ explicit FactorScheduler(int step, float factor = 1, float stop_factor_lr = 1e-8)
+ : LRScheduler() {
+ step_ = step;
+ factor_ = factor;
+ stop_factor_lr_ = stop_factor_lr;
+ }
+
+ float GetLR(unsigned num_update) override {
+ while (num_update > unsigned(count_ + step_)) {
+ count_ += step_;
+ base_lr_ *= factor_;
+ if (base_lr_ < stop_factor_lr_) {
+ base_lr_ = stop_factor_lr_;
+ LG << "Update[" << num_update << "]: now learning rate arrived at " \
+ << base_lr_ << ", will not change in the future";
+ } else {
+ LG << "Update[" << num_update << "]: Change learning rate to " << base_lr_;
+ }
+ }
+ return base_lr_;
+ }
+
+ private:
+ int count_ = 0;
+ int step_;
+ float factor_;
+ float stop_factor_lr_;
+};
+
+} // namespace cpp
+} // namespace mxnet
+
+#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_LR_SCHEDULER_H_
diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h
index 76f8a35..1bc36d5 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.h
+++ b/cpp-package/include/mxnet-cpp/optimizer.h
@@ -17,6 +17,7 @@
#include "dmlc/logging.h"
#include "mxnet-cpp/ndarray.h"
#include "mxnet-cpp/op_map.h"
+#include "mxnet-cpp/lr_scheduler.h"
namespace mxnet {
namespace cpp {
@@ -57,15 +58,16 @@ class Optimizer {
return this;
}
/*!
- * \brief Update a weight with gradient.
- * \param index the unique index for the weight.
- * \param weight the weight to update.
- * \param grad gradient for the weight.
- * \param lr learning rate.
- * \param wd weight decay.
+ * \bried set the lr scheduler
+ * \param lrScheduler lr scheduler used for this optimizer
+ * \return reference if self
*/
- void Update(int index, NDArray weight, NDArray grad, mx_float lr,
- mx_float wd);
+ Optimizer *SetLRScheduler(std::unique_ptr<LRScheduler> lrScheduler) {
+ CHECK(lrScheduler);
+ lrScheduler_ = std::move(lrScheduler);
+ lrScheduler_->SetLR(std::stof(params_["lr"]));
+ return this;
+ }
/*!
* \brief Update a weight with gradient.
* \param index the unique index for the weight.
@@ -92,7 +94,10 @@ class Optimizer {
std::map<int, unsigned> count_;
unsigned begin_num_update_, num_update_;
unsigned UpdateCount_(int index);
+ float GetLR_(int index);
+ float GetWD_(int index);
virtual void CreateState_(int index, NDArray weight);
+ std::unique_ptr<LRScheduler> lrScheduler_ = nullptr;
};
typedef std::function<Optimizer*()> OptimizerCreator;
@@ -172,7 +177,6 @@ class AdaDeltaOptimizer : public Optimizer {
std::map<int, NDArray*> acc_g_, acc_delta_;
};
-
} // namespace cpp
} // namespace mxnet
diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp
index 9dcb158..0d6a7be 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.hpp
+++ b/cpp-package/include/mxnet-cpp/optimizer.hpp
@@ -42,6 +42,8 @@ namespace cpp {
inline Optimizer::Optimizer(unsigned begin_num_update)
: begin_num_update_(begin_num_update),
num_update_(begin_num_update_) {
+ params_["lr"] = "0.01f";
+ params_["wd"] = "0.f";
}
inline std::map<std::string, OptimizerCreator>& OptimizerRegistry::cmap() {
@@ -56,14 +58,6 @@ inline OpMap*& Optimizer::op_map() {
inline Optimizer::~Optimizer() {}
-inline void Optimizer::Update(int index, NDArray weight, NDArray grad, mx_float lr,
- mx_float wd) {
- params_["lr"] = std::to_string(lr);
- params_["wd"] = std::to_string(wd);
- UpdateCount_(index);
- Update(index, weight, grad);
-}
-
inline void Optimizer::CreateState_(int index, NDArray weight) {
}
@@ -100,6 +94,18 @@ inline unsigned Optimizer::UpdateCount_(int index) {
return new_count;
}
+inline float Optimizer::GetLR_(int index) {
+ if (nullptr != lrScheduler_) {
+ return lrScheduler_->GetLR(num_update_);
+ }
+ return std::stof(params_["lr"]);
+}
+
+inline float Optimizer::GetWD_(int index) {
+ float wd = std::stof(params_["wd"]);
+ return wd;
+}
+
inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility
@@ -140,6 +146,9 @@ inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) {
CreateState_(index, weight);
}
+ params_["lr"] = std::to_string(GetLR_(index));
+ params_["wd"] = std::to_string(GetWD_(index));
+ UpdateCount_(index);
auto keys = GetParamKeys_();
auto values = GetParamValues_();
CHECK_EQ(keys.size(), values.size());
@@ -203,6 +212,9 @@ inline void RMSPropOptimizer::Update(int index, NDArray weight, NDArray grad) {
CreateState_(index, weight);
}
+ params_["lr"] = std::to_string(GetLR_(index));
+ params_["wd"] = std::to_string(GetWD_(index));
+ UpdateCount_(index);
auto keys = GetParamKeys_();
auto values = GetParamValues_();
CHECK_EQ(keys.size(), values.size());
@@ -257,6 +269,10 @@ inline void AdamOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (mean_.count(index) == 0) {
CreateState_(index, weight);
}
+
+ params_["lr"] = std::to_string(GetLR_(index));
+ params_["wd"] = std::to_string(GetWD_(index));
+ UpdateCount_(index);
auto keys = GetParamKeys_();
auto values = GetParamValues_();
CHECK_EQ(keys.size(), values.size());
@@ -306,9 +322,11 @@ inline void AdaGradOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (history_.count(index) == 0) {
CreateState_(index, weight);
}
- float lr = std::stof(params_["lr"]);
- float wd = std::stof(params_["wd"]);
+
float eps = std::stof(params_["eps"]);
+ float lr = GetLR_(index);
+ float wd = GetWD_(index);
+ UpdateCount_(index);
if (params_.count("rescale_grad") > 0) {
grad *= std::stof(params_["rescale_grad"]);
}
@@ -345,9 +363,11 @@ inline void AdaDeltaOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (acc_g_.count(index) == 0) {
CreateState_(index, weight);
}
- float wd = std::stof(params_["wd"]);
+
float rho = std::stof(params_["rho"]);
float epsilon = std::stof(params_["epsilon"]);
+ float wd = GetWD_(index);
+ UpdateCount_(index);
if (params_.count("rescale_grad") > 0) {
grad *= std::stof(params_["rescale_grad"]);
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].