You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/12 22:18:33 UTC

[GitHub] zhreshold commented on a change in pull request #8918: Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add?

zhreshold commented on a change in pull request #8918: Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add?
URL: https://github.com/apache/incubator-mxnet/pull/8918#discussion_r161338805
 
 

 ##########
 File path: python/mxnet/optimizer.py
 ##########
 @@ -531,6 +531,197 @@ def update_multi_precision(self, index, weight, grad, state):
         self._update_impl(index, weight, grad, state,
                           multi_precision=use_multi_precision)
 
+@register
+class LBSGD(Optimizer):
+    """The Large Batch SGD optimizer with momentum and weight decay.
+
+    The optimizer updates the weight by::
+
+        state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
+        weight = weight - state
+
+    For details of the update algorithm see :class:`~mxnet.ndarray.lbsgd_update` and
+    :class:`~mxnet.ndarray.lbsgd_mom_update`.
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    momentum : float, optional
+       The momentum value.
+    multi_precision: bool, optional
+       Flag to control the internal precision of the optimizer.
+       ``False`` results in using the same precision as the weights (default),
+       ``True`` makes internal 32-bit copy of the weights and applies gradients
+                in 32-bit precision even if actual weights used in the model have lower precision.`<
+                Turning this on can improve convergence and accuracy when training with float16.
+    warmup_strategy: string ('linear', 'power', 'sqrt'. , 'lars'   default : 'linear')
+    warmup_epochs: unsigned, default: 5
+    batch_scale:   unsigned, default: 1 (same as batch size*numworkers)
+    updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
+    begin_epoch: unsigned, default 0, starting epoch.
+    """
+
+    def __init__(self, momentum=0.0, multi_precision=False, warmup_strategy='linear',
+                 warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60,
+                 **kwargs):
+        super(LBSGD, self).__init__(**kwargs)
+        logging.info('Running Large-Batch SGD Algorithm')
+        logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)',
+                     batch_scale, warmup_epochs, warmup_strategy, updates_per_epoch)
+        self.momentum = momentum
+        self.multi_precision = multi_precision
+        # new user parameters for large batch
+        self.warmup_strategy = warmup_strategy
+        self.warmup_epochs = warmup_epochs
+        self.batch_scale = batch_scale
+        self.updates_per_epoch = updates_per_epoch
+        self.init_updates = begin_epoch * updates_per_epoch
+        self.num_epochs = num_epochs
+        # addl internal usage parameters and storage
+        self.lbmult = 1
+        self.cumgrads = {}
+        # for adaptive lr
+        self.adaptive = False
+        self.admult = 1  # adaptation constant
+
+    def create_state(self, index, weight):
+        momentum = None
+        weight_master_copy = None
+        if self.multi_precision and weight.dtype == numpy.float16:
+            weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32)
+            if self.momentum != 0.0:
+                momentum = zeros(weight.shape, weight.context, dtype=numpy.float32,
+                                 stype=weight.stype)
+            return (momentum, weight_master_copy)
+        if weight.dtype == numpy.float16 and not self.multi_precision:
+            warnings.warn("Accumulating with float16 in optimizer can lead to "
+                          "poor accuracy or slow convergence. "
+                          "Consider using multi_precision=True option of the "
+                          "SGD optimizer")
+        if self.momentum != 0.0:
+            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+        return momentum
+
+    def _get_lbmult(self, nup):
+        """Returns lr scaling factor for large batch according to warmup schedule
+        (to be implemented)
+        """
+        nwup = self.warmup_epochs * self.updates_per_epoch
+        strategy = self.warmup_strategy
+        maxmult = float(self.batch_scale)
+        if nup >= nwup:
+            mult = maxmult
+        elif nwup <= 1:
+            mult = 1.0
+        else:
+            if (strategy == 'linear'):
+                mult = 1.0 + (maxmult - 1) * nup / nwup
+            elif (strategy == 'power2'):
+                mult = 1.0 + (maxmult - 1) * (nup * nup) / (nwup * nwup)
+            elif (strategy == 'power3'):
+                mult = 1.0 + (maxmult - 1) * (nup * nup) / (nwup * nwup)
 
 Review comment:
   Power3 is wrong

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services