You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/02/21 05:32:00 UTC
[incubator-mxnet] branch master updated: fix update params (#14218)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new fdbc433 fix update params (#14218)
fdbc433 is described below
commit fdbc433ffb29c3442b523d32d59f4b44989eac7b
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Wed Feb 20 21:31:36 2019 -0800
fix update params (#14218)
---
python/mxnet/model.py | 6 ++++--
tests/python/unittest/test_module.py | 14 ++++++++++++++
2 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index c08077c..efb5109 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -181,8 +181,10 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
w, g = p
updates[k].append((index*num_device+k, g, w))
for dev_updates in updates:
- i, w, g = zip(*dev_updates)
- updater(i, w, g)
+ # update params if param_arrays and grad_arrays are not empty
+ if dev_updates:
+ i, w, g = zip(*dev_updates)
+ updater(i, w, g)
def _multiple_callbacks(callbacks, *args, **kwargs):
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index ae38a22..36c1993 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -917,6 +917,20 @@ def test_bucket_module_grad_req():
assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == 2 * batch_size)
+def test_module_update_no_pragram():
+ # test module to do update on layers without params
+ data_shape = (10, 10)
+ data = mx.sym.Variable('data')
+ out = mx.sym.Dropout(data, 0.5)
+ mod = mx.mod.Module(out)
+ mod.bind(data_shapes=[('data', data_shape)])
+ mod.init_params()
+ mod.init_optimizer()
+ data_batch = mx.io.DataBatch([nd.ones(data_shape)])
+ mod.forward_backward(data_batch)
+ mod.update()
+ assert(mod.get_outputs()[0].shape == data_shape)
+
if __name__ == '__main__':
import nose
nose.runmodule()