You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/02/21 05:32:00 UTC

[incubator-mxnet] branch master updated: fix update params (#14218)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new fdbc433  fix update params (#14218)
fdbc433 is described below

commit fdbc433ffb29c3442b523d32d59f4b44989eac7b
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Wed Feb 20 21:31:36 2019 -0800

    fix update params (#14218)
---
 python/mxnet/model.py                |  6 ++++--
 tests/python/unittest/test_module.py | 14 ++++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index c08077c..efb5109 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -181,8 +181,10 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
             w, g = p
             updates[k].append((index*num_device+k, g, w))
     for dev_updates in updates:
-        i, w, g = zip(*dev_updates)
-        updater(i, w, g)
+        # update params if param_arrays and grad_arrays are not empty
+        if dev_updates:
+            i, w, g = zip(*dev_updates)
+            updater(i, w, g)
 
 
 def _multiple_callbacks(callbacks, *args, **kwargs):
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index ae38a22..36c1993 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -917,6 +917,20 @@ def test_bucket_module_grad_req():
     assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == 2 * batch_size)
 
 
+def test_module_update_no_pragram():
+    # test module to do update on layers without params
+    data_shape = (10, 10)
+    data = mx.sym.Variable('data')
+    out = mx.sym.Dropout(data, 0.5)
+    mod = mx.mod.Module(out)
+    mod.bind(data_shapes=[('data', data_shape)])
+    mod.init_params()
+    mod.init_optimizer()
+    data_batch = mx.io.DataBatch([nd.ones(data_shape)])
+    mod.forward_backward(data_batch)
+    mod.update()
+    assert(mod.get_outputs()[0].shape == data_shape)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()