You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/10 20:40:55 UTC

[GitHub] piiswrong closed pull request #8539: expose group2ctx to module

piiswrong closed pull request #8539: expose group2ctx to module
URL: https://github.com/apache/incubator-mxnet/pull/8539
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index f3c7ecbddc..dd6cafb277 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -52,10 +52,12 @@ class BucketingModule(BaseModule):
     state_names : list of str
         States are similar to data and label, but not provided by data iterator.
         Instead they are initialized to 0 and can be set by set_states()
+    group2ctxs : list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
     """
     def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
                  context=ctx.cpu(), work_load_list=None,
-                 fixed_param_names=None, state_names=None):
+                 fixed_param_names=None, state_names=None, group2ctxs=None):
         super(BucketingModule, self).__init__(logger=logger)
 
         assert default_bucket_key is not None
@@ -77,6 +79,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
         self._state_names = state_names
         self._context = context
         self._work_load_list = work_load_list
+        self._group2ctxs = group2ctxs
 
         self._buckets = {}
         self._curr_module = None
@@ -319,7 +322,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
         module = Module(symbol, data_names, label_names, logger=self.logger,
                         context=self._context, work_load_list=self._work_load_list,
                         fixed_param_names=self._fixed_param_names,
-                        state_names=self._state_names)
+                        state_names=self._state_names, group2ctxs=self._group2ctxs)
         module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
                     force_rebind=False, shared_module=None, grad_req=grad_req)
         self._curr_module = module
@@ -349,7 +352,7 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
                             logger=self.logger, context=self._context,
                             work_load_list=self._work_load_list,
                             fixed_param_names=self._fixed_param_names,
-                            state_names=self._state_names)
+                            state_names=self._state_names, group2ctxs=self._group2ctxs)
             module.bind(data_shapes, label_shapes, self._curr_module.for_training,
                         self._curr_module.inputs_need_grad,
                         force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 0f3c079f8f..ea7651b65d 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -139,10 +139,12 @@ class DataParallelExecutorGroup(object):
         Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
         (default to 'write').
         Can be specified globally (str) or for each argument (list, dict).
+    group2ctxs : list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
     """
     def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
                  for_training, inputs_need_grad, shared_group=None, logger=logging,
-                 fixed_param_names=None, grad_req='write', state_names=None):
+                 fixed_param_names=None, grad_req='write', state_names=None, group2ctxs=None):
         self.param_names = param_names
         self.arg_names = symbol.list_arguments()
         self.aux_names = symbol.list_auxiliary_states()
@@ -150,6 +152,10 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_
         self.symbol = symbol
         self.contexts = contexts
         self.workload = workload
+        if group2ctxs is None:
+            group2ctxs = [None] * len(self.contexts)
+        assert len(group2ctxs) == len(self.contexts)
+        self.group2ctxs = group2ctxs
 
         self.for_training = for_training
         self.inputs_need_grad = inputs_need_grad
@@ -597,9 +603,11 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
         if label_shapes is not None:
             input_types.update({x.name: x.dtype for x in label_shapes})
 
+        group2ctx = self.group2ctxs[i]
+
         executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
                                            type_dict=input_types, shared_arg_names=self.param_names,
-                                           shared_exec=shared_exec,
+                                           shared_exec=shared_exec, group2ctx=group2ctx,
                                            shared_buffer=shared_data_arrays, **input_shapes)
         self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
         return executor
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 4c20a6fed5..8301330313 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -59,10 +59,12 @@ class Module(BaseModule):
     state_names : list of str
         states are similar to data and label, but not provided by data iterator.
         Instead they are initialized to 0 and can be set by `set_states()`.
+    group2ctxs : list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
     """
     def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
                  logger=logging, context=ctx.cpu(), work_load_list=None,
-                 fixed_param_names=None, state_names=None):
+                 fixed_param_names=None, state_names=None, group2ctxs=None):
         super(Module, self).__init__(logger=logger)
 
         if isinstance(context, ctx.Context):
@@ -73,6 +75,8 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
         assert len(work_load_list) == len(self._context)
         self._work_load_list = work_load_list
 
+        self._group2ctxs = group2ctxs
+
         self._symbol = symbol
 
         data_names = list(data_names) if data_names is not None else []
@@ -413,7 +417,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
                                                      for_training, inputs_need_grad,
                                                      shared_group, logger=self.logger,
                                                      fixed_param_names=self._fixed_param_names,
-                                                     grad_req=grad_req,
+                                                     grad_req=grad_req, group2ctxs=self._group2ctxs,
                                                      state_names=self._state_names)
         self._total_exec_bytes = self._exec_group._total_exec_bytes
         if shared_module is not None:
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 180d2ee052..722ba9885c 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -70,6 +70,63 @@ def test_module_input_grads():
     assert np.all(c_grad == 3), c_grad
 
 
+def test_module_ctx_group():
+    with mx.AttrScope(ctx_group='dev1'):
+        a = mx.symbol.Variable('a')
+        a = a * 2
+    with mx.AttrScope(ctx_group='dev2'):
+        b = mx.symbol.Variable('b')
+        c = a + b
+    shape = (2, 5)
+    mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None,
+                         group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}])
+    mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
+    mod1.init_params()
+    mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
+    mod1.backward([mx.nd.ones(shape)])
+    mod1_input_grads = mod1.get_input_grads()
+
+    mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None)
+    mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
+    mod2.init_params()
+    mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
+    mod2.backward([mx.nd.ones(shape)])
+    mod2_input_grads = mod2.get_input_grads()
+
+    assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
+    assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())
+
+
+def test_bucket_module_ctx_group():
+    num_hidden = 10
+    batch_size = 5
+    def sym_gen(seq_len):
+        with mx.AttrScope(ctx_group='dev1'):
+            data = mx.symbol.Variable('data')
+            weight = mx.symbol.Variable('dev1_weight')
+            bias = mx.symbol.Variable('dev1_bias')
+            fc = data
+            for i in range(seq_len):
+                fc  = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
+                                               name='dev1_fc_%d' % i, num_hidden=num_hidden)
+        with mx.AttrScope(ctx_group='dev2'):
+            label = mx.symbol.Variable('label')
+            weight = mx.symbol.Variable('dev2_weight')
+            bias = mx.symbol.Variable('dev2_bias')
+            for i in range(seq_len):
+                fc  = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
+                                               name='dev2_fc_%d' % i, num_hidden=num_hidden)
+            sym = mx.symbol.SoftmaxOutput(fc, label, name='softmax')
+        
+        return sym, ('data',), ('label',)
+
+    mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)],
+                                 group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}])
+    mod.bind(data_shapes=[['data', (batch_size, num_hidden)]],
+             label_shapes=[['label', (batch_size,)]],
+             for_training=True, inputs_need_grad=True)
+    assert(mod.binded)
+
 def test_module_layout():
     sym = mx.sym.Variable('data')
     sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services