You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/30 14:47:26 UTC

[incubator-mxnet] branch master updated: [FEATURE] AdaBelief operator (#20065)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 059a055  [FEATURE] AdaBelief operator (#20065)
059a055 is described below

commit 059a05549e7892007a9490e9ca2987d22ddec816
Author: khaotik <kh...@users.noreply.github.com>
AuthorDate: Fri Apr 30 22:46:04 2021 +0800

    [FEATURE] AdaBelief operator (#20065)
    
    * copycat from adamw to adabelief
    
    * fix py lint
    
    * fix py lint #2
    
    * fix cpp lint
    
    * add adabelief to amp list
---
 python/mxnet/amp/lists/symbol_fp16.py              |   4 +
 python/mxnet/ndarray/contrib.py                    |  51 +++++
 python/mxnet/optimizer/__init__.py                 |  12 +-
 python/mxnet/optimizer/adabelief.py                | 231 +++++++++++++++++++++
 .../contrib/{adamw-inl.h => adabelief-inl.h}       | 114 +++++-----
 src/operator/contrib/{adamw.cc => adabelief.cc}    | 132 ++++++------
 src/operator/contrib/{adamw.cu => adabelief.cu}    |  27 +--
 src/operator/contrib/adamw-inl.h                   |  10 +-
 src/operator/contrib/adamw.cc                      |  13 +-
 src/operator/contrib/adamw.cu                      |  10 +-
 tests/python/unittest/test_contrib_optimizer.py    | 147 ++++++++-----
 tests/python/unittest/test_optimizer.py            |  25 +++
 12 files changed, 573 insertions(+), 203 deletions(-)

diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index f78d32d..6359384 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -82,6 +82,7 @@ FP16_FP32_FUNCS = [
     '_FusedOpHelper',
     '_FusedOpOutHelper',
     '_NoGradient',
+    '_adabelief_update',
     '_adamw_update',
     '_arange',
     '_cond',
@@ -153,11 +154,14 @@ FP16_FP32_FUNCS = [
     '_minimum_scalar',
     '_minus_scalar',
     '_mod_scalar',
+    '_mp_adabelief_update',
     '_mp_adamw_update',
     '_mul_scalar',
+    '_multi_adabelief_update',
     '_multi_adamw_update',
     '_multi_lamb_update',
     '_multi_lans_update',
+    '_multi_mp_adabelief_update',
     '_multi_mp_adamw_update',
     '_multi_mp_lamb_update',
     '_multi_mp_lans_update',
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index 204d694..ed70f8c 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -679,6 +679,57 @@ def multi_mp_lamb_update(weights, grads, mean, var, weights32, step_count,
                                                    wds=wds,
                                                    **kwargs)
 
+def adabelief_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
+                     epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
+    rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context)
+    return ndarray._internal._adabelief_update(weight=weight, grad=grad, mean=mean, var=var,
+                                               rescale_grad=rescale_grad, lr=lr, eta=eta,
+                                               beta1=beta1, beta2=beta2, epsilon=epsilon,
+                                               wd=wd, clip_gradient=clip_gradient, out=out,
+                                               name=name, **kwargs)
+
+def mp_adabelief_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
+                        beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
+                        name=None, **kwargs):
+    rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context)
+    return ndarray._internal._mp_adabelief_update(weight=weight, grad=grad, mean=mean, var=var,
+                                                  weight32=weight32,
+                                                  rescale_grad=rescale_grad, lr=lr, eta=eta,
+                                                  beta1=beta1, beta2=beta2, epsilon=epsilon,
+                                                  wd=wd, clip_gradient=clip_gradient, out=out,
+                                                  name=name, **kwargs)
+
+def multi_adabelief_update(weights, grads, mean, var, rescale_grad, lrs, wds, etas,
+                           out=None, name=None, size=0, **kwargs):
+    if not size:
+        size = len(weights)
+
+    rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context)
+    temp_list = _flatten_list(zip(weights, grads, mean, var)) + [rescale_grad]
+    return ndarray._internal._multi_adabelief_update(*temp_list,
+                                                     out=out,
+                                                     num_weights=size,
+                                                     lrs=lrs,
+                                                     wds=wds,
+                                                     etas=etas,
+                                                     name=name,
+                                                     **kwargs)
+
+def multi_mp_adabelief_update(weights, grads, mean, var, weights32, rescale_grad, lrs, wds, etas,
+                              out=None, name=None, size=0, **kwargs):
+    if not size:
+        size = len(weights)
+
+    rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context)
+    temp_list = _flatten_list(zip(weights, grads, mean, var, weights32)) + [rescale_grad]
+    return ndarray._internal._multi_mp_adabelief_update(*temp_list,
+                                                        out=out,
+                                                        num_weights=size,
+                                                        lrs=lrs,
+                                                        wds=wds,
+                                                        etas=etas,
+                                                        name=name,
+                                                        **kwargs)
 
 def multi_lans_update(weights, grads, mean, var, step_count,
                       lrs, wds, out=None, num_tensors=0, **kwargs):
diff --git a/python/mxnet/optimizer/__init__.py b/python/mxnet/optimizer/__init__.py
index 9bf0c1f..fba34a3 100644
--- a/python/mxnet/optimizer/__init__.py
+++ b/python/mxnet/optimizer/__init__.py
@@ -19,8 +19,11 @@
 from . import (optimizer, contrib, updater, utils, sgd,
                sgld, signum, dcasgd, nag, adagrad,
                adadelta, adam, adamax, nadam, ftrl,
-               ftml, lars, lamb, rmsprop, lans, adamW)
+               ftml, lars, lamb, rmsprop, lans, adamW,
+               adabelief)
 # pylint: disable=wildcard-import
+from .adabelief import *
+
 from .adamW import *
 
 from .optimizer import *
@@ -62,6 +65,7 @@ from .rmsprop import *
 from .lans import *
 
 __all__ = optimizer.__all__ + updater.__all__ + ['contrib'] + sgd.__all__ + sgld.__all__ \
