You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/04/06 00:52:34 UTC

[incubator-mxnet] branch master updated: [MXNET-241] Module API for distributed training w/ row_sparse weight (#10285)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f600940  [MXNET-241] Module API for distributed training w/ row_sparse weight (#10285)
f600940 is described below

commit f60094053e5f2c233bb39b63bdd16503875a4551
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Thu Apr 5 17:52:29 2018 -0700

    [MXNET-241] Module API for distributed training w/ row_sparse weight (#10285)
    
    * initial commit
    
    update example
    
    update MF example
    
    +deterministic
    
    update LR"
    
    revert
    
    * fix lint
    
    * fix lint 2
    
    * update doc
    
    * update doc
    
    * fix lint
    
    * replace row-id-fn with sparse_row_id_fn. and update doc
---
 CONTRIBUTORS.md                               |  1 +
 example/sparse/linear_classification/train.py | 31 +++++++------
 example/sparse/matrix_factorization/README.md | 27 +++++++++--
 example/sparse/matrix_factorization/data.py   | 14 ++----
 example/sparse/matrix_factorization/model.py  | 32 +++++++------
 example/sparse/matrix_factorization/train.py  | 58 +++++++++++++++--------
 python/mxnet/io.py                            |  2 +-
 python/mxnet/module/base_module.py            | 66 +++++++++++++++++++++++----
 python/mxnet/module/bucketing_module.py       | 25 ++++++++--
 python/mxnet/module/module.py                 | 58 +++++++++++++++++++++++
 10 files changed, 240 insertions(+), 74 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 829d836..a32e33e 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -29,6 +29,7 @@ The committers are the granted write access to the project.
 * [Chiyuan Zhang](https://github.com/pluskid)
   - Chiyuan is the creator of MXNet Julia Package.
 * [Junyuan Xie](https://github.com/piiswrong)
+* [Haibin Lin](https://github.com/eric-haibin-lin)
 * [Qiang Kou](https://github.com/thirdwing)
   - KK is a R ninja, he makes mxnet available for R users.
 * [Tong He](https://github.com/hetong007)
diff --git a/example/sparse/linear_classification/train.py b/example/sparse/linear_classification/train.py
index eb7871b..cde40dd 100644
--- a/example/sparse/linear_classification/train.py
+++ b/example/sparse/linear_classification/train.py
@@ -31,7 +31,7 @@ parser.add_argument('--batch-size', type=int, default=8192,
                     help='number of examples per batch')
 parser.add_argument('--kvstore', type=str, default=None,
                     help='what kvstore to use',
-                    choices=["dist_async", "local"])
+                    choices=["dist_sync", "dist_async", "local"])
 parser.add_argument('--optimizer', type=str, default='ftrl',
                     help='what optimizer to use',
                     choices=["ftrl", "sgd", "adam"])
@@ -44,6 +44,15 @@ AVAZU = {
     'num_features': 1000001,
 }
 
+def batch_row_ids(data_batch):
+    """ Generate row ids based on the current mini-batch """
+    return {'weight': batch.data[0].indices}
+
+def all_row_ids(data_batch):
+    """ Generate row ids for all rows """
+    all_rows = mx.nd.arange(0, AVAZU['num_features'], dtype='int64')
+    return {'weight': all_rows}
+
 if __name__ == '__main__':
     import logging
     head = '%(asctime)-15s %(message)s'
@@ -94,9 +103,6 @@ if __name__ == '__main__':
     metric = mx.metric.create(['nll_loss'])
 
     # get the sparse weight parameter
-    weight_index = mod._exec_group.param_names.index('weight')
-    weight_param = mod._exec_group.param_arrays[weight_index]
-    all_row_ids = mx.nd.arange(0, num_features, dtype='int64')
     speedometer = mx.callback.Speedometer(batch_size, 100)
 
     logging.info('Training started ...')
@@ -106,10 +112,7 @@ if __name__ == '__main__':
         for batch in train_data:
             nbatch += 1
             # for distributed training, we need to manually pull sparse weights from kvstore
-            if kv:
-                row_ids = batch.data[0].indices
-                kv.row_sparse_pull('weight', weight_param, row_ids=[row_ids],
-                                   priority=-weight_index)
+            mod.prepare(batch, sparse_row_id_fn=batch_row_ids)
             mod.forward_backward(batch)
             # update all parameters (including the weight parameter)
             mod.update()
@@ -118,15 +121,15 @@ if __name__ == '__main__':
             speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                        eval_metric=metric, locals=locals())
             speedometer(speedometer_param)
-        # pull all rows before making a checkpoint
-        if kv:
-            kv.row_sparse_pull('weight', weight_param, row_ids=[all_row_ids],
-                               priority=-weight_index)
+
+        # prepare the module weight with all row ids for inference. Alternatively, one could call
+        # score = mod.score(val_iter, ['MSE'], sparse_row_id_fn=batch_row_ids)
+        # to fetch the weight per mini-batch
+        mod.prepare(None, all_row_ids)
         # evaluate metric on validation dataset
         score = mod.score(eval_data, ['nll_loss'])
         logging.info('epoch %d, eval nll = %s ' % (epoch, score[0][1]))
-        save_optimizer_states = 'dist' not in kv.type if kv else True
-        mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=save_optimizer_states)
+        mod.save_checkpoint("checkpoint", epoch)
         # reset the iterator for next pass of data
         train_data.reset()
         eval_data.reset()
diff --git a/example/sparse/matrix_factorization/README.md b/example/sparse/matrix_factorization/README.md
index 3ada5e8..5c4beef 100644
--- a/example/sparse/matrix_factorization/README.md
+++ b/example/sparse/matrix_factorization/README.md
@@ -1,8 +1,27 @@
 Matrix Factorization w/ Sparse Embedding
 ===========
 The example demonstrates the basic usage of the SparseEmbedding operator in MXNet, adapted based on @leopd's recommender examples.
-The operator is available on both CPU and GPU. This is for demonstration purpose only.
+This is for demonstration purpose only.
 
-- `python train.py`
-- To compare the training speed with (dense) Embedding, run `python train.py --use-dense`
-- To run the example on the GPU, run `python train.py --use-gpu`
+```
+usage: train.py [-h] [--num-epoch NUM_EPOCH] [--seed SEED]
+                [--batch-size BATCH_SIZE] [--log-interval LOG_INTERVAL]
+                [--factor-size FACTOR_SIZE] [--gpus GPUS] [--dense]
+
+Run matrix factorization with sparse embedding
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --num-epoch NUM_EPOCH
+                        number of epochs to train (default: 3)
+  --seed SEED           random seed (default: 1)
+  --batch-size BATCH_SIZE
+                        number of examples per batch (default: 128)
+  --log-interval LOG_INTERVAL
+                        logging interval (default: 100)
+  --factor-size FACTOR_SIZE
+                        the factor size of the embedding operation (default: 128)
+  --gpus GPUS           list of gpus to run, e.g. 0 or 0,2. empty means using
+                        cpu(). (default: None)
+  --dense               whether to use dense embedding (default: False)
+```
diff --git a/example/sparse/matrix_factorization/data.py b/example/sparse/matrix_factorization/data.py
index c897165..049f5c2 100644
--- a/example/sparse/matrix_factorization/data.py
+++ b/example/sparse/matrix_factorization/data.py
@@ -15,9 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import os
+import os, logging
 import mxnet as mx
-from mxnet.test_utils import DummyIter
 
 def get_movielens_data(data_dir, prefix):
     if not os.path.exists(os.path.join(data_dir, "ml-10M100K")):
@@ -27,11 +26,11 @@ def get_movielens_data(data_dir, prefix):
         assert os.path.exists(os.path.join(data_dir, "ml-10M100K"))
         os.system("cd data/ml-10M100K; chmod +x allbut.pl; sh split_ratings.sh; cd -;")
 
-def get_movielens_iter(filename, batch_size, dummy_iter):
+def get_movielens_iter(filename, batch_size):
     """Not particularly fast code to parse the text file and load into NDArrays.
     return two data iters, one for train, the other for validation.
     """
-    print("Preparing data iterators for " + filename + " ... ")
+    logging.info("Preparing data iterators for " + filename + " ... ")
     user = []
     item = []
     score = []
@@ -45,18 +44,15 @@ def get_movielens_iter(filename, batch_size, dummy_iter):
             user.append((tks[0]))
             item.append((tks[1]))
             score.append((tks[2]))
-            if dummy_iter and num_samples > batch_size * 10:
-                break
     # convert to ndarrays
     user = mx.nd.array(user, dtype='int32')
     item = mx.nd.array(item)
     score = mx.nd.array(score)
     # prepare data iters
-    data_train = {'user':user, 'item':item}
-    label_train = {'score':score}
+    data_train = {'user': user, 'item': item}
+    label_train = {'score': score}
     iter_train = mx.io.NDArrayIter(data=data_train,label=label_train,
                                    batch_size=batch_size, shuffle=True)
-    iter_train = DummyIter(iter_train) if dummy_iter else iter_train
     return mx.io.PrefetchingIter(iter_train)
 
 
diff --git a/example/sparse/matrix_factorization/model.py b/example/sparse/matrix_factorization/model.py
index d2d8de5..672c392 100644
--- a/example/sparse/matrix_factorization/model.py
+++ b/example/sparse/matrix_factorization/model.py
@@ -17,34 +17,38 @@
 
 import mxnet as mx
 
-def matrix_fact_net(factor_size, num_hidden, max_user, max_item, sparse_embed=True):
+def matrix_fact_net(factor_size, num_hidden, max_user, max_item, dense):
     # input
     user = mx.symbol.Variable('user')
     item = mx.symbol.Variable('item')
     score = mx.symbol.Variable('score')
-    if sparse_embed:
+    stype = 'default' if dense else 'row_sparse'
+    user_weight = mx.symbol.Variable('user_weight', stype=stype)
+    item_weight = mx.symbol.Variable('item_weight', stype=stype)
+    if not dense:
+        embed = mx.symbol.contrib.SparseEmbedding
         # user feature lookup
-        user_weight = mx.symbol.Variable('user_weight', stype='row_sparse')
-        user = mx.symbol.contrib.SparseEmbedding(data=user, weight=user_weight,
-                                                 input_dim=max_user, output_dim=factor_size)
+        user = embed(data=user, weight=user_weight,
+                     input_dim=max_user, output_dim=factor_size, deterministic=True)
         # item feature lookup
-        item_weight = mx.symbol.Variable('item_weight', stype='row_sparse')
-        item = mx.symbol.contrib.SparseEmbedding(data=item, weight=item_weight,
-                                                 input_dim=max_item, output_dim=factor_size)
+        item = embed(data=item, weight=item_weight,
+                     input_dim=max_item, output_dim=factor_size, deterministic=True)
     else:
         # user feature lookup
-        user = mx.symbol.Embedding(data=user, input_dim=max_user, output_dim=factor_size)
+        user = mx.symbol.Embedding(data=user, weight=user_weight,
+                                   input_dim=max_user, output_dim=factor_size)
         # item feature lookup
-        item = mx.symbol.Embedding(data=item, input_dim=max_item, output_dim=factor_size)
+        item = mx.symbol.Embedding(data=item, weight=item_weight,
+                                   input_dim=max_item, output_dim=factor_size)
     # non-linear transformation of user features
     user = mx.symbol.Activation(data=user, act_type='relu')
-    user = mx.symbol.FullyConnected(data=user, num_hidden=num_hidden)
+    user_act = mx.symbol.FullyConnected(data=user, num_hidden=num_hidden)
     # non-linear transformation of item features
     item = mx.symbol.Activation(data=item, act_type='relu')
-    item = mx.symbol.FullyConnected(data=item, num_hidden=num_hidden)
+    item_act = mx.symbol.FullyConnected(data=item, num_hidden=num_hidden)
     # predict by the inner product, which is elementwise product and then sum
-    pred = user * item
-    pred = mx.symbol.sum(data=pred, axis = 1)
+    pred = user_act * item_act
+    pred = mx.symbol.sum(data=pred, axis=1)
     pred = mx.symbol.Flatten(data=pred)
     # loss layer
     pred = mx.symbol.LinearRegressionOutput(data=pred, label=score)
diff --git a/example/sparse/matrix_factorization/train.py b/example/sparse/matrix_factorization/train.py
index 0db58ad..44bab2c 100644
--- a/example/sparse/matrix_factorization/train.py
+++ b/example/sparse/matrix_factorization/train.py
@@ -29,18 +29,17 @@ parser = argparse.ArgumentParser(description="Run matrix factorization with spar
                                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 parser.add_argument('--num-epoch', type=int, default=3,
                     help='number of epochs to train')
+parser.add_argument('--seed', type=int, default=1,
+                    help='random seed')
 parser.add_argument('--batch-size', type=int, default=128,
                     help='number of examples per batch')
-parser.add_argument('--print-every', type=int, default=100,
-                    help='logging frequency')
+parser.add_argument('--log-interval', type=int, default=100,
+                    help='logging interval')
 parser.add_argument('--factor-size', type=int, default=128,
                     help="the factor size of the embedding operation")
-parser.add_argument('--use-dense', action='store_true',
-                    help="use the dense embedding operator")
-parser.add_argument('--use-gpu', action='store_true',
-                    help="use gpu")
-parser.add_argument('--dummy-iter', action='store_true',
-                    help="use the dummy data iterator for speed test")
+parser.add_argument('--gpus', type=str,
+                    help="list of gpus to run, e.g. 0 or 0,2. empty means using cpu().")
+parser.add_argument('--dense', action='store_true', help="whether to use dense embedding")
 
 MOVIELENS = {
     'dataset': 'ml-10m',
@@ -50,6 +49,19 @@ MOVIELENS = {
     'max_movie': 65135,
 }
 
+def batch_row_ids(data_batch):
+    """ Generate row ids based on the current mini-batch """
+    item = data_batch.data[0]
+    user = data_batch.data[1]
+    return {'user_weight': user.astype(np.int64),
+            'item_weight': item.astype(np.int64)}
+
+def all_row_ids(data_batch):
+    """ Generate row ids for all rows """
+    all_users = mx.nd.arange(0, MOVIELENS['max_user'], dtype='int64')
+    all_movies = mx.nd.arange(0, MOVIELENS['max_movie'], dtype='int64')
+    return {'user_weight': all_users, 'item_weight': all_movies}
+
 if __name__ == '__main__':
     head = '%(asctime)-15s %(message)s'
     logging.basicConfig(level=logging.INFO, format=head)
@@ -60,43 +72,44 @@ if __name__ == '__main__':
     num_epoch = args.num_epoch
     batch_size = args.batch_size
     optimizer = 'sgd'
-    use_sparse = not args.use_dense
     factor_size = args.factor_size
-    dummy_iter = args.dummy_iter
-    print_every = args.print_every
+    log_interval = args.log_interval
 
     momentum = 0.9
-    ctx = mx.gpu(0) if args.use_gpu else mx.cpu(0)
+    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')] if args.gpus else [mx.cpu()]
     learning_rate = 0.1
+    mx.random.seed(args.seed)
+    np.random.seed(args.seed)
 
     # prepare dataset and iterators
     max_user = MOVIELENS['max_user']
     max_movies = MOVIELENS['max_movie']
     data_dir = os.path.join(os.getcwd(), 'data')
     get_movielens_data(data_dir, MOVIELENS['dataset'])
-    train_iter = get_movielens_iter(MOVIELENS['train'], batch_size, dummy_iter)
-    val_iter = get_movielens_iter(MOVIELENS['val'], batch_size, dummy_iter)
+    train_iter = get_movielens_iter(MOVIELENS['train'], batch_size)
+    val_iter = get_movielens_iter(MOVIELENS['val'], batch_size)
 
     # construct the model
-    net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, sparse_embed=use_sparse)
+    net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, dense=args.dense)
 
     # initialize the module
-    mod = mx.module.Module(symbol=net, context=ctx, data_names=['user', 'item'],
+    mod = mx.module.Module(net, context=ctx, data_names=['user', 'item'],
                            label_names=['score'])
     mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
     mod.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
-    optim = mx.optimizer.create(optimizer, learning_rate=learning_rate, momentum=momentum,
-                                wd=1e-4, rescale_grad=1.0/batch_size)
-    mod.init_optimizer(optimizer=optim)
+    optim = mx.optimizer.create(optimizer, learning_rate=learning_rate,
+                                rescale_grad=1.0/batch_size)
+    mod.init_optimizer(optimizer=optim, kvstore='device')
     # use MSE as the metric
     metric = mx.metric.create(['MSE'])
-    speedometer = mx.callback.Speedometer(batch_size, print_every)
+    speedometer = mx.callback.Speedometer(batch_size, log_interval)
     logging.info('Training started ...')
     for epoch in range(num_epoch):
         nbatch = 0
         metric.reset()
         for batch in train_iter:
             nbatch += 1
+            mod.prepare(batch, sparse_row_id_fn=batch_row_ids)
             mod.forward_backward(batch)
             # update all parameters
             mod.update()
@@ -105,6 +118,11 @@ if __name__ == '__main__':
             speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                        eval_metric=metric, locals=locals())
             speedometer(speedometer_param)
+
+        # prepare the module weight with all row ids for inference. Alternatively, one could call
+        # score = mod.score(val_iter, ['MSE'], sparse_row_id_fn=batch_row_ids)
+        # to fetch the weight per mini-batch
+        mod.prepare(None, sparse_row_id_fn=all_row_ids)
         # evaluate metric on validation dataset
         score = mod.score(val_iter, ['MSE'])
         logging.info('epoch %d, eval MSE = %s ' % (epoch, score[0][1]))
diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index 2bace6f..884e929 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -517,7 +517,7 @@ def _init_data(data, allow_empty, default_name):
                 raise TypeError(("Invalid type '%s' for %s, "  % (type(v), k)) + \
                                 "should be NDArray, numpy.ndarray or h5py.Dataset")
 
-    return list(data.items())
+    return list(sorted(data.items()))
 
 def _has_instance(data, dtype):
     """Return True if ``data`` has instance of ``dtype``.
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index df3fcc5..c03f8e7 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -15,7 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint: disable=fixme, too-many-arguments, too-many-locals, too-many-public-methods, too-many-branches
+# pylint: disable=fixme, too-many-arguments, too-many-locals
+# pylint: disable=too-many-public-methods, too-many-branches, too-many-lines
 """`BaseModule` defines an API for modules."""
 
 import time
@@ -136,6 +137,7 @@ class BaseModule(object):
     - setup
         - `bind()`: prepare environment for computation.
         - `init_optimizer()`: install optimizer for parameter updating.
+        - `prepare()`: prepare the module based on the current data batch.
 
     - computation
         - `forward(data_batch)`: forward operation.
@@ -193,7 +195,7 @@ class BaseModule(object):
 
     def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
               score_end_callback=None,
-              reset=True, epoch=0):
+              reset=True, epoch=0, sparse_row_id_fn=None):
         """Runs prediction on ``eval_data`` and evaluates the performance according to
         the given ``eval_metric``.
 
@@ -217,6 +219,11 @@ class BaseModule(object):
         epoch : int
             Defaults to 0. For compatibility, this will be passed to callbacks (if any).
             During training, this will correspond to the training epoch number.
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
 
         Examples
         --------
@@ -240,7 +247,7 @@ class BaseModule(object):
         for nbatch, eval_batch in enumerate(eval_data):
             if num_batch is not None and nbatch == num_batch:
                 break
-
+            self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
             self.forward(eval_batch, is_train=False)
             self.update_metric(eval_metric, eval_batch.label)
 
@@ -263,7 +270,7 @@ class BaseModule(object):
 
         return eval_metric.get_name_value()
 
-    def iter_predict(self, eval_data, num_batch=None, reset=True):
+    def iter_predict(self, eval_data, num_batch=None, reset=True, sparse_row_id_fn=None):
         """Iterates over predictions.
 
         Example Usage:
@@ -282,6 +289,11 @@ class BaseModule(object):
         reset : bool
             Default is ``True``, indicating whether we should reset the data iter before start
             doing prediction.
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
         """
         assert self.binded and self.params_initialized
 
@@ -291,6 +303,7 @@ class BaseModule(object):
         for nbatch, eval_batch in enumerate(eval_data):
             if num_batch is not None and nbatch == num_batch:
                 break
+            self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
             self.forward(eval_batch, is_train=False)
             pad = eval_batch.pad
             outputs = [out[0:out.shape[0]-pad] for out in self.get_outputs()]
@@ -298,7 +311,7 @@ class BaseModule(object):
             yield (outputs, nbatch, eval_batch)
 
     def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
-                always_output_list=False):
+                always_output_list=False, sparse_row_id_fn=None):
         """Runs prediction and collects the outputs.
 
         When `merge_batches` is ``True`` (by default), the return value will be a list
@@ -327,6 +340,11 @@ class BaseModule(object):
             doing prediction.
         always_output_list : bool
             Defaults to ``False``, see above for return values.
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
 
         Returns
         -------
@@ -349,6 +367,7 @@ class BaseModule(object):
         for nbatch, eval_batch in enumerate(eval_data):
             if num_batch is not None and nbatch == num_batch:
                 break
+            self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
             self.forward(eval_batch, is_train=False)
             pad = eval_batch.pad
             outputs = [out[0:out.shape[0]-pad].copy() for out in self.get_outputs()]
@@ -380,7 +399,7 @@ class BaseModule(object):
             eval_batch_end_callback=None, initializer=Uniform(0.01),
             arg_params=None, aux_params=None, allow_missing=False,
             force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
-            validation_metric=None, monitor=None):
+            validation_metric=None, monitor=None, sparse_row_id_fn=None):
         """Trains the module parameters.
 
         Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
@@ -442,6 +461,11 @@ class BaseModule(object):
             N+1.
         num_epoch : int
             Number of epochs for training.
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
 
         Examples
         --------
@@ -489,7 +513,7 @@ class BaseModule(object):
                 try:
                     # pre fetch next batch
                     next_data_batch = next(data_iter)
-                    self.prepare(next_data_batch)
+                    self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
                 except StopIteration:
                     end_of_batch = True
 
@@ -740,16 +764,34 @@ class BaseModule(object):
     ################################################################################
     # Computations
     ################################################################################
-    def prepare(self, data_batch):
+    # pylint: disable=unused-argument
+    def prepare(self, data_batch, sparse_row_id_fn=None):
         '''Prepares the module for processing a data batch.
 
         Usually involves switching bucket and reshaping.
+        For modules that contain `row_sparse` parameters in KVStore,
+        it prepares the `row_sparse` parameters based on the sparse_row_id_fn.
+
+        When KVStore is used to update parameters for multi-device or multi-machine training,
+        a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+        the `update()` updates the copy of parameters in KVStore, but doesn't broadcast
+        the updated parameters to all devices / machines. The `prepare` function is used to
+        broadcast `row_sparse` parameters with the next batch of data.
 
         Parameters
         ----------
         data_batch : DataBatch
+            The current batch of data for forward computation.
+
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
         '''
-        pass
+        if sparse_row_id_fn is not None:
+            warnings.warn(UserWarning("sparse_row_id_fn is not invoked for BaseModule."))
+    # pylint: enable=unused-argument
 
     def forward(self, data_batch, is_train=None):
         """Forward computation. It supports data batches with different shapes, such as
@@ -877,6 +919,12 @@ class BaseModule(object):
         """Updates parameters according to the installed optimizer and the gradients computed
         in the previous forward-backward batch.
 
+        When KVStore is used to update parameters for multi-device or multi-machine training,
+        a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+        this function does update the copy of parameters in KVStore, but doesn't broadcast the
+        updated parameters to all devices / machines. Please call `prepare` to broadcast
+        `row_sparse` parameters with the next batch of data.
+
         Examples
         --------
         >>> # An example of updating module parameters.
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index 2f5cc9e..18cec29 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -412,13 +412,24 @@ class BucketingModule(BaseModule):
 
         self.optimizer_initialized = True
 
-    def prepare(self, data_batch):
-        """Prepares a data batch for forward.
+    def prepare(self, data_batch, sparse_row_id_fn=None):
+        '''Prepares the module for processing a data batch.
+
+        Usually involves switching bucket and reshaping.
+        For modules that contain `row_sparse` parameters in KVStore,
+        it prepares the `row_sparse` parameters based on the sparse_row_id_fn.
 
         Parameters
         ----------
         data_batch : DataBatch
-        """
+            The current batch of data for forward computation.
+
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
+        '''
         # perform bind if haven't done so
         assert self.binded and self.params_initialized
         bucket_key = data_batch.bucket_key
@@ -426,6 +437,7 @@ class BucketingModule(BaseModule):
         data_shapes = data_batch.provide_data
         label_shapes = data_batch.provide_label
         self.switch_bucket(bucket_key, data_shapes, label_shapes)
+        self._curr_module.prepare(data_batch, sparse_row_id_fn=sparse_row_id_fn)
         # switch back
         self.switch_bucket(original_bucket_key, None, None)
 
@@ -451,6 +463,13 @@ class BucketingModule(BaseModule):
     def update(self):
         """Updates parameters according to installed optimizer and the gradient computed
         in the previous forward-backward cycle.
+
+        When KVStore is used to update parameters for multi-device or multi-machine training,
+        a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+        this function does update the copy of parameters in KVStore, but doesn't broadcast the
+        updated parameters to all devices / machines. Please call `prepare` to broadcast
+        `row_sparse` parameters with the next batch of data.
+
         """
         assert self.binded and self.params_initialized and self.optimizer_initialized
         self._params_dirty = True
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 1cf7040..21d9b56 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -26,6 +26,7 @@ import warnings
 
 from .. import context as ctx
 from .. import optimizer as opt
+from .. import ndarray as nd
 
 from .executor_group import DataParallelExecutorGroup
 from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
@@ -630,6 +631,12 @@ class Module(BaseModule):
         """Updates parameters according to the installed optimizer and the gradients computed
         in the previous forward-backward batch.
 
+        When KVStore is used to update parameters for multi-device or multi-machine training,
+        a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+        this function does update the copy of parameters in KVStore, but doesn't broadcast the
+        updated parameters to all devices / machines. Please call `prepare` to broadcast
+        `row_sparse` parameters with the next batch of data.
+
         See Also
         ----------
         :meth:`BaseModule.update`.
@@ -752,8 +759,16 @@ class Module(BaseModule):
         """Synchronizes parameters from devices to CPU. This function should be called after
         calling `update` that updates the parameters on the devices, before one can read the
         latest parameters from ``self._arg_params`` and ``self._aux_params``.
+
+        For row_sparse parameters on devices, ther are pulled from KVStore with all row ids.
+
         """
         self._exec_group.get_params(self._arg_params, self._aux_params)
+        if self._kvstore and self._update_on_kvstore:
+            for param_name, param_val in sorted(self._arg_params.items()):
+                if param_val.stype == 'row_sparse':
+                    row_ids = nd.arange(0, param_val.shape[0], dtype='int64')
+                    self._kvstore.row_sparse_pull(param_name, param_val, row_ids=row_ids)
         self._params_dirty = False
 
     def save_optimizer_states(self, fname):
@@ -791,3 +806,46 @@ class Module(BaseModule):
         """Installs monitor on all executors. """
         assert self.binded
         self._exec_group.install_monitor(mon)
+
+    def prepare(self, data_batch, sparse_row_id_fn=None):
+        '''Prepares the module for processing a data batch.
+
+        Usually involves switching bucket and reshaping.
+        For modules that contain `row_sparse` parameters in KVStore,
+        it prepares the `row_sparse` parameters based on the sparse_row_id_fn.
+
+        When KVStore is used to update parameters for multi-device or multi-machine training,
+        a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+        the `update()` updates the copy of parameters in KVStore, but doesn't broadcast
+        the updated parameters to all devices / machines. The `prepare` function is used to
+        broadcast `row_sparse` parameters with the next batch of data.
+
+        Parameters
+        ----------
+        data_batch : DataBatch
+            The current batch of data for forward computation.
+
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
+        '''
+        assert self.binded
+        if sparse_row_id_fn is not None:
+            if not self._kvstore or not self._update_on_kvstore:
+                warnings.warn(UserWarning("Parameters are not updated in the KVStore. "
+                                          "No need to call sparse_row_id_fn."))
+            else:
+                row_ids = sparse_row_id_fn(data_batch)
+                assert(isinstance(row_ids, dict)), "Expected dict output from sparse_row_id_fn"
+                for param_name, row_id in row_ids.items():
+                    param_idx = self._exec_group.param_names.index(param_name)
+                    param_val = self._exec_group.param_arrays[param_idx]
+                    assert(isinstance(param_val, (tuple, list)))
+                    if param_val[0].stype != 'row_sparse':
+                        warnings.warn(UserWarning("%s.stype is not 'row_sparse'. No need to "
+                                                  "perform row_sparse_pull." % param_name))
+                    else:
+                        self._kvstore.row_sparse_pull(param_name, param_val, row_ids=row_id,
+                                                      priority=-param_idx)

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.