You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/11/01 18:23:25 UTC

[incubator-mxnet] branch master updated: Set correct update on kvstore flag in dist_device_sync mode (#12786)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5bd6f10  Set correct update on kvstore flag in dist_device_sync mode (#12786)
5bd6f10 is described below

commit 5bd6f109c475a2763d8bec55ff019f6a7ff97519
Author: Sandeep Krishnamurthy <sa...@gmail.com>
AuthorDate: Thu Nov 1 11:23:07 2018 -0700

    Set correct update on kvstore flag in dist_device_sync mode (#12786)
    
    * Set correct update on kvstore flag in dist_device_sync mode
    
    * Add warning message for batch-size change in dist mode
    
    * Empty commit
    
    * Fix lint issues
---
 python/mxnet/gluon/trainer.py             | 17 ++++++++++++++---
 tests/nightly/dist_device_sync_kvstore.py | 19 +++++++++++++++++++
 2 files changed, 33 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 028e660..c4d49e8 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -194,14 +194,18 @@ class Trainer(object):
 
             if config['update_on_kvstore'] is not None:
                 update_on_kvstore = config['update_on_kvstore']
+
         if kvstore:
             if self._compression_params:
                 kvstore.set_gradient_compression(self._compression_params)
             self._distributed = 'dist' in kvstore.type
             if self._distributed:
                 # kv.pull(row_sparse_grad) is not supported for dist kvstore
+                # Captures condition for dist_async, dist_device_sync or based on config for
+                # update_on_kvstore
                 update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad \
-                                    or 'async' in kvstore.type
+                                    or 'device' in kvstore.type or 'async' in kvstore.type \
+                                    or config['update_on_kvstore']
             if update_on_kvstore:
                 # optimizer preferably needs to be set before init for multiprecision
                 kvstore.set_optimizer(self._optimizer)
@@ -269,13 +273,20 @@ class Trainer(object):
             If true, ignores Parameters with stale gradient (gradient that has not
             been updated by `backward` after last step) and skip update.
         """
+        rescale_grad = self._scale / batch_size
+        if self._update_on_kvstore and self._distributed and \
+           self._optimizer.rescale_grad != rescale_grad:
+            raise UserWarning('Possible change in the `batch_size` from previous `step` detected.' \
+                            'Optimizer gradient normalizing factor will not change w.r.t new batch_size when ' \
+                            'update_on_kvstore=True and when distributed `kvstore` is used.')
+
+        self._optimizer.rescale_grad = rescale_grad
+
         if not self._kv_initialized:
             self._init_kvstore()
         if self._params_to_init:
             self._init_params()
 
-        self._optimizer.rescale_grad = self._scale / batch_size
-
         self._allreduce_grads()
         self._update(ignore_stale_grad)
 
diff --git a/tests/nightly/dist_device_sync_kvstore.py b/tests/nightly/dist_device_sync_kvstore.py
index 75b48f4..7fd0333 100644
--- a/tests/nightly/dist_device_sync_kvstore.py
+++ b/tests/nightly/dist_device_sync_kvstore.py
@@ -90,6 +90,25 @@ def test_sync_init():
     my_rank = kv.rank
     print('worker ' + str(my_rank) + ' is initialized')
 
+def test_gluon_trainer_type():
+    def check_trainer_kv_update(update_on_kv):
+        params = mx.gluon.ParameterDict()
+        x = params.get('x', shape=(10,1), lr_mult=1.0)
+        params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+        try:
+            trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv)
+            trainer._init_kvstore()
+            assert trainer._kv_initialized
+            assert trainer._update_on_kvstore is True
+        except ValueError:
+            assert update_on_kv is False
+
+    check_trainer_kv_update(False)
+    check_trainer_kv_update(True)
+    check_trainer_kv_update(None)
+    my_rank = kv.rank
+    print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type')
+
 if __name__ == "__main__":
     test_sync_init()
     test_sync_push_pull()