You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/12 19:20:09 UTC
[incubator-mxnet] branch master updated: Signum optimizer (#9220)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 5251b86 Signum optimizer (#9220)
5251b86 is described below
commit 5251b861693402a3e6394da990a5af183b6f7247
Author: Yu-Xiang Wang <lu...@gmail.com>
AuthorDate: Fri Jan 12 11:20:05 2018 -0800
Signum optimizer (#9220)
* the c++ version of signum and signsgd optimizer
* optimizer signum, tested working with mac on cpuusing mnist
* unit test for signum
* fix lint and incorporate haibin's code review
* rerun jenkins
* adding link to the Loshachilov and Hutter to the documentation
---
cpp-package/include/mxnet-cpp/optimizer.h | 14 +++
cpp-package/include/mxnet-cpp/optimizer.hpp | 64 +++++++++++++
python/mxnet/optimizer.py | 67 ++++++++++++-
src/operator/optimizer_op-inl.h | 142 ++++++++++++++++++++++++++++
src/operator/optimizer_op.cc | 61 ++++++++++++
src/operator/optimizer_op.cu | 7 ++
tests/python/unittest/test_optimizer.py | 81 ++++++++++++++++
7 files changed, 433 insertions(+), 3 deletions(-)
diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h
index f3763bb..320b13e 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.h
+++ b/cpp-package/include/mxnet-cpp/optimizer.h
@@ -146,6 +146,20 @@ class SGDOptimizer : public Optimizer {
AtomicSymbolCreator mom_update_handle_;
};
+class SignumOptimizer : public Optimizer {
+ public:
+ explicit SignumOptimizer(unsigned begin_num_update = 0);
+ std::string GetType() const override;
+ void Update(int index, NDArray weight, NDArray grad) override;
+ private:
+ virtual ~SignumOptimizer();
+ void CreateState_(int index, NDArray weight) override;
+ std::map<int, NDArray*> states_;
+ AtomicSymbolCreator update_handle_;
+ AtomicSymbolCreator mom_update_handle_;
+};
+
+
class RMSPropOptimizer : public Optimizer {
public:
explicit RMSPropOptimizer(unsigned begin_num_update = 0);
diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp
index cb8442d..e3d47d1 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.hpp
+++ b/cpp-package/include/mxnet-cpp/optimizer.hpp
@@ -131,6 +131,7 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer);
+ MXNETCPP_REGISTER_OPTIMIZER(signum, SignumOptimizer);
auto it = cmap().find(name);
if (it == cmap().end())
return nullptr;
@@ -200,6 +201,69 @@ inline void SGDOptimizer::CreateState_(int index, NDArray weight) {
}
}
+// inplementing Signum optimizer
+
+inline SignumOptimizer::SignumOptimizer(unsigned begin_num_update)
+ : Optimizer(begin_num_update) {
+ update_handle_ = op_map()->GetSymbolCreator("signsgd_update");
+ mom_update_handle_ = op_map()->GetSymbolCreator("signum_update");
+}
+
+inline std::string SignumOptimizer::GetType() const {
+ return "signum";
+}
+
+inline SignumOptimizer::~SignumOptimizer() {
+ for (auto &it : states_) {
+ delete it.second;
+ }
+}
+
+inline void SignumOptimizer::Update(int index, NDArray weight, NDArray grad) {
+ if (states_.count(index) == 0) {
+ CreateState_(index, weight);
+ }
+
+ params_["lr"] = std::to_string(GetLR_(index));
+ params_["wd"] = std::to_string(GetWD_(index));
+ UpdateCount_(index);
+ auto keys = GetParamKeys_();
+ auto values = GetParamValues_();
+ CHECK_EQ(keys.size(), values.size());
+
+ NDArrayHandle inputs[3];
+ inputs[0] = weight.GetHandle();
+ inputs[1] = grad.GetHandle();
+
+ int num_outputs = 1;
+ NDArrayHandle output = weight.GetHandle();
+ NDArrayHandle *outputs = &output;
+
+ if (states_[index] == nullptr) {
+ MXImperativeInvoke(update_handle_, 2, inputs,
+ &num_outputs, &outputs,
+ keys.size(), keys.data(), values.data());
+ } else {
+ inputs[2] = states_[index]->GetHandle();
+ MXImperativeInvoke(mom_update_handle_, 3, inputs,
+ &num_outputs, &outputs,
+ keys.size(), keys.data(), values.data());
+ }
+}
+
+inline void SignumOptimizer::CreateState_(int index, NDArray weight) {
+ if (params_.count("momentum") == 0) {
+ states_[index] = nullptr;
+ } else {
+ states_[index] = new NDArray(weight.GetShape(), weight.GetContext());
+ *states_[index] = 0;
+ }
+}
+
+// finish implementing Signum
+
+
+
inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("rmsprop_update");
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index feff87e..4285aec 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -25,7 +25,8 @@ import numpy
from .base import py_str
from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs)
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
- mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update)
+ mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
+ signsgd_update, signum_update)
from .ndarray import _internal
from .ndarray import op
from .ndarray import sparse
@@ -534,6 +535,67 @@ class SGD(Optimizer):
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)
+@register
+class Signum(Optimizer):
+ """The Signum optimizer that takes the sign of gradient or momentum.
+
+ The optimizer updates the weight by:
+
+ rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+ state = momentum * state + (1-momentum)*rescaled_grad
+ weight = (1 - lr * wd_lh) * weight - lr * sign(state)
+
+ See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
+
+ For details of the update algorithm see
+ :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
+
+ This optimizer accepts the following parameters in addition to those accepted
+ by :class:`.Optimizer`.
+
+ Parameters
+ ----------
+ momentum : float, optional
+ The momentum value.
+ wd_lh : float, optional
+ The amount of decoupled weight decay regularization, see details in the original paper at:\
+ https://arxiv.org/abs/1711.05101
+ """
+ def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs):
+ super(Signum, self).__init__(learning_rate=learning_rate, **kwargs)
+ self.momentum = momentum
+ self.wd_lh = wd_lh
+
+ def create_state(self, index, weight):
+ momentum = None
+ if self.momentum != 0.0:
+ momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+ return momentum
+
+ def _update_impl(self, index, weight, grad, state):
+ assert(isinstance(weight, NDArray))
+ assert(isinstance(grad, NDArray))
+ self._update_count(index)
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+
+ kwargs = {'rescale_grad': self.rescale_grad}
+ if self.momentum > 0:
+ kwargs['momentum'] = self.momentum
+ if self.clip_gradient:
+ kwargs['clip_gradient'] = self.clip_gradient
+ if self.wd_lh:
+ kwargs['wd_lh'] = self.wd_lh
+
+ if state is not None:
+ signum_update(weight, grad, state, out=weight,
+ lr=lr, wd=wd, **kwargs)
+ else:
+ signsgd_update(weight, grad, out=weight,
+ lr=lr, wd=wd, **kwargs)
+
+ def update(self, index, weight, grad, state):
+ self._update_impl(index, weight, grad, state)
@register
class FTML(Optimizer):
@@ -702,8 +764,7 @@ class SGLD(Optimizer):
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
- shape=weight.shape,
- ctx=weight.context)
+ weight.shape, weight.context)
@register # pylint: disable=invalid-name
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 33b7dd5..c2564db 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -66,6 +66,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
}
};
+
struct SGDKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
@@ -228,6 +229,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
}
};
+
struct SGDMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
@@ -1281,6 +1283,146 @@ inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs,
}
}
+
+// Implementation for signSGD and Signum
+
+struct SignSGDParam : public dmlc::Parameter<SignSGDParam> {
+ float lr;
+ float wd;
+ float rescale_grad;
+ float clip_gradient;
+ DMLC_DECLARE_PARAMETER(SignSGDParam) {
+ DMLC_DECLARE_FIELD(lr)
+ .describe("Learning rate");
+ DMLC_DECLARE_FIELD(wd)
+ .set_default(0.0f)
+ .describe("Weight decay augments the objective function with a "
+ "regularization term that penalizes large weights. "
+ "The penalty scales with the square of the magnitude of each weight.");
+ DMLC_DECLARE_FIELD(rescale_grad)
+ .set_default(1.0f)
+ .describe("Rescale gradient to grad = rescale_grad*grad.");
+ DMLC_DECLARE_FIELD(clip_gradient)
+ .set_default(-1.0f)
+ .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+ "If clip_gradient <= 0, gradient clipping is turned off. "
+ "grad = max(min(grad, clip_gradient), -clip_gradient).");
+ }
+};
+
+
+struct SignSGDKernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
+ const DType* grad_data, const DType param_clip_gradient,
+ const DType param_lr, const DType param_wd, const DType param_rescale_grad,
+ const OpReqType req) {
+
+ // param_clip_gradient has no effect for SignSGD
+ KERNEL_ASSIGN(out_data[i], req,
+ (1.f-param_lr*param_wd)*weight_data[i]
+ - (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0)));
+ }
+};
+
+template<typename xpu>
+inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs) {
+ using namespace mxnet_op;
+ const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed);
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+ Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
+ grad.dptr_, static_cast<DType>(param.clip_gradient),
+ static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+ static_cast<DType>(param.rescale_grad), req[0]);
+ });
+}
+
+
+struct SignumParam : public dmlc::Parameter<SignumParam> {
+ float lr;
+ float momentum;
+ float wd;
+ float rescale_grad;
+ float clip_gradient;
+ float wd_lh; // the amount of algorithmic weight decay by Loshchilov and Frank Hutter
+ DMLC_DECLARE_PARAMETER(SignumParam) {
+ DMLC_DECLARE_FIELD(lr)
+ .describe("Learning rate");
+ DMLC_DECLARE_FIELD(momentum)
+ .set_default(0.0f)
+ .describe("The decay rate of momentum estimates at each epoch.");
+ DMLC_DECLARE_FIELD(wd)
+ .set_default(0.0f)
+ .describe("Weight decay augments the objective function with a "
+ "regularization term that penalizes large weights. "
+ "The penalty scales with the square of the magnitude of each weight.");
+ DMLC_DECLARE_FIELD(rescale_grad)
+ .set_default(1.0f)
+ .describe("Rescale gradient to grad = rescale_grad*grad.");
+ DMLC_DECLARE_FIELD(clip_gradient)
+ .set_default(-1.0f)
+ .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+ "If clip_gradient <= 0, gradient clipping is turned off. "
+ "grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(wd_lh)
+ .set_default(0.0f)
+ .describe("The amount of weight decay that does not go into gradient/momentum calculations"
+ "otherwise do weight decay algorithmically only.");
+ }
+};
+
+struct SignumKernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
+ const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
+ const DType param_lr, const DType param_wd, const DType param_rescale_grad,
+ const DType param_wd_lh, const OpReqType req) {
+ if (param_clip_gradient >= 0.0f) {
+ mom_data[i] = param_momentum*mom_data[i]
+ - (1-param_momentum)*param_wd*weight_data[i]
+ - (1-param_momentum)
+ *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
+ } else {
+ mom_data[i] = param_momentum*mom_data[i]
+ - (1-param_momentum)*param_wd*weight_data[i]
+ - (1-param_momentum)*param_rescale_grad*grad_data[i];
+ }
+ KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i]
+ + (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0)));
+ }
+};
+
+template<typename xpu>
+inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs) {
+ using namespace mxnet_op;
+ SignumParam param = nnvm::get<SignumParam>(attrs.parsed);
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
+ Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+ Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
+ grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
+ static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+ static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]);
+ });
+}
+
+
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index dda8092..8760fe9 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -36,6 +36,67 @@ DMLC_REGISTER_PARAMETER(AdamParam);
DMLC_REGISTER_PARAMETER(RMSPropParam);
DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
DMLC_REGISTER_PARAMETER(FtrlParam);
+DMLC_REGISTER_PARAMETER(SignSGDParam);
+DMLC_REGISTER_PARAMETER(SignumParam);
+
+NNVM_REGISTER_OP(signsgd_update)
+.describe(R"code(Update function for SignSGD optimizer.
+.. math::
+
+ g_t = \nabla J(W_{t-1})\\
+ W_t = W_{t-1} - \eta_t \text{sign}(g_t)}
+
+It updates the weights using::
+
+ weight = weight - learning_rate * sign(gradient)
+
+.. note::
+ - sparse ndarray not supported for this optimizer yet.
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SignSGDParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<FCompute>("FCompute<cpu>", SignSGDUpdate<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_arguments(SignSGDParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(signum_update)
+.describe(R"code(SIGN momentUM (Signum) optimizer.
+
+.. math::
+
+ g_t = \nabla J(W_{t-1})\\
+ m_t = \beta m_{t-1} + (1 - \beta) g_t\\
+ W_t = W_{t-1} - \eta_t \text{sign}(m_t)}
+
+It updates the weights using::
+ state = momentum * state + (1-momentum) * gradient
+ weight = weight - learning_rate * sign(state)
+
+Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
+
+.. note::
+ - sparse ndarray not supported for this optimizer yet.
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SignumParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ return std::vector<uint32_t>{2};
+ })
+.set_attr<FCompute>("FCompute<cpu>", SignumUpdate<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("mom", "NDArray-or-Symbol", "Momentum")
+.add_arguments(SignumParam::__FIELDS__());
+
template<>
void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 9512e92..891f24f 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -94,6 +94,13 @@ void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
});
}
+
+NNVM_REGISTER_OP(signsgd_update)
+.set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>);
+
+NNVM_REGISTER_OP(signum_update)
+.set_attr<FCompute>("FCompute<gpu>", SignumUpdate<gpu>);
+
NNVM_REGISTER_OP(sgd_update)
.set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", SGDUpdateEx<gpu>);
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index ae248b0..2d22391 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -524,6 +524,87 @@ def test_adam():
compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
dtype, w_stype='row_sparse', g_stype='row_sparse')
+
+# Signum
+class PySignum(mx.optimizer.Optimizer):
+ """The python reference of Signum optimizer.
+
+ The optimizer updates the weight by:
+
+ rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+ state = momentum * state + (1-momentum)*rescaled_grad
+ weight = (1 - lr * wd_lh) * weight - lr * sign(state)
+
+ See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
+
+ For details of the update algorithm see
+ :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
+
+ This optimizer accepts the following parameters in addition to those accepted
+ by :class:`.Optimizer`.
+
+ Parameters
+ ----------
+ momentum : float, optional
+ The momentum value.
+ wd_lh : float, optitional
+ The amount of decoupled weight decay regularization.
+ """
+ def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh = 0.0, **kwargs):
+ super(PySignum, self).__init__(learning_rate = learning_rate, **kwargs)
+ self.momentum = momentum
+ self.wd_lh = wd_lh
+
+ def create_state(self, index, weight):
+ momentum = None
+ if self.momentum != 0.0:
+ momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+ return momentum
+
+ def update(self, index, weight, grad, state):
+ self._update_count(index)
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+
+ if state is not None:
+ mom = state
+ if self.clip_gradient is not None:
+ mom[:] = (self.momentum*mom - (1-self.momentum)*(wd*weight +
+ mx.nd.clip(grad*self.rescale_grad, -self.clip_gradient, self.clip_gradient)))
+ else:
+ mom[:] = self.momentum*mom - (1-self.momentum)*wd*weight - (1-self.momentum)*self.rescale_grad*grad
+ weight[:] = (1 - lr*self.wd_lh)*weight + lr*mx.nd.sign(mom)
+ else:
+ weight[:] = (1 - lr*(wd+self.wd_lh))*weight - lr*mx.nd.sign(grad)
+
+def test_signum():
+ mx.random.seed(0)
+ opt1 = PySignum
+ opt2 = mx.optimizer.Signum
+ shape = (3, 4, 5)
+ cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+ rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+ wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+ wd_lh_options = [{}, {'wd_lh': 0.015}, {'wd_lh': 0.0}]
+ mom_options = [{}, {'momentum': 0.9}]
+ lr_options = [{'learning_rate': 0.05},{'learning_rate': 0.01}]
+ for dtype in [np.float32, np.float64]:
+ for cg_option in cg_options:
+ for rg_option in rg_options:
+ for wd_option in wd_options:
+ for mp_option in wd_lh_options:
+ for lr_option in lr_options:
+ for mom_option in mom_options:
+ kwarg = {}
+ kwarg.update(cg_option)
+ kwarg.update(rg_option)
+ kwarg.update(wd_option)
+ kwarg.update(mp_option)
+ kwarg.update(lr_option)
+ kwarg.update(mom_option)
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
+
+
# RMSProp
class PyRMSProp(mx.optimizer.Optimizer):
"""RMSProp optimizer of Tieleman & Hinton, 2012,
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].