You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/10/17 22:07:45 UTC

[GitHub] [incubator-mxnet] eric-haibin-lin commented on a change in pull request #16122: Add fast implementation of LARS

eric-haibin-lin commented on a change in pull request #16122: Add fast implementation of LARS
URL: https://github.com/apache/incubator-mxnet/pull/16122#discussion_r336248596
 
 

 ##########
 File path: python/mxnet/optimizer/optimizer.py
 ##########
 @@ -781,6 +784,266 @@ def update(self, index, weight, grad, state):
         ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight,
                     lr=lr, wd=wd, **kwargs)
 
+@register
+class LARS(Optimizer):
+    """the LARS optimizer from 'Large Batch Training of Convolution Networks' \
+    (https://arxiv.org/abs/1708.03888)
+
+    Behave mostly like SGD with momentum and weight decay but is scaling \
+    adaptively the learning for each layer (except bias and batch norm parameters):
+    w_norm = L2norm(weights)
+    g_norm = L2norm(gradients)
+    if w_norm > 0 and g_norm > 0:
+        lr_layer = lr * lr_mult * eta * w_norm / (g_norm + weight_decay * w_norm + eps)
+    else:
+        lr_layer = lr * lr_mult
+
+    Parameters
+    ----------
+    momentum : float, optional
+        The momentum value.
+    lazy_update : bool, optional
+        Default is True. If True, lazy updates are applied \
+        if the storage types of weight and grad are both ``row_sparse``.
+    lars_eta : float, optional
+        LARS coefficient used to scale the learning rate. Default set to 0.001.
+    lars_epsilon : float, optional
+        Optional epsilon in case of very small gradients. Default set to 0.
+    momentum_correction : bool, optional
+        If True scale momentum w.r.t global learning rate change (with an lr_scheduler) \
+        as indicated in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour` \
+        (https://arxiv.org/pdf/1706.02677.pdf)
+        Default set to True.
+    """
+    def __init__(self, momentum=0.0, lazy_update=True, eta=0.001, eps=0,
+                 momentum_correction=True, **kwargs):
+        super(LARS, self).__init__(**kwargs)
+        self.momentum = momentum
+        self.momentum_correction = momentum_correction
+        self.lazy_update = lazy_update
+        self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4"))
+        self.eta = eta
+        self.eps = eps
+        self.skip = 0
+        self.last_lr = None
+        self.cur_lr = None
+
+
+    def _get_lrs(self, indices):
+        """Gets the learning rates given the indices of the weights.
+
+        Parameters
+        ----------
+        indices : list of int
+            Indices corresponding to weights.
+
+        Returns
+        -------
+        lrs : list of float
+            Learning rates for those indices.
+        """
+        if self.cur_lr is not None:
+            self.last_lr = self.cur_lr
+
+        if self.lr_scheduler is not None:
+            lr = self.lr_scheduler(self.num_update)
+        else:
+            lr = self.lr
+
+        if self.cur_lr is None:
+            self.last_lr = lr
+        self.cur_lr = lr
+
+        lrs = [lr for _ in indices]
+        for i, index in enumerate(indices):
+            if index in self.param_dict:
+                lrs[i] *= self.param_dict[index].lr_mult
+            elif index in self.lr_mult:
+                lrs[i] *= self.lr_mult[index]
+            elif index in self.idx2name:
+                lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0)
+        return lrs
+
+    def set_wd_mult(self, args_wd_mult):
+        self.wd_mult = {}
+        for n in self.idx2name.values():
+            is_weight = n.endswith('_weight')
+
+            if not is_weight:
+                self.wd_mult[n] = 0.0
+
+        if self.sym_info:
+            attr, arg_names = self.sym_info
+            for name in arg_names:
+                if name in attr and '__wd_mult__' in attr[name]:
+                    self.wd_mult[name] = float(attr[name]['__wd_mult__'])
+        self.wd_mult.update(args_wd_mult)
+
+    def create_state_multi_precision(self, index, weight):
+        weight_master_copy = None
+        if self.multi_precision and weight.dtype == numpy.float16:
+            weight_master_copy = weight.astype(numpy.float32)
+            return (self.create_state(index, weight_master_copy), weight_master_copy)
+        if weight.dtype == numpy.float16 and not self.multi_precision:
+            warnings.warn("Accumulating with float16 in optimizer can lead to "
+                          "poor accuracy or slow convergence. "
+                          "Consider using multi_precision=True option of the "
+                          "SGD optimizer")
+        return self.create_state(index, weight)
+
+    def create_state(self, index, weight):
+        momentum = None
+        if self.momentum != 0.0:
+            stype = weight.stype if self.lazy_update else 'default'
+            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
+        return momentum
+
+    def _l2norm(self, v, rescale=False):
+        """L2 Norm implementation"""
+        v = v.astype('float32')
+        if rescale:
+            v *= self.rescale_grad
+        norm = NDnorm(v).asnumpy()[0]
 
 Review comment:
   Is this intended? I thought having a blockingc call is bad

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services