You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by gi...@git.apache.org on 2017/08/25 03:17:44 UTC

[GitHub] smolix commented on a change in pull request #7605: Many loss functions

smolix commented on a change in pull request #7605: Many loss functions 
URL: https://github.com/apache/incubator-mxnet/pull/7605#discussion_r135177021
 
 

 ##########
 File path: python/mxnet/gluon/loss.py
 ##########
 @@ -385,3 +377,511 @@ def hybrid_forward(self, F, data, label,
                                  data_lengths=data_lengths, label_lengths=label_lengths,
                                  padding_mask=self._padding_mask)
         return _apply_weighting(F, loss, self._weight, sample_weight)
+
+class Huber(Loss):
+    """Calculates Huber's robust loss function yielding a trimmed mean estimator, i.e.
+       L2 loss in the center and L1 loss for deviations beyond rho:
+
+    .. math::
+        L = \\begin{cases} \\frac{1}{2 \\rho} ({output}_i - {label}_i)^2 &
+                           \\text{ if } |{output}_i - {label}_i| < \\rho \\\
+                           |{output}_i - {label}_i| - \\frac{\\rho}{2} &
+                           \\text{ otherwise }
+            \\end{cases}
+
+    Output and label must have the same shape. This is a scalar loss function.
+
+    Parameters
+    ----------
+    rho : float
+        Threshold for trimmed mean estimator. By default set to 1
+    weight : float or None
+        Global scalar weight for loss.
+    sample_weight : Symbol or None
+        Per sample weighting. Must be broadcastable to
+        the same shape as loss. For example, if loss has
+        shape (64, 10) and you want to weight each sample
+        in the batch, `sample_weight` should have shape (64, 1).
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+    """
+    def __init__(self, rho=1, weight=None, batch_axis=0, **kwargs):
+        super(Huber, self).__init__(weight, batch_axis, **kwargs)
+        self._rho = rho
+
+    def hybrid_forward(self, F, output, label, sample_weight=None):
+        label = _reshape_label_as_output(F, output, label)
+        loss = F.abs(output - label)
+        loss = ((loss > self._rho) * (loss - 0.5 * self._rho) +
+               (0.5/self._rho) * (loss <= self._rho) * loss**2)
 
 Review comment:
   Elegant! Now I learned that we have a where operator ...
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services