You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/12 19:20:10 UTC

[GitHub] piiswrong closed pull request #9220: Signum optimizer

piiswrong closed pull request #9220: Signum optimizer
URL: https://github.com/apache/incubator-mxnet/pull/9220
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h
index f3763bbd6e..320b13eebf 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.h
+++ b/cpp-package/include/mxnet-cpp/optimizer.h
@@ -146,6 +146,20 @@ class SGDOptimizer : public Optimizer {
   AtomicSymbolCreator mom_update_handle_;
 };
 
+class SignumOptimizer : public Optimizer {
+ public:
+  explicit SignumOptimizer(unsigned begin_num_update = 0);
+  std::string GetType() const override;
+  void Update(int index, NDArray weight, NDArray grad) override;
+ private:
+  virtual ~SignumOptimizer();
+  void CreateState_(int index, NDArray weight) override;
+  std::map<int, NDArray*> states_;
+  AtomicSymbolCreator update_handle_;
+  AtomicSymbolCreator mom_update_handle_;
+};
+
+
 class RMSPropOptimizer : public Optimizer {
  public:
   explicit RMSPropOptimizer(unsigned begin_num_update = 0);
diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp
index f9c885fc1f..4b9e69bcca 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.hpp
+++ b/cpp-package/include/mxnet-cpp/optimizer.hpp
@@ -131,6 +131,7 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
   MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer);
   MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer);
   MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer);
+  MXNETCPP_REGISTER_OPTIMIZER(signum, SignumOptimizer);
   auto it = cmap().find(name);
   if (it == cmap().end())
     return nullptr;
@@ -200,6 +201,69 @@ inline void SGDOptimizer::CreateState_(int index, NDArray weight) {
   }
 }
 
+// inplementing Signum optimizer
+
+inline SignumOptimizer::SignumOptimizer(unsigned begin_num_update)
+  : Optimizer(begin_num_update) {
+  update_handle_ = op_map()->GetSymbolCreator("signsgd_update");
+  mom_update_handle_ = op_map()->GetSymbolCreator("signum_update");
+}
+
+inline std::string SignumOptimizer::GetType() const {
+  return "signum";
+}
+
+inline SignumOptimizer::~SignumOptimizer() {
+  for (auto &it : states_) {
+    delete it.second;
+  }
+}
+
+inline void SignumOptimizer::Update(int index, NDArray weight, NDArray grad) {
+  if (states_.count(index) == 0) {
+    CreateState_(index, weight);
+  }
+
+  params_["lr"] = std::to_string(GetLR_(index));
+  params_["wd"] = std::to_string(GetWD_(index));
+  UpdateCount_(index);
+  auto keys = GetParamKeys_();
+  auto values = GetParamValues_();
+  CHECK_EQ(keys.size(), values.size());
+
+  NDArrayHandle inputs[3];
+  inputs[0] = weight.GetHandle();
+  inputs[1] = grad.GetHandle();
+
+  int num_outputs = 1;
+  NDArrayHandle output = weight.GetHandle();
+  NDArrayHandle *outputs = &output;
+
+  if (states_[index] == nullptr) {
+    MXImperativeInvoke(update_handle_, 2, inputs,
+        &num_outputs, &outputs,
+        keys.size(), keys.data(), values.data());
+  } else {
+    inputs[2] = states_[index]->GetHandle();
+    MXImperativeInvoke(mom_update_handle_, 3, inputs,
+        &num_outputs, &outputs,
+        keys.size(), keys.data(), values.data());
+  }
+}
+
+inline void SignumOptimizer::CreateState_(int index, NDArray weight) {
+  if (params_.count("momentum") == 0) {
+    states_[index] = nullptr;
+  } else {
+    states_[index] = new NDArray(weight.GetShape(), weight.GetContext());
+    *states_[index] = 0;
+  }
+}
+
+// finish implementing Signum
+
+
+
 inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update)
   : Optimizer(begin_num_update) {
   update_handle_ = op_map()->GetSymbolCreator("rmsprop_update");
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index feff87e0ba..4285aecef1 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -25,7 +25,8 @@
 from .base import py_str
 from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs)
 from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
