You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/19 20:41:58 UTC

[incubator-mxnet] branch master updated: Fix nadam (#9127)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9272d9c  Fix nadam (#9127)
9272d9c is described below

commit 9272d9c43e69e3253d4148181f701bbc4b4f031c
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Tue Dec 19 12:41:54 2017 -0800

    Fix nadam (#9127)
---
 python/mxnet/optimizer.py               |  2 +-
 tests/python/unittest/test_optimizer.py | 28 ++++++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 2da6452..7e8e7c2 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1098,7 +1098,7 @@ class Nadam(Optimizer):
         t = self._index_update_count[index]
 
         # preprocess grad
-        grad *= self.rescale_grad + wd * weight
+        grad = grad * self.rescale_grad + wd * weight
         if self.clip_gradient is not None:
             grad = clip(grad, -self.clip_gradient, self.clip_gradient)
 
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index ec4fbfd..6178cbe 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -18,6 +18,7 @@
 import numpy as np
 import mxnet as mx
 import mxnet.lr_scheduler as lr_scheduler
+from mxnet import gluon
 import unittest
 from nose.tools import raises
 import math
@@ -644,6 +645,33 @@ def test_ftrl():
         compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
                           np.float32, w_stype='row_sparse', g_stype='row_sparse')
 
+def test_nadam():
+
+    def get_net(num_hidden, flatten=True):
+        data = mx.symbol.Variable('data')
+        fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128, flatten=flatten)
+        act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
+        fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64, flatten=flatten)
+        act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
+        fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=num_hidden, flatten=flatten)
+        return fc3
+    np.random.seed(1234)
+    N = 20
+    data = mx.random.uniform(-1, 1, shape=(N, 10))
+    label = mx.random.uniform(-1, 1, shape=(N, 1))
+    data_iter = mx.io.NDArrayIter(data, label, batch_size=5, label_name='label', shuffle=True)
+    output = get_net(1)
+    l = mx.symbol.Variable('label')
+    Loss = gluon.loss.L1Loss()
+    loss = Loss(output, l)
+    loss = mx.sym.make_loss(loss)
+    mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
+    mod.fit(data_iter, num_epoch=30, optimizer_params={'learning_rate': 0.005, 'wd': 0.0005},
+            initializer=mx.init.Xavier(magnitude=2), eval_metric=mx.metric.Loss(),
+            optimizer='nadam')
+    assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].