You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/02/27 22:46:50 UTC
[incubator-mxnet] branch master updated: [op] add back support for
scalar type rescale_grad argument for adamw_update/mp_adamw_update (#14221)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e3a51b5 [op] add back support for scalar type rescale_grad argument for adamw_update/mp_adamw_update (#14221)
e3a51b5 is described below
commit e3a51b5a3ed989bf1e9c9f53b56819b32957527f
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed Feb 27 14:46:33 2019 -0800
[op] add back support for scalar type rescale_grad argument for adamw_update/mp_adamw_update (#14221)
* support scalar
* remove two copies of documentation for adamw
* fix lint
---
python/mxnet/ndarray/contrib.py | 26 +++++++++++++++++++++++++
python/mxnet/ndarray/register.py | 11 +++++++++++
python/mxnet/symbol/contrib.py | 22 +++++++++++++++++++++
python/mxnet/symbol/register.py | 11 +++++++++++
src/operator/contrib/adamw.cc | 10 ++++++----
src/operator/contrib/adamw.cu | 4 ++--
tests/python/unittest/test_contrib_optimizer.py | 12 ++++++++++++
7 files changed, 90 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index 6bbee8a..74c355d 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -542,3 +542,29 @@ def isnan(data):
<NDArray 2 @cpu(0)>
"""
return data != data
+
+def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
+ epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
+ if not isinstance(rescale_grad, ndarray.NDArray):
+ rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
+ else:
+ rescale_grad = rescale_grad.as_in_context(weight.context)
+ return ndarray._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
+ rescale_grad=rescale_grad, lr=lr, eta=eta,
+ beta1=beta1, beta2=beta2, epsilon=epsilon,
+ wd=wd, clip_gradient=clip_gradient, out=out,
+ name=name, **kwargs)
+
+def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
+ beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
+ name=None, **kwargs):
+ if not isinstance(rescale_grad, ndarray.NDArray):
+ rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
+ else:
+ rescale_grad = rescale_grad.as_in_context(weight.context)
+ return ndarray._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
+ weight32=weight32,
+ rescale_grad=rescale_grad, lr=lr, eta=eta,
+ beta1=beta1, beta2=beta2, epsilon=epsilon,
+ wd=wd, clip_gradient=clip_gradient, out=out,
+ name=name, **kwargs)
diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py
index 3b19a77..05d7f17 100644
--- a/python/mxnet/ndarray/register.py
+++ b/python/mxnet/ndarray/register.py
@@ -167,3 +167,14 @@ def _make_ndarray_function(handle, name, func_name):
return ndarray_function
_init_op_module('mxnet', 'ndarray', _make_ndarray_function)
+
+# Update operator documentation with added float support
+# Note that we can only do this after the op module is initialized
+# Otherwise the backend operators cannot be found
+# pylint: disable=wrong-import-position
+from .contrib import adamw_update, mp_adamw_update
+from ._internal import _adamw_update, _mp_adamw_update
+adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : NDArray",
+ "rescale_grad : NDArray or float")
+mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : NDArray",
+ "rescale_grad : NDArray or float")
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index a83227a..d1048df 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -727,3 +727,25 @@ def cond(pred, then_func, else_func, name="cond"):
outputs = [result[i] for i in range(then_num_outputs)]
outputs, _ = _regroup(outputs, then_fmt)
return outputs
+
+def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
+ epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
+ if not isinstance(rescale_grad, Symbol):
+ rescale_grad = symbol.full(shape=(1,), val=rescale_grad)
+ return symbol._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
+ rescale_grad=rescale_grad, lr=lr, eta=eta,
+ beta1=beta1, beta2=beta2, epsilon=epsilon,
+ wd=wd, clip_gradient=clip_gradient, out=out,
+ name=name, **kwargs)
+
+def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
+ beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
+ name=None, **kwargs):
+ if not isinstance(rescale_grad, Symbol):
+ rescale_grad = symbol.full(shape=(1,), val=rescale_grad)
+ return symbol._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
+ weight32=weight32,
+ rescale_grad=rescale_grad, lr=lr, eta=eta,
+ beta1=beta1, beta2=beta2, epsilon=epsilon,
+ wd=wd, clip_gradient=clip_gradient, out=out,
+ name=name, **kwargs)
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index c147914..15c8e5e 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -208,3 +208,14 @@ def _make_symbol_function(handle, name, func_name):
return symbol_function
_init_op_module('mxnet', 'symbol', _make_symbol_function)
+
+# Update operator documentation with added float support
+# Note that we can only do this after the op module is initialized
+# Otherwise the backend operators cannot be found
+# pylint: disable=wrong-import-position
+from .contrib import adamw_update, mp_adamw_update
+from ._internal import _adamw_update, _mp_adamw_update
+adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : Symbol",
+ "rescale_grad : Symbol or float")
+mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : Symbol",
+ "rescale_grad : Symbol or float")
diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc
index 2fbc397..874cce8 100644
--- a/src/operator/contrib/adamw.cc
+++ b/src/operator/contrib/adamw.cc
@@ -50,7 +50,7 @@ inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs,
});
}
-NNVM_REGISTER_OP(_contrib_mp_adamw_update)
+NNVM_REGISTER_OP(_mp_adamw_update)
.describe(R"code(Update function for multi-precision AdamW optimizer.
AdamW is seen as a modification of Adam by decoupling the weight decay from the
@@ -91,10 +91,11 @@ the update is skipped.
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_argument("rescale_grad", "NDArray-or-Symbol",
- "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
+ "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
+ "the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());
-NNVM_REGISTER_OP(_contrib_adamw_update)
+NNVM_REGISTER_OP(_adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
@@ -132,7 +133,8 @@ the update is skipped.
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("rescale_grad", "NDArray-or-Symbol",
- "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
+ "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
+ "the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());
} // namespace op
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu
index e21b83b..1521749 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adamw.cu
@@ -50,10 +50,10 @@ inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs,
});
}
-NNVM_REGISTER_OP(_contrib_adamw_update)
+NNVM_REGISTER_OP(_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>);
-NNVM_REGISTER_OP(_contrib_mp_adamw_update)
+NNVM_REGISTER_OP(_mp_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>);
} // namespace op
diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py
index dad7bed..675cc94 100644
--- a/tests/python/unittest/test_contrib_optimizer.py
+++ b/tests/python/unittest/test_contrib_optimizer.py
@@ -107,6 +107,12 @@ def test_adamw():
kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon,
'beta1': beta1, 'beta2': beta2}
+ # update is skipped for rescale = nan scalar
+ mx.nd.contrib.adamw_update(weight, grad, m, v,
+ np.nan, out=weight, **kwargs)
+ # weight remains unchanged
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+
# update is skipped for rescale = 0
mx.nd.contrib.adamw_update(weight, grad, m, v,
rescale_grad * 0, out=weight, **kwargs)
@@ -134,6 +140,12 @@ def test_adamw():
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
+ # multi-precision update is skipped for rescale = nan scalar
+ mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
+ np.nan, out=weight_fp16, **kwargs)
+ mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+ mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
+
# multi-precision update is skipped for rescale = inf
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
rescale_grad * np.inf, out=weight_fp16, **kwargs)