You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/30 14:47:26 UTC
[incubator-mxnet] branch master updated: [FEATURE] AdaBelief
operator (#20065)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 059a055 [FEATURE] AdaBelief operator (#20065)
059a055 is described below
commit 059a05549e7892007a9490e9ca2987d22ddec816
Author: khaotik <kh...@users.noreply.github.com>
AuthorDate: Fri Apr 30 22:46:04 2021 +0800
[FEATURE] AdaBelief operator (#20065)
* copycat from adamw to adabelief
* fix py lint
* fix py lint #2
* fix cpp lint
* add adabelief to amp list
---
python/mxnet/amp/lists/symbol_fp16.py | 4 +
python/mxnet/ndarray/contrib.py | 51 +++++
python/mxnet/optimizer/__init__.py | 12 +-
python/mxnet/optimizer/adabelief.py | 231 +++++++++++++++++++++
.../contrib/{adamw-inl.h => adabelief-inl.h} | 114 +++++-----
src/operator/contrib/{adamw.cc => adabelief.cc} | 132 ++++++------
src/operator/contrib/{adamw.cu => adabelief.cu} | 27 +--
src/operator/contrib/adamw-inl.h | 10 +-
src/operator/contrib/adamw.cc | 13 +-
src/operator/contrib/adamw.cu | 10 +-
tests/python/unittest/test_contrib_optimizer.py | 147 ++++++++-----
tests/python/unittest/test_optimizer.py | 25 +++
12 files changed, 573 insertions(+), 203 deletions(-)
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index f78d32d..6359384 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -82,6 +82,7 @@ FP16_FP32_FUNCS = [
'_FusedOpHelper',
'_FusedOpOutHelper',
'_NoGradient',
+ '_adabelief_update',
'_adamw_update',
'_arange',
'_cond',
@@ -153,11 +154,14 @@ FP16_FP32_FUNCS = [
'_minimum_scalar',
'_minus_scalar',
'_mod_scalar',
+ '_mp_adabelief_update',
'_mp_adamw_update',
'_mul_scalar',
+ '_multi_adabelief_update',
'_multi_adamw_update',
'_multi_lamb_update',
'_multi_lans_update',
+ '_multi_mp_adabelief_update',
'_multi_mp_adamw_update',
'_multi_mp_lamb_update',
'_multi_mp_lans_update',
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index 204d694..ed70f8c 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -679,6 +679,57 @@ def multi_mp_lamb_update(weights, grads, mean, var, weights32, step_count,
wds=wds,
**kwargs)
+def adabelief_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
+ epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
+ rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context)
+ return ndarray._internal._adabelief_update(weight=weight, grad=grad, mean=mean, var=var,
+ rescale_grad=rescale_grad, lr=lr, eta=eta,
+ beta1=beta1, beta2=beta2, epsilon=epsilon,
+ wd=wd, clip_gradient=clip_gradient, out=out,
+ name=name, **kwargs)
+
+def mp_adabelief_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
+ beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
+ name=None, **kwargs):
+ rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context)
+ return ndarray._internal._mp_adabelief_update(weight=weight, grad=grad, mean=mean, var=var,
+ weight32=weight32,
+ rescale_grad=rescale_grad, lr=lr, eta=eta,
+ beta1=beta1, beta2=beta2, epsilon=epsilon,
+ wd=wd, clip_gradient=clip_gradient, out=out,
+ name=name, **kwargs)
+
+def multi_adabelief_update(weights, grads, mean, var, rescale_grad, lrs, wds, etas,
+ out=None, name=None, size=0, **kwargs):
+ if not size:
+ size = len(weights)
+
+ rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context)
+ temp_list = _flatten_list(zip(weights, grads, mean, var)) + [rescale_grad]
+ return ndarray._internal._multi_adabelief_update(*temp_list,
+ out=out,
+ num_weights=size,
+ lrs=lrs,
+ wds=wds,
+ etas=etas,
+ name=name,
+ **kwargs)
+
+def multi_mp_adabelief_update(weights, grads, mean, var, weights32, rescale_grad, lrs, wds, etas,
+ out=None, name=None, size=0, **kwargs):
+ if not size:
+ size = len(weights)
+
+ rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context)
+ temp_list = _flatten_list(zip(weights, grads, mean, var, weights32)) + [rescale_grad]
+ return ndarray._internal._multi_mp_adabelief_update(*temp_list,
+ out=out,
+ num_weights=size,
+ lrs=lrs,
+ wds=wds,
+ etas=etas,
+ name=name,
+ **kwargs)
def multi_lans_update(weights, grads, mean, var, step_count,
lrs, wds, out=None, num_tensors=0, **kwargs):
diff --git a/python/mxnet/optimizer/__init__.py b/python/mxnet/optimizer/__init__.py
index 9bf0c1f..fba34a3 100644
--- a/python/mxnet/optimizer/__init__.py
+++ b/python/mxnet/optimizer/__init__.py
@@ -19,8 +19,11 @@
from . import (optimizer, contrib, updater, utils, sgd,
sgld, signum, dcasgd, nag, adagrad,
adadelta, adam, adamax, nadam, ftrl,
- ftml, lars, lamb, rmsprop, lans, adamW)
+ ftml, lars, lamb, rmsprop, lans, adamW,
+ adabelief)
# pylint: disable=wildcard-import
+from .adabelief import *
+
from .adamW import *
from .optimizer import *
@@ -62,6 +65,7 @@ from .rmsprop import *
from .lans import *
__all__ = optimizer.__all__ + updater.__all__ + ['contrib'] + sgd.__all__ + sgld.__all__ \
- + signum.__all__ + dcasgd.__all__ + nag.__all__ + adagrad.__all__ + adadelta.__all__ \
- + adam.__all__ + adamax.__all__ + nadam.__all__ + ftrl.__all__ + ftml.__all__ \
- + lars.__all__ + lamb.__all__ + rmsprop.__all__ + lans.__all__
+ + signum.__all__ + dcasgd.__all__ + nag.__all__ + adabelief.__all__ \
+ + adagrad.__all__ + adadelta.__all__ + adam.__all__ + adamax.__all__ \
+ + nadam.__all__ + ftrl.__all__ + ftml.__all__ + lars.__all__ \
+ + lamb.__all__ + rmsprop.__all__ + lans.__all__
diff --git a/python/mxnet/optimizer/adabelief.py b/python/mxnet/optimizer/adabelief.py
new file mode 100644
index 0000000..c224ebf
--- /dev/null
+++ b/python/mxnet/optimizer/adabelief.py
@@ -0,0 +1,231 @@
+# coding: utf-8
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""AdaBelief optimizer."""
+import math
+import os
+import numpy as np
+from .optimizer import Optimizer, register
+from ..ndarray import (zeros, clip, sqrt, square, full, NDArray)
+from ..ndarray.contrib import mp_adabelief_update, adabelief_update,\
+ multi_mp_adabelief_update, multi_adabelief_update
+
+
+__all__ = ['AdaBelief']
+
+
+@register
+class AdaBelief(Optimizer):
+ """The AdaBelief optimizer.
+
+ This class implements the optimizer described in *Adapting Stepsizes by the Belief in Observed Gradients*,
+ available at https://arxiv.org/pdf/2010.07468.pdf.
+
+ Updates are applied by::
+
+ grad = clip(grad * rescale_grad, clip_gradient) + wd * w
+ m = beta1 * m + (1 - beta1) * grad
+ s = beta2 * s + (1 - beta2) * ((grad - m)**2) + epsilon
+ lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
+ w = w - lr * (m / (sqrt(s) + epsilon))
+
+
+ Also, we can turn off the bias correction term and the updates are as follows::
+
+ grad = clip(grad * rescale_grad, clip_gradient) + wd * w
+ m = beta1 * m + (1 - beta1) * grad
+ s = beta2 * s + (1 - beta2) * ((grad - m)**2) + epsilon
+ lr = learning_rate
+ w = w - lr * (m / (sqrt(s) + epsilon))
+
+ This optimizer accepts the following parameters in addition to those accepted
+ by :class:`.Optimizer`.
+
+ Parameters
+ ----------
+ learning_rate : float, default 0.001
+ The initial learning rate. If None, the optimization will use the
+ learning rate from ``lr_scheduler``. If not None, it will overwrite
+ the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
+ is also None, then it will be set to 0.01 by default.
+ beta1 : float, default 0.9
+ Exponential decay rate for the first moment estimates.
+ beta2 : float, default 0.999
+ Exponential decay rate for the second moment estimates.
+ epsilon : float, default 1e-6
+ Small value to avoid division by 0.
+ correct_bias : bool, default True
+ Can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository).
+ Default True.
+ use_fused_step : bool, default True
+ Whether or not to use fused kernels for optimizer.
+ When use_fused_step=False, step is called,
+ otherwise, fused_step is called.
+ """
+ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
+ correct_bias=True, use_fused_step=True, **kwargs):
+ super().__init__(use_fused_step=use_fused_step,
+ learning_rate=learning_rate,
+ **kwargs)
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.epsilon = epsilon
+ self.correct_bias = correct_bias
+ self.aggregate_num = max(1, min(50,
+ int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', '4'))))
+
+ def create_state(self, index, weight):
+ """state creation function."""
+ return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
+ zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance
+
+ def step(self, indices, weights, grads, states):
+ """Perform an optimization step using gradients and states.
+
+ Parameters
+ ----------
+ indices : list of int
+ List of unique indices of the parameters into the individual learning rates
+ and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
+ and `set_wd_mult()`, respectively.
+ weights : list of NDArray
+ List of parameters to be updated.
+ grads : list of NDArray
+ List of gradients of the objective with respect to this parameter.
+ states : List of any obj
+ List of state returned by `create_state()`.
+ """
+ for index, weight, grad, state in zip(indices, weights, grads, states):
+ self._update_count(index)
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+ eps = self.epsilon
+ t = self._index_update_count[index]
+
+ # preprocess grad
+ grad *= self.rescale_grad
+ grad += wd * weight
+ if self.clip_gradient is not None:
+ grad = clip(grad, -self.clip_gradient, self.clip_gradient)
+ if self.correct_bias:
+ coef1 = 1. - self.beta1**t
+ coef2 = 1. - self.beta2**t
+ lr *= math.sqrt(coef2) / coef1
+
+ # update mean and var
+ mean, var = state
+ mean[:] *= self.beta1
+ mean[:] += (1. - self.beta1) * grad
+ var[:] *= self.beta2
+ var[:] += (1. - self.beta2) * square(grad - mean)
+ var[:] += eps
+
+ # update weight
+ d = mean / (sqrt(var) + eps)
+ weight[:] -= lr * d
+
+ def fused_step(self, indices, weights, grads, states):
+ """Perform a fused optimization step using gradients and states.
+ Fused kernel is used for update.
+
+ Parameters
+ ----------
+ indices : list of int
+ List of unique indices of the parameters into the individual learning rates
+ and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
+ and `set_wd_mult()`, respectively.
+ weights : list of NDArray
+ List of parameters to be updated.
+ grads : list of NDArray
+ List of gradients of the objective with respect to this parameter.
+ states : List of any obj
+ List of state returned by `create_state()`.
+ """
+ multi_precision = self.multi_precision and weights[0].dtype == np.float16
+ aggregate = self.aggregate_num > 1
+ if not isinstance(indices, (tuple, list)):
+ indices = [indices]
+ weights = [weights]
+ grads = [grads]
+ states = [states]
+ for w_i, g_i in zip(weights, grads):
+ assert(isinstance(w_i, NDArray))
+ assert(isinstance(g_i, NDArray))
+ aggregate = (aggregate and
+ w_i.stype == 'default' and
+ g_i.stype == 'default')
+ self._update_count(indices)
+ lrs = self._get_lrs(indices)
+ wds = self._get_wds(indices)
+ if self.correct_bias:
+ new_lrs = []
+ for idx, lr in zip(indices, lrs):
+ t = self._index_update_count[idx]
+ coef1 = 1. - self.beta1 ** t
+ coef2 = 1. - self.beta2 ** t
+ new_lrs.append(lr * math.sqrt(coef2) / coef1)
+ lrs = new_lrs
+ if not isinstance(self.rescale_grad, NDArray):
+ self.rescale_grad = full(shape=(1,), val=self.rescale_grad, ctx=weights[0].context)
+ else:
+ self.rescale_grad = self.rescale_grad.as_in_context(weights[0].context)
+ kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
+ 'rescale_grad': self.rescale_grad}
+ if self.clip_gradient:
+ kwargs['clip_gradient'] = self.clip_gradient
+
+ if aggregate:
+ current_index = 0
+ while current_index < len(indices):
+ sidx = current_index
+ eidx = min(current_index + self.aggregate_num, len(indices))
+ if not multi_precision:
+ mean, var = list(zip(*states[sidx:eidx]))
+ multi_adabelief_update(weights[sidx:eidx], grads[sidx:eidx],
+ mean, var,
+ out=weights[sidx:eidx],
+ size=len(weights[sidx:eidx]),
+ lrs=list(np.ones(len(weights[sidx:eidx]))),
+ wds=wds[sidx:eidx],
+ etas=lrs[sidx:eidx],
+ **kwargs)
+ else:
+ mean_var = list(zip(*states[sidx:eidx]))[0]
+ tmean_var = list(zip(*mean_var))
+ mean = tmean_var[0]
+ var = tmean_var[1]
+ multi_mp_adabelief_update(weights[sidx:eidx],
+ grads[sidx:eidx],
+ mean, var,
+ list(zip(*states[sidx:eidx]))[1],
+ out=weights[sidx:eidx],
+ size=len(weights[sidx:eidx]),
+ lrs=list(np.ones(len(weights[sidx:eidx]))),
+ wds=wds[sidx:eidx],
+ etas=lrs[sidx:eidx],
+ **kwargs)
+ current_index += self.aggregate_num
+ else:
+ for w_i, g_i, s_i, lr, wd in zip(weights, grads, states, lrs, wds):
+ if not multi_precision:
+ mean, var = s_i
+ adabelief_update(w_i, g_i, mean, var, out=w_i,
+ lr=1, wd=wd, eta=lr, **kwargs)
+ else:
+ mean, var = s_i[0]
+ mp_adabelief_update(w_i, g_i, mean, var, s_i[1], out=w_i,
+ lr=1, wd=wd, eta=lr, **kwargs)
diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adabelief-inl.h
similarity index 85%
copy from src/operator/contrib/adamw-inl.h
copy to src/operator/contrib/adabelief-inl.h
index 6f48333..2f28215 100644
--- a/src/operator/contrib/adamw-inl.h
+++ b/src/operator/contrib/adabelief-inl.h
@@ -18,13 +18,13 @@
*/
/*!
- * Copyright (c) 2018 by Contributors
- * \file adamw-inl.h
+ * Copyright (c) 2021 by Contributors
+ * \file adabelief-inl.h
* \brief Optimizer operators
- * \author Haibin Lin, Moises Hernandez, Andrei Ivanov
+ * \author khaotik
*/
-#ifndef MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
-#define MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
+#ifndef MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_
+#define MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_
#include <mxnet/operator.h>
#include <vector>
#include "../mshadow_op.h"
@@ -32,8 +32,9 @@
namespace mxnet {
namespace op {
+namespace adabelief {
-struct AdamWParam : public dmlc::Parameter<AdamWParam> {
+struct AdaBeliefParam : public dmlc::Parameter<AdaBeliefParam> {
float lr;
float beta1;
float beta2;
@@ -41,7 +42,7 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
float wd;
float eta;
float clip_gradient;
- DMLC_DECLARE_PARAMETER(AdamWParam) {
+ DMLC_DECLARE_PARAMETER(AdaBeliefParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(beta1)
@@ -98,7 +99,7 @@ inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs,
}
template<int req>
-struct MPAdamWKernel {
+struct MPAdaBeliefKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data,
float* var_data, const DType* weight_data, const DType* grad_data, float* weight32,
@@ -107,22 +108,24 @@ struct MPAdamWKernel {
const float param_rescale_grad, const float param_epsilon) {
float w = weight32[i];
float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]);
- if (param_clip_gradient >= 0.0f)
+ scaled_grad += param_wd * w;
+ if (param_clip_gradient >= 0.f)
scaled_grad = mshadow_op::clip::Map(scaled_grad, param_clip_gradient);
- float mean = mean_data[i] = param_beta1 * mean_data[i] + (1.0f - param_beta1) * scaled_grad;
- float var = var_data[i] = param_beta2 * var_data[i] +
- (1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad);
+ const float mean = param_beta1 * (mean_data[i] - scaled_grad) + scaled_grad;
+ const float adj = mshadow_op::square::Map(scaled_grad - mean);
+ const float var = param_beta2*(var_data[i] - adj) + adj + param_epsilon;
- w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
- + param_wd * w);
+ w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon));
+ mean_data[i] = mean;
+ var_data[i] = var;
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
};
template<typename xpu>
-struct MPAdamWUpdate {
+struct MPAdaBeliefUpdate {
static inline void Forward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
@@ -130,7 +133,7 @@ struct MPAdamWUpdate {
const std::vector<TBlob> &outputs,
const float rescale_grad) {
using namespace mxnet_op;
- const auto& param = nnvm::get<AdamWParam>(attrs.parsed);
+ const auto& param = nnvm::get<AdaBeliefParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
@@ -140,19 +143,22 @@ struct MPAdamWUpdate {
Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
- Kernel<MPAdamWKernel<req_type>, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_,
- var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1,
- param.beta2, param.eta, param.lr, param.wd, rescale_grad, param.epsilon);
+ Kernel<MPAdaBeliefKernel<req_type>, xpu>::Launch(
+ s, weight.shape_.Size(), out.dptr_, mean.dptr_, var.dptr_,
+ weight.dptr_, grad.dptr_, weight32.dptr_,
+ param.clip_gradient, param.beta1, param.beta2, param.eta,
+ param.lr, param.wd, rescale_grad, param.epsilon);
});
});
}
};
/*
- * \brief adam_w update.
+ * \brief adabelief update.
+ *
*/
template<typename xpu>
-struct AdamWUpdate {
+struct AdaBeliefUpdate {
static inline void Forward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
@@ -162,7 +168,7 @@ struct AdamWUpdate {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
- const auto ¶m = nnvm::get<AdamWParam>(attrs.parsed);
+ const auto ¶m = nnvm::get<AdaBeliefParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
const Tensor<xpu, 2, DType> &weight = inputs[0].FlatTo2D<xpu, DType>(s);
@@ -171,18 +177,19 @@ struct AdamWUpdate {
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
- grad = scalar<DType>(rescale_grad) * grad;
+ grad = scalar<DType>(rescale_grad) * grad + scalar<DType>(param.wd) * weight;
if (param.clip_gradient >= 0.0f)
grad = F<clip>(grad, DType(param.clip_gradient));
mean = scalar<DType>(param.beta1) * mean + scalar<DType>(1.f-param.beta1) * grad;
- var = scalar<DType>(param.beta2) * var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
+ var = scalar<DType>(param.beta2) * var +
+ scalar<DType>(1.f-param.beta2) * F<square>(grad - mean) +
+ scalar<DType>(param.epsilon);
Assign(out, req[0],
weight -
scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
- mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
- (scalar<DType>(param.wd) * weight)));
+ mean / (F<square_root>(var) + scalar<DType>(param.epsilon))));
});
}
};
@@ -190,7 +197,7 @@ struct AdamWUpdate {
////
// Multiple gradients in single kernel
////
-struct MultiAdamWParam : public dmlc::Parameter<MultiAdamWParam> {
+struct MultiAdaBeliefParam : public dmlc::Parameter<MultiAdaBeliefParam> {
mxnet::Tuple<float> lrs;
mxnet::Tuple<float> wds;
mxnet::Tuple<float> etas;
@@ -199,7 +206,7 @@ struct MultiAdamWParam : public dmlc::Parameter<MultiAdamWParam> {
float epsilon;
float clip_gradient;
int num_weights;
- DMLC_DECLARE_PARAMETER(MultiAdamWParam) {
+ DMLC_DECLARE_PARAMETER(MultiAdaBeliefParam) {
DMLC_DECLARE_FIELD(lrs)
.describe("Learning rates");
DMLC_DECLARE_FIELD(beta1)
@@ -230,7 +237,7 @@ struct MultiAdamWParam : public dmlc::Parameter<MultiAdamWParam> {
template<typename ParamType, int input_stride>
-inline bool MP_MultiAdamW_InferShape(const nnvm::NodeAttrs& attrs,
+inline bool MP_MultiAdaBelief_InferShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
@@ -272,7 +279,7 @@ inline bool MP_MultiAdamW_InferShape(const nnvm::NodeAttrs& attrs,
}
template <typename ParamType, int input_stride, int num_fp32_inputs>
-inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs,
+inline bool MP_MultiAdaBelief_InferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
@@ -312,20 +319,20 @@ inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs,
template<typename T>
-class Adam_type_identity {
+class _type_identity {
public:
using type = T;
};
template<typename T>
-class Adam_single_precision {
+class _single_precision {
public:
using type = float;
};
template<typename DType, typename MPDType>
-struct MultiAdamKernelParam {
+struct MultiKernelParam {
static const int N = 50;
int count;
size_t max_size;
@@ -346,10 +353,10 @@ struct MultiAdamKernelParam {
};
template<typename MPDType, bool has_mixed_precision>
-struct MultiMPAdamWKernel {
+struct MultiMPAdaBeliefKernel {
template<typename DType>
- MSHADOW_XINLINE static void Map(int i, const MultiAdamKernelParam<DType, MPDType>& param,
- const OpReqType req, const float rescale_grad){
+ MSHADOW_XINLINE static void Map(int i, const MultiKernelParam<DType, MPDType>& param,
+ const OpReqType req, const float rescale_grad) {
for (int index = 0; index < param.count; ++index) {
if ((size_t)i < param.sizes[index]) {
MPDType w = has_mixed_precision ? param.weights32[index][i]:
@@ -357,18 +364,18 @@ struct MultiMPAdamWKernel {
MPDType scaled_grad = static_cast<MPDType>(rescale_grad)*
static_cast<MPDType>(param.grad_data[index][i]);
- if (param.clip_gradient >= 0.0f)
+ scaled_grad += param.wds[index] * w;
+ if (param.clip_gradient >= 0.f)
scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient);
- const auto mean = param.beta1 * (param.mean_data[index][i]- scaled_grad) + scaled_grad;
- const auto adj = mshadow_op::square::Map(scaled_grad);
- const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj;
+ const auto mean = param.beta1 * (param.mean_data[index][i] - scaled_grad) + scaled_grad;
+ const auto adj = mshadow_op::square::Map(mean - scaled_grad);
+ const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj + param.epsilon;
param.mean_data[index][i] = mean;
param.var_data[index][i] = var;
w = w - param.etas[index] * (param.lrs[index] *
- mean / (mshadow_op::square_root::Map(var) + param.epsilon)
- + param.wds[index] * w);
+ mean / (mshadow_op::square_root::Map(var) + param.epsilon));
if (has_mixed_precision)
param.weights32[index][i] = w;
@@ -381,13 +388,13 @@ struct MultiMPAdamWKernel {
template<typename xpu,
typename DType,
typename MPDType,
- typename ParamType = MultiAdamWParam,
+ typename ParamType = MultiAdaBeliefParam,
int input_stride = 4>
-void FillMultiAdamKernelParam(const nnvm::NodeAttrs& attrs,
+void FillMultiKernelParam(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
- MultiAdamKernelParam<DType, MPDType> *pParam) {
+ MultiKernelParam<DType, MPDType> *pParam) {
const ParamType& p = nnvm::get<ParamType>(attrs.parsed);
mxnet_op::Stream<xpu>* s = ctx.get_stream<xpu>();
pParam->clip_gradient = p.clip_gradient;
@@ -422,7 +429,7 @@ void FillMultiAdamKernelParam(const nnvm::NodeAttrs& attrs,
}
template<typename xpu, template<typename> class MPTypeChooser, int input_stride>
-static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs,
+static inline void MultiAdaBeliefUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
@@ -432,11 +439,11 @@ static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs,
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
using MPDType = typename MPTypeChooser<DType>::type;
- MultiAdamKernelParam<DType, MPDType> param;
- FillMultiAdamKernelParam<xpu, DType, MPDType, MultiAdamWParam, input_stride>
+ MultiKernelParam<DType, MPDType> param;
+ FillMultiKernelParam<xpu, DType, MPDType, MultiAdaBeliefParam, input_stride>
(attrs, ctx, inputs, outputs, ¶m);
- Kernel<MultiMPAdamWKernel<MPDType, !std::is_same<DType, MPDType>::value>, xpu>::
+ Kernel<MultiMPAdaBeliefKernel<MPDType, !std::is_same<DType, MPDType>::value>, xpu>::
Launch(s, param.max_size, param, req[0], rescale_grad);
});
}
@@ -450,7 +457,7 @@ bool PrepareInputBlobs(const OpContext &ctx,
std::vector<TBlob> *inputs_wo_scale,
float *pScalef) {
const size_t num_in = inputs.size() - 1;
- GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef);
+ adabelief::GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef);
if (!std::isfinite(*pScalef) || *pScalef == 0)
return false;
@@ -487,14 +494,15 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs,
return;
if (!MP)
- MultiAdamWUpdate<xpu, Adam_type_identity, 4>
+ MultiAdaBeliefUpdate<xpu, _type_identity, 4>
(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
else
- MultiAdamWUpdate<xpu, Adam_single_precision, 5>
+ MultiAdaBeliefUpdate<xpu, _single_precision, 5>
(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
}
+} // namespace adabelief
} // namespace op
} // namespace mxnet
-#endif // MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_
+#endif // MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_
diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adabelief.cc
similarity index 68%
copy from src/operator/contrib/adamw.cc
copy to src/operator/contrib/adabelief.cc
index effae5c..06be748 100644
--- a/src/operator/contrib/adamw.cc
+++ b/src/operator/contrib/adabelief.cc
@@ -18,54 +18,55 @@
*/
/*!
- * Copyright (c) 2018 by Contributors
- * \file adamw.cc
+ * Copyright (c) 2021 by Contributors
+ * \file adabelief.cc
* \brief Optimizer operators
- * \author Haibin Lin, Moises Hernandez, Andrei Ivanov
+ * \author khaotik
*/
-#include "./adamw-inl.h"
+#include "./adabelief-inl.h"
namespace mxnet {
namespace op {
+namespace adabelief {
-DMLC_REGISTER_PARAMETER(AdamWParam);
-DMLC_REGISTER_PARAMETER(MultiAdamWParam);
+DMLC_REGISTER_PARAMETER(AdaBeliefParam);
+DMLC_REGISTER_PARAMETER(MultiAdaBeliefParam);
-NNVM_REGISTER_OP(_mp_adamw_update)
-.describe(R"code(Update function for multi-precision AdamW optimizer.
+NNVM_REGISTER_OP(_mp_adabelief_update)
+.describe(R"code(Update function for multi-precision AdaBelief optimizer.
-AdamW is seen as a modification of Adam by decoupling the weight decay from the
-optimization steps taken w.r.t. the loss function.
+AdaBelief is seen as a modification of Adam with a different variance
+estimator.
-Adam update consists of the following steps, where g represents gradient and m, v
+Adam update consists of the following steps, where g represents gradient and m, s
are 1st and 2nd order moment estimates (mean and variance).
.. math::
- g_t = \nabla J(W_{t-1})\\
+ g_t = \nabla J(W_{t-1}) + w * wd \\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
- v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
- W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
+ s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon })
It updates the weights using::
m = beta1*m + (1-beta1)*grad
- v = beta2*v + (1-beta2)*(grad**2)
- w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+ s = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(s) + epsilon))
Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
the update is skipped.
)code" ADD_FILELINE)
.set_num_inputs(6)
.set_num_outputs(1)
-.set_attr_parser(ParamParser<AdamWParam>)
+.set_attr_parser(ParamParser<AdaBeliefParam>)
.set_attr<mxnet::FInferShape>("FInferShape", MPUpdateInferShape<2, 1, 6>)
.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<2, 1, 6>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3, 4};
})
-.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, MPAdamWUpdate<cpu>>)
+.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, MPAdaBeliefUpdate<cpu>>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
@@ -74,41 +75,43 @@ the update is skipped.
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
"the update is skipped.")
-.add_arguments(AdamWParam::__FIELDS__());
+.add_arguments(AdaBeliefParam::__FIELDS__());
-NNVM_REGISTER_OP(_adamw_update)
-.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
-Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
+NNVM_REGISTER_OP(_adabelief_update)
+.describe(R"code(Update function for AdaBelief optimizer.
-Adam update consists of the following steps, where g represents gradient and m, v
+AdaBelief is seen as a modification of Adam with a different variance
+estimator.
+
+Adam update consists of the following steps, where g represents gradient and m, s
are 1st and 2nd order moment estimates (mean and variance).
.. math::
- g_t = \nabla J(W_{t-1})\\
+ g_t = \nabla J(W_{t-1}) + w * wd \\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
- v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
- W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
+ s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon })
It updates the weights using::
m = beta1*m + (1-beta1)*grad
- v = beta2*v + (1-beta2)*(grad**2)
- w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+ s = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(s) + epsilon))
Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
the update is skipped.
-)code" ADD_FILELINE)
+))code" ADD_FILELINE)
.set_num_inputs(5)
.set_num_outputs(1)
-.set_attr_parser(ParamParser<AdamWParam>)
+.set_attr_parser(ParamParser<AdaBeliefParam>)
.set_attr<mxnet::FInferShape>("FInferShape", MPUpdateInferShape<4, 1, 5>)
.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<4, 1, 5>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
-.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, AdamWUpdate<cpu>>)
+.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, AdaBeliefUpdate<cpu>>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
@@ -116,7 +119,7 @@ the update is skipped.
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
"the update is skipped.")
-.add_arguments(AdamWParam::__FIELDS__());
+.add_arguments(AdaBeliefParam::__FIELDS__());
template<>
void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float *pScalef) {
@@ -125,7 +128,8 @@ void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float
)
}
-std::vector<std::string> ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
+static std::vector<std::string>
+ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
std::vector<std::string> ret;
for (uint32_t i = 0; i < num_args; ++i) {
const auto idx = std::to_string(i);
@@ -137,42 +141,42 @@ std::vector<std::string> ParamToVector(uint32_t num_args, const char *pName[], s
}
inline uint32_t num_weights(const nnvm::NodeAttrs& attrs) {
- return static_cast<uint32_t>(dmlc::get<MultiAdamWParam>(attrs.parsed).num_weights);
+ return static_cast<uint32_t>(dmlc::get<MultiAdaBeliefParam>(attrs.parsed).num_weights);
}
-NNVM_REGISTER_OP(_multi_adamw_update)
-.describe(R"code(Update function for AdamW optimizer.
+NNVM_REGISTER_OP(_multi_adabelief_update)
+.describe(R"code(Update function for AdaBelief optimizer.
-AdamW is seen as a modification of Adam by decoupling the weight decay from the
-optimization steps taken w.r.t. the loss function.
+AdaBelief is seen as a modification of Adam with a different variance
+estimator.
-Adam update consists of the following steps, where g represents gradient and m, v
+Adam update consists of the following steps, where g represents gradient and m, s
are 1st and 2nd order moment estimates (mean and variance).
.. math::
- g_t = \nabla J(W_{t-1})\\
+ g_t = \nabla J(W_{t-1}) + w * wd \\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
- v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
- W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
+ s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon })
It updates the weights using::
m = beta1*m + (1-beta1)*grad
- v = beta2*v + (1-beta2)*(grad**2)
- w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+ s = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(s) + epsilon))
Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
the update is skipped.
-)code" ADD_FILELINE)
+))code" ADD_FILELINE)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
return num_weights(attrs) * 4 + 1;
})
.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
return num_weights(attrs);
})
-.set_attr_parser(ParamParser<MultiAdamWParam>)
-.set_attr<mxnet::FInferShape>("FInferShape", MP_MultiAdamW_InferShape<MultiAdamWParam, 4>)
+.set_attr_parser(ParamParser<MultiAdaBeliefParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", MP_MultiAdaBelief_InferShape<MultiAdaBeliefParam, 4>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
@@ -193,43 +197,43 @@ the update is skipped.
.set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, false>)
.add_argument("data", "NDArray-or-Symbol[]", "data")
-.add_arguments(MultiAdamWParam::__FIELDS__());
+.add_arguments(MultiAdaBeliefParam::__FIELDS__());
-NNVM_REGISTER_OP(_multi_mp_adamw_update)
-.describe(R"code(Update function for multi-precision AdamW optimizer.
+NNVM_REGISTER_OP(_multi_mp_adabelief_update)
+.describe(R"code(Update function for multi-precision AdaBelief optimizer.
-AdamW is seen as a modification of Adam by decoupling the weight decay from the
-optimization steps taken w.r.t. the loss function.
+AdaBelief is seen as a modification of Adam with a different variance
+estimator.
-Adam update consists of the following steps, where g represents gradient and m, v
+Adam update consists of the following steps, where g represents gradient and m, s
are 1st and 2nd order moment estimates (mean and variance).
.. math::
- g_t = \nabla J(W_{t-1})\\
+ g_t = \nabla J(W_{t-1}) + w * wd \\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
- v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
- W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
+ s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon })
It updates the weights using::
m = beta1*m + (1-beta1)*grad
- v = beta2*v + (1-beta2)*(grad**2)
- w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+ s = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(s) + epsilon))
Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
the update is skipped.
-)code" ADD_FILELINE)
+))code" ADD_FILELINE)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
return num_weights(attrs) * 5 + 1;
})
.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
return num_weights(attrs);
})
-.set_attr_parser(ParamParser<MultiAdamWParam>)
-.set_attr<mxnet::FInferShape>("FInferShape", MP_MultiAdamW_InferShape<MultiAdamWParam, 5>)
-.set_attr<nnvm::FInferType>("FInferType", MP_MultiAdamW_InferType<MultiAdamWParam, 5, 1>)
+.set_attr_parser(ParamParser<MultiAdaBeliefParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", MP_MultiAdaBelief_InferShape<MultiAdaBeliefParam, 5>)
+.set_attr<nnvm::FInferType>("FInferType", MP_MultiAdaBelief_InferType<MultiAdaBeliefParam, 5, 1>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "weight32_", "rescale_grad_"};
@@ -250,8 +254,8 @@ the update is skipped.
.set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, true>)
.add_argument("data", "NDArray-or-Symbol[]", "data")
-.add_arguments(MultiAdamWParam::__FIELDS__());
-
+.add_arguments(MultiAdaBeliefParam::__FIELDS__());
+} // namespace adabelief
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adabelief.cu
similarity index 67%
copy from src/operator/contrib/adamw.cu
copy to src/operator/contrib/adabelief.cu
index 2b0040e..e64dcb4 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adabelief.cu
@@ -18,16 +18,16 @@
*/
/*!
- * Copyright (c) 2018 by Contributors
- * \file adamw.cu
+ * Copyright (c) 2021 by Contributors
+ * \file adabelief.cu
* \brief Optimizer operators
- * \author Haibin Lin, Moises Hernandez, Andrei Ivanov
+ * \author khaotik
*/
-#include "./adamw-inl.h"
+#include "./adabelief-inl.h"
namespace mxnet {
namespace op {
-
+namespace adabelief {
template<>
void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float *pScalef) {
MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
@@ -39,18 +39,19 @@ void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float
*pScalef = static_cast<float>(scale);
})
}
+} // namespace adabelief
-NNVM_REGISTER_OP(_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", MPUpdate<gpu, AdamWUpdate<gpu>>);
+NNVM_REGISTER_OP(_adabelief_update)
+.set_attr<FCompute>("FCompute<gpu>", adabelief::MPUpdate<gpu, adabelief::AdaBeliefUpdate<gpu>>);
-NNVM_REGISTER_OP(_mp_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", MPUpdate<gpu, MPAdamWUpdate<gpu>>);
+NNVM_REGISTER_OP(_mp_adabelief_update)
+.set_attr<FCompute>("FCompute<gpu>", adabelief::MPUpdate<gpu, adabelief::MPAdaBeliefUpdate<gpu>>);
-NNVM_REGISTER_OP(_multi_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", multiMPUpdate<gpu, false>);
+NNVM_REGISTER_OP(_multi_adabelief_update)
+.set_attr<FCompute>("FCompute<gpu>", adabelief::multiMPUpdate<gpu, false>);
-NNVM_REGISTER_OP(_multi_mp_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", multiMPUpdate<gpu, true>);
+NNVM_REGISTER_OP(_multi_mp_adabelief_update)
+.set_attr<FCompute>("FCompute<gpu>", adabelief::multiMPUpdate<gpu, true>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h
index 6f48333..56c5ea2 100644
--- a/src/operator/contrib/adamw-inl.h
+++ b/src/operator/contrib/adamw-inl.h
@@ -32,6 +32,7 @@
namespace mxnet {
namespace op {
+namespace adamw {
struct AdamWParam : public dmlc::Parameter<AdamWParam> {
float lr;
@@ -114,7 +115,7 @@ struct MPAdamWKernel {
float var = var_data[i] = param_beta2 * var_data[i] +
(1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad);
- w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
+ w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
+ param_wd * w);
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
@@ -349,7 +350,7 @@ template<typename MPDType, bool has_mixed_precision>
struct MultiMPAdamWKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const MultiAdamKernelParam<DType, MPDType>& param,
- const OpReqType req, const float rescale_grad){
+ const OpReqType req, const float rescale_grad) {
for (int index = 0; index < param.count; ++index) {
if ((size_t)i < param.sizes[index]) {
MPDType w = has_mixed_precision ? param.weights32[index][i]:
@@ -442,7 +443,7 @@ static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs,
}
template<typename xpu>
-void GetScaleFloat(mshadow::Stream<xpu> *s, const TBlob &scale_blob, float *pScalef);
+static void GetScaleFloat(mshadow::Stream<xpu> *s, const TBlob &scale_blob, float *pScalef);
template<typename xpu>
bool PrepareInputBlobs(const OpContext &ctx,
@@ -450,7 +451,7 @@ bool PrepareInputBlobs(const OpContext &ctx,
std::vector<TBlob> *inputs_wo_scale,
float *pScalef) {
const size_t num_in = inputs.size() - 1;
- GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef);
+ adamw::GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef);
if (!std::isfinite(*pScalef) || *pScalef == 0)
return false;
@@ -494,6 +495,7 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs,
(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
}
+} // namespace adamw
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc
index effae5c..e662502 100644
--- a/src/operator/contrib/adamw.cc
+++ b/src/operator/contrib/adamw.cc
@@ -27,6 +27,7 @@
namespace mxnet {
namespace op {
+namespace adamw {
DMLC_REGISTER_PARAMETER(AdamWParam);
DMLC_REGISTER_PARAMETER(MultiAdamWParam);
@@ -65,7 +66,7 @@ the update is skipped.
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3, 4};
})
-.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, MPAdamWUpdate<cpu>>)
+.set_attr<FCompute>("FCompute<cpu>", adamw::MPUpdate<cpu, MPAdamWUpdate<cpu>>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
@@ -108,7 +109,7 @@ the update is skipped.
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
-.set_attr<FCompute>("FCompute<cpu>", MPUpdate<cpu, AdamWUpdate<cpu>>)
+.set_attr<FCompute>("FCompute<cpu>", adamw::MPUpdate<cpu, AdamWUpdate<cpu>>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
@@ -125,7 +126,8 @@ void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float
)
}
-std::vector<std::string> ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
+static std::vector<std::string>
+ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
std::vector<std::string> ret;
for (uint32_t i = 0; i < num_args; ++i) {
const auto idx = std::to_string(i);
@@ -191,7 +193,7 @@ the update is skipped.
return ret;
})
-.set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, false>)
+.set_attr<FCompute>("FCompute<cpu>", adamw::multiMPUpdate<cpu, false>)
.add_argument("data", "NDArray-or-Symbol[]", "data")
.add_arguments(MultiAdamWParam::__FIELDS__());
@@ -248,10 +250,11 @@ the update is skipped.
return ret;
})
-.set_attr<FCompute>("FCompute<cpu>", multiMPUpdate<cpu, true>)
+.set_attr<FCompute>("FCompute<cpu>", adamw::multiMPUpdate<cpu, true>)
.add_argument("data", "NDArray-or-Symbol[]", "data")
.add_arguments(MultiAdamWParam::__FIELDS__());
+} // namespace adamw
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu
index 2b0040e..95fcffb 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adamw.cu
@@ -27,6 +27,7 @@
namespace mxnet {
namespace op {
+namespace adamw {
template<>
void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float *pScalef) {
@@ -41,16 +42,17 @@ void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float
}
NNVM_REGISTER_OP(_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", MPUpdate<gpu, AdamWUpdate<gpu>>);
+.set_attr<FCompute>("FCompute<gpu>", adamw::MPUpdate<gpu, AdamWUpdate<gpu>>);
NNVM_REGISTER_OP(_mp_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", MPUpdate<gpu, MPAdamWUpdate<gpu>>);
+.set_attr<FCompute>("FCompute<gpu>", adamw::MPUpdate<gpu, MPAdamWUpdate<gpu>>);
NNVM_REGISTER_OP(_multi_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", multiMPUpdate<gpu, false>);
+.set_attr<FCompute>("FCompute<gpu>", adamw::multiMPUpdate<gpu, false>);
NNVM_REGISTER_OP(_multi_mp_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", multiMPUpdate<gpu, true>);
+.set_attr<FCompute>("FCompute<gpu>", adamw::multiMPUpdate<gpu, true>);
+} // namespace adamw
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py
index f0fbb7b..b6f624f 100644
--- a/tests/python/unittest/test_contrib_optimizer.py
+++ b/tests/python/unittest/test_contrib_optimizer.py
@@ -61,27 +61,27 @@ def test_group_adagrad():
dtype,
g_stype='row_sparse')
-
-@xfail_when_nonstandard_decimal_separator
-@pytest.mark.serial
-def test_adamw():
- def get_refs(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1):
- if clip_grad >= 0:
- grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad)
-
- mean_ref = beta1*m + (1-beta1)*grad_rescale
- v_ref = beta2*v + (1-beta2)*(grad_rescale**2)
- weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon) + weight * wd)
- return mean_ref, v_ref, weight_ref
-
- def run_adamw_test(nElem=1, aggregate=False):
- aggregate = aggregate or nElem > 1
+def _fn_noimpl(*args, **kwargs):
+ raise NotImplementedError()
+
+class _AdamLikeTestHelper:
+ fn_update = _fn_noimpl
+ fn_multi_update = _fn_noimpl
+ fn_mp_update = _fn_noimpl
+ fn_multi_mp_update = _fn_noimpl
+ @staticmethod
+ def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1):
+ '''Returns (mean_ref, v_ref, weight_ref)'''
+ raise NotImplementedError()
+ @classmethod
+ def run_test(cls, num_elem=1, aggregate=False):
+ aggregate = aggregate or num_elem > 1
rescale_factor = 10
eta, lr, wd, epsilon = 1, 1, 0.1, 1e-8
beta1, beta2 = 0.9, 0.999
clip_gradient = np.random.uniform(rescale_factor, rescale_factor)
weight, grad, m, v, etas, lrs, wds, weight_ref = [], [], [], [], [], [], [], []
- for i in range(nElem):
+ for i in range(num_elem):
shape = (np.random.randint(3, high=10), np.random.randint(3, high=10))
weight.append(mx.nd.random.uniform(shape=shape))
grad.append(mx.nd.random.uniform(-1.0, 1.0, shape=shape))
@@ -107,95 +107,130 @@ def test_adamw():
for rescaled_grad in tested_rescaled_grad:
if aggregate:
- mx.nd.contrib.multi_adamw_update(weight, grad, m, v,
- rescaled_grad, out=weight, **kwargs)
+ cls.fn_multi_update(weight, grad, m, v,
+ rescaled_grad, out=weight, **kwargs)
else:
- mx.nd.contrib.adamw_update(weight[0], grad[0], m[0], v[0],
- rescaled_grad, out=weight[0], **kwargs)
-
+ cls.fn_update(weight[0], grad[0], m[0], v[0],
+ rescaled_grad, out=weight[0], **kwargs)
# weights should remain unchanged
- for j in range(nElem):
+ for j in range(num_elem):
assert_almost_equal(weight_ref[j], weight[j])
-
# Test 2: Same as Test 1 for multi-precision update
weight_fp16, grad_fp16, weight_fp16_refs = [], [], []
- for i in range(nElem):
+ for i in range(num_elem):
weight_fp16.append(weight[i].astype('float16'))
grad_fp16.append(grad[i].astype('float16'))
weight_fp16_refs.append(weight_fp16[i].copy())
for rescaled_grad in tested_grad:
if aggregate:
- mx.nd.contrib.multi_mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
- rescaled_grad, out=weight_fp16, **kwargs)
+ cls.fn_multi_mp_update(weight_fp16, grad_fp16, m, v, weight,
+ rescaled_grad, out=weight_fp16, **kwargs)
else:
- mx.nd.contrib.mp_adamw_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0],
- rescaled_grad, out=weight_fp16[0], **kwargs)
-
+ cls.fn_mp_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0],
+ rescaled_grad, out=weight_fp16[0], **kwargs)
# weights should remain unchanged
- for i in range(nElem):
+ for i in range(num_elem):
assert_almost_equal(weight_ref[i], weight[i])
assert_almost_equal(weight_fp16_refs[i], weight_fp16[i])
-
# Test 3: Reference normal update
grad_rescale, weight_test, m_refs, v_refs, weight_refs = [], [], [], [], []
- for i in range(nElem):
+ for i in range(num_elem):
grad_rescale.append(rescale_grad * grad[i])
- m_ref, v_ref, weight_ref = get_refs(m[i], v[i], weight[i], grad_rescale[i], beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient)
+ m_ref, v_ref, weight_ref = cls.ref_impl(
+ m[i], v[i], weight[i], grad_rescale[i],
+ beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient)
m_refs.append(m_ref)
v_refs.append(v_ref)
weight_refs.append(weight_ref)
weight_test.append(weight[i].copy())
-
# op normal update
if aggregate:
- mx.nd.contrib.multi_adamw_update(weight_test, grad, m, v,
- rescale_grad, out=weight_test, **kwargs)
+ cls.fn_multi_update(weight_test, grad, m, v,
+ rescale_grad, out=weight_test, **kwargs)
else:
- mx.nd.contrib.adamw_update(weight_test[0], grad[0], m[0], v[0],
- rescale_grad, out=weight_test[0], **kwargs)
-
+ cls.fn_update(weight_test[0], grad[0], m[0], v[0],
+ rescale_grad, out=weight_test[0], **kwargs)
# Compare results
atol = 1e-4 if aggregate else 1e-5
rtol = 1e-4 if aggregate else None
- for i in range(nElem):
+ for i in range(num_elem):
assert_almost_equal(weight_refs[i], weight_test[i], rtol=rtol, atol=atol)
assert_almost_equal(m_refs[i], m[i], rtol=rtol, atol=atol)
assert_almost_equal(v_refs[i], v[i], atol=atol)
-
# Test 4: Reference normal multi-precision update
grad_rescale, m_refs, v_refs, weight_refs, weight_fp16_refs = [], [], [], [], []
- for i in range(nElem):
+ for i in range(num_elem):
grad_rescale.append(rescale_grad * grad_fp16[i].astype('float32'))
- m_ref, v_ref, weight_ref = get_refs(m[i], v[i], weight[i], grad_rescale[i], beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient)
+ m_ref, v_ref, weight_ref = cls.ref_impl(
+ m[i], v[i], weight[i], grad_rescale[i],
+ beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient)
m_refs.append(m_ref)
v_refs.append(v_ref)
weight_refs.append(weight_ref)
weight_fp16_refs.append(weight_ref.astype('float16'))
-
# op normal multi-precision update
if aggregate:
- mx.nd.contrib.multi_mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
- rescale_grad, out=weight_fp16, **kwargs)
+ cls.fn_multi_mp_update(weight_fp16, grad_fp16, m, v, weight,
+ rescale_grad, out=weight_fp16, **kwargs)
else:
- mx.nd.contrib.mp_adamw_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0],
- rescale_grad, out=weight_fp16[0], **kwargs)
-
+ cls.fn_mp_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0],
+ rescale_grad, out=weight_fp16[0], **kwargs)
# Compare results
- for i in range(nElem):
+ for i in range(num_elem):
assert_almost_equal(m_refs[i], m[i], rtol=rtol, atol=atol)
assert_almost_equal(v_refs[i], v[i], atol=atol)
assert_almost_equal(weight_refs[i], weight[i], rtol=rtol, atol=atol)
assert_almost_equal(weight_fp16_refs[i], weight_fp16[i], rtol=1e-3, atol=atol)
- # Testing aggregated Adam update for one element
- run_adamw_test(1, aggregate=True)
+ def __call__(self):
+ # Testing aggregated Adam update for one element
+ self.run_test(1, aggregate=True)
+ # Testing Adam update, if num_elem == 0, OR
+ # aggregated Adam update, if num_elem > 0
+ for num_elem in reversed(range(6)):
+ self.run_test(num_elem+1)
+
+class _AdamWTestHelper(_AdamLikeTestHelper):
+ fn_update = mx.nd.contrib.adamw_update
+ fn_multi_update = mx.nd.contrib.multi_adamw_update
+ fn_mp_update = mx.nd.contrib.mp_adamw_update
+ fn_multi_mp_update = mx.nd.contrib.multi_mp_adamw_update
+ @staticmethod
+ def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1):
+ if clip_grad >= 0:
+ grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad)
- # Testing Adam update, if nElem = 0, OR
- # aggregated Adam update, if nElem > 0
- for nElem in range(6):
- run_adamw_test(nElem+1)
+ mean_ref = beta1*m + (1.-beta1)*grad_rescale
+ v_ref = beta2*v + (1.-beta2)*(grad_rescale**2)
+ weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon) + weight * wd)
+ return mean_ref, v_ref, weight_ref
+
+class _AdaBeliefTestHelper(_AdamLikeTestHelper):
+ fn_update = mx.nd.contrib.adabelief_update
+ fn_multi_update = mx.nd.contrib.multi_adabelief_update
+ fn_mp_update = mx.nd.contrib.mp_adabelief_update
+ fn_multi_mp_update = mx.nd.contrib.multi_mp_adabelief_update
+ @staticmethod
+ def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1):
+ grad_rescale += wd * weight
+ if clip_grad >= 0:
+ grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad)
+ mean_ref = beta1*m + (1.-beta1)*grad_rescale
+ v_ref = beta2*v + (1.-beta2)*((grad_rescale-mean_ref)**2) + epsilon
+ weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon))
+ return mean_ref, v_ref, weight_ref
+
+@xfail_when_nonstandard_decimal_separator
+@pytest.mark.serial
+def test_adamw():
+ _AdamWTestHelper()()
+
+@xfail_when_nonstandard_decimal_separator
+@pytest.mark.serial
+def test_adabelief():
+ _AdaBeliefTestHelper()()
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 8927bcd..7ccb8f1 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -923,6 +923,31 @@ def test_adamW():
opt2(use_fused_step=True, **kwarg), shapes, dtype,
rtol=1e-3, atol=2e-3)
+def test_adabelief():
+ opt1 = mx.optimizer.AdaBelief
+ opt2 = mx.optimizer.AdaBelief
+ shapes = [(3, 4, 5), (10, 4), (7,)]
+ beta1_options = [{}, {'beta1': 0.5}, {'beta1': 0.7}]
+ beta2_options = [{}, {'beta2': 0.8}, {'beta2': 0.9}]
+ cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+ rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+ wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+ mp_options = [{'multi_precision': False}, {'multi_precision': True}]
+ agg_options = [{'aggregate_num': 0}, {'aggregate_num': 1},
+ {'aggregate_num': 4}, {'aggregate_num': np.inf}]
+ correct_bias_options = [{'correct_bias': True}, {'correct_bias': False}]
+ for dtype in [np.float16, np.float32]:
+ for params in itertools.product(beta1_options, beta2_options, cg_options,
+ rg_options, wd_options, mp_options,
+ agg_options, correct_bias_options):
+ kwarg = {k: v for param in params for k, v in param.items()}
+ if (dtype == np.float16 and ('multi_precision' not in kwarg or
+ not kwarg['multi_precision'])):
+ continue
+ compare_optimizer(opt1(use_fused_step=False, **kwarg),
+ opt2(use_fused_step=True, **kwarg), shapes, dtype,
+ rtol=1e-3, atol=2e-3)
+
def test_factor_scheduler():
base_lr = 1
step = 100