You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/29 17:16:16 UTC

[GitHub] zhreshold closed pull request #8918: Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add?

zhreshold closed pull request #8918: Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add?
URL: https://github.com/apache/incubator-mxnet/pull/8918
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index 2b002c7702..d9f96d0eba 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -15,10 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import mxnet as mx
+""" example train fit utility """
 import logging
 import os
 import time
+import re
+import math
+import mxnet as mx
+
 
 def _get_lr_scheduler(args, kv):
     if 'lr_factor' not in args or args.lr_factor >= 1:
@@ -27,17 +31,26 @@ def _get_lr_scheduler(args, kv):
     if 'dist' in args.kv_store:
         epoch_size /= kv.num_workers
     begin_epoch = args.load_epoch if args.load_epoch else 0
+    if 'pow' in args.lr_step_epochs:
+        lr = args.lr
+        max_up = args.num_epochs * epoch_size
+        pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs))
+        poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr)
+        return (lr, poly_sched)
     step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
     lr = args.lr
     for s in step_epochs:
         if begin_epoch >= s:
             lr *= args.lr_factor
     if lr != args.lr:
-        logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))
+        logging.info('Adjust learning rate to %e for epoch %d',
+                     lr, begin_epoch)
 
-    steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
+    steps = [epoch_size * (x - begin_epoch)
+             for x in step_epochs if x - begin_epoch > 0]
     return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))
 
+
 def _load_model(args, rank=0):
     if 'load_epoch' not in args or args.load_epoch is None:
         return (None, None, None)
@@ -50,6 +63,7 @@ def _load_model(args, rank=0):
     logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
     return (sym, arg_params, aux_params)
 
+
 def _save_model(args, rank=0):
     if args.model_prefix is None:
         return None
@@ -59,6 +73,7 @@ def _save_model(args, rank=0):
     return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
         args.model_prefix, rank))
 
+
 def add_fit_args(parser):
     """
     parser : argparse.ArgumentParser
@@ -68,7 +83,8 @@ def add_fit_args(parser):
     train.add_argument('--network', type=str,
                        help='the neural network to use')
     train.add_argument('--num-layers', type=int,
-                       help='number of layers in the neural network, required by some networks such as resnet')
+                       help='number of layers in the neural network, \
+                             required by some networks such as resnet')
     train.add_argument('--gpus', type=str,
                        help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
     train.add_argument('--kv-store', type=str, default='device',
@@ -81,6 +97,8 @@ def add_fit_args(parser):
                        help='the ratio to reduce lr on each step')
     train.add_argument('--lr-step-epochs', type=str,
                        help='the epochs to reduce the lr, e.g. 30,60')
+    train.add_argument('--initializer', type=str, default='default',
+                       help='the initializer type')
     train.add_argument('--optimizer', type=str, default='sgd',
                        help='the optimizer type')
     train.add_argument('--mom', type=float, default=0.9,
@@ -108,8 +126,16 @@ def add_fit_args(parser):
                              takes `2bit` or `none` for now')
     train.add_argument('--gc-threshold', type=float, default=0.5,
                        help='threshold for 2bit gradient compression')
+    # additional parameters for large batch sgd
+    train.add_argument('--macrobatch-size', type=int, default=0,
+                       help='distributed effective batch size')
+    train.add_argument('--warmup-epochs', type=int, default=5,
+                       help='the epochs to ramp-up lr to scaled large-batch value')
+    train.add_argument('--warmup-strategy', type=str, default='linear',
+                       help='the ramping-up strategy for large batch sgd')
     return train
 
+
 def fit(args, network, data_loader, **kwargs):
     """
     train a model
@@ -135,14 +161,13 @@ def fit(args, network, data_loader, **kwargs):
         for i, batch in enumerate(train):
             for j in batch.data:
                 j.wait_to_read()
-            if (i+1) % args.disp_batches == 0:
-                logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (
-                    i, args.disp_batches*args.batch_size/(time.time()-tic)))
+            if (i + 1) % args.disp_batches == 0:
+                logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
+                             args.disp_batches * args.batch_size / (time.time() - tic))
                 tic = time.time()
 
         return
 
-
     # load model
     if 'arg_params' in kwargs and 'aux_params' in kwargs:
         arg_params = kwargs['arg_params']
@@ -156,7 +181,7 @@ def fit(args, network, data_loader, **kwargs):
     checkpoint = _save_model(args, kv.rank)
 
     # devices for training
-    devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
+    devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
         mx.gpu(int(i)) for i in args.gpus.split(',')]
 
     # learning rate
