You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/29 21:10:43 UTC

[GitHub] eric-haibin-lin closed pull request #13721: Fixes for trainer with update_on_kvstore=False

eric-haibin-lin closed pull request #13721: Fixes for trainer with update_on_kvstore=False
URL: https://github.com/apache/incubator-mxnet/pull/13721
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index c4d49e82c90..f6c0a31b52e 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -28,6 +28,15 @@ class Trainer(object):
     """Applies an `Optimizer` on a set of Parameters. Trainer should
     be used together with `autograd`.
 
+    .. note::
+
+        For the following cases, updates will always happen on kvstore,
+        i.e., you cannot set update_on_kvstore=False.
+
+        - dist kvstore with sparse weights or sparse gradients
+        - dist async kvstore
+        - `optimizer.lr_scheduler` is not None
+
     Parameters
     ----------
     params : ParameterDict
@@ -115,11 +124,12 @@ def _init_optimizer(self, optimizer, optimizer_params):
                 "optimizer_params must be None if optimizer is an instance of " \
                 "Optimizer instead of str"
             self._optimizer = optimizer
+            # param_dict must not be deep copied, so that if user mutate the lr_mult
+            # or wd_mult of some parameters, it takes effect.
             self._optimizer.param_dict = param_dict
         else:
             self._optimizer = opt.create(optimizer, param_dict=param_dict,
                                          **optimizer_params)
-
         self._updaters = [opt.get_updater(self._optimizer) \
                             for _ in self._contexts]
 
@@ -158,59 +168,82 @@ def _reset_kvstore(self):
     def _init_kvstore(self):
         """Create kvstore."""
         config = self._kvstore_params
-        # if weight is sparse, the weight must be updated on KVStore.
-        # training loop contains:
-        #    - row_sparse_pull(sparse_weight)
-        #    - forward()
-        #    - backward()
-        #    - push(sparse_grad), push(dense_grad)
-        #    - pull(dense_weight)
+        # configure kvstore, update_on_kvstore and self._distributed on three cases:
         if self._contains_sparse_weight:
+            # If weight is sparse, kvstore must be present and the weight must be updated on kvstore.
+            # The training loop is the following:
+            #    - row_sparse_pull(sparse_weight)
+            #    - forward()
+            #    - backward()
+            #    - push_and_update(grad)
+            #    - pull(weight)
             kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore'])
-            # raise Error if update_on_kvstore is set to False by the user
+            self._distributed = 'dist' in kvstore.type
+            # raise err if user provides unsupported configs
             if config['update_on_kvstore'] is False:
-                raise RuntimeError("Cannot set update_on_kvstore to False when sparse weights "
-                                   "are present.")
-        # if weight is dense and grad is sparse, the weight better not be updated on KVStore.
-        # training loop contains:
-        #    - forward()
-        #    - backward()
-        #    - push(grad)
-        #    - pull(grad)
-        #    - update(grad, weight)
+                raise ValueError("Cannot set update_on_kvstore=False when sparse weights "
+                                 "are present.")
+
         elif self._contains_sparse_grad:
+            # For single node training with dense weight and sparse grad,
+            # we prefer update_on_kvstore=False because this is usually faster.
+            # This means we push and pull sparse gradients, and we do not store weight in kvstore.
+            # The training loop is the following:
+            #    - forward()
+            #    - backward()
+            #    - push(grad)
+            #    - pull(grad)
+            #    - update(grad, weight)
+            #
+            # For multi-node training with dense weight and sparse grad,
+            # only update_on_kvstore=True is supported, due to the fact that
+            # kv.row_sparse_pull(grad) is not implemented.
+            # Therefore, we push sparse gradients and pull dense weights.
+            # The training loop contains:
+            #    - forward()
+            #    - backward()
+            #    - push_and_update(grad)
+            #    - pull(weight)
             arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
             kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays)
-            update_on_kvstore = False
-        # normal case
+            self._distributed = 'dist' in kvstore.type if kvstore else False
+            update_on_kvstore = self._distributed
+            # raise err if user provides unsupported configs
+            if config['update_on_kvstore'] is not None:
+                if config['update_on_kvstore'] is False and self._distributed:
+                    raise ValueError("Cannot set update_on_kvstore=False on dist kvstore "
+                                     "when sparse gradients are present.")
+                update_on_kvstore = config['update_on_kvstore']
+
         else:
+            # Training with dense weight and dense gradients.
+            # The only unsupported mode is async with update_on_kvstore=False
             arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
             kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
                                                          arg_arrays)
-            if kvstore and 'async' in kvstore.type and config['update_on_kvstore'] is not None\
-                    and not config['update_on_kvstore']:
-                raise ValueError("Please set update_on_kvstore to true "
-                                 "when training in async mode.")
-
+            self._distributed = 'dist' in kvstore.type if kvstore else False
+            if self._distributed and 'async' in kvstore.type:
+                update_on_kvstore = True
+                # raise err if user provides unsupported configs
+                if config['update_on_kvstore'] is False:
+                    raise ValueError("Please set update_on_kvstore=True "
+                                     "when training in async mode.")
             if config['update_on_kvstore'] is not None:
                 update_on_kvstore = config['update_on_kvstore']
 
+        # set grad compression and optimizers
         if kvstore:
             if self._compression_params:
                 kvstore.set_gradient_compression(self._compression_params)
-            self._distributed = 'dist' in kvstore.type
-            if self._distributed:
-                # kv.pull(row_sparse_grad) is not supported for dist kvstore
-                # Captures condition for dist_async, dist_device_sync or based on config for
-                # update_on_kvstore
-                update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad \
-                                    or 'device' in kvstore.type or 'async' in kvstore.type \
-                                    or config['update_on_kvstore']
             if update_on_kvstore:
                 # optimizer preferably needs to be set before init for multiprecision
                 kvstore.set_optimizer(self._optimizer)
             self._kvstore = kvstore
             self._update_on_kvstore = update_on_kvstore
+            if self._optimizer.lr_scheduler and not self._update_on_kvstore:
+                raise ValueError("update_on_kvstore=False does not support " \
+                                 "optimizer with LRScheduler. Please " \
+                                 "consider setting learning rate manually.")
         else:
             self._kvstore = None
             self._update_on_kvstore = None
@@ -255,6 +288,16 @@ def _row_sparse_pull(self, parameter, out, row_id, full_idx=False):
         else:
             self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)
 
+    def _check_and_rescale_grad(self, scale):
+        if self._update_on_kvstore and self._distributed and self._kv_initialized:
+            if self._optimizer.rescale_grad != scale:
+                raise UserWarning('Possible change in the `batch_size` from previous '
+                                  '`step` detected. Optimizer gradient normalizing '
+                                  'factor will not change w.r.t new batch_size when '
+                                  'update_on_kvstore=True and when distributed kvstore '
+                                  'is used.')
+        self._optimizer.rescale_grad = scale
+
     def step(self, batch_size, ignore_stale_grad=False):
         """Makes one step of parameter update. Should be called after
         `autograd.backward()` and outside of `record()` scope.
