You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/04/06 00:52:34 UTC
[incubator-mxnet] branch master updated: [MXNET-241] Module API for
distributed training w/ row_sparse weight (#10285)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f600940 [MXNET-241] Module API for distributed training w/ row_sparse weight (#10285)
f600940 is described below
commit f60094053e5f2c233bb39b63bdd16503875a4551
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Thu Apr 5 17:52:29 2018 -0700
[MXNET-241] Module API for distributed training w/ row_sparse weight (#10285)
* initial commit
update example
update MF example
+deterministic
update LR"
revert
* fix lint
* fix lint 2
* update doc
* update doc
* fix lint
* replace row-id-fn with sparse_row_id_fn. and update doc
---
CONTRIBUTORS.md | 1 +
example/sparse/linear_classification/train.py | 31 +++++++------
example/sparse/matrix_factorization/README.md | 27 +++++++++--
example/sparse/matrix_factorization/data.py | 14 ++----
example/sparse/matrix_factorization/model.py | 32 +++++++------
example/sparse/matrix_factorization/train.py | 58 +++++++++++++++--------
python/mxnet/io.py | 2 +-
python/mxnet/module/base_module.py | 66 +++++++++++++++++++++++----
python/mxnet/module/bucketing_module.py | 25 ++++++++--
python/mxnet/module/module.py | 58 +++++++++++++++++++++++
10 files changed, 240 insertions(+), 74 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 829d836..a32e33e 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -29,6 +29,7 @@ The committers are the granted write access to the project.
* [Chiyuan Zhang](https://github.com/pluskid)
- Chiyuan is the creator of MXNet Julia Package.
* [Junyuan Xie](https://github.com/piiswrong)
+* [Haibin Lin](https://github.com/eric-haibin-lin)
* [Qiang Kou](https://github.com/thirdwing)
- KK is a R ninja, he makes mxnet available for R users.
* [Tong He](https://github.com/hetong007)
diff --git a/example/sparse/linear_classification/train.py b/example/sparse/linear_classification/train.py
index eb7871b..cde40dd 100644
--- a/example/sparse/linear_classification/train.py
+++ b/example/sparse/linear_classification/train.py
@@ -31,7 +31,7 @@ parser.add_argument('--batch-size', type=int, default=8192,
help='number of examples per batch')
parser.add_argument('--kvstore', type=str, default=None,
help='what kvstore to use',
- choices=["dist_async", "local"])
+ choices=["dist_sync", "dist_async", "local"])
parser.add_argument('--optimizer', type=str, default='ftrl',
help='what optimizer to use',
choices=["ftrl", "sgd", "adam"])
@@ -44,6 +44,15 @@ AVAZU = {
'num_features': 1000001,
}
+def batch_row_ids(data_batch):
+ """ Generate row ids based on the current mini-batch """
+ return {'weight': batch.data[0].indices}
+
+def all_row_ids(data_batch):
+ """ Generate row ids for all rows """
+ all_rows = mx.nd.arange(0, AVAZU['num_features'], dtype='int64')
+ return {'weight': all_rows}
+
if __name__ == '__main__':
import logging
head = '%(asctime)-15s %(message)s'
@@ -94,9 +103,6 @@ if __name__ == '__main__':
metric = mx.metric.create(['nll_loss'])
# get the sparse weight parameter
- weight_index = mod._exec_group.param_names.index('weight')
- weight_param = mod._exec_group.param_arrays[weight_index]
- all_row_ids = mx.nd.arange(0, num_features, dtype='int64')
speedometer = mx.callback.Speedometer(batch_size, 100)
logging.info('Training started ...')
@@ -106,10 +112,7 @@ if __name__ == '__main__':
for batch in train_data:
nbatch += 1
# for distributed training, we need to manually pull sparse weights from kvstore
- if kv:
- row_ids = batch.data[0].indices
- kv.row_sparse_pull('weight', weight_param, row_ids=[row_ids],
- priority=-weight_index)
+ mod.prepare(batch, sparse_row_id_fn=batch_row_ids)
mod.forward_backward(batch)
# update all parameters (including the weight parameter)
mod.update()
@@ -118,15 +121,15 @@ if __name__ == '__main__':
speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=metric, locals=locals())
speedometer(speedometer_param)
- # pull all rows before making a checkpoint
- if kv:
- kv.row_sparse_pull('weight', weight_param, row_ids=[all_row_ids],
- priority=-weight_index)
+
+ # prepare the module weight with all row ids for inference. Alternatively, one could call
+ # score = mod.score(val_iter, ['MSE'], sparse_row_id_fn=batch_row_ids)
+ # to fetch the weight per mini-batch
+ mod.prepare(None, all_row_ids)
# evaluate metric on validation dataset
score = mod.score(eval_data, ['nll_loss'])
logging.info('epoch %d, eval nll = %s ' % (epoch, score[0][1]))
- save_optimizer_states = 'dist' not in kv.type if kv else True
- mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=save_optimizer_states)
+ mod.save_checkpoint("checkpoint", epoch)
# reset the iterator for next pass of data
train_data.reset()
eval_data.reset()
diff --git a/example/sparse/matrix_factorization/README.md b/example/sparse/matrix_factorization/README.md
index 3ada5e8..5c4beef 100644
--- a/example/sparse/matrix_factorization/README.md
+++ b/example/sparse/matrix_factorization/README.md
@@ -1,8 +1,27 @@
Matrix Factorization w/ Sparse Embedding
===========
The example demonstrates the basic usage of the SparseEmbedding operator in MXNet, adapted based on @leopd's recommender examples.
-The operator is available on both CPU and GPU. This is for demonstration purpose only.
+This is for demonstration purpose only.
-- `python train.py`
-- To compare the training speed with (dense) Embedding, run `python train.py --use-dense`
-- To run the example on the GPU, run `python train.py --use-gpu`
+```
+usage: train.py [-h] [--num-epoch NUM_EPOCH] [--seed SEED]
+ [--batch-size BATCH_SIZE] [--log-interval LOG_INTERVAL]
+ [--factor-size FACTOR_SIZE] [--gpus GPUS] [--dense]
+
+Run matrix factorization with sparse embedding
+
+optional arguments:
+ -h, --help show this help message and exit
+ --num-epoch NUM_EPOCH
+ number of epochs to train (default: 3)
+ --seed SEED random seed (default: 1)
+ --batch-size BATCH_SIZE
+ number of examples per batch (default: 128)
+ --log-interval LOG_INTERVAL
+ logging interval (default: 100)
+ --factor-size FACTOR_SIZE
+ the factor size of the embedding operation (default: 128)
+ --gpus GPUS list of gpus to run, e.g. 0 or 0,2. empty means using
+ cpu(). (default: None)
+ --dense whether to use dense embedding (default: False)
+```
diff --git a/example/sparse/matrix_factorization/data.py b/example/sparse/matrix_factorization/data.py
index c897165..049f5c2 100644
--- a/example/sparse/matrix_factorization/data.py
+++ b/example/sparse/matrix_factorization/data.py
@@ -15,9 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-import os
+import os, logging
import mxnet as mx
-from mxnet.test_utils import DummyIter
def get_movielens_data(data_dir, prefix):
if not os.path.exists(os.path.join(data_dir, "ml-10M100K")):
@@ -27,11 +26,11 @@ def get_movielens_data(data_dir, prefix):
assert os.path.exists(os.path.join(data_dir, "ml-10M100K"))
os.system("cd data/ml-10M100K; chmod +x allbut.pl; sh split_ratings.sh; cd -;")
-def get_movielens_iter(filename, batch_size, dummy_iter):
+def get_movielens_iter(filename, batch_size):
"""Not particularly fast code to parse the text file and load into NDArrays.
return two data iters, one for train, the other for validation.
"""
- print("Preparing data iterators for " + filename + " ... ")
+ logging.info("Preparing data iterators for " + filename + " ... ")
user = []
item = []
score = []
@@ -45,18 +44,15 @@ def get_movielens_iter(filename, batch_size, dummy_iter):
user.append((tks[0]))
item.append((tks[1]))
score.append((tks[2]))
- if dummy_iter and num_samples > batch_size * 10:
- break
# convert to ndarrays
user = mx.nd.array(user, dtype='int32')
item = mx.nd.array(item)
score = mx.nd.array(score)
# prepare data iters
- data_train = {'user':user, 'item':item}
- label_train = {'score':score}
+ data_train = {'user': user, 'item': item}
+ label_train = {'score': score}
iter_train = mx.io.NDArrayIter(data=data_train,label=label_train,
batch_size=batch_size, shuffle=True)
- iter_train = DummyIter(iter_train) if dummy_iter else iter_train
return mx.io.PrefetchingIter(iter_train)
diff --git a/example/sparse/matrix_factorization/model.py b/example/sparse/matrix_factorization/model.py
index d2d8de5..672c392 100644
--- a/example/sparse/matrix_factorization/model.py
+++ b/example/sparse/matrix_factorization/model.py
@@ -17,34 +17,38 @@
import mxnet as mx
-def matrix_fact_net(factor_size, num_hidden, max_user, max_item, sparse_embed=True):
+def matrix_fact_net(factor_size, num_hidden, max_user, max_item, dense):
# input
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')
- if sparse_embed:
+ stype = 'default' if dense else 'row_sparse'
+ user_weight = mx.symbol.Variable('user_weight', stype=stype)
+ item_weight = mx.symbol.Variable('item_weight', stype=stype)
+ if not dense:
+ embed = mx.symbol.contrib.SparseEmbedding
# user feature lookup
- user_weight = mx.symbol.Variable('user_weight', stype='row_sparse')
- user = mx.symbol.contrib.SparseEmbedding(data=user, weight=user_weight,
- input_dim=max_user, output_dim=factor_size)
+ user = embed(data=user, weight=user_weight,
+ input_dim=max_user, output_dim=factor_size, deterministic=True)
# item feature lookup
- item_weight = mx.symbol.Variable('item_weight', stype='row_sparse')
- item = mx.symbol.contrib.SparseEmbedding(data=item, weight=item_weight,
- input_dim=max_item, output_dim=factor_size)
+ item = embed(data=item, weight=item_weight,
+ input_dim=max_item, output_dim=factor_size, deterministic=True)
else:
# user feature lookup
- user = mx.symbol.Embedding(data=user, input_dim=max_user, output_dim=factor_size)
+ user = mx.symbol.Embedding(data=user, weight=user_weight,
+ input_dim=max_user, output_dim=factor_size)
# item feature lookup
- item = mx.symbol.Embedding(data=item, input_dim=max_item, output_dim=factor_size)
+ item = mx.symbol.Embedding(data=item, weight=item_weight,
+ input_dim=max_item, output_dim=factor_size)
# non-linear transformation of user features
user = mx.symbol.Activation(data=user, act_type='relu')
- user = mx.symbol.FullyConnected(data=user, num_hidden=num_hidden)
+ user_act = mx.symbol.FullyConnected(data=user, num_hidden=num_hidden)
# non-linear transformation of item features
item = mx.symbol.Activation(data=item, act_type='relu')
- item = mx.symbol.FullyConnected(data=item, num_hidden=num_hidden)
+ item_act = mx.symbol.FullyConnected(data=item, num_hidden=num_hidden)
# predict by the inner product, which is elementwise product and then sum
- pred = user * item
- pred = mx.symbol.sum(data=pred, axis = 1)
+ pred = user_act * item_act
+ pred = mx.symbol.sum(data=pred, axis=1)
pred = mx.symbol.Flatten(data=pred)
# loss layer
pred = mx.symbol.LinearRegressionOutput(data=pred, label=score)
diff --git a/example/sparse/matrix_factorization/train.py b/example/sparse/matrix_factorization/train.py
index 0db58ad..44bab2c 100644
--- a/example/sparse/matrix_factorization/train.py
+++ b/example/sparse/matrix_factorization/train.py
@@ -29,18 +29,17 @@ parser = argparse.ArgumentParser(description="Run matrix factorization with spar
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-epoch', type=int, default=3,
help='number of epochs to train')
+parser.add_argument('--seed', type=int, default=1,
+ help='random seed')
parser.add_argument('--batch-size', type=int, default=128,
help='number of examples per batch')
-parser.add_argument('--print-every', type=int, default=100,
- help='logging frequency')
+parser.add_argument('--log-interval', type=int, default=100,
+ help='logging interval')
parser.add_argument('--factor-size', type=int, default=128,
help="the factor size of the embedding operation")
-parser.add_argument('--use-dense', action='store_true',
- help="use the dense embedding operator")
-parser.add_argument('--use-gpu', action='store_true',
- help="use gpu")
-parser.add_argument('--dummy-iter', action='store_true',
- help="use the dummy data iterator for speed test")
+parser.add_argument('--gpus', type=str,
+ help="list of gpus to run, e.g. 0 or 0,2. empty means using cpu().")
+parser.add_argument('--dense', action='store_true', help="whether to use dense embedding")
MOVIELENS = {
'dataset': 'ml-10m',
@@ -50,6 +49,19 @@ MOVIELENS = {
'max_movie': 65135,
}
+def batch_row_ids(data_batch):
+ """ Generate row ids based on the current mini-batch """
+ item = data_batch.data[0]
+ user = data_batch.data[1]
+ return {'user_weight': user.astype(np.int64),
+ 'item_weight': item.astype(np.int64)}
+
+def all_row_ids(data_batch):
+ """ Generate row ids for all rows """
+ all_users = mx.nd.arange(0, MOVIELENS['max_user'], dtype='int64')
+ all_movies = mx.nd.arange(0, MOVIELENS['max_movie'], dtype='int64')
+ return {'user_weight': all_users, 'item_weight': all_movies}
+
if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)
@@ -60,43 +72,44 @@ if __name__ == '__main__':
num_epoch = args.num_epoch
batch_size = args.batch_size
optimizer = 'sgd'
- use_sparse = not args.use_dense
factor_size = args.factor_size
- dummy_iter = args.dummy_iter
- print_every = args.print_every
+ log_interval = args.log_interval
momentum = 0.9
- ctx = mx.gpu(0) if args.use_gpu else mx.cpu(0)
+ ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')] if args.gpus else [mx.cpu()]
learning_rate = 0.1
+ mx.random.seed(args.seed)
+ np.random.seed(args.seed)
# prepare dataset and iterators
max_user = MOVIELENS['max_user']
max_movies = MOVIELENS['max_movie']
data_dir = os.path.join(os.getcwd(), 'data')
get_movielens_data(data_dir, MOVIELENS['dataset'])
- train_iter = get_movielens_iter(MOVIELENS['train'], batch_size, dummy_iter)
- val_iter = get_movielens_iter(MOVIELENS['val'], batch_size, dummy_iter)
+ train_iter = get_movielens_iter(MOVIELENS['train'], batch_size)
+ val_iter = get_movielens_iter(MOVIELENS['val'], batch_size)
# construct the model
- net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, sparse_embed=use_sparse)
+ net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, dense=args.dense)
# initialize the module
- mod = mx.module.Module(symbol=net, context=ctx, data_names=['user', 'item'],
+ mod = mx.module.Module(net, context=ctx, data_names=['user', 'item'],
label_names=['score'])
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
- optim = mx.optimizer.create(optimizer, learning_rate=learning_rate, momentum=momentum,
- wd=1e-4, rescale_grad=1.0/batch_size)
- mod.init_optimizer(optimizer=optim)
+ optim = mx.optimizer.create(optimizer, learning_rate=learning_rate,
+ rescale_grad=1.0/batch_size)
+ mod.init_optimizer(optimizer=optim, kvstore='device')
# use MSE as the metric
metric = mx.metric.create(['MSE'])
- speedometer = mx.callback.Speedometer(batch_size, print_every)
+ speedometer = mx.callback.Speedometer(batch_size, log_interval)
logging.info('Training started ...')
for epoch in range(num_epoch):
nbatch = 0
metric.reset()
for batch in train_iter:
nbatch += 1
+ mod.prepare(batch, sparse_row_id_fn=batch_row_ids)
mod.forward_backward(batch)
# update all parameters
mod.update()
@@ -105,6 +118,11 @@ if __name__ == '__main__':
speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=metric, locals=locals())
speedometer(speedometer_param)
+
+ # prepare the module weight with all row ids for inference. Alternatively, one could call
+ # score = mod.score(val_iter, ['MSE'], sparse_row_id_fn=batch_row_ids)
+ # to fetch the weight per mini-batch
+ mod.prepare(None, sparse_row_id_fn=all_row_ids)
# evaluate metric on validation dataset
score = mod.score(val_iter, ['MSE'])
logging.info('epoch %d, eval MSE = %s ' % (epoch, score[0][1]))
diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index 2bace6f..884e929 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -517,7 +517,7 @@ def _init_data(data, allow_empty, default_name):
raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \
"should be NDArray, numpy.ndarray or h5py.Dataset")
- return list(data.items())
+ return list(sorted(data.items()))
def _has_instance(data, dtype):
"""Return True if ``data`` has instance of ``dtype``.
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index df3fcc5..c03f8e7 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=fixme, too-many-arguments, too-many-locals, too-many-public-methods, too-many-branches
+# pylint: disable=fixme, too-many-arguments, too-many-locals
+# pylint: disable=too-many-public-methods, too-many-branches, too-many-lines
"""`BaseModule` defines an API for modules."""
import time
@@ -136,6 +137,7 @@ class BaseModule(object):
- setup
- `bind()`: prepare environment for computation.
- `init_optimizer()`: install optimizer for parameter updating.
+ - `prepare()`: prepare the module based on the current data batch.
- computation
- `forward(data_batch)`: forward operation.
@@ -193,7 +195,7 @@ class BaseModule(object):
def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
score_end_callback=None,
- reset=True, epoch=0):
+ reset=True, epoch=0, sparse_row_id_fn=None):
"""Runs prediction on ``eval_data`` and evaluates the performance according to
the given ``eval_metric``.
@@ -217,6 +219,11 @@ class BaseModule(object):
epoch : int
Defaults to 0. For compatibility, this will be passed to callbacks (if any).
During training, this will correspond to the training epoch number.
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
Examples
--------
@@ -240,7 +247,7 @@ class BaseModule(object):
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
-
+ self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
self.forward(eval_batch, is_train=False)
self.update_metric(eval_metric, eval_batch.label)
@@ -263,7 +270,7 @@ class BaseModule(object):
return eval_metric.get_name_value()
- def iter_predict(self, eval_data, num_batch=None, reset=True):
+ def iter_predict(self, eval_data, num_batch=None, reset=True, sparse_row_id_fn=None):
"""Iterates over predictions.
Example Usage:
@@ -282,6 +289,11 @@ class BaseModule(object):
reset : bool
Default is ``True``, indicating whether we should reset the data iter before start
doing prediction.
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
"""
assert self.binded and self.params_initialized
@@ -291,6 +303,7 @@ class BaseModule(object):
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
+ self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
self.forward(eval_batch, is_train=False)
pad = eval_batch.pad
outputs = [out[0:out.shape[0]-pad] for out in self.get_outputs()]
@@ -298,7 +311,7 @@ class BaseModule(object):
yield (outputs, nbatch, eval_batch)
def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
- always_output_list=False):
+ always_output_list=False, sparse_row_id_fn=None):
"""Runs prediction and collects the outputs.
When `merge_batches` is ``True`` (by default), the return value will be a list
@@ -327,6 +340,11 @@ class BaseModule(object):
doing prediction.
always_output_list : bool
Defaults to ``False``, see above for return values.
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
Returns
-------
@@ -349,6 +367,7 @@ class BaseModule(object):
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
+ self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
self.forward(eval_batch, is_train=False)
pad = eval_batch.pad
outputs = [out[0:out.shape[0]-pad].copy() for out in self.get_outputs()]
@@ -380,7 +399,7 @@ class BaseModule(object):
eval_batch_end_callback=None, initializer=Uniform(0.01),
arg_params=None, aux_params=None, allow_missing=False,
force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
- validation_metric=None, monitor=None):
+ validation_metric=None, monitor=None, sparse_row_id_fn=None):
"""Trains the module parameters.
Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
@@ -442,6 +461,11 @@ class BaseModule(object):
N+1.
num_epoch : int
Number of epochs for training.
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
Examples
--------
@@ -489,7 +513,7 @@ class BaseModule(object):
try:
# pre fetch next batch
next_data_batch = next(data_iter)
- self.prepare(next_data_batch)
+ self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
except StopIteration:
end_of_batch = True
@@ -740,16 +764,34 @@ class BaseModule(object):
################################################################################
# Computations
################################################################################
- def prepare(self, data_batch):
+ # pylint: disable=unused-argument
+ def prepare(self, data_batch, sparse_row_id_fn=None):
'''Prepares the module for processing a data batch.
Usually involves switching bucket and reshaping.
+ For modules that contain `row_sparse` parameters in KVStore,
+ it prepares the `row_sparse` parameters based on the sparse_row_id_fn.
+
+ When KVStore is used to update parameters for multi-device or multi-machine training,
+ a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+ the `update()` updates the copy of parameters in KVStore, but doesn't broadcast
+ the updated parameters to all devices / machines. The `prepare` function is used to
+ broadcast `row_sparse` parameters with the next batch of data.
Parameters
----------
data_batch : DataBatch
+ The current batch of data for forward computation.
+
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
'''
- pass
+ if sparse_row_id_fn is not None:
+ warnings.warn(UserWarning("sparse_row_id_fn is not invoked for BaseModule."))
+ # pylint: enable=unused-argument
def forward(self, data_batch, is_train=None):
"""Forward computation. It supports data batches with different shapes, such as
@@ -877,6 +919,12 @@ class BaseModule(object):
"""Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
+ When KVStore is used to update parameters for multi-device or multi-machine training,
+ a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+ this function does update the copy of parameters in KVStore, but doesn't broadcast the
+ updated parameters to all devices / machines. Please call `prepare` to broadcast
+ `row_sparse` parameters with the next batch of data.
+
Examples
--------
>>> # An example of updating module parameters.
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index 2f5cc9e..18cec29 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -412,13 +412,24 @@ class BucketingModule(BaseModule):
self.optimizer_initialized = True
- def prepare(self, data_batch):
- """Prepares a data batch for forward.
+ def prepare(self, data_batch, sparse_row_id_fn=None):
+ '''Prepares the module for processing a data batch.
+
+ Usually involves switching bucket and reshaping.
+ For modules that contain `row_sparse` parameters in KVStore,
+ it prepares the `row_sparse` parameters based on the sparse_row_id_fn.
Parameters
----------
data_batch : DataBatch
- """
+ The current batch of data for forward computation.
+
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
+ '''
# perform bind if haven't done so
assert self.binded and self.params_initialized
bucket_key = data_batch.bucket_key
@@ -426,6 +437,7 @@ class BucketingModule(BaseModule):
data_shapes = data_batch.provide_data
label_shapes = data_batch.provide_label
self.switch_bucket(bucket_key, data_shapes, label_shapes)
+ self._curr_module.prepare(data_batch, sparse_row_id_fn=sparse_row_id_fn)
# switch back
self.switch_bucket(original_bucket_key, None, None)
@@ -451,6 +463,13 @@ class BucketingModule(BaseModule):
def update(self):
"""Updates parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
+
+ When KVStore is used to update parameters for multi-device or multi-machine training,
+ a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+ this function does update the copy of parameters in KVStore, but doesn't broadcast the
+ updated parameters to all devices / machines. Please call `prepare` to broadcast
+ `row_sparse` parameters with the next batch of data.
+
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 1cf7040..21d9b56 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -26,6 +26,7 @@ import warnings
from .. import context as ctx
from .. import optimizer as opt
+from .. import ndarray as nd
from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
@@ -630,6 +631,12 @@ class Module(BaseModule):
"""Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
+ When KVStore is used to update parameters for multi-device or multi-machine training,
+ a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+ this function does update the copy of parameters in KVStore, but doesn't broadcast the
+ updated parameters to all devices / machines. Please call `prepare` to broadcast
+ `row_sparse` parameters with the next batch of data.
+
See Also
----------
:meth:`BaseModule.update`.
@@ -752,8 +759,16 @@ class Module(BaseModule):
"""Synchronizes parameters from devices to CPU. This function should be called after
calling `update` that updates the parameters on the devices, before one can read the
latest parameters from ``self._arg_params`` and ``self._aux_params``.
+
+ For row_sparse parameters on devices, ther are pulled from KVStore with all row ids.
+
"""
self._exec_group.get_params(self._arg_params, self._aux_params)
+ if self._kvstore and self._update_on_kvstore:
+ for param_name, param_val in sorted(self._arg_params.items()):
+ if param_val.stype == 'row_sparse':
+ row_ids = nd.arange(0, param_val.shape[0], dtype='int64')
+ self._kvstore.row_sparse_pull(param_name, param_val, row_ids=row_ids)
self._params_dirty = False
def save_optimizer_states(self, fname):
@@ -791,3 +806,46 @@ class Module(BaseModule):
"""Installs monitor on all executors. """
assert self.binded
self._exec_group.install_monitor(mon)
+
+ def prepare(self, data_batch, sparse_row_id_fn=None):
+ '''Prepares the module for processing a data batch.
+
+ Usually involves switching bucket and reshaping.
+ For modules that contain `row_sparse` parameters in KVStore,
+ it prepares the `row_sparse` parameters based on the sparse_row_id_fn.
+
+ When KVStore is used to update parameters for multi-device or multi-machine training,
+ a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+ the `update()` updates the copy of parameters in KVStore, but doesn't broadcast
+ the updated parameters to all devices / machines. The `prepare` function is used to
+ broadcast `row_sparse` parameters with the next batch of data.
+
+ Parameters
+ ----------
+ data_batch : DataBatch
+ The current batch of data for forward computation.
+
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
+ '''
+ assert self.binded
+ if sparse_row_id_fn is not None:
+ if not self._kvstore or not self._update_on_kvstore:
+ warnings.warn(UserWarning("Parameters are not updated in the KVStore. "
+ "No need to call sparse_row_id_fn."))
+ else:
+ row_ids = sparse_row_id_fn(data_batch)
+ assert(isinstance(row_ids, dict)), "Expected dict output from sparse_row_id_fn"
+ for param_name, row_id in row_ids.items():
+ param_idx = self._exec_group.param_names.index(param_name)
+ param_val = self._exec_group.param_arrays[param_idx]
+ assert(isinstance(param_val, (tuple, list)))
+ if param_val[0].stype != 'row_sparse':
+ warnings.warn(UserWarning("%s.stype is not 'row_sparse'. No need to "
+ "perform row_sparse_pull." % param_name))
+ else:
+ self._kvstore.row_sparse_pull(param_name, param_val, row_ids=row_id,
+ priority=-param_idx)
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.