You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/10 20:40:54 UTC

[incubator-mxnet] branch master updated: expose group2ctx to module (#8539)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c9bde1b  expose group2ctx to module (#8539)
c9bde1b is described below

commit c9bde1b02b89b5c4b987be1d3950d716b3a21692
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Sat Nov 11 04:40:51 2017 +0800

    expose group2ctx to module (#8539)
    
    * expose group2ctx to module
    
    * Update test_module.py
    
    * address comments
    
    * update
---
 python/mxnet/module/bucketing_module.py |  9 ++++--
 python/mxnet/module/executor_group.py   | 12 +++++--
 python/mxnet/module/module.py           |  8 +++--
 tests/python/unittest/test_module.py    | 57 +++++++++++++++++++++++++++++++++
 4 files changed, 79 insertions(+), 7 deletions(-)

diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index f3c7ecb..dd6cafb 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -52,10 +52,12 @@ class BucketingModule(BaseModule):
     state_names : list of str
         States are similar to data and label, but not provided by data iterator.
         Instead they are initialized to 0 and can be set by set_states()
+    group2ctxs : list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
     """
     def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
                  context=ctx.cpu(), work_load_list=None,
-                 fixed_param_names=None, state_names=None):
+                 fixed_param_names=None, state_names=None, group2ctxs=None):
         super(BucketingModule, self).__init__(logger=logger)
 
         assert default_bucket_key is not None
@@ -77,6 +79,7 @@ class BucketingModule(BaseModule):
         self._state_names = state_names
         self._context = context
         self._work_load_list = work_load_list
+        self._group2ctxs = group2ctxs
 
         self._buckets = {}
         self._curr_module = None
@@ -319,7 +322,7 @@ class BucketingModule(BaseModule):
         module = Module(symbol, data_names, label_names, logger=self.logger,
                         context=self._context, work_load_list=self._work_load_list,
                         fixed_param_names=self._fixed_param_names,
-                        state_names=self._state_names)
+                        state_names=self._state_names, group2ctxs=self._group2ctxs)
         module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
                     force_rebind=False, shared_module=None, grad_req=grad_req)
         self._curr_module = module
@@ -349,7 +352,7 @@ class BucketingModule(BaseModule):
                             logger=self.logger, context=self._context,
                             work_load_list=self._work_load_list,
                             fixed_param_names=self._fixed_param_names,
-                            state_names=self._state_names)
+                            state_names=self._state_names, group2ctxs=self._group2ctxs)
             module.bind(data_shapes, label_shapes, self._curr_module.for_training,
                         self._curr_module.inputs_need_grad,
                         force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 0f3c079..ea7651b 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -139,10 +139,12 @@ class DataParallelExecutorGroup(object):
         Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
         (default to 'write').
         Can be specified globally (str) or for each argument (list, dict).
+    group2ctxs : list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
     """
     def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
                  for_training, inputs_need_grad, shared_group=None, logger=logging,
-                 fixed_param_names=None, grad_req='write', state_names=None):
+                 fixed_param_names=None, grad_req='write', state_names=None, group2ctxs=None):
         self.param_names = param_names
         self.arg_names = symbol.list_arguments()
         self.aux_names = symbol.list_auxiliary_states()
@@ -150,6 +152,10 @@ class DataParallelExecutorGroup(object):
         self.symbol = symbol
         self.contexts = contexts
         self.workload = workload
+        if group2ctxs is None:
+            group2ctxs = [None] * len(self.contexts)
+        assert len(group2ctxs) == len(self.contexts)
+        self.group2ctxs = group2ctxs
 
         self.for_training = for_training
         self.inputs_need_grad = inputs_need_grad
@@ -597,9 +603,11 @@ class DataParallelExecutorGroup(object):
         if label_shapes is not None:
             input_types.update({x.name: x.dtype for x in label_shapes})
 
