You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/12 19:20:09 UTC

[incubator-mxnet] branch master updated: Signum optimizer (#9220)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5251b86  Signum optimizer (#9220)
5251b86 is described below

commit 5251b861693402a3e6394da990a5af183b6f7247
Author: Yu-Xiang Wang <lu...@gmail.com>
AuthorDate: Fri Jan 12 11:20:05 2018 -0800

    Signum optimizer (#9220)
    
    * the c++ version of signum and signsgd optimizer
    
    * optimizer signum, tested working with mac on cpuusing mnist
    
    * unit test for signum
    
    * fix lint and incorporate haibin's code review
    
    * rerun jenkins
    
    * adding link to the Loshachilov and Hutter to the documentation
---
 cpp-package/include/mxnet-cpp/optimizer.h   |  14 +++
 cpp-package/include/mxnet-cpp/optimizer.hpp |  64 +++++++++++++
 python/mxnet/optimizer.py                   |  67 ++++++++++++-
 src/operator/optimizer_op-inl.h             | 142 ++++++++++++++++++++++++++++
 src/operator/optimizer_op.cc                |  61 ++++++++++++
 src/operator/optimizer_op.cu                |   7 ++
 tests/python/unittest/test_optimizer.py     |  81 ++++++++++++++++
 7 files changed, 433 insertions(+), 3 deletions(-)

diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h
index f3763bb..320b13e 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.h
+++ b/cpp-package/include/mxnet-cpp/optimizer.h
@@ -146,6 +146,20 @@ class SGDOptimizer : public Optimizer {
   AtomicSymbolCreator mom_update_handle_;
 };
 
+class SignumOptimizer : public Optimizer {
+ public:
+  explicit SignumOptimizer(unsigned begin_num_update = 0);
+  std::string GetType() const override;
+  void Update(int index, NDArray weight, NDArray grad) override;
+ private:
+  virtual ~SignumOptimizer();
+  void CreateState_(int index, NDArray weight) override;
+  std::map<int, NDArray*> states_;
+  AtomicSymbolCreator update_handle_;
+  AtomicSymbolCreator mom_update_handle_;
+};
+
+
 class RMSPropOptimizer : public Optimizer {
  public:
   explicit RMSPropOptimizer(unsigned begin_num_update = 0);
diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp
index cb8442d..e3d47d1 100644
--- a/cpp-package/include/mxnet-cpp/optimizer.hpp
+++ b/cpp-package/include/mxnet-cpp/optimizer.hpp
@@ -131,6 +131,7 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
   MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer);
   MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer);
   MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer);
+  MXNETCPP_REGISTER_OPTIMIZER(signum, SignumOptimizer);
   auto it = cmap().find(name);
   if (it == cmap().end())
     return nullptr;
@@ -200,6 +201,69 @@ inline void SGDOptimizer::CreateState_(int index, NDArray weight) {
   }
 }
 
+// inplementing Signum optimizer
+
+inline SignumOptimizer::SignumOptimizer(unsigned begin_num_update)
+  : Optimizer(begin_num_update) {
+  update_handle_ = op_map()->GetSymbolCreator("signsgd_update");
+  mom_update_handle_ = op_map()->GetSymbolCreator("signum_update");
+}
+
+inline std::string SignumOptimizer::GetType() const {
+  return "signum";
+}
+
+inline SignumOptimizer::~SignumOptimizer() {
+  for (auto &it : states_) {
+    delete it.second;
+  }
+}
+
+inline void SignumOptimizer::Update(int index, NDArray weight, NDArray grad) {
+  if (states_.count(index) == 0) {
+    CreateState_(index, weight);
+  }
+
+  params_["lr"] = std::to_string(GetLR_(index));
+  params_["wd"] = std::to_string(GetWD_(index));
+  UpdateCount_(index);
+  auto keys = GetParamKeys_();
+  auto values = GetParamValues_();
+  CHECK_EQ(keys.size(), values.size());
+
+  NDArrayHandle inputs[3];
+  inputs[0] = weight.GetHandle();
+  inputs[1] = grad.GetHandle();
+
+  int num_outputs = 1;
+  NDArrayHandle output = weight.GetHandle();
+  NDArrayHandle *outputs = &output;
+
+  if (states_[index] == nullptr) {
+    MXImperativeInvoke(update_handle_, 2, inputs,
+        &num_outputs, &outputs,
+        keys.size(), keys.data(), values.data());
+  } else {
+    inputs[2] = states_[index]->GetHandle();
+    MXImperativeInvoke(mom_update_handle_, 3, inputs,
+        &num_outputs, &outputs,
+        keys.size(), keys.data(), values.data());
+  }
+}
+
+inline void SignumOptimizer::CreateState_(int index, NDArray weight) {
+  if (params_.count("momentum") == 0) {
+    states_[index] = nullptr;
+  } else {
+    states_[index] = new NDArray(weight.GetShape(), weight.GetContext());
+    *states_[index] = 0;
+  }
+}
+
+// finish implementing Signum
+
+
+
 inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update)
   : Optimizer(begin_num_update) {
   update_handle_ = op_map()->GetSymbolCreator("rmsprop_update");
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index feff87e..4285aec 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -25,7 +25,8 @@ import numpy
 from .base import py_str
 from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs)
 from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