@@ -274,13 +317,7 @@ def step(self, batch_size, ignore_stale_grad=False):
             been updated by `backward` after last step) and skip update.
         """
         rescale_grad = self._scale / batch_size
-        if self._update_on_kvstore and self._distributed and \
-           self._optimizer.rescale_grad != rescale_grad:
-            raise UserWarning('Possible change in the `batch_size` from previous `step` detected.' \
-                            'Optimizer gradient normalizing factor will not change w.r.t new batch_size when ' \
-                            'update_on_kvstore=True and when distributed `kvstore` is used.')
-
-        self._optimizer.rescale_grad = rescale_grad
+        self._check_and_rescale_grad(rescale_grad)
 
         if not self._kv_initialized:
             self._init_kvstore()
@@ -352,7 +389,7 @@ def update(self, batch_size, ignore_stale_grad=False):
                 'is not supported. Try setting `update_on_kvstore` ' \
                 'to False when creating trainer.'
 
-        self._optimizer.rescale_grad = self._scale / batch_size
+        self._check_and_rescale_grad(self._scale / batch_size)
         self._update(ignore_stale_grad)
 
     def _update(self, ignore_stale_grad=False):
@@ -387,10 +424,16 @@ def _update(self, ignore_stale_grad=False):
     def save_states(self, fname):
         """Saves trainer states (e.g. optimizer, momentum) to a file.
 
+
         Parameters
         ----------
         fname : str
             Path to output states file.
+
+        Note
+        ----
+        `optimizer.param_dict`, which contains Parameter information (such as
+        `lr_mult` and `wd_mult`) will not be saved.
         """
         assert self._optimizer is not None
 
@@ -414,6 +457,12 @@ def load_states(self, fname):
         ----------
         fname : str
             Path to input states file.
+
+        Note
+        ----
+        `optimizer.param_dict`, which contains Parameter information (such as
+        `lr_mult` and `wd_mult`) will not be loaded from the file, but rather set
+        based on current Trainer's parameters.
         """
         if not self._kv_initialized:
             self._init_kvstore()
@@ -423,8 +472,6 @@ def load_states(self, fname):
         if self._update_on_kvstore:
             self._kvstore.load_optimizer_states(fname)
             self._optimizer = self._kvstore._updater.optimizer
-            param_dict = {i: param for i, param in enumerate(self._params)}
-            self._optimizer.param_dict = param_dict
         else:
             with open(fname, 'rb') as f:
                 states = f.read()
