You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/19 04:54:43 UTC

[GitHub] solin319 closed pull request #8107: "add warmup lr_scheduler" create a new pr

solin319 closed pull request #8107:  "add warmup lr_scheduler" create a new pr
URL: https://github.com/apache/incubator-mxnet/pull/8107
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py
index e4af77aa869..cdb75e96a71 100644
--- a/python/mxnet/lr_scheduler.py
+++ b/python/mxnet/lr_scheduler.py
@@ -100,16 +100,27 @@ class MultiFactorScheduler(LRScheduler):
 
     Then calculate the new learning rate by::
 
-       base_lr * pow(factor, k+1)
+        base_lr * pow(factor, k+1)
+
+    When warmup_step>1, warmup the learning rate by a const value for first warmup_step steps.
+    It returns a new learning rate by::
+
+        warmup_start_lr + (num_update - 1) * const_update
 
     Parameters
     ----------
     step: list of int
-        The list of steps to schedule a change
+        The list of steps to schedule a change.
     factor: float
         The factor to change the learning rate.
+    warmup_step : int, optional
+        Increment the learning rate by a constant amount in the first 'warmup_step' updates.
+    warmup_start_lr : float, optional
+        The warmup will update the learning rate start from warmup_start_lr.
+    warmup_stop_lr : float, optional
+        After the first 'warmup_step' updates, the learning rate will reach warmup_stop_lr.
     """
-    def __init__(self, step, factor=1):
+    def __init__(self, step, factor=1, warmup_step=0, warmup_start_lr=0, warmup_stop_lr=0):
         super(MultiFactorScheduler, self).__init__()
         assert isinstance(step, list) and len(step) >= 1
         for i, _step in enumerate(step):
@@ -119,20 +130,53 @@ def __init__(self, step, factor=1):
                 raise ValueError("Schedule step must be greater or equal than 1 round")
         if factor > 1.0:
             raise ValueError("Factor must be no more than 1 to make lr reduce")
+
+        #multifactor parameter
         self.step = step
         self.cur_step_ind = 0
         self.factor = factor
         self.count = 0
 
+        #warmup parameter
+        self.warmup_step = warmup_step
+        if warmup_step > 1:
+            if step[0] <= warmup_step:
+                raise ValueError("Schedule step must be greater than warmup_step")
+            if warmup_stop_lr <= warmup_start_lr:
+                raise ValueError("Stop lr must be greater than begin lr")
+            self.warmup_start_lr = warmup_start_lr
+            self.warmup_stop_lr = warmup_stop_lr
+            self.const_update = (self.warmup_stop_lr - self.warmup_start_lr) / \
+                                (self.warmup_step - 1)
+            self.cur_step = 0
+
     def __call__(self, num_update):
-        # NOTE: use while rather than if  (for continuing training via load_epoch)
-        while self.cur_step_ind <= len(self.step)-1:
-            if num_update > self.step[self.cur_step_ind]:
-                self.count = self.step[self.cur_step_ind]
-                self.cur_step_ind += 1
-                self.base_lr *= self.factor
-                logging.info("Update[%d]: Change learning rate to %0.5e",
-                             num_update, self.base_lr)
+        """
+        Call to schedule current learning rate
+        Parameters
+        ----------
+        num_update: int
+            the maximal number of updates applied to a weight.
+        """
+        if self.warmup_step > 1 and num_update <= self.warmup_step:
+            if num_update > self.cur_step:
+                self.base_lr = (num_update - 1) * self.const_update + self.warmup_start_lr
+                self.cur_step = num_update
+                if num_update == self.warmup_step or self.base_lr >= self.warmup_stop_lr:
+                    self.base_lr = self.warmup_stop_lr
+                    logging.info("Update[%d]: now learning rate arrived at %0.5e, will not "
+                                 "warm up in the future", num_update, self.base_lr)
             else:
                 return self.base_lr
+        else:
+            # NOTE: use while rather than if  (for continuing training via load_epoch)
+            while self.cur_step_ind <= len(self.step)-1:
+                if num_update > self.step[self.cur_step_ind]:
+                    self.count = self.step[self.cur_step_ind]
+                    self.cur_step_ind += 1
+                    self.base_lr *= self.factor
+                    logging.info("Update[%d]: Change learning rate to %0.5e",
+                                 num_update, self.base_lr)
+                else:
+                    return self.base_lr
         return self.base_lr
diff --git a/tests/python/unittest/test_lr_scheduler.py b/tests/python/unittest/test_lr_scheduler.py
new file mode 100644
index 00000000000..56a586d83ea
--- /dev/null
+++ b/tests/python/unittest/test_lr_scheduler.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+import mxnet as mx 
+import mxnet.optimizer as opt              
+
+def multi_lr_sceduler(lr, steps, lr_factor = 1, warmup_step = 0, warmup_lr = 0):
+    lr_scheduler = None
+    if warmup_step > 0 and warmup_lr > lr:
+        lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_factor, 
+                    warmup_step = warmup_step, warmup_start_lr=lr, warmup_stop_lr=warmup_lr)
+    else:  
+        lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_factor) 
+
+    optimizer_params = {
+            'learning_rate': lr,
+            'lr_scheduler': lr_scheduler}
+
+    optimizer = opt.create('sgd', **optimizer_params)  
+    updater = opt.get_updater(optimizer)     
+
+    x = [[[[i*10+j for j in range(10)] for i in range(10)]]]
+    x = mx.nd.array(x, dtype='float32')
+    y = mx.nd.ones(shape = x.shape, dtype='float32') 
+
+    res_lr = []
+    for i in range(1,steps[-1] + 5):
+        updater(0, y, x)
+        cur_lr = optimizer._get_lr(0)
+        res_lr.append(cur_lr)
+
+    if warmup_step > 1:
+        assert mx.test_utils.almost_equal(res_lr[warmup_step], warmup_lr, 1e-10) 
+        lr = warmup_lr
+    for i in range(len(steps)):
+        assert mx.test_utils.almost_equal(res_lr[steps[i]], lr * pow(lr_factor, i + 1), 1e-10)  
+
+def test_multi_lr_scheduler():
+    #Legal input
+    multi_lr_sceduler(lr = 0.02, steps=[100, 200])
+    multi_lr_sceduler(lr = 0.2, steps = [8,12], lr_factor = 0.1, warmup_step = 0, warmup_lr = 0.1)
+    multi_lr_sceduler(lr = 0.02, steps = [8,12], lr_factor = 0.1, warmup_step = 1, warmup_lr = 0.1)
+    multi_lr_sceduler(lr = 0.02, steps = [8,12], lr_factor = 0.3, warmup_step = 5, warmup_lr = 0.1)
+    multi_lr_sceduler(lr = 0.002, steps = [8,12], lr_factor = 0.1, warmup_step = 7, warmup_lr = 0.1)
+    #Illegal input
+    """
+    #Schedule step must be greater than warmup_step
+    multi_lr_sceduler(lr = 0.02, steps = [8,12], lr_factor = 0.1, warmup_step = 10, warmup_lr = 0.1)
+    #warmup_stop_lr must larger than warmup_start_lr
+    multi_lr_sceduler(lr = 0.02, steps = [8,12], lr_factor = 0.1, warmup_step = 10, warmup_lr = 0.001)
+    #Schedule step must be an list
+    multi_lr_sceduler(lr = 0.02, steps = 8, lr_factor = 0.1, warmup_step = 5, warmup_lr = 0.1)
+    #Factor must be no more than 1 to make lr reduce
+    multi_lr_sceduler(lr = 0.02, steps = [8,12], lr_factor = 2, warmup_step = 5, warmup_lr = 0.1)
+    #Schedule step must be an increasing integer list
+    multi_lr_sceduler(lr = 0.02, steps = [12,8], lr_factor = 0.1, warmup_step = 5, warmup_lr = 0.1)
+    """
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()
+    


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services