-                      mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update)
+                      mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
+                      signsgd_update, signum_update)
 from .ndarray import _internal
 from .ndarray import op
 from .ndarray import sparse
@@ -534,6 +535,67 @@ class SGD(Optimizer):
         self._update_impl(index, weight, grad, state,
                           multi_precision=use_multi_precision)
 
+@register
+class Signum(Optimizer):
+    """The Signum optimizer that takes the sign of gradient or momentum.
+
+    The optimizer updates the weight by:
+
+        rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+        state = momentum * state + (1-momentum)*rescaled_grad
+        weight = (1 - lr * wd_lh) * weight - lr * sign(state)
+
+    See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
+
+    For details of the update algorithm see
+    :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    momentum : float, optional
+       The momentum value.
+    wd_lh : float, optional
+       The amount of decoupled weight decay regularization, see details in the original paper at:\
+       https://arxiv.org/abs/1711.05101
+    """
+    def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs):
+        super(Signum, self).__init__(learning_rate=learning_rate, **kwargs)
+        self.momentum = momentum
+        self.wd_lh = wd_lh
+
+    def create_state(self, index, weight):
+        momentum = None
+        if self.momentum != 0.0:
+            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+        return momentum
+
+    def _update_impl(self, index, weight, grad, state):
+        assert(isinstance(weight, NDArray))
+        assert(isinstance(grad, NDArray))
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+
+        kwargs = {'rescale_grad': self.rescale_grad}
+        if self.momentum > 0:
+            kwargs['momentum'] = self.momentum
+        if self.clip_gradient:
+            kwargs['clip_gradient'] = self.clip_gradient
+        if self.wd_lh:
+            kwargs['wd_lh'] = self.wd_lh
+
+        if state is not None:
+            signum_update(weight, grad, state, out=weight,
+                          lr=lr, wd=wd, **kwargs)
+        else:
+            signsgd_update(weight, grad, out=weight,
+                           lr=lr, wd=wd, **kwargs)
+
+    def update(self, index, weight, grad, state):
+        self._update_impl(index, weight, grad, state)
 
 @register
 class FTML(Optimizer):
@@ -702,8 +764,7 @@ class SGLD(Optimizer):
         if self.clip_gradient is not None:
             grad = clip(grad, -self.clip_gradient, self.clip_gradient)
         weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
-                                                            shape=weight.shape,
-                                                            ctx=weight.context)
+                                                            weight.shape, weight.context)
 
 
 @register  # pylint: disable=invalid-name
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 33b7dd5..c2564db 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -66,6 +66,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
   }
 };
 
+
 struct SGDKernel {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
@@ -228,6 +229,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
   }
 };
 
