You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/10/13 00:00:55 UTC

[incubator-mxnet] branch master updated: [MXNET -1004] Poisson NegativeLog Likelihood loss (#12697)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new efa7d3a  [MXNET -1004] Poisson NegativeLog Likelihood loss (#12697)
efa7d3a is described below

commit efa7d3ad96408fe4c5d290dcf19b296ab2ce0fd9
Author: Gaurav Gireesh <Ga...@fox.com>
AuthorDate: Fri Oct 12 17:00:38 2018 -0700

    [MXNET -1004] Poisson NegativeLog Likelihood loss (#12697)
    
    * PoissonNLLLoss function to compute negative log likelihood loss
    
    * Removing debugging print statements
    
    * Pylint code formatting problems addressed
    
    * Added Stirling approximation for factorial term in the denominator and test case for the same
    
    * Separated the test cases for Flag value for logits and compute_full
    
    * Added comments for package- numpy inclusion and some pylint formatting
    
    * Trigger CI
    
    * Markdown file updted. Added entry for Poissons NLLLoss
    
    * Fixing pending documentation issue
    
    * Documentation docstring changed
    
    * PR Comment to remove extra newline removed.
    
    * Symbol PI corrected
    
    * epsilon spellicng correction
    
    * More unit tests added - testing with mod.score() and mod.fit()
    
    * changed the number of epochs
    
    * PR Comments addressed added mod score tests and a newline
    
    * Empty line added
    
    * Adding hybridized test
    
    * Trigger CI
    
    * Variable names changed
---
 docs/api/python/gluon/loss.md      |  1 +
 python/mxnet/gluon/loss.py         | 63 +++++++++++++++++++++++++++++++++++++-
 tests/python/unittest/test_loss.py | 55 +++++++++++++++++++++++++++++++++
 3 files changed, 118 insertions(+), 1 deletion(-)

diff --git a/docs/api/python/gluon/loss.md b/docs/api/python/gluon/loss.md
index 1aeb340..3747a0f 100644
--- a/docs/api/python/gluon/loss.md
+++ b/docs/api/python/gluon/loss.md
@@ -25,6 +25,7 @@ This package includes several commonly used loss functions in neural networks.
     LogisticLoss
     TripletLoss
     CTCLoss
+    PoissonNLLLoss
 ```
 
 
diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 7e4d345..5d3ebb7 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -23,8 +23,9 @@ __all__ = ['Loss', 'L2Loss', 'L1Loss',
            'SigmoidBinaryCrossEntropyLoss', 'SigmoidBCELoss',
            'SoftmaxCrossEntropyLoss', 'SoftmaxCELoss',
            'KLDivLoss', 'CTCLoss', 'HuberLoss', 'HingeLoss',
-           'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss']
+           'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss', 'PoissonNLLLoss']
 
+import numpy as np
 from .. import ndarray
 from ..base import numeric_types
 from .block import HybridBlock
@@ -706,3 +707,63 @@ class TripletLoss(Loss):
                      axis=self._batch_axis, exclude=True)
         loss = F.relu(loss + self._margin)
         return _apply_weighting(F, loss, self._weight, None)
+
+
+class PoissonNLLLoss(Loss):
+    r"""For a target (Random Variable) in a Poisson distribution, the function calculates the Negative
+    Log likelihood loss.
+    PoissonNLLLoss measures the loss accrued from a poisson regression prediction made by the model.
+
+    .. math::
+        L = \text{pred} - \text{target} * \log(\text{pred}) +\log(\text{target!})
+
+    `pred`, `target` can have arbitrary shape as long as they have the same number of elements.
+
+    Parameters
+    ----------
+    from_logits : boolean, default True
+        indicating whether log(predicted) value has already been computed. If True, the loss is computed as
+        :math:`\exp(\text{pred}) - \text{target} * \text{pred}`, and if False, then loss is computed as
+        :math:`\text{pred} - \text{target} * \log(\text{pred}+\text{epsilon})`.The default value
+    weight : float or None
+        Global scalar weight for loss.
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+    compute_full: boolean, default False
+        Indicates whether to add an approximation(Stirling factor) for the Factorial term in the formula for the loss.
+        The Stirling factor is:
+        :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`
+    epsilon: float, default 1e-08
+        This is to avoid calculating log(0) which is not defined.
+
+
+    Inputs:
+        - **pred**:   Predicted value
+        - **target**: Random variable(count or number) which belongs to a Poisson distribution.
+        - **sample_weight**: element-wise weighting tensor. Must be broadcastable
+          to the same shape as pred. For example, if pred has shape (64, 10)
+          and you want to weigh each sample in the batch separately,
+          sample_weight should have shape (64, 1).
+
+    Outputs:
+        - **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,).
+    """
+    def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=False, **kwargs):
+        super(PoissonNLLLoss, self).__init__(weight, batch_axis, **kwargs)
+        self._from_logits = from_logits
+        self._compute_full = compute_full
+
+    def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
+        target = _reshape_like(F, target, pred)
+        if self._from_logits:
+            loss = F.exp(pred) - target * pred
+        else:
+            loss = pred - target * F.log(pred + epsilon)
+        if self._compute_full:
+            # Using numpy's pi value
+            stirling_factor = target * F.log(target)- target + 0.5 * F.log(2 * target * np.pi)
+            target_gt_1 = target > 1
+            stirling_factor *= target_gt_1
+            loss += stirling_factor
+        loss = _apply_weighting(F, loss, self._weight, sample_weight)
+        return F.mean(loss)
diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py
index a931a30..2b062fb 100644
--- a/tests/python/unittest/test_loss.py
+++ b/tests/python/unittest/test_loss.py
@@ -348,6 +348,61 @@ def test_triplet_loss():
             optimizer='adam')
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05
 
+@with_seed()
+def test_poisson_nllloss():
+    pred = mx.nd.random.normal(shape=(3, 4))
+    min_pred = mx.nd.min(pred)
+    #This is necessary to ensure only positive random values are generated for prediction,
+    # to avoid ivalid log calculation
+    pred[:] = pred + mx.nd.abs(min_pred)
+    target = mx.nd.random.normal(shape=(3, 4))
+    min_target = mx.nd.min(target)
+    #This is necessary to ensure only positive random values are generated for prediction,
+    # to avoid ivalid log calculation
+    target[:] += mx.nd.abs(min_target)
+
+    Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
+    Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
+    #Calculating by brute formula for default value of from_logits = True
+
+    # 1) Testing for flag logits = True
+    brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy())
+    loss_withlogits = Loss(pred, target)
+    assert_almost_equal(brute_loss, loss_withlogits.asscalar())
+
+    #2) Testing for flag logits = False
+    loss_no_logits = Loss_no_logits(pred, target)
+    np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08))
+    if np.isnan(loss_no_logits.asscalar()):
+        assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar()))
+    else:
+        assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar())
+
+    #3) Testing for Sterling approximation
+    np_pred = np.random.uniform(1, 5, (2, 3))
+    np_target = np.random.uniform(1, 5, (2, 3))
+    np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
+     np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)))
+    Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
+    loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target))
+    assert_almost_equal(np_compute_full, loss_compute_full.asscalar())
+
+@with_seed()
+def test_poisson_nllloss_mod():
+    N = 1000
+    data = mx.random.poisson(shape=(N, 2))
+    label = mx.random.poisson(lam=4, shape=(N, 1))
+    data_iter = mx.io.NDArrayIter(data, label, batch_size=20, label_name='label', shuffle=True)
+    output = mx.sym.exp(get_net(1))
+    l = mx.symbol.Variable('label')
+    Loss = gluon.loss.PoissonNLLLoss(from_logits=False)
+    loss = Loss(output, l)
+    loss = mx.sym.make_loss(loss)
+    mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
+    mod.fit(data_iter, num_epoch=20, optimizer_params={'learning_rate': 0.01},
+            initializer=mx.init.Normal(sigma=0.1), eval_metric=mx.metric.Loss(),
+            optimizer='adam')
+    assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05
 
 if __name__ == '__main__':
     import nose