You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/03/17 06:48:56 UTC

[incubator-mxnet] branch master updated: Correct update count with Gluon trainer and update_on_kvstore=False (#14377)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 63ed258  Correct update count with Gluon trainer and update_on_kvstore=False (#14377)
63ed258 is described below

commit 63ed258063137421e6d4def30435014ab57fb468
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Sat Mar 16 23:48:30 2019 -0700

    Correct update count with Gluon trainer and update_on_kvstore=False (#14377)
    
    * LRScheduler with update_on_kvstore=False
    
    * Cleaning trainer.py
    
    * Retrigger CI
    
    * Fixes from review
---
 python/mxnet/gluon/trainer.py               |  4 ----
 python/mxnet/optimizer/optimizer.py         | 17 ++++++++++++++++-
 tests/python/unittest/test_gluon_trainer.py | 21 ++++++++++++---------
 3 files changed, 28 insertions(+), 14 deletions(-)

diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 8060f38..45a44d8 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -241,10 +241,6 @@ class Trainer(object):
                 kvstore.set_optimizer(self._optimizer)
             self._kvstore = kvstore
             self._update_on_kvstore = update_on_kvstore
-            if self._optimizer.lr_scheduler and not self._update_on_kvstore:
-                raise ValueError("update_on_kvstore=False does not support " \
-                                 "optimizer with LRScheduler. Please " \
-                                 "consider setting learning rate manually.")
         else:
             self._kvstore = None
             self._update_on_kvstore = None
diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py
index def2c95..2e7fe86 100644
--- a/python/mxnet/optimizer/optimizer.py
+++ b/python/mxnet/optimizer/optimizer.py
@@ -106,7 +106,8 @@ class Optimizer(object):
         self.wd_mult = {}
         self.begin_num_update = begin_num_update
         self.num_update = begin_num_update
-        self._index_update_count = {}
+        self._all_index_update_counts = {0 : {}}
+        self._index_update_count = self._all_index_update_counts[0]
         self.clip_gradient = clip_gradient
         self.multi_precision = multi_precision
         self.aggregate_num = 0
@@ -380,6 +381,18 @@ class Optimizer(object):
                     self.wd_mult[name] = float(attr[name]['__wd_mult__'])
         self.wd_mult.update(args_wd_mult)
 
+    def _set_current_context(self, device_id):
+        """Sets the number of the currently handled device.
+
+        Parameters
+        ----------
+        device_id : int
+            The number of current device.
+        """
+        if device_id not in self._all_index_update_counts:
+            self._all_index_update_counts[device_id] = {}
+        self._index_update_count = self._all_index_update_counts[device_id]
+
     def _update_count(self, index):
         """Updates num_update.
 
@@ -1623,6 +1636,8 @@ class Updater(object):
             indices = index
             grads = grad
             weights = weight
+        if weights:
+            self.optimizer._set_current_context(weights[0].context.device_id)
         for i, idx in enumerate(indices):
             # convert ctypes.char_p.value back to python str if needed
             if isinstance(idx, bytes):
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
index 9f190a0..2d5874a 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -272,19 +272,22 @@ def test_trainer_lr_sched():
             lr *= factor
     mx.nd.waitall()
 
-@with_seed()
-def test_trainer_invalid_lr_sched():
+    # Update on kvstore = False
     x = gluon.Parameter('x', shape=(10,))
     x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
     freq = 2
     factor = 0.1
     lr = 1
     lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr)
-    invalid_trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched},
-                                    update_on_kvstore=False)
-    with mx.autograd.record():
-        for w in x.list_data():
-            y = w + 1
-            y.backward()
-    assert_raises(ValueError, invalid_trainer.step, 1)
+    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched},
+                            update_on_kvstore=False)
+    for i in range(10):
+        with mx.autograd.record():
+            for w in x.list_data():
+                y = w + 1
+                y.backward()
+        trainer.step(1)
+        if i % freq == 0:
+            assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i)
+            lr *= factor
     mx.nd.waitall()