+
 struct SGDMomKernel {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
@@ -1281,6 +1283,146 @@ inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+
+// Implementation for signSGD and Signum
+
+struct SignSGDParam : public dmlc::Parameter<SignSGDParam> {
+  float lr;
+  float wd;
+  float rescale_grad;
+  float clip_gradient;
+  DMLC_DECLARE_PARAMETER(SignSGDParam) {
+    DMLC_DECLARE_FIELD(lr)
+    .describe("Learning rate");
+    DMLC_DECLARE_FIELD(wd)
+    .set_default(0.0f)
+    .describe("Weight decay augments the objective function with a "
+              "regularization term that penalizes large weights. "
+              "The penalty scales with the square of the magnitude of each weight.");
+    DMLC_DECLARE_FIELD(rescale_grad)
+    .set_default(1.0f)
+    .describe("Rescale gradient to grad = rescale_grad*grad.");
+    DMLC_DECLARE_FIELD(clip_gradient)
+    .set_default(-1.0f)
+    .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+              "If clip_gradient <= 0, gradient clipping is turned off. "
+              "grad = max(min(grad, clip_gradient), -clip_gradient).");
+  }
+};
+
+
+struct SignSGDKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
+    const DType* grad_data, const DType param_clip_gradient,
+    const DType param_lr, const DType param_wd, const DType param_rescale_grad,
+    const OpReqType req) {
+
+    // param_clip_gradient has no effect for SignSGD
+    KERNEL_ASSIGN(out_data[i], req,
+             (1.f-param_lr*param_wd)*weight_data[i]
+               - (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0)));
+  }
+};
+
+template<typename xpu>
+inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs,
+                      const OpContext &ctx,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed);
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+    Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
+      grad.dptr_, static_cast<DType>(param.clip_gradient),
+      static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+      static_cast<DType>(param.rescale_grad), req[0]);
+  });
+}
+
+
+struct SignumParam : public dmlc::Parameter<SignumParam> {
+  float lr;
+  float momentum;
+  float wd;
+  float rescale_grad;
+  float clip_gradient;
+  float wd_lh;  // the amount of algorithmic weight decay by Loshchilov and Frank Hutter
+  DMLC_DECLARE_PARAMETER(SignumParam) {
+    DMLC_DECLARE_FIELD(lr)
+    .describe("Learning rate");
+    DMLC_DECLARE_FIELD(momentum)
+    .set_default(0.0f)
+    .describe("The decay rate of momentum estimates at each epoch.");
+    DMLC_DECLARE_FIELD(wd)
+    .set_default(0.0f)
+    .describe("Weight decay augments the objective function with a "
+              "regularization term that penalizes large weights. "
+              "The penalty scales with the square of the magnitude of each weight.");
+    DMLC_DECLARE_FIELD(rescale_grad)
+    .set_default(1.0f)
+    .describe("Rescale gradient to grad = rescale_grad*grad.");
+    DMLC_DECLARE_FIELD(clip_gradient)
+    .set_default(-1.0f)
+    .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+              "If clip_gradient <= 0, gradient clipping is turned off. "
+              "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    DMLC_DECLARE_FIELD(wd_lh)
+    .set_default(0.0f)
+    .describe("The amount of weight decay that does not go into gradient/momentum calculations"
+              "otherwise do weight decay algorithmically only.");
+  }
+};
+
+struct SignumKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
+    const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
+    const DType param_lr, const DType param_wd, const DType param_rescale_grad,
+    const DType param_wd_lh, const OpReqType req) {
+    if (param_clip_gradient >= 0.0f) {
+      mom_data[i] = param_momentum*mom_data[i]
+              - (1-param_momentum)*param_wd*weight_data[i]
+              - (1-param_momentum)
+              *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
+    } else {
+      mom_data[i] = param_momentum*mom_data[i]
+                - (1-param_momentum)*param_wd*weight_data[i]
+                - (1-param_momentum)*param_rescale_grad*grad_data[i];
+    }
+    KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i]
+      + (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0)));
+  }
+};
+
+template<typename xpu>
+inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
+                         const OpContext &ctx,
+                         const std::vector<TBlob> &inputs,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  SignumParam param = nnvm::get<SignumParam>(attrs.parsed);
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+    Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
+      grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
+      static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+      static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]);
+    });
+}
+
+
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index dda8092..8760fe9 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -36,6 +36,67 @@ DMLC_REGISTER_PARAMETER(AdamParam);
 DMLC_REGISTER_PARAMETER(RMSPropParam);
 DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
 DMLC_REGISTER_PARAMETER(FtrlParam);
