You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/13 00:00:40 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12697: [MXNET -1004] Poisson NegativeLog Likelihood loss

sandeep-krishnamurthy closed pull request #12697: [MXNET -1004] Poisson NegativeLog Likelihood loss
URL: https://github.com/apache/incubator-mxnet/pull/12697
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/gluon/loss.md b/docs/api/python/gluon/loss.md
index 1aeb340a3db..3747a0f89bf 100644
--- a/docs/api/python/gluon/loss.md
+++ b/docs/api/python/gluon/loss.md
@@ -25,6 +25,7 @@ This package includes several commonly used loss functions in neural networks.
     LogisticLoss
     TripletLoss
     CTCLoss
+    PoissonNLLLoss
 ```
 
 
diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 2be43981a64..fb5dd9d610c 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -23,8 +23,9 @@
            'SigmoidBinaryCrossEntropyLoss', 'SigmoidBCELoss',
            'SoftmaxCrossEntropyLoss', 'SoftmaxCELoss',
            'KLDivLoss', 'CTCLoss', 'HuberLoss', 'HingeLoss',
-           'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss']
+           'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss', 'PoissonNLLLoss']
 
+import numpy as np
 from .. import ndarray
 from ..base import numeric_types
 from .block import HybridBlock
@@ -706,3 +707,63 @@ def hybrid_forward(self, F, pred, positive, negative):
                      axis=self._batch_axis, exclude=True)
         loss = F.relu(loss + self._margin)
         return _apply_weighting(F, loss, self._weight, None)
+
+
+class PoissonNLLLoss(Loss):
+    r"""For a target (Random Variable) in a Poisson distribution, the function calculates the Negative
+    Log likelihood loss.
+    PoissonNLLLoss measures the loss accrued from a poisson regression prediction made by the model.
+
+    .. math::
+        L = \text{pred} - \text{target} * \log(\text{pred}) +\log(\text{target!})
+
+    `pred`, `target` can have arbitrary shape as long as they have the same number of elements.
+
+    Parameters
+    ----------
+    from_logits : boolean, default True
+        indicating whether log(predicted) value has already been computed. If True, the loss is computed as
+        :math:`\exp(\text{pred}) - \text{target} * \text{pred}`, and if False, then loss is computed as
+        :math:`\text{pred} - \text{target} * \log(\text{pred}+\text{epsilon})`.The default value
+    weight : float or None
+        Global scalar weight for loss.
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+    compute_full: boolean, default False
+        Indicates whether to add an approximation(Stirling factor) for the Factorial term in the formula for the loss.
+        The Stirling factor is:
+        :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`
+    epsilon: float, default 1e-08
+        This is to avoid calculating log(0) which is not defined.
+
+
+    Inputs:
+        - **pred**:   Predicted value
+        - **target**: Random variable(count or number) which belongs to a Poisson distribution.
+        - **sample_weight**: element-wise weighting tensor. Must be broadcastable
+          to the same shape as pred. For example, if pred has shape (64, 10)
+          and you want to weigh each sample in the batch separately,
+          sample_weight should have shape (64, 1).
+
+    Outputs:
+        - **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,).
+    """
+    def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=False, **kwargs):
+        super(PoissonNLLLoss, self).__init__(weight, batch_axis, **kwargs)
+        self._from_logits = from_logits
+        self._compute_full = compute_full
+
+    def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
+        target = _reshape_like(F, target, pred)
+        if self._from_logits:
+            loss = F.exp(pred) - target * pred
+        else:
+            loss = pred - target * F.log(pred + epsilon)
+        if self._compute_full:
+            # Using numpy's pi value
+            stirling_factor = target * F.log(target)- target + 0.5 * F.log(2 * target * np.pi)
+            target_gt_1 = target > 1
+            stirling_factor *= target_gt_1
+            loss += stirling_factor
+        loss = _apply_weighting(F, loss, self._weight, sample_weight)
+        return F.mean(loss)
diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py
index 24cc747a308..dc0a6359bb2 100644
--- a/tests/python/unittest/test_loss.py
+++ b/tests/python/unittest/test_loss.py
@@ -348,6 +348,61 @@ def test_triplet_loss():
             optimizer='adam')
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05
 
+@with_seed()
+def test_poisson_nllloss():
+    pred = mx.nd.random.normal(shape=(3, 4))
+    min_pred = mx.nd.min(pred)
+    #This is necessary to ensure only positive random values are generated for prediction,
+    # to avoid ivalid log calculation
+    pred[:] = pred + mx.nd.abs(min_pred)
+    target = mx.nd.random.normal(shape=(3, 4))
+    min_target = mx.nd.min(target)
+    #This is necessary to ensure only positive random values are generated for prediction,
+    # to avoid ivalid log calculation
+    target[:] += mx.nd.abs(min_target)
+
+    Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
+    Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
+    #Calculating by brute formula for default value of from_logits = True
+
+    # 1) Testing for flag logits = True
+    brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy())
+    loss_withlogits = Loss(pred, target)
+    assert_almost_equal(brute_loss, loss_withlogits.asscalar())
+
+    #2) Testing for flag logits = False
+    loss_no_logits = Loss_no_logits(pred, target)
+    np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08))
+    if np.isnan(loss_no_logits.asscalar()):
+        assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar()))
+    else:
+        assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar())
+
+    #3) Testing for Sterling approximation
+    np_pred = np.random.uniform(1, 5, (2, 3))
+    np_target = np.random.uniform(1, 5, (2, 3))
+    np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
+     np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)))
+    Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
+    loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target))
+    assert_almost_equal(np_compute_full, loss_compute_full.asscalar())
+
+@with_seed()
+def test_poisson_nllloss_mod():
+    N = 1000
+    data = mx.random.poisson(shape=(N, 2))
+    label = mx.random.poisson(lam=4, shape=(N, 1))
+    data_iter = mx.io.NDArrayIter(data, label, batch_size=20, label_name='label', shuffle=True)
+    output = mx.sym.exp(get_net(1))
+    l = mx.symbol.Variable('label')
+    Loss = gluon.loss.PoissonNLLLoss(from_logits=False)
+    loss = Loss(output, l)
+    loss = mx.sym.make_loss(loss)
+    mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
+    mod.fit(data_iter, num_epoch=20, optimizer_params={'learning_rate': 0.01},
+            initializer=mx.init.Normal(sigma=0.1), eval_metric=mx.metric.Loss(),
+            optimizer='adam')
+    assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services