@@ -164,14 +189,14 @@ def fit(args, network, data_loader, **kwargs):
 
     # create model
     model = mx.mod.Module(
-        context       = devs,
-        symbol        = network
+        context=devs,
+        symbol=network
     )
 
-    lr_scheduler  = lr_scheduler
+    lr_scheduler = lr_scheduler
     optimizer_params = {
         'learning_rate': lr,
-        'wd' : args.wd,
+        'wd': args.wd,
         'lr_scheduler': lr_scheduler,
         'multi_precision': True}
 
@@ -180,40 +205,81 @@ def fit(args, network, data_loader, **kwargs):
     if args.optimizer in has_momentum:
         optimizer_params['momentum'] = args.mom
 
-    monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None
+    monitor = mx.mon.Monitor(
+        args.monitor, pattern=".*") if args.monitor > 0 else None
 
-    if args.network == 'alexnet':
-        # AlexNet will not converge using Xavier
-        initializer = mx.init.Normal()
-    else:
-        initializer = mx.init.Xavier(
-            rnd_type='gaussian', factor_type="in", magnitude=2)
+    # A limited number of optimizers have a warmup period
+    has_warmup = {'lbsgd', 'lbnag'}
+    if args.optimizer in has_warmup:
+        if 'dist' in args.kv_store:
+            nworkers = kv.num_workers
+        else:
+            nworkers = 1
+        epoch_size = args.num_examples / args.batch_size / nworkers
+        if epoch_size < 1:
+            epoch_size = 1
+        macrobatch_size = args.macrobatch_size
+        if macrobatch_size < args.batch_size * nworkers:
+            macrobatch_size = args.batch_size * nworkers
+        #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
+        batch_scale = math.ceil(
+            float(macrobatch_size) / args.batch_size / nworkers)
+        optimizer_params['updates_per_epoch'] = epoch_size
+        optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
+        optimizer_params['batch_scale'] = batch_scale
+        optimizer_params['warmup_strategy'] = args.warmup_strategy
+        optimizer_params['warmup_epochs'] = args.warmup_epochs
+        optimizer_params['num_epochs'] = args.num_epochs
+
+    if args.initializer == 'default':
+        if args.network == 'alexnet':
+            # AlexNet will not converge using Xavier
+            initializer = mx.init.Normal()
+        else:
+            initializer = mx.init.Xavier(
+                rnd_type='gaussian', factor_type="in", magnitude=2)
     # initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34),
+    elif args.initializer == 'xavier':
+        initializer = mx.init.Xavier()
+    elif args.initializer == 'msra':
+        initializer = mx.init.MSRAPrelu()
+    elif args.initializer == 'orthogonal':
+        initializer = mx.init.Orthogonal()
+    elif args.initializer == 'normal':
+        initializer = mx.init.Normal()
+    elif args.initializer == 'uniform':
+        initializer = mx.init.Uniform()
+    elif args.initializer == 'one':
+        initializer = mx.init.One()
+    elif args.initializer == 'zero':
+        initializer = mx.init.Zero()
 
     # evaluation metrices
     eval_metrics = ['accuracy']
     if args.top_k > 0:
-        eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k))
+        eval_metrics.append(mx.metric.create(
+            'top_k_accuracy', top_k=args.top_k))
 
     # callbacks that run after each batch
-    batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)]
+    batch_end_callbacks = [mx.callback.Speedometer(
+        args.batch_size, args.disp_batches)]
     if 'batch_end_callback' in kwargs:
         cbs = kwargs['batch_end_callback']
         batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]
 
     # run
     model.fit(train,
-              begin_epoch        = args.load_epoch if args.load_epoch else 0,
-              num_epoch          = args.num_epochs,
-              eval_data          = val,
-              eval_metric        = eval_metrics,
-              kvstore            = kv,
-              optimizer          = args.optimizer,
-              optimizer_params   = optimizer_params,
-              initializer        = initializer,
-              arg_params         = arg_params,
-              aux_params         = aux_params,
-              batch_end_callback = batch_end_callbacks,
-              epoch_end_callback = checkpoint,
-              allow_missing      = True,
-              monitor            = monitor)
+              begin_epoch=args.load_epoch if args.load_epoch else 0,
+              num_epoch=args.num_epochs,
+              eval_data=val,
+              eval_metric=eval_metrics,
+              kvstore=kv,
+              optimizer=args.optimizer,
+              optimizer_params=optimizer_params,
+              initializer=initializer,
+              arg_params=arg_params,
+              aux_params=aux_params,
+              batch_end_callback=batch_end_callbacks,
+              epoch_end_callback=checkpoint,
+              allow_missing=True,
+              monitor=monitor)
diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py
index e4af77aa86..963560d178 100644
--- a/python/mxnet/lr_scheduler.py
+++ b/python/mxnet/lr_scheduler.py
@@ -136,3 +136,35 @@ def __call__(self, num_update):
             else:
                 return self.base_lr
         return self.base_lr