-          + signum.__all__ + dcasgd.__all__ + nag.__all__ + adagrad.__all__ + adadelta.__all__ \
-          + adam.__all__ + adamax.__all__ + nadam.__all__ + ftrl.__all__ + ftml.__all__ \
-          + lars.__all__ + lamb.__all__ + rmsprop.__all__ + lans.__all__
+          + signum.__all__ + dcasgd.__all__ + nag.__all__ + adabelief.__all__ \
+          + adagrad.__all__ + adadelta.__all__ + adam.__all__ + adamax.__all__ \
+          + nadam.__all__ + ftrl.__all__ + ftml.__all__ + lars.__all__ \
+          + lamb.__all__ + rmsprop.__all__ + lans.__all__
diff --git a/python/mxnet/optimizer/adabelief.py b/python/mxnet/optimizer/adabelief.py
new file mode 100644
index 0000000..c224ebf
--- /dev/null
+++ b/python/mxnet/optimizer/adabelief.py
@@ -0,0 +1,231 @@
+# coding: utf-8
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""AdaBelief optimizer."""
+import math
+import os
+import numpy as np
+from .optimizer import Optimizer, register
+from ..ndarray import (zeros, clip, sqrt, square, full, NDArray)
+from ..ndarray.contrib import mp_adabelief_update, adabelief_update,\
+    multi_mp_adabelief_update, multi_adabelief_update
+
+
+__all__ = ['AdaBelief']
+
+
+@register
+class AdaBelief(Optimizer):
+    """The AdaBelief optimizer.
+
+    This class implements the optimizer described in *Adapting Stepsizes by the Belief in Observed Gradients*,
+     available at https://arxiv.org/pdf/2010.07468.pdf.
+
+    Updates are applied by::
+
+        grad = clip(grad * rescale_grad, clip_gradient) + wd * w
+        m = beta1 * m + (1 - beta1) * grad
+        s = beta2 * s + (1 - beta2) * ((grad - m)**2) + epsilon
+        lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
+        w = w - lr * (m / (sqrt(s) + epsilon))
+
+
+    Also, we can turn off the bias correction term and the updates are as follows::
+
+        grad = clip(grad * rescale_grad, clip_gradient) + wd * w
+        m = beta1 * m + (1 - beta1) * grad
+        s = beta2 * s + (1 - beta2) * ((grad - m)**2) + epsilon
+        lr = learning_rate
+        w = w - lr * (m / (sqrt(s) + epsilon))
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    learning_rate : float, default 0.001
+        The initial learning rate. If None, the optimization will use the
+        learning rate from ``lr_scheduler``. If not None, it will overwrite
+        the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
+        is also None, then it will be set to 0.01 by default.
+    beta1 : float, default 0.9
+        Exponential decay rate for the first moment estimates.
+    beta2 : float, default 0.999
+        Exponential decay rate for the second moment estimates.
+    epsilon : float, default 1e-6
+        Small value to avoid division by 0.
+    correct_bias : bool, default True
+       Can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository).
+       Default True.
+    use_fused_step : bool, default True
+        Whether or not to use fused kernels for optimizer.
+        When use_fused_step=False, step is called,
+        otherwise, fused_step is called.
+    """
+    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
+                 correct_bias=True, use_fused_step=True, **kwargs):
+        super().__init__(use_fused_step=use_fused_step,
+                         learning_rate=learning_rate,
+                         **kwargs)
+        self.beta1 = beta1
+        self.beta2 = beta2
+        self.epsilon = epsilon
+        self.correct_bias = correct_bias
+        self.aggregate_num = max(1, min(50,
+                                        int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', '4'))))
+
+    def create_state(self, index, weight):
+        """state creation function."""
+        return (zeros(weight.shape, weight.context, dtype=weight.dtype),  # mean
+                zeros(weight.shape, weight.context, dtype=weight.dtype))  # variance
+
+    def step(self, indices, weights, grads, states):
+        """Perform an optimization step using gradients and states.
+
+        Parameters
+        ----------
+        indices : list of int
+            List of unique indices of the parameters into the individual learning rates
+            and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
+            and `set_wd_mult()`, respectively.
+        weights : list of NDArray
+            List of parameters to be updated.
+        grads : list of NDArray
+            List of gradients of the objective with respect to this parameter.
+        states : List of any obj
+            List of state returned by `create_state()`.
+        """
+        for index, weight, grad, state in zip(indices, weights, grads, states):
+            self._update_count(index)
+            lr = self._get_lr(index)
+            wd = self._get_wd(index)
+            eps = self.epsilon
+            t = self._index_update_count[index]
+
+            # preprocess grad
+            grad *= self.rescale_grad
+            grad += wd * weight
+            if self.clip_gradient is not None:
+                grad = clip(grad, -self.clip_gradient, self.clip_gradient)
+            if self.correct_bias:
+                coef1 = 1. - self.beta1**t
+                coef2 = 1. - self.beta2**t
+                lr *= math.sqrt(coef2) / coef1
+
+            # update mean and var
+            mean, var = state
+            mean[:] *= self.beta1
+            mean[:] += (1. - self.beta1) * grad
+            var[:] *= self.beta2
+            var[:] += (1. - self.beta2) * square(grad - mean)
+            var[:] += eps
+
+            # update weight
+            d = mean / (sqrt(var) + eps)
+            weight[:] -= lr * d
+
+    def fused_step(self, indices, weights, grads, states):
+        """Perform a fused optimization step using gradients and states.
+        Fused kernel is used for update.
+
+        Parameters
+        ----------
+        indices : list of int
+            List of unique indices of the parameters into the individual learning rates
+            and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
+            and `set_wd_mult()`, respectively.
+        weights : list of NDArray
+            List of parameters to be updated.
+        grads : list of NDArray
+            List of gradients of the objective with respect to this parameter.
+        states : List of any obj
+            List of state returned by `create_state()`.
+        """
+        multi_precision = self.multi_precision and weights[0].dtype == np.float16
+        aggregate = self.aggregate_num > 1
+        if not isinstance(indices, (tuple, list)):
+            indices = [indices]
+            weights = [weights]
+            grads = [grads]
+            states = [states]
+        for w_i, g_i in zip(weights, grads):
+            assert(isinstance(w_i, NDArray))
+            assert(isinstance(g_i, NDArray))
+            aggregate = (aggregate and
+                         w_i.stype == 'default' and
+                         g_i.stype == 'default')
+        self._update_count(indices)
+        lrs = self._get_lrs(indices)
+        wds = self._get_wds(indices)
+        if self.correct_bias:
+            new_lrs = []
+            for idx, lr in zip(indices, lrs):
+                t = self._index_update_count[idx]
+                coef1 = 1. - self.beta1 ** t
+                coef2 = 1. - self.beta2 ** t
+                new_lrs.append(lr * math.sqrt(coef2) / coef1)
+            lrs = new_lrs
+        if not isinstance(self.rescale_grad, NDArray):
+            self.rescale_grad = full(shape=(1,), val=self.rescale_grad, ctx=weights[0].context)
+        else:
+            self.rescale_grad = self.rescale_grad.as_in_context(weights[0].context)
+        kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
+                  'rescale_grad': self.rescale_grad}
+        if self.clip_gradient:
+            kwargs['clip_gradient'] = self.clip_gradient
+
+        if aggregate:
+            current_index = 0
+            while current_index < len(indices):
+                sidx = current_index
+                eidx = min(current_index + self.aggregate_num, len(indices))
+                if not multi_precision:
+                    mean, var = list(zip(*states[sidx:eidx]))
+                    multi_adabelief_update(weights[sidx:eidx], grads[sidx:eidx],
+                                           mean, var,
+                                           out=weights[sidx:eidx],
+                                           size=len(weights[sidx:eidx]),
+                                           lrs=list(np.ones(len(weights[sidx:eidx]))),
+                                           wds=wds[sidx:eidx],
+                                           etas=lrs[sidx:eidx],
+                                           **kwargs)
+                else:
+                    mean_var = list(zip(*states[sidx:eidx]))[0]
+                    tmean_var = list(zip(*mean_var))
+                    mean = tmean_var[0]
+                    var = tmean_var[1]
+                    multi_mp_adabelief_update(weights[sidx:eidx],
+                                              grads[sidx:eidx],
+                                              mean, var,
+                                              list(zip(*states[sidx:eidx]))[1],
+                                              out=weights[sidx:eidx],
+                                              size=len(weights[sidx:eidx]),
+                                              lrs=list(np.ones(len(weights[sidx:eidx]))),
+                                              wds=wds[sidx:eidx],
+                                              etas=lrs[sidx:eidx],
+                                              **kwargs)
+                current_index += self.aggregate_num
+        else:
+            for w_i, g_i, s_i, lr, wd in zip(weights, grads, states, lrs, wds):
+                if not multi_precision:
+                    mean, var = s_i
+                    adabelief_update(w_i, g_i, mean, var, out=w_i,
+                                     lr=1, wd=wd, eta=lr, **kwargs)
+                else:
+                    mean, var = s_i[0]
+                    mp_adabelief_update(w_i, g_i, mean, var, s_i[1], out=w_i,
+                                        lr=1, wd=wd, eta=lr, **kwargs)
diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adabelief-inl.h
similarity index 85%
copy from src/operator/contrib/adamw-inl.h
copy to src/operator/contrib/adabelief-inl.h
index 6f48333..2f28215 100644
--- a/src/operator/contrib/adamw-inl.h
+++ b/src/operator/contrib/adabelief-inl.h
@@ -18,13 +18,13 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
- * \file adamw-inl.h
+ *  Copyright (c) 2021 by Contributors
+ * \file adabelief-inl.h
  * \brief Optimizer operators
- * \author Haibin Lin, Moises Hernandez, Andrei Ivanov
+ * \author khaotik
  */
-#ifndef MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
-#define MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
+#ifndef MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_
+#define MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_
 #include <mxnet/operator.h>
 #include <vector>
 #include "../mshadow_op.h"
@@ -32,8 +32,9 @@
 
 namespace mxnet {
 namespace op {
+namespace adabelief {
 
-struct AdamWParam : public dmlc::Parameter<AdamWParam> {
+struct AdaBeliefParam : public dmlc::Parameter<AdaBeliefParam> {
   float lr;
   float beta1;
   float beta2;
@@ -41,7 +42,7 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
   float wd;
   float eta;
   float clip_gradient;
-  DMLC_DECLARE_PARAMETER(AdamWParam) {
+  DMLC_DECLARE_PARAMETER(AdaBeliefParam) {
     DMLC_DECLARE_FIELD(lr)
     .describe("Learning rate");
     DMLC_DECLARE_FIELD(beta1)
@@ -98,7 +99,7 @@ inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs,
 }
 
 template<int req>
-struct MPAdamWKernel {
+struct MPAdaBeliefKernel {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data,
     float* var_data, const DType* weight_data, const DType* grad_data, float* weight32,
@@ -107,22 +108,24 @@ struct MPAdamWKernel {
     const float param_rescale_grad, const float param_epsilon) {
     float w = weight32[i];
     float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]);
-    if (param_clip_gradient >= 0.0f)
+    scaled_grad += param_wd * w;
+    if (param_clip_gradient >= 0.f)
       scaled_grad = mshadow_op::clip::Map(scaled_grad, param_clip_gradient);
 
-    float mean = mean_data[i] = param_beta1 * mean_data[i] + (1.0f - param_beta1) * scaled_grad;
-    float var = var_data[i] = param_beta2 * var_data[i] +
-                  (1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad);
+    const float mean = param_beta1 * (mean_data[i] - scaled_grad) + scaled_grad;
+    const float adj = mshadow_op::square::Map(scaled_grad - mean);
+    const float var = param_beta2*(var_data[i] - adj) + adj + param_epsilon;
 
-    w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
-                         + param_wd * w);
+    w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon));
+    mean_data[i] = mean;
+    var_data[i] = var;
     weight32[i] = w;
     KERNEL_ASSIGN(out_data[i], req, w);
   }
 };
 
 template<typename xpu>
-struct MPAdamWUpdate {
+struct MPAdaBeliefUpdate {
   static inline void Forward(const nnvm::NodeAttrs& attrs,
                const OpContext &ctx,
                const std::vector<TBlob> &inputs,
@@ -130,7 +133,7 @@ struct MPAdamWUpdate {
                const std::vector<TBlob> &outputs,
                const float rescale_grad) {
     using namespace mxnet_op;
-    const auto& param = nnvm::get<AdamWParam>(attrs.parsed);
+    const auto& param = nnvm::get<AdaBeliefParam>(attrs.parsed);
     Stream<xpu>* s = ctx.get_stream<xpu>();
     MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
       Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
@@ -140,19 +143,22 @@ struct MPAdamWUpdate {
       Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
       Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
       MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<MPAdamWKernel<req_type>, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_,
-          var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1,
-          param.beta2, param.eta, param.lr, param.wd, rescale_grad, param.epsilon);
+        Kernel<MPAdaBeliefKernel<req_type>, xpu>::Launch(
+            s, weight.shape_.Size(), out.dptr_, mean.dptr_, var.dptr_,
+            weight.dptr_, grad.dptr_, weight32.dptr_,
+            param.clip_gradient, param.beta1, param.beta2, param.eta,
+            param.lr, param.wd, rescale_grad, param.epsilon);
       });
     });
   }
 };
 
 /*
- * \brief adam_w update.
+ * \brief adabelief update.
+ *
  */
 template<typename xpu>
-struct AdamWUpdate {
+struct AdaBeliefUpdate {
   static inline void Forward(const nnvm::NodeAttrs& attrs,
                              const OpContext &ctx,
                              const std::vector<TBlob> &inputs,
@@ -162,7 +168,7 @@ struct AdamWUpdate {
     using namespace mshadow;
     using namespace mshadow::expr;
     using namespace mshadow_op;
-    const auto &param = nnvm::get<AdamWParam>(attrs.parsed);
+    const auto &param = nnvm::get<AdaBeliefParam>(attrs.parsed);
     Stream<xpu>* s = ctx.get_stream<xpu>();
     MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
       const Tensor<xpu, 2, DType> &weight = inputs[0].FlatTo2D<xpu, DType>(s);
@@ -171,18 +177,19 @@ struct AdamWUpdate {
       Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
       Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
 
-      grad = scalar<DType>(rescale_grad) * grad;
+      grad = scalar<DType>(rescale_grad) * grad + scalar<DType>(param.wd) * weight;
       if (param.clip_gradient >= 0.0f)
         grad = F<clip>(grad, DType(param.clip_gradient));
 
       mean = scalar<DType>(param.beta1) * mean + scalar<DType>(1.f-param.beta1) * grad;
-      var = scalar<DType>(param.beta2) * var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
+      var = scalar<DType>(param.beta2) * var +
+            scalar<DType>(1.f-param.beta2) * F<square>(grad - mean) +
+            scalar<DType>(param.epsilon);
 
       Assign(out, req[0],
              weight -
              scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
-             mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
-             (scalar<DType>(param.wd) * weight)));
+             mean / (F<square_root>(var) + scalar<DType>(param.epsilon))));
     });
   }
 };
@@ -190,7 +197,7 @@ struct AdamWUpdate {
 ////
 // Multiple gradients in single kernel
 ////
-struct MultiAdamWParam : public dmlc::Parameter<MultiAdamWParam> {
+struct MultiAdaBeliefParam : public dmlc::Parameter<MultiAdaBeliefParam> {
   mxnet::Tuple<float> lrs;
   mxnet::Tuple<float> wds;
   mxnet::Tuple<float> etas;
@@ -199,7 +206,7 @@ struct MultiAdamWParam : public dmlc::Parameter<MultiAdamWParam> {
   float epsilon;
   float clip_gradient;
   int num_weights;
-  DMLC_DECLARE_PARAMETER(MultiAdamWParam) {
+  DMLC_DECLARE_PARAMETER(MultiAdaBeliefParam) {
     DMLC_DECLARE_FIELD(lrs)
     .describe("Learning rates");
     DMLC_DECLARE_FIELD(beta1)
@@ -230,7 +237,7 @@ struct MultiAdamWParam : public dmlc::Parameter<MultiAdamWParam> {
 
 
 template<typename ParamType, int input_stride>
-inline bool MP_MultiAdamW_InferShape(const nnvm::NodeAttrs& attrs,
+inline bool MP_MultiAdaBelief_InferShape(const nnvm::NodeAttrs& attrs,
                                           mxnet::ShapeVector *in_attrs,
                                           mxnet::ShapeVector *out_attrs) {
   const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
@@ -272,7 +279,7 @@ inline bool MP_MultiAdamW_InferShape(const nnvm::NodeAttrs& attrs,
 }
 
 template <typename ParamType, int input_stride, int num_fp32_inputs>
-inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs,
+inline bool MP_MultiAdaBelief_InferType(const nnvm::NodeAttrs& attrs,
                                     std::vector<int> *in_attrs,
                                     std::vector<int> *out_attrs) {
   const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
@@ -312,20 +319,20 @@ inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs,
 
 
 template<typename T>
-class Adam_type_identity {
+class _type_identity {
  public:
   using type = T;
 };
 
 
 template<typename T>
-class Adam_single_precision {
+class _single_precision {
  public:
   using type = float;
 };
 
 template<typename DType, typename MPDType>
-struct MultiAdamKernelParam {
+struct MultiKernelParam {
   static const int N = 50;
   int count;
   size_t max_size;
@@ -346,10 +353,10 @@ struct MultiAdamKernelParam {
 };
 
 template<typename MPDType, bool has_mixed_precision>
-struct MultiMPAdamWKernel {
+struct MultiMPAdaBeliefKernel {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, const MultiAdamKernelParam<DType, MPDType>& param,
-                                  const OpReqType req, const float rescale_grad){
+  MSHADOW_XINLINE static void Map(int i, const MultiKernelParam<DType, MPDType>& param,
+                                  const OpReqType req, const float rescale_grad) {
     for (int index = 0; index < param.count; ++index) {
       if ((size_t)i < param.sizes[index]) {
         MPDType w = has_mixed_precision ? param.weights32[index][i]:
@@ -357,18 +364,18 @@ struct MultiMPAdamWKernel {
         MPDType scaled_grad = static_cast<MPDType>(rescale_grad)*
                               static_cast<MPDType>(param.grad_data[index][i]);
 
-        if (param.clip_gradient >= 0.0f)
+        scaled_grad += param.wds[index] * w;
+        if (param.clip_gradient >= 0.f)
           scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient);
 
-        const auto mean = param.beta1 * (param.mean_data[index][i]- scaled_grad) + scaled_grad;
-        const auto adj = mshadow_op::square::Map(scaled_grad);
-        const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj;
+        const auto mean = param.beta1 * (param.mean_data[index][i] - scaled_grad) + scaled_grad;
+        const auto adj = mshadow_op::square::Map(mean - scaled_grad);
+        const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj + param.epsilon;
 
         param.mean_data[index][i] = mean;
         param.var_data[index][i] = var;
         w = w - param.etas[index] * (param.lrs[index] *
-            mean / (mshadow_op::square_root::Map(var) + param.epsilon)
-            + param.wds[index] * w);
+            mean / (mshadow_op::square_root::Map(var) + param.epsilon));
         if (has_mixed_precision)
           param.weights32[index][i] = w;
 
@@ -381,13 +388,13 @@ struct MultiMPAdamWKernel {
 template<typename xpu,
          typename DType,
          typename MPDType,
-         typename ParamType = MultiAdamWParam,
+         typename ParamType = MultiAdaBeliefParam,
          int input_stride = 4>
-void FillMultiAdamKernelParam(const nnvm::NodeAttrs& attrs,
+void FillMultiKernelParam(const nnvm::NodeAttrs& attrs,
                               const OpContext &ctx,
                               const std::vector<TBlob> &inputs,
                               const std::vector<TBlob> &outputs,
-                              MultiAdamKernelParam<DType, MPDType> *pParam) {
+                              MultiKernelParam<DType, MPDType> *pParam) {
   const ParamType& p = nnvm::get<ParamType>(attrs.parsed);
   mxnet_op::Stream<xpu>* s = ctx.get_stream<xpu>();
   pParam->clip_gradient = p.clip_gradient;
@@ -422,7 +429,7 @@ void FillMultiAdamKernelParam(const nnvm::NodeAttrs& attrs,
 }
 
 template<typename xpu, template<typename> class MPTypeChooser, int input_stride>
-static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs,
+static inline void MultiAdaBeliefUpdate(const nnvm::NodeAttrs& attrs,
                                     const OpContext &ctx,
                                     const std::vector<TBlob> &inputs,
                                     const std::vector<OpReqType> &req,
@@ -432,11 +439,11 @@ static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs,
   Stream<xpu>* s = ctx.get_stream<xpu>();
   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
     using MPDType = typename MPTypeChooser<DType>::type;
-    MultiAdamKernelParam<DType, MPDType> param;
-    FillMultiAdamKernelParam<xpu, DType, MPDType, MultiAdamWParam, input_stride>
+    MultiKernelParam<DType, MPDType> param;
+    FillMultiKernelParam<xpu, DType, MPDType, MultiAdaBeliefParam, input_stride>
             (attrs, ctx, inputs, outputs, &param);
 
-    Kernel<MultiMPAdamWKernel<MPDType, !std::is_same<DType, MPDType>::value>, xpu>::
+    Kernel<MultiMPAdaBeliefKernel<MPDType, !std::is_same<DType, MPDType>::value>, xpu>::
                               Launch(s, param.max_size, param, req[0], rescale_grad);
   });
 }
@@ -450,7 +457,7 @@ bool PrepareInputBlobs(const OpContext &ctx,
                        std::vector<TBlob> *inputs_wo_scale,
                        float *pScalef) {
   const size_t num_in = inputs.size() - 1;
-  GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef);
+  adabelief::GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef);
   if (!std::isfinite(*pScalef) || *pScalef == 0)
     return false;
 
@@ -487,14 +494,15 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs,
     return;
 
   if (!MP)
-    MultiAdamWUpdate<xpu, Adam_type_identity, 4>
+    MultiAdaBeliefUpdate<xpu, _type_identity, 4>
       (attrs, ctx, inputs_wo_scale, req, outputs, scalef);
   else
-    MultiAdamWUpdate<xpu, Adam_single_precision, 5>
+    MultiAdaBeliefUpdate<xpu, _single_precision, 5>
       (attrs, ctx, inputs_wo_scale, req, outputs, scalef);
 }
 
+}  // namespace adabelief
 }  // namespace op
 }  // namespace mxnet
 
-#endif  // MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
+#endif  // MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_
diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adabelief.cc
similarity index 68%
copy from src/operator/contrib/adamw.cc
copy to src/operator/contrib/adabelief.cc
index effae5c..06be748 100644
--- a/src/operator/contrib/adamw.cc
+++ b/src/operator/contrib/adabelief.cc
@@ -18,54 +18,55 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
- * \file adamw.cc
+ *  Copyright (c) 2021 by Contributors
+ * \file adabelief.cc
  * \brief Optimizer operators
- * \author Haibin Lin, Moises Hernandez, Andrei Ivanov
+ * \author khaotik
  */
-#include "./adamw-inl.h"
+#include "./adabelief-inl.h"
 
 namespace mxnet {
 namespace op {
+namespace adabelief {
 
-DMLC_REGISTER_PARAMETER(AdamWParam);
-DMLC_REGISTER_PARAMETER(MultiAdamWParam);
+DMLC_REGISTER_PARAMETER(AdaBeliefParam);
+DMLC_REGISTER_PARAMETER(MultiAdaBeliefParam);
 
-NNVM_REGISTER_OP(_mp_adamw_update)
-.describe(R"code(Update function for multi-precision AdamW optimizer.
+NNVM_REGISTER_OP(_mp_adabelief_update)
+.describe(R"code(Update function for multi-precision AdaBelief optimizer.
 
-AdamW is seen as a modification of Adam by decoupling the weight decay from the
-optimization steps taken w.r.t. the loss function.
+AdaBelief is seen as a modification of Adam with a different variance 
+estimator.
 
-Adam update consists of the following steps, where g represents gradient and m, v
+Adam update consists of the following steps, where g represents gradient and m, s
 are 1st and 2nd order moment estimates (mean and variance).
 
 .. math::
 
- g_t = \nabla J(W_{t-1})\\
+ g_t = \nabla J(W_{t-1}) + w * wd \\
  m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
- v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
- W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
+ s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon })
 
 It updates the weights using::
 
  m = beta1*m + (1-beta1)*grad
- v = beta2*v + (1-beta2)*(grad**2)
- w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+ s = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(s) + epsilon))
 
 Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
 the update is skipped.
 )code" ADD_FILELINE)
 .set_num_inputs(6)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<AdamWParam>)
+.set_attr_parser(ParamParser<AdaBeliefParam>)
 .set_attr<mxnet::FInferShape>("FInferShape", MPUpdateInferShape<2, 1, 6>)
 .set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<2, 1, 6>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3, 4};
   })
-.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, MPAdamWUpdate<cpu>>)
+.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, MPAdaBeliefUpdate<cpu>>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
@@ -74,41 +75,43 @@ the update is skipped.
 .add_argument("rescale_grad", "NDArray-or-Symbol",
               "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
               "the update is skipped.")
-.add_arguments(AdamWParam::__FIELDS__());
+.add_arguments(AdaBeliefParam::__FIELDS__());
 
-NNVM_REGISTER_OP(_adamw_update)
-.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
-Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
+NNVM_REGISTER_OP(_adabelief_update)
+.describe(R"code(Update function for AdaBelief optimizer.
 
-Adam update consists of the following steps, where g represents gradient and m, v
+AdaBelief is seen as a modification of Adam with a different variance 
+estimator.
+
+Adam update consists of the following steps, where g represents gradient and m, s
 are 1st and 2nd order moment estimates (mean and variance).
 
 .. math::
 
- g_t = \nabla J(W_{t-1})\\
+ g_t = \nabla J(W_{t-1}) + w * wd \\
  m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
- v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
- W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
+ s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon })
 
 It updates the weights using::
 
  m = beta1*m + (1-beta1)*grad
- v = beta2*v + (1-beta2)*(grad**2)
- w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+ s = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(s) + epsilon))
 
 Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
 the update is skipped.
-)code" ADD_FILELINE)
+))code" ADD_FILELINE)
 .set_num_inputs(5)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<AdamWParam>)
+.set_attr_parser(ParamParser<AdaBeliefParam>)
 .set_attr<mxnet::FInferShape>("FInferShape", MPUpdateInferShape<4, 1, 5>)
 .set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<4, 1, 5>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3};
   })
-.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, AdamWUpdate<cpu>>)
+.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, AdaBeliefUpdate<cpu>>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
@@ -116,7 +119,7 @@ the update is skipped.
 .add_argument("rescale_grad", "NDArray-or-Symbol",
               "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
               "the update is skipped.")
-.add_arguments(AdamWParam::__FIELDS__());
+.add_arguments(AdaBeliefParam::__FIELDS__());
 
 template<>
 void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float *pScalef) {
@@ -125,7 +128,8 @@ void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float
   )
 }
 
-std::vector<std::string> ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
+static std::vector<std::string>
+ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
   std::vector<std::string> ret;
   for (uint32_t i = 0; i < num_args; ++i) {
     const auto idx = std::to_string(i);
@@ -137,42 +141,42 @@ std::vector<std::string> ParamToVector(uint32_t num_args, const char *pName[], s
 }
 
 inline uint32_t num_weights(const nnvm::NodeAttrs& attrs) {
-  return static_cast<uint32_t>(dmlc::get<MultiAdamWParam>(attrs.parsed).num_weights);
+  return static_cast<uint32_t>(dmlc::get<MultiAdaBeliefParam>(attrs.parsed).num_weights);
 }
 
-NNVM_REGISTER_OP(_multi_adamw_update)
-.describe(R"code(Update function for AdamW optimizer.
+NNVM_REGISTER_OP(_multi_adabelief_update)
+.describe(R"code(Update function for AdaBelief optimizer.
 
-AdamW is seen as a modification of Adam by decoupling the weight decay from the
-optimization steps taken w.r.t. the loss function.
+AdaBelief is seen as a modification of Adam with a different variance 
+estimator.
 
-Adam update consists of the following steps, where g represents gradient and m, v
+Adam update consists of the following steps, where g represents gradient and m, s
 are 1st and 2nd order moment estimates (mean and variance).
 
 .. math::
 
- g_t = \nabla J(W_{t-1})\\
+ g_t = \nabla J(W_{t-1}) + w * wd \\
  m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
- v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
- W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
+ s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon })
 
 It updates the weights using::
 
  m = beta1*m + (1-beta1)*grad
- v = beta2*v + (1-beta2)*(grad**2)
- w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+ s = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(s) + epsilon))
 
 Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
 the update is skipped.
-)code" ADD_FILELINE)
+))code" ADD_FILELINE)
 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
     return num_weights(attrs) * 4 + 1;
   })
 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
     return num_weights(attrs);
   })
-.set_attr_parser(ParamParser<MultiAdamWParam>)
-.set_attr<mxnet::FInferShape>("FInferShape", MP_MultiAdamW_InferShape<MultiAdamWParam, 4>)
+.set_attr_parser(ParamParser<MultiAdaBeliefParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", MP_MultiAdaBelief_InferShape<MultiAdaBeliefParam, 4>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
@@ -193,43 +197,43 @@ the update is skipped.
 
 .set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, false>)
 .add_argument("data", "NDArray-or-Symbol[]", "data")
-.add_arguments(MultiAdamWParam::__FIELDS__());
+.add_arguments(MultiAdaBeliefParam::__FIELDS__());
 
 
-NNVM_REGISTER_OP(_multi_mp_adamw_update)
-.describe(R"code(Update function for multi-precision AdamW optimizer.
+NNVM_REGISTER_OP(_multi_mp_adabelief_update)
+.describe(R"code(Update function for multi-precision AdaBelief optimizer.
 
-AdamW is seen as a modification of Adam by decoupling the weight decay from the
-optimization steps taken w.r.t. the loss function.
+AdaBelief is seen as a modification of Adam with a different variance 
+estimator.
 
-Adam update consists of the following steps, where g represents gradient and m, v
+Adam update consists of the following steps, where g represents gradient and m, s
 are 1st and 2nd order moment estimates (mean and variance).
 
 .. math::
 
- g_t = \nabla J(W_{t-1})\\
+ g_t = \nabla J(W_{t-1}) + w * wd \\
  m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
- v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
- W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
+ s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon })
 
 It updates the weights using::
 
  m = beta1*m + (1-beta1)*grad
- v = beta2*v + (1-beta2)*(grad**2)
- w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+ s = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(s) + epsilon))
 
 Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
 the update is skipped.
-)code" ADD_FILELINE)
+))code" ADD_FILELINE)
 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
     return num_weights(attrs) * 5 + 1;
   })
 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
     return num_weights(attrs);
   })
-.set_attr_parser(ParamParser<MultiAdamWParam>)
-.set_attr<mxnet::FInferShape>("FInferShape", MP_MultiAdamW_InferShape<MultiAdamWParam, 5>)
-.set_attr<nnvm::FInferType>("FInferType", MP_MultiAdamW_InferType<MultiAdamWParam, 5, 1>)
+.set_attr_parser(ParamParser<MultiAdaBeliefParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", MP_MultiAdaBelief_InferShape<MultiAdaBeliefParam, 5>)
+.set_attr<nnvm::FInferType>("FInferType", MP_MultiAdaBelief_InferType<MultiAdaBeliefParam, 5, 1>)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
     const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "weight32_", "rescale_grad_"};
@@ -250,8 +254,8 @@ the update is skipped.
 
 .set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, true>)
 .add_argument("data", "NDArray-or-Symbol[]", "data")
-.add_arguments(MultiAdamWParam::__FIELDS__());
-
+.add_arguments(MultiAdaBeliefParam::__FIELDS__());
 
+}  // namespace adabelief
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adabelief.cu
similarity index 67%
copy from src/operator/contrib/adamw.cu
copy to src/operator/contrib/adabelief.cu
index 2b0040e..e64dcb4 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adabelief.cu
@@ -18,16 +18,16 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
- * \file adamw.cu
+ *  Copyright (c) 2021 by Contributors
+ * \file adabelief.cu
  * \brief Optimizer operators
- * \author Haibin Lin, Moises Hernandez, Andrei Ivanov
+ * \author khaotik
  */
-#include "./adamw-inl.h"
+#include "./adabelief-inl.h"
 
 namespace mxnet {
 namespace op {
-
+namespace adabelief {
 template<>
 void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float *pScalef) {
   MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
@@ -39,18 +39,19 @@ void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float
     *pScalef = static_cast<float>(scale);
   })
 }
+}  // namespace adabelief
 
-NNVM_REGISTER_OP(_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", MPUpdate<gpu, AdamWUpdate<gpu>>);
+NNVM_REGISTER_OP(_adabelief_update)
+.set_attr<FCompute>("FCompute<gpu>", adabelief::MPUpdate<gpu, adabelief::AdaBeliefUpdate<gpu>>);
 
-NNVM_REGISTER_OP(_mp_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", MPUpdate<gpu, MPAdamWUpdate<gpu>>);
+NNVM_REGISTER_OP(_mp_adabelief_update)
+.set_attr<FCompute>("FCompute<gpu>", adabelief::MPUpdate<gpu, adabelief::MPAdaBeliefUpdate<gpu>>);
 
-NNVM_REGISTER_OP(_multi_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", multiMPUpdate<gpu, false>);
+NNVM_REGISTER_OP(_multi_adabelief_update)
+.set_attr<FCompute>("FCompute<gpu>", adabelief::multiMPUpdate<gpu, false>);
 
-NNVM_REGISTER_OP(_multi_mp_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", multiMPUpdate<gpu, true>);
+NNVM_REGISTER_OP(_multi_mp_adabelief_update)
+.set_attr<FCompute>("FCompute<gpu>", adabelief::multiMPUpdate<gpu, true>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h
index 6f48333..56c5ea2 100644
--- a/src/operator/contrib/adamw-inl.h
+++ b/src/operator/contrib/adamw-inl.h
@@ -32,6 +32,7 @@
 
 namespace mxnet {
 namespace op {
+namespace adamw {
 
 struct AdamWParam : public dmlc::Parameter<AdamWParam> {
   float lr;
@@ -114,7 +115,7 @@ struct MPAdamWKernel {
     float var = var_data[i] = param_beta2 * var_data[i] +
                   (1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad);
 
-    w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
+    w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
                          + param_wd * w);
     weight32[i] = w;
     KERNEL_ASSIGN(out_data[i], req, w);
@@ -349,7 +350,7 @@ template<typename MPDType, bool has_mixed_precision>
 struct MultiMPAdamWKernel {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, const MultiAdamKernelParam<DType, MPDType>& param,
-                                  const OpReqType req, const float rescale_grad){
+                                  const OpReqType req, const float rescale_grad) {
     for (int index = 0; index < param.count; ++index) {
       if ((size_t)i < param.sizes[index]) {
         MPDType w = has_mixed_precision ? param.weights32[index][i]:
@@ -442,7 +443,7 @@ static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs,
 }
 
 template<typename xpu>
-void GetScaleFloat(mshadow::Stream<xpu> *s, const TBlob &scale_blob, float *pScalef);
+static void GetScaleFloat(mshadow::Stream<xpu> *s, const TBlob &scale_blob, float *pScalef);
 
 template<typename xpu>
 bool PrepareInputBlobs(const OpContext &ctx,
@@ -450,7 +451,7 @@ bool PrepareInputBlobs(const OpContext &ctx,
                        std::vector<TBlob> *inputs_wo_scale,
                        float *pScalef) {
   const size_t num_in = inputs.size() - 1;
-  GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef);
+  adamw::GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef);
   if (!std::isfinite(*pScalef) || *pScalef == 0)
     return false;
 
@@ -494,6 +495,7 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs,
       (attrs, ctx, inputs_wo_scale, req, outputs, scalef);
 }
 
+}  // namespace adamw
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc
index effae5c..e662502 100644
--- a/src/operator/contrib/adamw.cc
+++ b/src/operator/contrib/adamw.cc
@@ -27,6 +27,7 @@
 
 namespace mxnet {
 namespace op {
+namespace adamw {
 
 DMLC_REGISTER_PARAMETER(AdamWParam);
 DMLC_REGISTER_PARAMETER(MultiAdamWParam);
@@ -65,7 +66,7 @@ the update is skipped.
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3, 4};
   })
-.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, MPAdamWUpdate<cpu>>)
+.set_attr<FCompute>("FCompute<cpu>", adamw::MPUpdate<cpu, MPAdamWUpdate<cpu>>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
@@ -108,7 +109,7 @@ the update is skipped.
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3};
   })
-.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, AdamWUpdate<cpu>>)
+.set_attr<FCompute>("FCompute<cpu>", adamw::MPUpdate<cpu, AdamWUpdate<cpu>>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
@@ -125,7 +126,8 @@ void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float
   )
 }
 
-std::vector<std::string> ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
+static std::vector<std::string>
+ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
   std::vector<std::string> ret;
   for (uint32_t i = 0; i < num_args; ++i) {
     const auto idx = std::to_string(i);
@@ -191,7 +193,7 @@ the update is skipped.
     return ret;
   })
 
-.set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, false>)
+.set_attr<FCompute>("FCompute<cpu>", adamw::multiMPUpdate<cpu, false>)
 .add_argument("data", "NDArray-or-Symbol[]", "data")
 .add_arguments(MultiAdamWParam::__FIELDS__());
 
@@ -248,10 +250,11 @@ the update is skipped.
     return ret;
   })
 
-.set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, true>)
+.set_attr<FCompute>("FCompute<cpu>", adamw::multiMPUpdate<cpu, true>)
 .add_argument("data", "NDArray-or-Symbol[]", "data")
 .add_arguments(MultiAdamWParam::__FIELDS__());
 
 
+}  // namespace adamw
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu
index 2b0040e..95fcffb 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adamw.cu
@@ -27,6 +27,7 @@
 
 namespace mxnet {
 namespace op {
+namespace adamw {
 
 template<>
 void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float *pScalef) {
@@ -41,16 +42,17 @@ void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float
 }
 
 NNVM_REGISTER_OP(_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", MPUpdate<gpu, AdamWUpdate<gpu>>);
+.set_attr<FCompute>("FCompute<gpu>", adamw::MPUpdate<gpu, AdamWUpdate<gpu>>);
 
 NNVM_REGISTER_OP(_mp_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", MPUpdate<gpu, MPAdamWUpdate<gpu>>);
+.set_attr<FCompute>("FCompute<gpu>", adamw::MPUpdate<gpu, MPAdamWUpdate<gpu>>);
 
 NNVM_REGISTER_OP(_multi_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", multiMPUpdate<gpu, false>);
+.set_attr<FCompute>("FCompute<gpu>", adamw::multiMPUpdate<gpu, false>);
 
 NNVM_REGISTER_OP(_multi_mp_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", multiMPUpdate<gpu, true>);
+.set_attr<FCompute>("FCompute<gpu>", adamw::multiMPUpdate<gpu, true>);
 
+}  // namespace adamw
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py
index f0fbb7b..b6f624f 100644
--- a/tests/python/unittest/test_contrib_optimizer.py
+++ b/tests/python/unittest/test_contrib_optimizer.py
@@ -61,27 +61,27 @@ def test_group_adagrad():
                 dtype,
                 g_stype='row_sparse')
 
-
-@xfail_when_nonstandard_decimal_separator
-@pytest.mark.serial
-def test_adamw():
-    def get_refs(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1):
-        if clip_grad >= 0:
-            grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad)
-
-        mean_ref = beta1*m + (1-beta1)*grad_rescale
-        v_ref = beta2*v + (1-beta2)*(grad_rescale**2)
-        weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon) + weight * wd)
-        return mean_ref, v_ref, weight_ref
-
-    def run_adamw_test(nElem=1, aggregate=False):
-        aggregate = aggregate or nElem > 1
+def _fn_noimpl(*args, **kwargs):
+    raise NotImplementedError()
+
+class _AdamLikeTestHelper:
+    fn_update = _fn_noimpl
+    fn_multi_update = _fn_noimpl
+    fn_mp_update = _fn_noimpl
+    fn_multi_mp_update = _fn_noimpl
+    @staticmethod
+    def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1):
+        '''Returns (mean_ref, v_ref, weight_ref)'''
+        raise NotImplementedError()
+    @classmethod
+    def run_test(cls, num_elem=1, aggregate=False):
+        aggregate = aggregate or num_elem > 1
         rescale_factor = 10
         eta, lr, wd, epsilon = 1, 1, 0.1, 1e-8
         beta1, beta2 = 0.9, 0.999
         clip_gradient = np.random.uniform(rescale_factor, rescale_factor)
         weight, grad, m, v, etas, lrs, wds, weight_ref = [], [], [], [], [], [], [], []
-        for i in range(nElem):
+        for i in range(num_elem):
             shape = (np.random.randint(3, high=10), np.random.randint(3, high=10))
             weight.append(mx.nd.random.uniform(shape=shape))
             grad.append(mx.nd.random.uniform(-1.0, 1.0, shape=shape))
@@ -107,95 +107,130 @@ def test_adamw():
 
         for rescaled_grad in tested_rescaled_grad:
             if aggregate:
-                mx.nd.contrib.multi_adamw_update(weight, grad, m, v,
-                                                 rescaled_grad, out=weight, **kwargs)
+                cls.fn_multi_update(weight, grad, m, v,
+                                     rescaled_grad, out=weight, **kwargs)
             else:
-                mx.nd.contrib.adamw_update(weight[0], grad[0], m[0], v[0],
-                                           rescaled_grad, out=weight[0], **kwargs)
-
+                cls.fn_update(weight[0], grad[0], m[0], v[0],
+                               rescaled_grad, out=weight[0], **kwargs)
             # weights should remain unchanged
-            for j in range(nElem):
+            for j in range(num_elem):
                 assert_almost_equal(weight_ref[j], weight[j])
 
-
         # Test 2: Same as Test 1 for multi-precision update
         weight_fp16, grad_fp16, weight_fp16_refs = [], [], []
-        for i in range(nElem):
+        for i in range(num_elem):
             weight_fp16.append(weight[i].astype('float16'))
             grad_fp16.append(grad[i].astype('float16'))
             weight_fp16_refs.append(weight_fp16[i].copy())
 
         for rescaled_grad in tested_grad:
             if aggregate:
-                mx.nd.contrib.multi_mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
-                                                    rescaled_grad, out=weight_fp16, **kwargs)
+                cls.fn_multi_mp_update(weight_fp16, grad_fp16, m, v, weight,
+                                       rescaled_grad, out=weight_fp16, **kwargs)
             else:
-                mx.nd.contrib.mp_adamw_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0],
-                                              rescaled_grad, out=weight_fp16[0], **kwargs)
-
+                cls.fn_mp_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0],
+                                 rescaled_grad, out=weight_fp16[0], **kwargs)
             # weights should remain unchanged
-            for i in range(nElem):
+            for i in range(num_elem):
                 assert_almost_equal(weight_ref[i], weight[i])
                 assert_almost_equal(weight_fp16_refs[i], weight_fp16[i])
 
-
         # Test 3: Reference normal update
         grad_rescale, weight_test, m_refs, v_refs, weight_refs = [], [], [], [], []
-        for i in range(nElem):
+        for i in range(num_elem):
             grad_rescale.append(rescale_grad * grad[i])
-            m_ref, v_ref, weight_ref = get_refs(m[i], v[i], weight[i], grad_rescale[i], beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient)
+            m_ref, v_ref, weight_ref = cls.ref_impl(
+                m[i], v[i], weight[i], grad_rescale[i],
+                beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient)
             m_refs.append(m_ref)
             v_refs.append(v_ref)
             weight_refs.append(weight_ref)
             weight_test.append(weight[i].copy())
-
         # op normal update
         if aggregate:
-            mx.nd.contrib.multi_adamw_update(weight_test, grad, m, v,
-                                             rescale_grad, out=weight_test, **kwargs)
+            cls.fn_multi_update(weight_test, grad, m, v,
+                                rescale_grad, out=weight_test, **kwargs)
         else:
-            mx.nd.contrib.adamw_update(weight_test[0], grad[0], m[0], v[0],
-                                       rescale_grad, out=weight_test[0], **kwargs)
-
+            cls.fn_update(weight_test[0], grad[0], m[0], v[0],
+                          rescale_grad, out=weight_test[0], **kwargs)
         # Compare results
         atol = 1e-4 if aggregate else 1e-5
         rtol = 1e-4 if aggregate else None
-        for i in range(nElem):
+        for i in range(num_elem):
             assert_almost_equal(weight_refs[i], weight_test[i], rtol=rtol, atol=atol)
             assert_almost_equal(m_refs[i], m[i], rtol=rtol, atol=atol)
             assert_almost_equal(v_refs[i], v[i], atol=atol)
 
-
         # Test 4: Reference normal multi-precision update
         grad_rescale, m_refs, v_refs, weight_refs, weight_fp16_refs = [], [], [], [], []
-        for i in range(nElem):
+        for i in range(num_elem):
             grad_rescale.append(rescale_grad * grad_fp16[i].astype('float32'))
-            m_ref, v_ref, weight_ref = get_refs(m[i], v[i], weight[i], grad_rescale[i], beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient)
+            m_ref, v_ref, weight_ref = cls.ref_impl(
+                m[i], v[i], weight[i], grad_rescale[i],
+                beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient)
             m_refs.append(m_ref)
             v_refs.append(v_ref)
             weight_refs.append(weight_ref)
             weight_fp16_refs.append(weight_ref.astype('float16'))
-
         # op normal multi-precision update
         if aggregate:
-            mx.nd.contrib.multi_mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
-                                                rescale_grad, out=weight_fp16, **kwargs)
+            cls.fn_multi_mp_update(weight_fp16, grad_fp16, m, v, weight,
+                                   rescale_grad, out=weight_fp16, **kwargs)
         else:
-            mx.nd.contrib.mp_adamw_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0],
-                                          rescale_grad, out=weight_fp16[0], **kwargs)
-
+            cls.fn_mp_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0],
+                             rescale_grad, out=weight_fp16[0], **kwargs)
         # Compare results
-        for i in range(nElem):
+        for i in range(num_elem):
             assert_almost_equal(m_refs[i], m[i], rtol=rtol, atol=atol)
             assert_almost_equal(v_refs[i], v[i], atol=atol)
             assert_almost_equal(weight_refs[i], weight[i], rtol=rtol, atol=atol)
             assert_almost_equal(weight_fp16_refs[i], weight_fp16[i], rtol=1e-3, atol=atol)
 
-    # Testing aggregated Adam update for one element
-    run_adamw_test(1, aggregate=True)
+    def __call__(self):
+        # Testing aggregated Adam update for one element
+        self.run_test(1, aggregate=True)
+        # Testing Adam update, if num_elem == 0, OR
+        #         aggregated Adam update, if num_elem > 0
+        for num_elem in reversed(range(6)):
+            self.run_test(num_elem+1)
+
+class _AdamWTestHelper(_AdamLikeTestHelper):
+    fn_update = mx.nd.contrib.adamw_update
+    fn_multi_update = mx.nd.contrib.multi_adamw_update
+    fn_mp_update = mx.nd.contrib.mp_adamw_update
+    fn_multi_mp_update = mx.nd.contrib.multi_mp_adamw_update
+    @staticmethod
+    def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1):
+        if clip_grad >= 0:
+            grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad)
 
-    # Testing Adam update, if nElem = 0, OR
-    #         aggregated Adam update, if nElem > 0
-    for nElem in range(6):
-        run_adamw_test(nElem+1)
+        mean_ref = beta1*m + (1.-beta1)*grad_rescale
+        v_ref = beta2*v + (1.-beta2)*(grad_rescale**2)
+        weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon) + weight * wd)
+        return mean_ref, v_ref, weight_ref
+
+class _AdaBeliefTestHelper(_AdamLikeTestHelper):
+    fn_update = mx.nd.contrib.adabelief_update
+    fn_multi_update = mx.nd.contrib.multi_adabelief_update
+    fn_mp_update = mx.nd.contrib.mp_adabelief_update
+    fn_multi_mp_update = mx.nd.contrib.multi_mp_adabelief_update
+    @staticmethod
+    def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1):
+        grad_rescale += wd * weight
+        if clip_grad >= 0:
+            grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad)
 
+        mean_ref = beta1*m + (1.-beta1)*grad_rescale
+        v_ref = beta2*v + (1.-beta2)*((grad_rescale-mean_ref)**2) + epsilon
+        weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon))
+        return mean_ref, v_ref, weight_ref
+
+@xfail_when_nonstandard_decimal_separator
+@pytest.mark.serial
+def test_adamw():
+    _AdamWTestHelper()()
+
+@xfail_when_nonstandard_decimal_separator
+@pytest.mark.serial
+def test_adabelief():
+    _AdaBeliefTestHelper()()
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 8927bcd..7ccb8f1 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -923,6 +923,31 @@ def test_adamW():
                               opt2(use_fused_step=True, **kwarg), shapes, dtype,
                               rtol=1e-3, atol=2e-3)
 
+def test_adabelief():
+    opt1 = mx.optimizer.AdaBelief
+    opt2 = mx.optimizer.AdaBelief
+    shapes = [(3, 4, 5), (10, 4), (7,)]
+    beta1_options = [{}, {'beta1': 0.5}, {'beta1': 0.7}]
+    beta2_options = [{}, {'beta2': 0.8}, {'beta2': 0.9}]
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    mp_options = [{'multi_precision': False}, {'multi_precision': True}]
+    agg_options = [{'aggregate_num': 0}, {'aggregate_num': 1},
+                   {'aggregate_num': 4}, {'aggregate_num': np.inf}]
+    correct_bias_options = [{'correct_bias': True}, {'correct_bias': False}]
+    for dtype in [np.float16, np.float32]:
+        for params in itertools.product(beta1_options, beta2_options, cg_options,
+                                        rg_options, wd_options, mp_options,
+                                        agg_options, correct_bias_options):
+            kwarg = {k: v for param in params for k, v in param.items()}
+            if (dtype == np.float16 and ('multi_precision' not in kwarg or
+                                         not kwarg['multi_precision'])):
+                continue
+            compare_optimizer(opt1(use_fused_step=False, **kwarg),
+                              opt2(use_fused_step=True, **kwarg), shapes, dtype,
+                              rtol=1e-3, atol=2e-3)
+
 def test_factor_scheduler():
     base_lr = 1
     step = 100