@@ -432,3 +479,5 @@ def load_states(self, fname):
                 updater.set_states(states)
                 updater.optimizer = self._updaters[0].optimizer
             self._optimizer = self._updaters[0].optimizer
+        param_dict = {i: param for i, param in enumerate(self._params)}
+        self._optimizer.param_dict = param_dict
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index 2666f8bbcd4..38fe739154d 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -62,6 +62,11 @@ def _create_sparse_kvstore(kvstore):
     ----------
     kvstore : KVStore or str
         The kvstore.
+
+    Returns
+    -------
+    kvstore : KVStore
+    update_on_kvstore : bool. Always True.
     """
     # always update on kvstore
     update_on_kvstore = True
diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py
index a085b6fe2ef..ba16132ab08 100644
--- a/python/mxnet/optimizer/optimizer.py
+++ b/python/mxnet/optimizer/optimizer.py
@@ -43,33 +43,33 @@ class Optimizer(object):
 
     Parameters
     ----------
-    rescale_grad : float, optional
+    rescale_grad : float, optional, default 1.0
         Multiply the gradient with `rescale_grad` before updating. Often
         choose to be ``1.0/batch_size``.
 
-    param_idx2name : dict from int to string, optional
+    param_idx2name : dict from int to string, optional, default None
         A dictionary that maps int index to string name.
 
-    clip_gradient : float, optional
+    clip_gradient : float, optional, default None
         Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``.
 
-    learning_rate : float, optional
+    learning_rate : float, optional, default 0.01
         The initial learning rate.
 
-    lr_scheduler : LRScheduler, optional
+    lr_scheduler : LRScheduler, optional, default None
         The learning rate scheduler.
 
-    wd : float, optional
+    wd : float, optional, default 0.0
         The weight decay (or L2 regularization) coefficient. Modifies objective
         by adding a penalty for having large weights.
 
-    sym: Symbol, optional
+    sym: Symbol, optional, default None
         The Symbol this optimizer is applying to.
 
-    begin_num_update : int, optional
+    begin_num_update : int, optional, default 0
         The initial number of updates.
 
-    multi_precision : bool, optional
+    multi_precision : bool, optional, default False
        Flag to control the internal precision of the optimizer.::
 
            False: results in using the same precision as the weights (default),
@@ -77,6 +77,10 @@ class Optimizer(object):
            in 32-bit precision even if actual weights used in the model have lower precision.
            Turning this on can improve convergence and accuracy when training with float16.
 
+    param_dict : dict of int -> gluon.Parameter, default None
+        Dictionary of parameter index to gluon.Parameter, used to lookup parameter attributes
+        such as lr_mult, wd_mult, etc. param_dict shall not be deep copied.
+
     Properties
     ----------
     learning_rate : float
diff --git a/tests/nightly/dist_async_kvstore.py b/tests/nightly/dist_async_kvstore.py
index 3e400eafa04..b990b6b3f13 100644
--- a/tests/nightly/dist_async_kvstore.py
+++ b/tests/nightly/dist_async_kvstore.py
@@ -27,22 +27,26 @@
 nworker = kv.num_workers
 
 def test_gluon_trainer_type():
-    def check_trainer_kv_update(update_on_kv):
+    def check_trainer_kv_update(weight_stype, update_on_kv):
         params = mx.gluon.ParameterDict()
-        x = params.get('x', shape=(10,1), lr_mult=1.0)
+        x = params.get('x', shape=(10,1), lr_mult=1.0, stype=weight_stype)
         params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
         try:
-            trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv)
+            trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1},
+                                       kvstore=kv, update_on_kvstore=update_on_kv)
             trainer._init_kvstore()
             assert trainer._kv_initialized
             assert trainer._update_on_kvstore is True
         except ValueError:
             assert update_on_kv is False
 
-    check_trainer_kv_update(False)
-    check_trainer_kv_update(True)
-    check_trainer_kv_update(None)
+    check_trainer_kv_update('default', False)
+    check_trainer_kv_update('default', True)
+    check_trainer_kv_update('default', None)
+    check_trainer_kv_update('row_sparse', False)
+    check_trainer_kv_update('row_sparse', True)
+    check_trainer_kv_update('row_sparse', None)
     print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type')
 
 if __name__ == "__main__":
-    test_gluon_trainer_type()
\ No newline at end of file
+    test_gluon_trainer_type()
diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py
index 861b85913ac..4523a361cf8 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -376,18 +376,26 @@ def check_invalid_pull():
     check_invalid_pull()
 
 def test_gluon_trainer_type():
-    def check_trainer_kv_type(stype, grad_stype, update_on_kv):
+    def check_trainer_kv_type(stype, grad_stype, update_on_kv, expected):
         params = mx.gluon.ParameterDict()
         x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype)
         params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
-        trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv)
-        trainer._init_kvstore()
-        assert trainer._kv_initialized
-        assert trainer._update_on_kvstore is update_on_kv
-
-    check_trainer_kv_type('default', 'default', False)
-    check_trainer_kv_type('default', 'row_sparse', True)
-    check_trainer_kv_type('row_sparse', 'row_sparse', True)
+        trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1},
+                                   kvstore=kv, update_on_kvstore=update_on_kv)
+        try:
+            trainer._init_kvstore()
+            assert trainer._kv_initialized
+            assert trainer._update_on_kvstore is expected
+        except Exception as err:
+            assert isinstance(err, expected)
+
+    check_trainer_kv_type('default', 'default', None, True)
+    check_trainer_kv_type('default', 'default', True, True)
+    check_trainer_kv_type('default', 'default', False, False)
+    check_trainer_kv_type('default', 'row_sparse', None, True)
+    check_trainer_kv_type('default', 'row_sparse', False, ValueError)
+    check_trainer_kv_type('row_sparse', 'row_sparse', None, True)
+    check_trainer_kv_type('row_sparse', 'row_sparse', False, ValueError)
     print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type')
 
 def test_gluon_trainer_step():
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
index b4bfe4c47f0..985c38c3135 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -55,16 +55,15 @@ def dict_equ(a, b):
             y.backward()
     trainer.step(1)
 
+    assert trainer._optimizer.param_dict == trainer._optimizer.param_dict
     assert (x.data(mx.cpu(1)).asnumpy() == -2).all()
 
     x.lr_mult = 0.5
-
     with mx.autograd.record():
         for w in x.list_data():
             y = w + 1
             y.backward()
     trainer.step(1)
-
     assert (x.data(mx.cpu(1)).asnumpy() == -4).all()
 
     trainer.save_states('test_trainer.states')
@@ -212,28 +211,74 @@ def check_trainer_reset_kv(kv):
 
 @with_seed()
 def test_trainer_sparse_kv():
-    def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv):
+    def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected):
         params = gluon.ParameterDict()
         x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype)
         params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
-        trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv)
+        trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1},
+                                kvstore=kv, update_on_kvstore=update_on_kv)
         all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0))
-        ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows)
+        try:
+            ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows)
+            with mx.autograd.record():
+                for w in ws:
+                    y = w + 1
+                    y.backward()
+            trainer.step(1)
+            assert trainer._kvstore.type == kv
+            assert trainer._kv_initialized
+            assert trainer._update_on_kvstore is expected
+            # the updated parameter should be based on the loaded checkpoint
+            mx.nd.waitall()
+            updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows)
+            assert (updated_w == -0.2).asnumpy().all()
+        except Exception as err:
+            assert isinstance(err, expected)
+
+    kvs = ['local', 'device']
+    for kv in kvs:
+        check_trainer_sparse_kv(kv, 'default', 'default', True, True)
+        check_trainer_sparse_kv(kv, 'default', 'default', False, False)
+        check_trainer_sparse_kv(kv, 'default', 'default', None, True)
+        check_trainer_sparse_kv(kv, 'default', 'row_sparse', None, False)
+        check_trainer_sparse_kv(kv, 'default', 'row_sparse', True, True)
+        check_trainer_sparse_kv(kv, 'default', 'row_sparse', False, False)
+        check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', None, True)
+        check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', False, ValueError)
+
+@with_seed()
+def test_trainer_lr_sched():
+    x = gluon.Parameter('x', shape=(10,))
+    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    freq = 2
+    factor = 0.1
+    lr = 1
+    lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr)
+    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched})
+    for i in range(10):
         with mx.autograd.record():
-            for w in ws:
+            for w in x.list_data():
                 y = w + 1
                 y.backward()
         trainer.step(1)
-        assert trainer._kvstore.type == kv
-        assert trainer._kv_initialized
-        assert trainer._update_on_kvstore is update_on_kv
-        # the updated parameter should be based on the loaded checkpoint
-        mx.nd.waitall()
-        updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows)
-        assert (updated_w == -0.2).asnumpy().all()
+        if i % freq == 0:
+            assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i)
+            lr *= factor
+    mx.nd.waitall()
 
-    kvs = ['local', 'device']
-    for kv in kvs:
-        check_trainer_sparse_kv(kv, 'default', 'default', True)
-        check_trainer_sparse_kv(kv, 'default', 'row_sparse', False)
-        check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', True)
+@with_seed()
+def test_trainer_invalid_lr_sched():
+    x = gluon.Parameter('x', shape=(10,))
+    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    freq = 2
+    factor = 0.1
+    lr = 1
+    lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr)
+    invalid_trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched},
+                                    update_on_kvstore=False)
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    assert_raises(ValueError, invalid_trainer.step, 1)
+    mx.nd.waitall()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services