+        group2ctx = self.group2ctxs[i]
+
         executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
                                            type_dict=input_types, shared_arg_names=self.param_names,
-                                           shared_exec=shared_exec,
+                                           shared_exec=shared_exec, group2ctx=group2ctx,
                                            shared_buffer=shared_data_arrays, **input_shapes)
         self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
         return executor
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 4c20a6f..8301330 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -59,10 +59,12 @@ class Module(BaseModule):
     state_names : list of str
         states are similar to data and label, but not provided by data iterator.
         Instead they are initialized to 0 and can be set by `set_states()`.
+    group2ctxs : list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
     """
     def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
                  logger=logging, context=ctx.cpu(), work_load_list=None,
-                 fixed_param_names=None, state_names=None):
+                 fixed_param_names=None, state_names=None, group2ctxs=None):
         super(Module, self).__init__(logger=logger)
 
         if isinstance(context, ctx.Context):
@@ -73,6 +75,8 @@ class Module(BaseModule):
         assert len(work_load_list) == len(self._context)
         self._work_load_list = work_load_list
 
+        self._group2ctxs = group2ctxs
+
         self._symbol = symbol
 
         data_names = list(data_names) if data_names is not None else []
@@ -413,7 +417,7 @@ class Module(BaseModule):
                                                      for_training, inputs_need_grad,
                                                      shared_group, logger=self.logger,
                                                      fixed_param_names=self._fixed_param_names,
-                                                     grad_req=grad_req,
+                                                     grad_req=grad_req, group2ctxs=self._group2ctxs,
                                                      state_names=self._state_names)
         self._total_exec_bytes = self._exec_group._total_exec_bytes
         if shared_module is not None:
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 180d2ee..722ba98 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -70,6 +70,63 @@ def test_module_input_grads():
     assert np.all(c_grad == 3), c_grad
 
 
+def test_module_ctx_group():
+    with mx.AttrScope(ctx_group='dev1'):
+        a = mx.symbol.Variable('a')
+        a = a * 2
+    with mx.AttrScope(ctx_group='dev2'):
+        b = mx.symbol.Variable('b')
+        c = a + b
+    shape = (2, 5)
+    mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None,
+                         group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}])
+    mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
+    mod1.init_params()
+    mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
+    mod1.backward([mx.nd.ones(shape)])
+    mod1_input_grads = mod1.get_input_grads()
+
+    mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None)
+    mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
+    mod2.init_params()
+    mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
+    mod2.backward([mx.nd.ones(shape)])
+    mod2_input_grads = mod2.get_input_grads()
+
+    assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
+    assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())
+
+
+def test_bucket_module_ctx_group():
+    num_hidden = 10
+    batch_size = 5
+    def sym_gen(seq_len):
+        with mx.AttrScope(ctx_group='dev1'):
+            data = mx.symbol.Variable('data')
+            weight = mx.symbol.Variable('dev1_weight')
+            bias = mx.symbol.Variable('dev1_bias')
+            fc = data
+            for i in range(seq_len):
+                fc  = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
+                                               name='dev1_fc_%d' % i, num_hidden=num_hidden)
+        with mx.AttrScope(ctx_group='dev2'):
+            label = mx.symbol.Variable('label')
+            weight = mx.symbol.Variable('dev2_weight')
+            bias = mx.symbol.Variable('dev2_bias')
+            for i in range(seq_len):
+                fc  = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
+                                               name='dev2_fc_%d' % i, num_hidden=num_hidden)
+            sym = mx.symbol.SoftmaxOutput(fc, label, name='softmax')
+        
+        return sym, ('data',), ('label',)
+
+    mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)],
+                                 group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}])
+    mod.bind(data_shapes=[['data', (batch_size, num_hidden)]],
+             label_shapes=[['label', (batch_size,)]],
+             for_training=True, inputs_need_grad=True)
+    assert(mod.binded)
+
 def test_module_layout():
     sym = mx.sym.Variable('data')
     sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].