-                      mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update)
+                      mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
+                      signsgd_update, signum_update)
 from .ndarray import _internal
 from .ndarray import op
 from .ndarray import sparse
@@ -534,6 +535,67 @@ def update_multi_precision(self, index, weight, grad, state):
         self._update_impl(index, weight, grad, state,
                           multi_precision=use_multi_precision)
 
+@register
+class Signum(Optimizer):
+    """The Signum optimizer that takes the sign of gradient or momentum.
+
+    The optimizer updates the weight by:
+
+        rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+        state = momentum * state + (1-momentum)*rescaled_grad
+        weight = (1 - lr * wd_lh) * weight - lr * sign(state)
+
+    See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
+
+    For details of the update algorithm see
+    :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    momentum : float, optional
+       The momentum value.
+    wd_lh : float, optional
+       The amount of decoupled weight decay regularization, see details in the original paper at:\
+       https://arxiv.org/abs/1711.05101
+    """
+    def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs):
+        super(Signum, self).__init__(learning_rate=learning_rate, **kwargs)
+        self.momentum = momentum
+        self.wd_lh = wd_lh
+
+    def create_state(self, index, weight):
+        momentum = None
+        if self.momentum != 0.0:
+            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+        return momentum
+
+    def _update_impl(self, index, weight, grad, state):
+        assert(isinstance(weight, NDArray))
+        assert(isinstance(grad, NDArray))
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+
+        kwargs = {'rescale_grad': self.rescale_grad}
+        if self.momentum > 0:
+            kwargs['momentum'] = self.momentum
+        if self.clip_gradient:
+            kwargs['clip_gradient'] = self.clip_gradient
+        if self.wd_lh:
+            kwargs['wd_lh'] = self.wd_lh
+
+        if state is not None:
+            signum_update(weight, grad, state, out=weight,
+                          lr=lr, wd=wd, **kwargs)
+        else:
+            signsgd_update(weight, grad, out=weight,
+                           lr=lr, wd=wd, **kwargs)
+
+    def update(self, index, weight, grad, state):
+        self._update_impl(index, weight, grad, state)
 
 @register
 class FTML(Optimizer):
@@ -702,8 +764,7 @@ def update(self, index, weight, grad, state):
         if self.clip_gradient is not None:
             grad = clip(grad, -self.clip_gradient, self.clip_gradient)
         weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
-                                                            shape=weight.shape,
-                                                            ctx=weight.context)
+                                                            weight.shape, weight.context)
 
 
 @register  # pylint: disable=invalid-name
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 33b7dd5fe5..c2564db0f0 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -66,6 +66,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
   }
 };
 
+
 struct SGDKernel {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
@@ -228,6 +229,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
   }
 };
 
