You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/25 18:39:52 UTC

[GitHub] leezu closed pull request #11223: Allow specifying AdaGrad initial accumulator value

leezu closed pull request #11223: Allow specifying AdaGrad initial accumulator value
URL: https://github.com/apache/incubator-mxnet/pull/11223
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 0c3fc904fb1..e7727b7e586 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1091,14 +1091,20 @@ class AdaGrad(Optimizer):
     ----------
     eps: float, optional
         Small value to avoid division by 0.
+    initial_accumulator_value: float, default 0
+        The Adagrad state is initially set to this value.
 
     """
-    def __init__(self, eps=1e-7, **kwargs):
+    def __init__(self, eps=1e-7, initial_accumulator_value=0, **kwargs):
         super(AdaGrad, self).__init__(**kwargs)
         self.float_stable_eps = eps
+        self.initial_accumulator_value = initial_accumulator_value
 
     def create_state(self, index, weight):
-        return zeros(weight.shape, weight.context, stype=weight.stype)  # history
+        history = zeros(weight.shape, weight.context, stype=weight.stype)
+        if self.initial_accumulator_value:
+            history[:] = self.initial_accumulator_value
+        return history
 
     def update(self, index, weight, grad, state):
         assert(isinstance(weight, NDArray))
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index fba10fb522a..cd516738130 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import itertools
 import numpy as np
 import mxnet as mx
 import mxnet.lr_scheduler as lr_scheduler
@@ -991,12 +992,16 @@ class PyAdaGrad(mx.optimizer.Optimizer):
         Small value to avoid division by 0.
 
     """
-    def __init__(self, eps=1e-7, **kwargs):
+    def __init__(self, eps=1e-7, initial_accumulator_value=0, **kwargs):
         super(PyAdaGrad, self).__init__(**kwargs)
         self.float_stable_eps = eps
+        self.initial_accumulator_value = initial_accumulator_value
 
     def create_state(self, index, weight):
-        return mx.nd.zeros(weight.shape, weight.context, stype=weight.stype)
+        history = mx.nd.zeros(weight.shape, weight.context, stype=weight.stype)
+        if self.initial_accumulator_value:
+            history[:] = self.initial_accumulator_value
+        return history
 
     def update(self, index, weight, grad, state):
         self._update_count(index)
@@ -1020,21 +1025,21 @@ def test_adagrad():
     cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
     rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
     wd_options = [{}, {'wd': 0.0}]
-    for dtype in [np.float32]:
-        for eps_option in eps_options:
-            for cg_option in cg_options:
-                for rg_option in rg_options:
-                    for wd_option in wd_options:
-                        kwarg = {}
-                        kwarg.update(eps_option)
-                        kwarg.update(cg_option)
-                        kwarg.update(rg_option)
-                        kwarg.update(wd_option)
-                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
-                        if wd_option.get('wd', 0.0) == 0.0:
-                            compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
-                                              w_stype='row_sparse', g_stype='row_sparse')
+    acc_options = [{}, {'initial_accumulator_value': 1.0}]
 
+    for dtype in [np.float32]:
+        for eps_option, cg_option, rg_option, wd_option, acc_option in itertools.product(
+                eps_options, cg_options, rg_options, wd_options, acc_options):
+            kwarg = {}
+            kwarg.update(eps_option)
+            kwarg.update(cg_option)
+            kwarg.update(rg_option)
+            kwarg.update(wd_option)
+            kwarg.update(acc_option)
+            compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
+            if wd_option.get('wd', 0.0) == 0.0:
+                compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+                                  w_stype='row_sparse', g_stype='row_sparse')
 
 
 if __name__ == '__main__':


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services