+DMLC_REGISTER_PARAMETER(SignSGDParam);
+DMLC_REGISTER_PARAMETER(SignumParam);
+
+NNVM_REGISTER_OP(signsgd_update)
+.describe(R"code(Update function for SignSGD optimizer.
+.. math::
+
+ g_t = \nabla J(W_{t-1})\\
+ W_t = W_{t-1} - \eta_t \text{sign}(g_t)}
+
+It updates the weights using::
+
+ weight = weight - learning_rate * sign(gradient)
+
+.. note:: 
+   - sparse ndarray not supported for this optimizer yet.
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SignSGDParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<FCompute>("FCompute<cpu>", SignSGDUpdate<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_arguments(SignSGDParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(signum_update)
+.describe(R"code(SIGN momentUM (Signum) optimizer.
+
+.. math::
+
+ g_t = \nabla J(W_{t-1})\\
+ m_t = \beta m_{t-1} + (1 - \beta) g_t\\
+ W_t = W_{t-1} - \eta_t \text{sign}(m_t)}
+
+It updates the weights using::
+ state = momentum * state + (1-momentum) * gradient
+ weight = weight - learning_rate * sign(state)
+
+Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
+
+.. note:: 
+   - sparse ndarray not supported for this optimizer yet.
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SignumParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    return std::vector<uint32_t>{2};
+  })
+.set_attr<FCompute>("FCompute<cpu>", SignumUpdate<cpu>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("mom", "NDArray-or-Symbol", "Momentum")
+.add_arguments(SignumParam::__FIELDS__());
+
 
 template<>
 void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 9512e92..891f24f 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -94,6 +94,13 @@ void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
   });
 }
 
+
+NNVM_REGISTER_OP(signsgd_update)
+.set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>);
+
+NNVM_REGISTER_OP(signum_update)
+.set_attr<FCompute>("FCompute<gpu>", SignumUpdate<gpu>);
+
 NNVM_REGISTER_OP(sgd_update)
 .set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", SGDUpdateEx<gpu>);
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index ae248b0..2d22391 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -524,6 +524,87 @@ def test_adam():
                             compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
                                           dtype, w_stype='row_sparse', g_stype='row_sparse')
 
+
+# Signum
+class PySignum(mx.optimizer.Optimizer):
+    """The python reference of Signum optimizer.
+
+    The optimizer updates the weight by:
+
+        rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+        state = momentum * state + (1-momentum)*rescaled_grad
+        weight = (1 - lr * wd_lh) * weight - lr * sign(state)
+
+    See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
+
+    For details of the update algorithm see
+    :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    momentum : float, optional
+       The momentum value.
+    wd_lh : float, optitional
+       The amount of decoupled weight decay regularization.
+    """
+    def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh = 0.0, **kwargs):
+        super(PySignum, self).__init__(learning_rate = learning_rate, **kwargs)
+        self.momentum = momentum
+        self.wd_lh = wd_lh
+
+    def create_state(self, index, weight):
+        momentum = None
+        if self.momentum != 0.0:
+            momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+        return momentum
+
+    def update(self, index, weight, grad, state):
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+
+        if state is not None:
+            mom = state
+            if self.clip_gradient is not None:
+              mom[:] = (self.momentum*mom - (1-self.momentum)*(wd*weight +
+                  mx.nd.clip(grad*self.rescale_grad, -self.clip_gradient, self.clip_gradient)))
+            else:
+              mom[:] = self.momentum*mom - (1-self.momentum)*wd*weight - (1-self.momentum)*self.rescale_grad*grad
+            weight[:] = (1 - lr*self.wd_lh)*weight + lr*mx.nd.sign(mom)
+        else:
+            weight[:] = (1 - lr*(wd+self.wd_lh))*weight - lr*mx.nd.sign(grad)
+
+def test_signum():
+    mx.random.seed(0)
+    opt1 = PySignum
+    opt2 = mx.optimizer.Signum
+    shape = (3, 4, 5)
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    wd_lh_options = [{}, {'wd_lh': 0.015}, {'wd_lh': 0.0}]
+    mom_options = [{}, {'momentum': 0.9}]
+    lr_options = [{'learning_rate': 0.05},{'learning_rate': 0.01}]
+    for dtype in [np.float32, np.float64]:
+        for cg_option in cg_options:
+            for rg_option in rg_options:
+                for wd_option in wd_options:
+                    for mp_option in wd_lh_options:
+                        for lr_option in lr_options:
+                            for mom_option in mom_options:
+                                kwarg = {}
+                                kwarg.update(cg_option)
+                                kwarg.update(rg_option)
+                                kwarg.update(wd_option)
+                                kwarg.update(mp_option)
+                                kwarg.update(lr_option)
+                                kwarg.update(mom_option)
+                                compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
+
+
 # RMSProp
 class PyRMSProp(mx.optimizer.Optimizer):
     """RMSProp optimizer of Tieleman & Hinton, 2012,

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].