+
 struct SGDMomKernel {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
@@ -1281,6 +1283,146 @@ inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+
+// Implementation for signSGD and Signum
+
+struct SignSGDParam : public dmlc::Parameter<SignSGDParam> {
+  float lr;
+  float wd;
+  float rescale_grad;
+  float clip_gradient;
+  DMLC_DECLARE_PARAMETER(SignSGDParam) {
+    DMLC_DECLARE_FIELD(lr)
+    .describe("Learning rate");
+    DMLC_DECLARE_FIELD(wd)
+    .set_default(0.0f)
+    .describe("Weight decay augments the objective function with a "
+              "regularization term that penalizes large weights. "
+              "The penalty scales with the square of the magnitude of each weight.");
+    DMLC_DECLARE_FIELD(rescale_grad)
+    .set_default(1.0f)
+    .describe("Rescale gradient to grad = rescale_grad*grad.");
+    DMLC_DECLARE_FIELD(clip_gradient)
+    .set_default(-1.0f)
+    .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+              "If clip_gradient <= 0, gradient clipping is turned off. "
+              "grad = max(min(grad, clip_gradient), -clip_gradient).");
+  }
+};
+
+
+struct SignSGDKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
+    const DType* grad_data, const DType param_clip_gradient,
+    const DType param_lr, const DType param_wd, const DType param_rescale_grad,
+    const OpReqType req) {
+
+    // param_clip_gradient has no effect for SignSGD
+    KERNEL_ASSIGN(out_data[i], req,
+             (1.f-param_lr*param_wd)*weight_data[i]
+               - (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0)));
+  }
+};
+
+template<typename xpu>
+inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs,
+                      const OpContext &ctx,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed);
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+    Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
+      grad.dptr_, static_cast<DType>(param.clip_gradient),
+      static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+      static_cast<DType>(param.rescale_grad), req[0]);
+  });
+}
+
+
+struct SignumParam : public dmlc::Parameter<SignumParam> {
+  float lr;
+  float momentum;
+  float wd;
+  float rescale_grad;
+  float clip_gradient;
+  float wd_lh;  // the amount of algorithmic weight decay by Loshchilov and Frank Hutter
+  DMLC_DECLARE_PARAMETER(SignumParam) {
+    DMLC_DECLARE_FIELD(lr)
+    .describe("Learning rate");
+    DMLC_DECLARE_FIELD(momentum)
+    .set_default(0.0f)
+    .describe("The decay rate of momentum estimates at each epoch.");
+    DMLC_DECLARE_FIELD(wd)
+    .set_default(0.0f)
+    .describe("Weight decay augments the objective function with a "
+              "regularization term that penalizes large weights. "
+              "The penalty scales with the square of the magnitude of each weight.");
+    DMLC_DECLARE_FIELD(rescale_grad)
+    .set_default(1.0f)
+    .describe("Rescale gradient to grad = rescale_grad*grad.");
+    DMLC_DECLARE_FIELD(clip_gradient)
+    .set_default(-1.0f)
+    .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+              "If clip_gradient <= 0, gradient clipping is turned off. "
+              "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    DMLC_DECLARE_FIELD(wd_lh)
+    .set_default(0.0f)
+    .describe("The amount of weight decay that does not go into gradient/momentum calculations"
+              "otherwise do weight decay algorithmically only.");
+  }
+};
+
+struct SignumKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
+    const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
+    const DType param_lr, const DType param_wd, const DType param_rescale_grad,
+    const DType param_wd_lh, const OpReqType req) {
+    if (param_clip_gradient >= 0.0f) {
+      mom_data[i] = param_momentum*mom_data[i]
+              - (1-param_momentum)*param_wd*weight_data[i]
+              - (1-param_momentum)
+              *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
+    } else {
+      mom_data[i] = param_momentum*mom_data[i]
+                - (1-param_momentum)*param_wd*weight_data[i]
+                - (1-param_momentum)*param_rescale_grad*grad_data[i];
+    }
+    KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i]
+      + (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0)));
+  }
+};
+
+template<typename xpu>
+inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
+                         const OpContext &ctx,
+                         const std::vector<TBlob> &inputs,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  SignumParam param = nnvm::get<SignumParam>(attrs.parsed);
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+    Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
+      grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
+      static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+      static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]);
+    });
+}
+
+
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index dda809255d..8760fe94a5 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -36,6 +36,67 @@ DMLC_REGISTER_PARAMETER(AdamParam);
 DMLC_REGISTER_PARAMETER(RMSPropParam);
 DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
 DMLC_REGISTER_PARAMETER(FtrlParam);
+DMLC_REGISTER_PARAMETER(SignSGDParam);
+DMLC_REGISTER_PARAMETER(SignumParam);
+
+NNVM_REGISTER_OP(signsgd_update)
+.describe(R"code(Update function for SignSGD optimizer.
+.. math::
+
+ g_t = \nabla J(W_{t-1})\\
+ W_t = W_{t-1} - \eta_t \text{sign}(g_t)}
+
+It updates the weights using::
+
+ weight = weight - learning_rate * sign(gradient)
+
+.. note:: 
+   - sparse ndarray not supported for this optimizer yet.
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SignSGDParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<FCompute>("FCompute<cpu>", SignSGDUpdate<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_arguments(SignSGDParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(signum_update)
+.describe(R"code(SIGN momentUM (Signum) optimizer.
+
+.. math::
+
+ g_t = \nabla J(W_{t-1})\\
+ m_t = \beta m_{t-1} + (1 - \beta) g_t\\
+ W_t = W_{t-1} - \eta_t \text{sign}(m_t)}
+
+It updates the weights using::
+ state = momentum * state + (1-momentum) * gradient
+ weight = weight - learning_rate * sign(state)
+
+Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
+
+.. note:: 
+   - sparse ndarray not supported for this optimizer yet.
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SignumParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    return std::vector<uint32_t>{2};
+  })
+.set_attr<FCompute>("FCompute<cpu>", SignumUpdate<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("mom", "NDArray-or-Symbol", "Momentum")
+.add_arguments(SignumParam::__FIELDS__());
+
 
 template<>
 void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 9512e92a80..891f24fe79 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -94,6 +94,13 @@ void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
   });
 }
 
