You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/07 18:10:44 UTC

[GitHub] szha closed pull request #13121: Set correct update on kvstore flag in dist_device_sync mode (v1.3.x)

szha closed pull request #13121: Set correct update on kvstore flag in dist_device_sync mode (v1.3.x)
URL: https://github.com/apache/incubator-mxnet/pull/13121
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 028e6607510..c4d49e82c90 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -194,14 +194,18 @@ def _init_kvstore(self):
 
             if config['update_on_kvstore'] is not None:
                 update_on_kvstore = config['update_on_kvstore']
+
         if kvstore:
             if self._compression_params:
                 kvstore.set_gradient_compression(self._compression_params)
             self._distributed = 'dist' in kvstore.type
             if self._distributed:
                 # kv.pull(row_sparse_grad) is not supported for dist kvstore
+                # Captures condition for dist_async, dist_device_sync or based on config for
+                # update_on_kvstore
                 update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad \
-                                    or 'async' in kvstore.type
+                                    or 'device' in kvstore.type or 'async' in kvstore.type \
+                                    or config['update_on_kvstore']
             if update_on_kvstore:
                 # optimizer preferably needs to be set before init for multiprecision
                 kvstore.set_optimizer(self._optimizer)
@@ -269,13 +273,20 @@ def step(self, batch_size, ignore_stale_grad=False):
             If true, ignores Parameters with stale gradient (gradient that has not
             been updated by `backward` after last step) and skip update.
         """
+        rescale_grad = self._scale / batch_size
+        if self._update_on_kvstore and self._distributed and \
+           self._optimizer.rescale_grad != rescale_grad:
+            raise UserWarning('Possible change in the `batch_size` from previous `step` detected.' \
+                            'Optimizer gradient normalizing factor will not change w.r.t new batch_size when ' \
+                            'update_on_kvstore=True and when distributed `kvstore` is used.')
+
+        self._optimizer.rescale_grad = rescale_grad
+
         if not self._kv_initialized:
             self._init_kvstore()
         if self._params_to_init:
             self._init_params()
 
-        self._optimizer.rescale_grad = self._scale / batch_size
-
         self._allreduce_grads()
         self._update(ignore_stale_grad)
 
diff --git a/tests/nightly/dist_device_sync_kvstore.py b/tests/nightly/dist_device_sync_kvstore.py
index 75b48f42c5e..7fd0333aea7 100644
--- a/tests/nightly/dist_device_sync_kvstore.py
+++ b/tests/nightly/dist_device_sync_kvstore.py
@@ -90,6 +90,25 @@ def check_init(kv, cur_keys, cur_shape, device=False):
     my_rank = kv.rank
     print('worker ' + str(my_rank) + ' is initialized')
 
+def test_gluon_trainer_type():
+    def check_trainer_kv_update(update_on_kv):
+        params = mx.gluon.ParameterDict()
+        x = params.get('x', shape=(10,1), lr_mult=1.0)
+        params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+        try:
+            trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv)
+            trainer._init_kvstore()
+            assert trainer._kv_initialized
+            assert trainer._update_on_kvstore is True
+        except ValueError:
+            assert update_on_kv is False
+
+    check_trainer_kv_update(False)
+    check_trainer_kv_update(True)
+    check_trainer_kv_update(None)
+    my_rank = kv.rank
+    print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type')
+
 if __name__ == "__main__":
     test_sync_init()
     test_sync_push_pull()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services