You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/01/27 05:38:06 UTC

[incubator-mxnet] branch master updated: Python BucketingModule bind() with grad_req = 'add' (#13984)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new bc98c0d  Python BucketingModule bind() with grad_req = 'add' (#13984)
bc98c0d is described below

commit bc98c0d03666bdafae7954180074125c7329a42d
Author: slyforce <mi...@gmail.com>
AuthorDate: Sun Jan 27 06:37:50 2019 +0100

    Python BucketingModule bind() with grad_req = 'add' (#13984)
    
    * remember grad_req from bind and apply it to sub-modules
    
    * unit-test for gradient accumulation with bucketing modules
---
 python/mxnet/module/bucketing_module.py |  7 ++++--
 tests/python/unittest/test_module.py    | 42 +++++++++++++++++++++++++++++++++
 2 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index 9b56861..66c6666 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -95,6 +95,7 @@ class BucketingModule(BaseModule):
         self._curr_bucket_key = None
         self._params_dirty = False
         self._monitor = None
+        self._grad_req = None
 
     def _reset_bind(self):
         """Internal utility function to reset binding."""
@@ -331,6 +332,7 @@ class BucketingModule(BaseModule):
         self.for_training = for_training
         self.inputs_need_grad = inputs_need_grad
         self.binded = True
+        self._grad_req = grad_req
 
         symbol, data_names, label_names = self._call_sym_gen(self._default_bucket_key)
         module = Module(symbol, data_names, label_names, logger=self.logger,
@@ -340,7 +342,7 @@ class BucketingModule(BaseModule):
                         group2ctxs=self._group2ctxs,
                         compression_params=self._compression_params)
         module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
-                    force_rebind=False, shared_module=None, grad_req=grad_req)
+                    force_rebind=False, shared_module=None, grad_req=self._grad_req)
         self._curr_module = module
         self._curr_bucket_key = self._default_bucket_key
         self._buckets[self._default_bucket_key] = module
@@ -373,7 +375,8 @@ class BucketingModule(BaseModule):
                             compression_params=self._compression_params)
             module.bind(data_shapes, label_shapes, self._curr_module.for_training,
                         self._curr_module.inputs_need_grad,
-                        force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
+                        force_rebind=False, shared_module=self._buckets[self._default_bucket_key],
+                        grad_req=self._grad_req)
             if self._monitor is not None:
                 module.install_monitor(self._monitor)
             self._buckets[bucket_key] = module
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index d9d7175..ae38a22 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -874,6 +874,48 @@ def test_reference_single_batch_during_fit():
     train_data = MockTrainData(batches=2)
     mod.fit(train_data, num_epoch=1)
 
+@with_seed()
+def test_bucket_module_grad_req():
+    batch_size = 2
+    def sym_gen(_):
+        data = mx.symbol.Variable('data')
+        weight = mx.symbol.Variable('a', shape=(1,), init=mx.init.One())
+        sym = mx.sym.make_loss(mx.sym.broadcast_mul(data, weight))
+        return sym, ('data',), None
+
+    mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10)
+    mod.bind(data_shapes=[['data', (batch_size, )]], for_training=True, grad_req='write')
+    mod.init_params()
+
+    mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
+                                         label=None,
+                                         provide_data=[mx.io.DataDesc(name='data', shape=(batch_size, ), layout='N')],
+                                         bucket_key=10))
+    assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)
+
+    mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
+                                         label=None,
+                                         provide_data=[mx.io.DataDesc(name='data', shape=(batch_size, ), layout='N')],
+                                         bucket_key=5))
+    assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)
+
+    mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10)
+    mod.bind(data_shapes=[['data', (batch_size, )]], for_training=True, grad_req='add')
+    mod.init_params()
+
+    mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
+                                         label=None,
+                                         provide_data=[mx.io.DataDesc(name='data', shape=(batch_size,), layout='N')],
+                                         bucket_key=10))
+    assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)
+
+    mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
+                                         label=None,
+                                         provide_data=[mx.io.DataDesc(name='data', shape=(batch_size,), layout='N')],
+                                         bucket_key=5))
+    assert mod._curr_module._grad_req == 'add'
+    assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == 2 * batch_size)
+
 
 if __name__ == '__main__':
     import nose