You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/01/13 22:47:22 UTC

[incubator-mxnet] 02/03: Fix nadam (#9127)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.0.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit fd199b4bc81e8b0e3d75ffef36acc050564b5b3c
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Tue Dec 19 12:41:54 2017 -0800

    Fix nadam (#9127)
---
 python/mxnet/optimizer.py               | 2 +-
 tests/python/unittest/test_optimizer.py | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 2bdebff..5b5941e 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1101,7 +1101,7 @@ class Nadam(Optimizer):
         t = self._index_update_count[index]
 
         # preprocess grad
-        grad *= self.rescale_grad + wd * weight
+        grad = grad * self.rescale_grad + wd * weight
         if self.clip_gradient is not None:
             grad = clip(grad, -self.clip_gradient, self.clip_gradient)
 
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 95097b3..b2f80c0 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -18,6 +18,7 @@
 import numpy as np
 import mxnet as mx
 import mxnet.lr_scheduler as lr_scheduler
+from mxnet import gluon
 import unittest
 from nose.tools import raises
 import math
@@ -664,7 +665,7 @@ def test_nadam():
     loss = Loss(output, l)
     loss = mx.sym.make_loss(loss)
     mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
-    mod.fit(data_iter, num_epoch=60, optimizer_params={'learning_rate': 0.0005, 'wd': 0.0005},
+    mod.fit(data_iter, num_epoch=30, optimizer_params={'learning_rate': 0.005, 'wd': 0.0005},
             initializer=mx.init.Xavier(magnitude=2), eval_metric=mx.metric.Loss(),
             optimizer='nadam')
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1

-- 
To stop receiving notification emails like this one, please contact
"commits@mxnet.apache.org" <co...@mxnet.apache.org>.