You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/02/27 22:46:50 UTC

[incubator-mxnet] branch master updated: [op] add back support for scalar type rescale_grad argument for adamw_update/mp_adamw_update (#14221)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e3a51b5  [op] add back support for scalar type rescale_grad argument for adamw_update/mp_adamw_update (#14221)
e3a51b5 is described below

commit e3a51b5a3ed989bf1e9c9f53b56819b32957527f
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed Feb 27 14:46:33 2019 -0800

    [op] add back support for scalar type rescale_grad argument for adamw_update/mp_adamw_update (#14221)
    
    * support scalar
    
    * remove two copies of documentation for adamw
    
    * fix lint
---
 python/mxnet/ndarray/contrib.py                 | 26 +++++++++++++++++++++++++
 python/mxnet/ndarray/register.py                | 11 +++++++++++
 python/mxnet/symbol/contrib.py                  | 22 +++++++++++++++++++++
 python/mxnet/symbol/register.py                 | 11 +++++++++++
 src/operator/contrib/adamw.cc                   | 10 ++++++----
 src/operator/contrib/adamw.cu                   |  4 ++--
 tests/python/unittest/test_contrib_optimizer.py | 12 ++++++++++++
 7 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index 6bbee8a..74c355d 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -542,3 +542,29 @@ def isnan(data):
     <NDArray 2 @cpu(0)>
     """
     return data != data
+
+def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
+                 epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
+    if not isinstance(rescale_grad, ndarray.NDArray):
+        rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
+    else:
+        rescale_grad = rescale_grad.as_in_context(weight.context)
+    return ndarray._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
+                                           rescale_grad=rescale_grad, lr=lr, eta=eta,
+                                           beta1=beta1, beta2=beta2, epsilon=epsilon,
+                                           wd=wd, clip_gradient=clip_gradient, out=out,
+                                           name=name, **kwargs)
+
+def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
+                    beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
+                    name=None, **kwargs):
+    if not isinstance(rescale_grad, ndarray.NDArray):
+        rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
+    else:
+        rescale_grad = rescale_grad.as_in_context(weight.context)
+    return ndarray._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
+                                              weight32=weight32,
+                                              rescale_grad=rescale_grad, lr=lr, eta=eta,
+                                              beta1=beta1, beta2=beta2, epsilon=epsilon,
+                                              wd=wd, clip_gradient=clip_gradient, out=out,
+                                              name=name, **kwargs)
diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py
index 3b19a77..05d7f17 100644
--- a/python/mxnet/ndarray/register.py
+++ b/python/mxnet/ndarray/register.py
@@ -167,3 +167,14 @@ def _make_ndarray_function(handle, name, func_name):
     return ndarray_function
 
 _init_op_module('mxnet', 'ndarray', _make_ndarray_function)
+
+# Update operator documentation with added float support
+# Note that we can only do this after the op module is initialized
+# Otherwise the backend operators cannot be found
+# pylint: disable=wrong-import-position
+from .contrib import adamw_update, mp_adamw_update
+from ._internal import _adamw_update, _mp_adamw_update
+adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : NDArray",
+                                                     "rescale_grad : NDArray or float")
+mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : NDArray",
+                                                           "rescale_grad : NDArray or float")
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index a83227a..d1048df 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -727,3 +727,25 @@ def cond(pred, then_func, else_func, name="cond"):
     outputs = [result[i] for i in range(then_num_outputs)]
     outputs, _ = _regroup(outputs, then_fmt)
     return outputs
+
+def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
+                 epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
+    if not isinstance(rescale_grad, Symbol):
+        rescale_grad = symbol.full(shape=(1,), val=rescale_grad)
+    return symbol._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
+                                          rescale_grad=rescale_grad, lr=lr, eta=eta,
+                                          beta1=beta1, beta2=beta2, epsilon=epsilon,
+                                          wd=wd, clip_gradient=clip_gradient, out=out,
+                                          name=name, **kwargs)
+
+def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
+                    beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
+                    name=None, **kwargs):
+    if not isinstance(rescale_grad, Symbol):
+        rescale_grad = symbol.full(shape=(1,), val=rescale_grad)
+    return symbol._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
+                                             weight32=weight32,
+                                             rescale_grad=rescale_grad, lr=lr, eta=eta,
+                                             beta1=beta1, beta2=beta2, epsilon=epsilon,
+                                             wd=wd, clip_gradient=clip_gradient, out=out,
+                                             name=name, **kwargs)
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index c147914..15c8e5e 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -208,3 +208,14 @@ def _make_symbol_function(handle, name, func_name):
     return symbol_function
 
 _init_op_module('mxnet', 'symbol', _make_symbol_function)
+
+# Update operator documentation with added float support
+# Note that we can only do this after the op module is initialized
+# Otherwise the backend operators cannot be found
+# pylint: disable=wrong-import-position
+from .contrib import adamw_update, mp_adamw_update
+from ._internal import _adamw_update, _mp_adamw_update
+adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : Symbol",
+                                                     "rescale_grad : Symbol or float")
+mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : Symbol",
+                                                           "rescale_grad : Symbol or float")
diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc
index 2fbc397..874cce8 100644
--- a/src/operator/contrib/adamw.cc
+++ b/src/operator/contrib/adamw.cc
@@ -50,7 +50,7 @@ inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs,
   });
 }
 
-NNVM_REGISTER_OP(_contrib_mp_adamw_update)
+NNVM_REGISTER_OP(_mp_adamw_update)
 .describe(R"code(Update function for multi-precision AdamW optimizer.
 
 AdamW is seen as a modification of Adam by decoupling the weight decay from the
@@ -91,10 +91,11 @@ the update is skipped.
 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
 .add_argument("rescale_grad", "NDArray-or-Symbol",
-              "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
+              "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
+              "the update is skipped.")
 .add_arguments(AdamWParam::__FIELDS__());
 
-NNVM_REGISTER_OP(_contrib_adamw_update)
+NNVM_REGISTER_OP(_adamw_update)
 .describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
 Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
 
@@ -132,7 +133,8 @@ the update is skipped.
 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
 .add_argument("rescale_grad", "NDArray-or-Symbol",
-              "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
+              "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
+              "the update is skipped.")
 .add_arguments(AdamWParam::__FIELDS__());
 
 }  // namespace op
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu
index e21b83b..1521749 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adamw.cu
@@ -50,10 +50,10 @@ inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs,
   });
 }
 
-NNVM_REGISTER_OP(_contrib_adamw_update)
+NNVM_REGISTER_OP(_adamw_update)
 .set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>);
 
-NNVM_REGISTER_OP(_contrib_mp_adamw_update)
+NNVM_REGISTER_OP(_mp_adamw_update)
 .set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>);
 
 }  // namespace op
diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py
index dad7bed..675cc94 100644
--- a/tests/python/unittest/test_contrib_optimizer.py
+++ b/tests/python/unittest/test_contrib_optimizer.py
@@ -107,6 +107,12 @@ def test_adamw():
     kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon,
               'beta1': beta1, 'beta2': beta2}
 
+    # update is skipped for rescale = nan scalar
+    mx.nd.contrib.adamw_update(weight, grad, m, v,
+                               np.nan, out=weight, **kwargs)
+    # weight remains unchanged
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+
     # update is skipped for rescale = 0
     mx.nd.contrib.adamw_update(weight, grad, m, v,
                                rescale_grad * 0, out=weight, **kwargs)
@@ -134,6 +140,12 @@ def test_adamw():
     mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
     mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
 
+    # multi-precision update is skipped for rescale = nan scalar
+    mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
+                                  np.nan, out=weight_fp16, **kwargs)
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+    mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
+
     # multi-precision update is skipped for rescale = inf
     mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
                                   rescale_grad * np.inf, out=weight_fp16, **kwargs)