You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/03/20 20:02:40 UTC
[incubator-mxnet] branch master updated: [MXNET-89] Bug fix for
bucketing module (#10094)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new dd85860 [MXNET-89] Bug fix for bucketing module (#10094)
dd85860 is described below
commit dd85860be914a1e7aa10a9ebebc18546fd262425
Author: Anirudh Subramanian <an...@gmail.com>
AuthorDate: Tue Mar 20 13:02:34 2018 -0700
[MXNET-89] Bug fix for bucketing module (#10094)
* Bug fix for bucketing module
* Use NameManager
---
python/mxnet/module/bucketing_module.py | 15 ++++++++++-----
tests/python/unittest/test_module.py | 6 +++++-
2 files changed, 15 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index d93ef3b..2f5cc9e 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -31,6 +31,7 @@ from ..initializer import Uniform
from .base_module import BaseModule, _check_input_names
from .module import Module
+from ..name import NameManager
class BucketingModule(BaseModule):
"""This module helps to deal efficiently with varying-length inputs.
@@ -71,7 +72,7 @@ class BucketingModule(BaseModule):
self._default_bucket_key = default_bucket_key
self._sym_gen = sym_gen
- symbol, data_names, label_names = sym_gen(default_bucket_key)
+ symbol, data_names, label_names = self._call_sym_gen(default_bucket_key)
data_names = list(data_names) if data_names is not None else []
label_names = list(label_names) if label_names is not None else []
state_names = list(state_names) if state_names is not None else []
@@ -102,13 +103,17 @@ class BucketingModule(BaseModule):
self._curr_module = None
self._curr_bucket_key = None
+ def _call_sym_gen(self, *args, **kwargs):
+ with NameManager():
+ return self._sym_gen(*args, **kwargs)
+
@property
def data_names(self):
"""A list of names for data required by this module."""
if self.binded:
return self._curr_module.data_names
else:
- _, data_names, _ = self._sym_gen(self._default_bucket_key)
+ _, data_names, _ = self._call_sym_gen(self._default_bucket_key)
return data_names
@property
@@ -117,7 +122,7 @@ class BucketingModule(BaseModule):
if self.binded:
return self._curr_module.output_names
else:
- symbol, _, _ = self._sym_gen(self._default_bucket_key)
+ symbol, _, _ = self._call_sym_gen(self._default_bucket_key)
return symbol.list_outputs()
@property
@@ -327,7 +332,7 @@ class BucketingModule(BaseModule):
self.inputs_need_grad = inputs_need_grad
self.binded = True
- symbol, data_names, label_names = self._sym_gen(self._default_bucket_key)
+ symbol, data_names, label_names = self._call_sym_gen(self._default_bucket_key)
module = Module(symbol, data_names, label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
@@ -358,7 +363,7 @@ class BucketingModule(BaseModule):
"""
assert self.binded, 'call bind before switching bucket'
if not bucket_key in self._buckets:
- symbol, data_names, label_names = self._sym_gen(bucket_key)
+ symbol, data_names, label_names = self._call_sym_gen(bucket_key)
module = Module(symbol, data_names, label_names,
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index d6e15c2..ae95045 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -274,7 +274,7 @@ def test_module_switch_bucket():
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=vocab_dim,
- output_dim=num_embedding, name='embed')
+ output_dim=num_embedding)
stack = mx.rnn.SequentialRNNCell()
for i in range(num_layer):
stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i))
@@ -299,6 +299,10 @@ def test_module_switch_bucket():
return model
#initialize the bucketing module with the default bucket key
bucketing_model = create_bucketing_module(default_key)
+ #check name
+ assert bucketing_model.symbol.list_arguments()[1] == "embedding0_weight",\
+ "Error in assigning names for args in BucketingModule"
+
#switch to test_key
bucketing_model.switch_bucket(test_key, [('data', (batch_size, test_key))],
[('softmax_label', (batch_size, test_key))])
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.