+
+NNVM_REGISTER_OP(signsgd_update)
+.set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>);
+
+NNVM_REGISTER_OP(signum_update)
+.set_attr<FCompute>("FCompute<gpu>", SignumUpdate<gpu>);
+
 NNVM_REGISTER_OP(sgd_update)
 .set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", SGDUpdateEx<gpu>);
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index ae248b0d0b..2d22391879 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -524,6 +524,87 @@ def test_adam():
                             compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
                                           dtype, w_stype='row_sparse', g_stype='row_sparse')
 
+
+# Signum
+class PySignum(mx.optimizer.Optimizer):
+    """The python reference of Signum optimizer.
+
+    The optimizer updates the weight by:
+
+        rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+        state = momentum * state + (1-momentum)*rescaled_grad
+        weight = (1 - lr * wd_lh) * weight - lr * sign(state)
+
+    See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
+
+    For details of the update algorithm see
+    :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    momentum : float, optional
+       The momentum value.
+    wd_lh : float, optitional
+       The amount of decoupled weight decay regularization.
+    """
+    def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh = 0.0, **kwargs):
+        super(PySignum, self).__init__(learning_rate = learning_rate, **kwargs)
+        self.momentum = momentum
+        self.wd_lh = wd_lh
+
+    def create_state(self, index, weight):
+        momentum = None
+        if self.momentum != 0.0:
+            momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+        return momentum
+
+    def update(self, index, weight, grad, state):
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+
+        if state is not None:
+            mom = state
+            if self.clip_gradient is not None:
+              mom[:] = (self.momentum*mom - (1-self.momentum)*(wd*weight +
+                  mx.nd.clip(grad*self.rescale_grad, -self.clip_gradient, self.clip_gradient)))
+            else:
+              mom[:] = self.momentum*mom - (1-self.momentum)*wd*weight - (1-self.momentum)*self.rescale_grad*grad
+            weight[:] = (1 - lr*self.wd_lh)*weight + lr*mx.nd.sign(mom)
+        else:
+            weight[:] = (1 - lr*(wd+self.wd_lh))*weight - lr*mx.nd.sign(grad)
+
+def test_signum():
+    mx.random.seed(0)
+    opt1 = PySignum
+    opt2 = mx.optimizer.Signum
+    shape = (3, 4, 5)
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    wd_lh_options = [{}, {'wd_lh': 0.015}, {'wd_lh': 0.0}]
+    mom_options = [{}, {'momentum': 0.9}]
+    lr_options = [{'learning_rate': 0.05},{'learning_rate': 0.01}]
+    for dtype in [np.float32, np.float64]:
+        for cg_option in cg_options:
+            for rg_option in rg_options:
+                for wd_option in wd_options:
+                    for mp_option in wd_lh_options:
+                        for lr_option in lr_options:
+                            for mom_option in mom_options:
+                                kwarg = {}
+                                kwarg.update(cg_option)
+                                kwarg.update(rg_option)
+                                kwarg.update(wd_option)
+                                kwarg.update(mp_option)
+                                kwarg.update(lr_option)
+                                kwarg.update(mom_option)
+                                compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
+
+
 # RMSProp
 class PyRMSProp(mx.optimizer.Optimizer):
     """RMSProp optimizer of Tieleman & Hinton, 2012,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services