+
+class PolyScheduler(LRScheduler):
+    """ Reduce the learning rate by given a list of steps.
+
+    Calculate the new learning rate by::
+
+       base_lr * (1-nup/max_nup)^pwr
+       if nup < max_nup, 0 otherwise.
+
+    Parameters
+    ----------
+       max_update: maximum number of updates before the decay reaches 0.
+       base_lr:    base learning rate
+       pwr:   power of the decay term as a funtion of the current number of updates.
+
+    """
+
+    def __init__(self, max_update, base_lr=0.01, pwr=2):
+        super(PolyScheduler, self).__init__(base_lr)
+        assert isinstance(max_update, int)
+        if max_update < 1:
+            raise ValueError("maximum number of updates must be strictly positive")
+        self.base_lr_orig = self.base_lr
+        self.max_update = max_update
+        self.power = pwr
+        self.base_lr = self.base_lr_orig
+
+    def __call__(self, num_update):
+        if num_update <= self.max_update:
+            self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
+                                                   self.power)
+        return self.base_lr
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 57340be280..c3338f4c76 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -645,6 +645,195 @@ def update(self, index, weight, grad, state):
         ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight,
                     lr=lr, wd=wd, **kwargs)
 
+@register
+class LBSGD(Optimizer):
+    """The Large Batch SGD optimizer with momentum and weight decay.
+
+    The optimizer updates the weight by::
+
+        state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
+        weight = weight - state
+
+    For details of the update algorithm see :class:`~mxnet.ndarray.lbsgd_update` and
+    :class:`~mxnet.ndarray.lbsgd_mom_update`.
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    momentum : float, optional
+       The momentum value.
+    multi_precision: bool, optional
+       Flag to control the internal precision of the optimizer.
+       ``False`` results in using the same precision as the weights (default),
+       ``True`` makes internal 32-bit copy of the weights and applies gradients
+                in 32-bit precision even if actual weights used in the model have lower precision.`<
+                Turning this on can improve convergence and accuracy when training with float16.
+    warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars'   default : 'linear')
+    warmup_epochs: unsigned, default: 5
+    batch_scale:   unsigned, default: 1 (same as batch size*numworkers)
+    updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
+    begin_epoch: unsigned, default 0, starting epoch.
+    """
+
+    def __init__(self, momentum=0.0, multi_precision=False, warmup_strategy='linear',
+                 warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60,
+                 **kwargs):
+        super(LBSGD, self).__init__(**kwargs)
+        logging.info('Running Large-Batch SGD Algorithm')
+        logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)',
+                     batch_scale, warmup_epochs, warmup_strategy, updates_per_epoch)
+        self.momentum = momentum
+        self.multi_precision = multi_precision
+        # new user parameters for large batch
+        self.warmup_strategy = warmup_strategy
+        self.warmup_epochs = warmup_epochs
+        self.batch_scale = batch_scale
+        self.updates_per_epoch = updates_per_epoch
+        self.init_updates = begin_epoch * updates_per_epoch
+        self.num_epochs = num_epochs
+        # addl internal usage parameters and storage
+        self.lbmult = 1
+        self.cumgrads = {}
+        # for adaptive lr
+        self.adaptive = False
+        self.admult = 1  # adaptation constant
+
+    def create_state(self, index, weight):
+        momentum = None
+        weight_master_copy = None
+        if self.multi_precision and weight.dtype == numpy.float16:
+            weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32)
+            if self.momentum != 0.0:
+                momentum = zeros(weight.shape, weight.context, dtype=numpy.float32,
+                                 stype=weight.stype)
+            return (momentum, weight_master_copy)
+        if weight.dtype == numpy.float16 and not self.multi_precision:
+            warnings.warn("Accumulating with float16 in optimizer can lead to "
+                          "poor accuracy or slow convergence. "
+                          "Consider using multi_precision=True option of the "
+                          "SGD optimizer")
+        if self.momentum != 0.0:
+            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+        return momentum
+
+    def _get_lbmult(self, nup):
+        """Returns lr scaling factor for large batch according to warmup schedule
+        (to be implemented)
+        """
+        nwup = self.warmup_epochs * self.updates_per_epoch
+        strategy = self.warmup_strategy
+        maxmult = float(self.batch_scale)
+        if nup >= nwup:
+            mult = maxmult
+        elif nwup <= 1:
+            mult = 1.0
+        else:
+            if (strategy == 'linear'):
+                mult = 1.0 + (maxmult - 1) * nup / nwup
+            elif (strategy == 'power2'):
+                mult = 1.0 + (maxmult-1) * (nup*nup)/(nwup*nwup)
+            elif (strategy == 'sqrt'):
+                mult = 1.0 + (maxmult - 1) * math.sqrt(float(nup) / nwup)
+            else:
+                mult = 1.0
+        return mult
+
+    def _get_lars(self, weight, g, wd):
+        """Returns a scaling factor for the learning rate for this layer
+        default is 1
+        """
+        weight2 = self._l2norm(weight)
+        grad2 = self._l2norm(g)
+        lars = math.sqrt(weight2 / (grad2 + wd * weight2 + 1e-18))
+        if lars < 0.01:
+            lars = 0.01
+        elif lars > 100:
+            lars = 100
+        return lars
+
+    def _l2norm(self, v):
+        "inner product implementation"
+        norm = multiply(v, v).asnumpy().sum()
+        return norm
+
+    def _reset_cum_gradient(self, index):
+        "called every macro-batch to reset cumulated gradients to 0 for a given index"
+        self.cumgrads[index]['cum_grad'] = 0
+
+    def _get_cum_gradient(self, index):
+        "get the cumulated gradient for index"
+        if index in self.cumgrads:
+            return self.cumgrads[index]
+        else:
+            return {}
+
+    def _put_cum_gradient(self, index, cgrad):
+        "store cumulated gradient for index"
+        self.cumgrads[index] = cgrad
+
+    def _cumulate_gradient(self, grad, index):
+        "Cumulate gradients for large-batch emulation. Cumulated by index (layer)"
+        cgrad = self._get_cum_gradient(index)
+        if cgrad:
+            num_cums = cgrad['num_cums']
+            if num_cums > 0:
+                cum_grad = cgrad['cum_grad'] + grad
+                num_cums += 1
+            else:
+                cum_grad = grad
+                num_cums = self.init_updates + 1
+        else:
+            cum_grad = grad
+            num_cums = self.init_updates + 1
+        cgrad = {'cum_grad': cum_grad, 'num_cums': num_cums}
+        self._put_cum_gradient(index, cgrad)
+        return cgrad
+
+    def update(self, index, weight, grad, state):
+        assert (isinstance(weight, NDArray))
+        assert (isinstance(grad, NDArray))
+
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+        self._update_count(index)
+
+        # new stuff for large batch
+        cgrad = self._cumulate_gradient(grad, index)
+        if (cgrad['num_cums'] % self.batch_scale) == 0:
+            grad = cgrad['cum_grad'] / self.batch_scale
+            if self.warmup_strategy == 'lars':
+                lbmult = self._get_lars(weight, grad, wd)
+            else:
+                lbmult = self._get_lbmult(cgrad['num_cums'])
+            lr = lr * lbmult
+            # do the regular sgd update flow
+            kwargs = {'rescale_grad': self.rescale_grad}
+            if self.momentum > 0:
+                kwargs['momentum'] = self.momentum
+            if self.clip_gradient:
+                kwargs['clip_gradient'] = self.clip_gradient
+            use_multi_precision = isinstance(state, (list, tuple))
+
+            if not use_multi_precision:
+                if state is not None:
+                    sgd_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
+                else:
+                    sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
+            else:
+                if state[0] is not None:
+                    mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, lr=lr, wd=wd,
+                                      **kwargs)
+                else:
+                    mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs)
+            # reset update count and cumulated gradient per large batch
+            self._reset_cum_gradient(index)
+        else:
+            lr = 0.0
+            kwargs = {}
+            sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
+
 # pylint: enable=line-too-long
 @register
 class DCASGD(Optimizer):
@@ -1282,6 +1471,7 @@ def __call__(self, index, grad, weight):
         self.optimizer.update_multi_precision(index, weight, grad, self.states[index])
 
     def sync_state_context(self, state, context):
+        """sync state context."""
         if isinstance(state, NDArray):
             return state.as_in_context(context)
         elif isinstance(state, (tuple, list)):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services