You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/03/20 20:02:40 UTC

[incubator-mxnet] branch master updated: [MXNET-89] Bug fix for bucketing module (#10094)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new dd85860  [MXNET-89] Bug fix for bucketing module (#10094)
dd85860 is described below

commit dd85860be914a1e7aa10a9ebebc18546fd262425
Author: Anirudh Subramanian <an...@gmail.com>
AuthorDate: Tue Mar 20 13:02:34 2018 -0700

    [MXNET-89] Bug fix for bucketing module (#10094)
    
    * Bug fix for bucketing module
    
    * Use NameManager
---
 python/mxnet/module/bucketing_module.py | 15 ++++++++++-----
 tests/python/unittest/test_module.py    |  6 +++++-
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index d93ef3b..2f5cc9e 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -31,6 +31,7 @@ from ..initializer import Uniform
 
 from .base_module import BaseModule, _check_input_names
 from .module import Module
+from ..name import NameManager
 
 class BucketingModule(BaseModule):
     """This module helps to deal efficiently with varying-length inputs.
@@ -71,7 +72,7 @@ class BucketingModule(BaseModule):
         self._default_bucket_key = default_bucket_key
         self._sym_gen = sym_gen
 
-        symbol, data_names, label_names = sym_gen(default_bucket_key)
+        symbol, data_names, label_names = self._call_sym_gen(default_bucket_key)
         data_names = list(data_names) if data_names is not None else []
         label_names = list(label_names) if label_names is not None else []
         state_names = list(state_names) if state_names is not None else []
@@ -102,13 +103,17 @@ class BucketingModule(BaseModule):
         self._curr_module = None
         self._curr_bucket_key = None
 
+    def _call_sym_gen(self, *args, **kwargs):
+        with NameManager():
+            return self._sym_gen(*args, **kwargs)
+
     @property
     def data_names(self):
         """A list of names for data required by this module."""
         if self.binded:
             return self._curr_module.data_names
         else:
-            _, data_names, _ = self._sym_gen(self._default_bucket_key)
+            _, data_names, _ = self._call_sym_gen(self._default_bucket_key)
             return data_names
 
     @property
@@ -117,7 +122,7 @@ class BucketingModule(BaseModule):
         if self.binded:
             return self._curr_module.output_names
         else:
-            symbol, _, _ = self._sym_gen(self._default_bucket_key)
+            symbol, _, _ = self._call_sym_gen(self._default_bucket_key)
             return symbol.list_outputs()
 
     @property
@@ -327,7 +332,7 @@ class BucketingModule(BaseModule):
         self.inputs_need_grad = inputs_need_grad
         self.binded = True
 
-        symbol, data_names, label_names = self._sym_gen(self._default_bucket_key)
+        symbol, data_names, label_names = self._call_sym_gen(self._default_bucket_key)
         module = Module(symbol, data_names, label_names, logger=self.logger,
                         context=self._context, work_load_list=self._work_load_list,
                         fixed_param_names=self._fixed_param_names,
@@ -358,7 +363,7 @@ class BucketingModule(BaseModule):
         """
         assert self.binded, 'call bind before switching bucket'
         if not bucket_key in self._buckets:
-            symbol, data_names, label_names = self._sym_gen(bucket_key)
+            symbol, data_names, label_names = self._call_sym_gen(bucket_key)
             module = Module(symbol, data_names, label_names,
                             logger=self.logger, context=self._context,
                             work_load_list=self._work_load_list,
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index d6e15c2..ae95045 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -274,7 +274,7 @@ def test_module_switch_bucket():
         data = mx.sym.Variable('data')
         label = mx.sym.Variable('softmax_label')
         embed = mx.sym.Embedding(data=data, input_dim=vocab_dim,
-                                 output_dim=num_embedding, name='embed')
+                                 output_dim=num_embedding)
         stack = mx.rnn.SequentialRNNCell()
         for i in range(num_layer):
             stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i))
@@ -299,6 +299,10 @@ def test_module_switch_bucket():
         return model
     #initialize the bucketing module with the default bucket key
     bucketing_model = create_bucketing_module(default_key)
+    #check name
+    assert bucketing_model.symbol.list_arguments()[1] == "embedding0_weight",\
+        "Error in assigning names for args in BucketingModule"
+
     #switch to test_key
     bucketing_model.switch_bucket(test_key, [('data', (batch_size, test_key))],
                                   [('softmax_label', (batch_size, test_key))])

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.