You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/19 20:41:58 UTC
[incubator-mxnet] branch master updated: Fix nadam (#9127)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9272d9c Fix nadam (#9127)
9272d9c is described below
commit 9272d9c43e69e3253d4148181f701bbc4b4f031c
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Tue Dec 19 12:41:54 2017 -0800
Fix nadam (#9127)
---
python/mxnet/optimizer.py | 2 +-
tests/python/unittest/test_optimizer.py | 28 ++++++++++++++++++++++++++++
2 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 2da6452..7e8e7c2 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1098,7 +1098,7 @@ class Nadam(Optimizer):
t = self._index_update_count[index]
# preprocess grad
- grad *= self.rescale_grad + wd * weight
+ grad = grad * self.rescale_grad + wd * weight
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index ec4fbfd..6178cbe 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -18,6 +18,7 @@
import numpy as np
import mxnet as mx
import mxnet.lr_scheduler as lr_scheduler
+from mxnet import gluon
import unittest
from nose.tools import raises
import math
@@ -644,6 +645,33 @@ def test_ftrl():
compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
np.float32, w_stype='row_sparse', g_stype='row_sparse')
+def test_nadam():
+
+ def get_net(num_hidden, flatten=True):
+ data = mx.symbol.Variable('data')
+ fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128, flatten=flatten)
+ act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
+ fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64, flatten=flatten)
+ act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
+ fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=num_hidden, flatten=flatten)
+ return fc3
+ np.random.seed(1234)
+ N = 20
+ data = mx.random.uniform(-1, 1, shape=(N, 10))
+ label = mx.random.uniform(-1, 1, shape=(N, 1))
+ data_iter = mx.io.NDArrayIter(data, label, batch_size=5, label_name='label', shuffle=True)
+ output = get_net(1)
+ l = mx.symbol.Variable('label')
+ Loss = gluon.loss.L1Loss()
+ loss = Loss(output, l)
+ loss = mx.sym.make_loss(loss)
+ mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
+ mod.fit(data_iter, num_epoch=30, optimizer_params={'learning_rate': 0.005, 'wd': 0.0005},
+ initializer=mx.init.Xavier(magnitude=2), eval_metric=mx.metric.Loss(),
+ optimizer='nadam')
+ assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1
+
+
if __name__ == '__main__':
import nose
nose.runmodule()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].