You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/01/27 05:38:06 UTC
[incubator-mxnet] branch master updated: Python BucketingModule
bind() with grad_req = 'add' (#13984)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new bc98c0d Python BucketingModule bind() with grad_req = 'add' (#13984)
bc98c0d is described below
commit bc98c0d03666bdafae7954180074125c7329a42d
Author: slyforce <mi...@gmail.com>
AuthorDate: Sun Jan 27 06:37:50 2019 +0100
Python BucketingModule bind() with grad_req = 'add' (#13984)
* remember grad_req from bind and apply it to sub-modules
* unit-test for gradient accumulation with bucketing modules
---
python/mxnet/module/bucketing_module.py | 7 ++++--
tests/python/unittest/test_module.py | 42 +++++++++++++++++++++++++++++++++
2 files changed, 47 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index 9b56861..66c6666 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -95,6 +95,7 @@ class BucketingModule(BaseModule):
self._curr_bucket_key = None
self._params_dirty = False
self._monitor = None
+ self._grad_req = None
def _reset_bind(self):
"""Internal utility function to reset binding."""
@@ -331,6 +332,7 @@ class BucketingModule(BaseModule):
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.binded = True
+ self._grad_req = grad_req
symbol, data_names, label_names = self._call_sym_gen(self._default_bucket_key)
module = Module(symbol, data_names, label_names, logger=self.logger,
@@ -340,7 +342,7 @@ class BucketingModule(BaseModule):
group2ctxs=self._group2ctxs,
compression_params=self._compression_params)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
- force_rebind=False, shared_module=None, grad_req=grad_req)
+ force_rebind=False, shared_module=None, grad_req=self._grad_req)
self._curr_module = module
self._curr_bucket_key = self._default_bucket_key
self._buckets[self._default_bucket_key] = module
@@ -373,7 +375,8 @@ class BucketingModule(BaseModule):
compression_params=self._compression_params)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
- force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
+ force_rebind=False, shared_module=self._buckets[self._default_bucket_key],
+ grad_req=self._grad_req)
if self._monitor is not None:
module.install_monitor(self._monitor)
self._buckets[bucket_key] = module
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index d9d7175..ae38a22 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -874,6 +874,48 @@ def test_reference_single_batch_during_fit():
train_data = MockTrainData(batches=2)
mod.fit(train_data, num_epoch=1)
+@with_seed()
+def test_bucket_module_grad_req():
+ batch_size = 2
+ def sym_gen(_):
+ data = mx.symbol.Variable('data')
+ weight = mx.symbol.Variable('a', shape=(1,), init=mx.init.One())
+ sym = mx.sym.make_loss(mx.sym.broadcast_mul(data, weight))
+ return sym, ('data',), None
+
+ mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10)
+ mod.bind(data_shapes=[['data', (batch_size, )]], for_training=True, grad_req='write')
+ mod.init_params()
+
+ mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
+ label=None,
+ provide_data=[mx.io.DataDesc(name='data', shape=(batch_size, ), layout='N')],
+ bucket_key=10))
+ assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)
+
+ mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
+ label=None,
+ provide_data=[mx.io.DataDesc(name='data', shape=(batch_size, ), layout='N')],
+ bucket_key=5))
+ assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)
+
+ mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10)
+ mod.bind(data_shapes=[['data', (batch_size, )]], for_training=True, grad_req='add')
+ mod.init_params()
+
+ mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
+ label=None,
+ provide_data=[mx.io.DataDesc(name='data', shape=(batch_size,), layout='N')],
+ bucket_key=10))
+ assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)
+
+ mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
+ label=None,
+ provide_data=[mx.io.DataDesc(name='data', shape=(batch_size,), layout='N')],
+ bucket_key=5))
+ assert mod._curr_module._grad_req == 'add'
+ assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == 2 * batch_size)
+
if __name__ == '__main__':
import nose