You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/10 20:40:54 UTC
[incubator-mxnet] branch master updated: expose group2ctx to module
(#8539)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c9bde1b expose group2ctx to module (#8539)
c9bde1b is described below
commit c9bde1b02b89b5c4b987be1d3950d716b3a21692
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Sat Nov 11 04:40:51 2017 +0800
expose group2ctx to module (#8539)
* expose group2ctx to module
* Update test_module.py
* address comments
* update
---
python/mxnet/module/bucketing_module.py | 9 ++++--
python/mxnet/module/executor_group.py | 12 +++++--
python/mxnet/module/module.py | 8 +++--
tests/python/unittest/test_module.py | 57 +++++++++++++++++++++++++++++++++
4 files changed, 79 insertions(+), 7 deletions(-)
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index f3c7ecb..dd6cafb 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -52,10 +52,12 @@ class BucketingModule(BaseModule):
state_names : list of str
States are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
+ group2ctxs : list of dict of str to context
+ Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
context=ctx.cpu(), work_load_list=None,
- fixed_param_names=None, state_names=None):
+ fixed_param_names=None, state_names=None, group2ctxs=None):
super(BucketingModule, self).__init__(logger=logger)
assert default_bucket_key is not None
@@ -77,6 +79,7 @@ class BucketingModule(BaseModule):
self._state_names = state_names
self._context = context
self._work_load_list = work_load_list
+ self._group2ctxs = group2ctxs
self._buckets = {}
self._curr_module = None
@@ -319,7 +322,7 @@ class BucketingModule(BaseModule):
module = Module(symbol, data_names, label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
- state_names=self._state_names)
+ state_names=self._state_names, group2ctxs=self._group2ctxs)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None, grad_req=grad_req)
self._curr_module = module
@@ -349,7 +352,7 @@ class BucketingModule(BaseModule):
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
- state_names=self._state_names)
+ state_names=self._state_names, group2ctxs=self._group2ctxs)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 0f3c079..ea7651b 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -139,10 +139,12 @@ class DataParallelExecutorGroup(object):
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(default to 'write').
Can be specified globally (str) or for each argument (list, dict).
+ group2ctxs : list of dict of str to context
+ Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
for_training, inputs_need_grad, shared_group=None, logger=logging,
- fixed_param_names=None, grad_req='write', state_names=None):
+ fixed_param_names=None, grad_req='write', state_names=None, group2ctxs=None):
self.param_names = param_names
self.arg_names = symbol.list_arguments()
self.aux_names = symbol.list_auxiliary_states()
@@ -150,6 +152,10 @@ class DataParallelExecutorGroup(object):
self.symbol = symbol
self.contexts = contexts
self.workload = workload
+ if group2ctxs is None:
+ group2ctxs = [None] * len(self.contexts)
+ assert len(group2ctxs) == len(self.contexts)
+ self.group2ctxs = group2ctxs
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
@@ -597,9 +603,11 @@ class DataParallelExecutorGroup(object):
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})
+ group2ctx = self.group2ctxs[i]
+
executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, shared_arg_names=self.param_names,
- shared_exec=shared_exec,
+ shared_exec=shared_exec, group2ctx=group2ctx,
shared_buffer=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 4c20a6f..8301330 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -59,10 +59,12 @@ class Module(BaseModule):
state_names : list of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by `set_states()`.
+ group2ctxs : list of dict of str to context
+ Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
- fixed_param_names=None, state_names=None):
+ fixed_param_names=None, state_names=None, group2ctxs=None):
super(Module, self).__init__(logger=logger)
if isinstance(context, ctx.Context):
@@ -73,6 +75,8 @@ class Module(BaseModule):
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list
+ self._group2ctxs = group2ctxs
+
self._symbol = symbol
data_names = list(data_names) if data_names is not None else []
@@ -413,7 +417,7 @@ class Module(BaseModule):
for_training, inputs_need_grad,
shared_group, logger=self.logger,
fixed_param_names=self._fixed_param_names,
- grad_req=grad_req,
+ grad_req=grad_req, group2ctxs=self._group2ctxs,
state_names=self._state_names)
self._total_exec_bytes = self._exec_group._total_exec_bytes
if shared_module is not None:
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 180d2ee..722ba98 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -70,6 +70,63 @@ def test_module_input_grads():
assert np.all(c_grad == 3), c_grad
+def test_module_ctx_group():
+ with mx.AttrScope(ctx_group='dev1'):
+ a = mx.symbol.Variable('a')
+ a = a * 2
+ with mx.AttrScope(ctx_group='dev2'):
+ b = mx.symbol.Variable('b')
+ c = a + b
+ shape = (2, 5)
+ mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None,
+ group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}])
+ mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
+ mod1.init_params()
+ mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
+ mod1.backward([mx.nd.ones(shape)])
+ mod1_input_grads = mod1.get_input_grads()
+
+ mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None)
+ mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
+ mod2.init_params()
+ mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
+ mod2.backward([mx.nd.ones(shape)])
+ mod2_input_grads = mod2.get_input_grads()
+
+ assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
+ assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())
+
+
+def test_bucket_module_ctx_group():
+ num_hidden = 10
+ batch_size = 5
+ def sym_gen(seq_len):
+ with mx.AttrScope(ctx_group='dev1'):
+ data = mx.symbol.Variable('data')
+ weight = mx.symbol.Variable('dev1_weight')
+ bias = mx.symbol.Variable('dev1_bias')
+ fc = data
+ for i in range(seq_len):
+ fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
+ name='dev1_fc_%d' % i, num_hidden=num_hidden)
+ with mx.AttrScope(ctx_group='dev2'):
+ label = mx.symbol.Variable('label')
+ weight = mx.symbol.Variable('dev2_weight')
+ bias = mx.symbol.Variable('dev2_bias')
+ for i in range(seq_len):
+ fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
+ name='dev2_fc_%d' % i, num_hidden=num_hidden)
+ sym = mx.symbol.SoftmaxOutput(fc, label, name='softmax')
+
+ return sym, ('data',), ('label',)
+
+ mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)],
+ group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}])
+ mod.bind(data_shapes=[['data', (batch_size, num_hidden)]],
+ label_shapes=[['label', (batch_size,)]],
+ for_training=True, inputs_need_grad=True)
+ assert(mod.binded)
+
def test_module_layout():
sym = mx.sym.Variable('data')
sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].