You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/01/15 01:06:10 UTC

[incubator-mxnet] branch master updated: Use a better formula to calculate sigmoid_bce and logistic_loss (#9404)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 134b205  Use a better formula to calculate sigmoid_bce and logistic_loss (#9404)
134b205 is described below

commit 134b2053a40846d07698a07d3408d46c5823c0ff
Author: Xingjian Shi <xs...@ust.hk>
AuthorDate: Sun Jan 14 17:06:01 2018 -0800

    Use a better formula to calculate sigmoid_bce and logistic_loss (#9404)
    
    * use a more stable formula to calculate sigmoid_bce and logistic_loss
    
    * fix docstring
---
 python/mxnet/gluon/loss.py         | 15 ++++++++-------
 tests/python/unittest/test_loss.py | 17 +++++++++++++++++
 2 files changed, 25 insertions(+), 7 deletions(-)

diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 435230e..2be4398 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -229,8 +229,8 @@ class SigmoidBinaryCrossEntropyLoss(Loss):
     def hybrid_forward(self, F, pred, label, sample_weight=None):
         label = _reshape_like(F, label, pred)
         if not self._from_sigmoid:
-            max_val = F.relu(-pred)
-            loss = pred - pred*label + max_val + F.log(F.exp(-max_val)+F.exp(-pred-max_val))
+            # We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x)))
+            loss = F.relu(pred) - pred * label + F.Activation(-F.abs(pred), act_type='softrelu')
         else:
             loss = -(F.log(pred+1e-12)*label + F.log(1.-pred+1e-12)*(1.-label))
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
@@ -635,8 +635,8 @@ class LogisticLoss(Loss):
 
     Inputs:
         - **pred**: prediction tensor with arbitrary shape.
-        - **label**: truth tensor with values -1 or 1. Must have the same size
-          as pred.
+        - **label**: truth tensor with values -1/1 (label_format is 'signed')
+          or 0/1 (label_format is 'binary'). Must have the same size as pred.
         - **sample_weight**: element-wise weighting tensor. Must be broadcastable
           to the same shape as pred. For example, if pred has shape (64, 10)
           and you want to weigh each sample in the batch separately,
@@ -655,9 +655,10 @@ class LogisticLoss(Loss):
 
     def hybrid_forward(self, F, pred, label, sample_weight=None):
         label = _reshape_like(F, label, pred)
-        if self._label_format == 'binary':
-            label = 2 * label - 1  # Transform label to be either -1 or 1
-        loss = F.log(1.0 + F.exp(-pred * label))
+        if self._label_format == 'signed':
+            label = (label + 1.0) / 2.0  # Transform label to be either 0 or 1
+        # Use a stable formula in computation
+        loss = F.relu(pred) - pred * label + F.Activation(-F.abs(pred), act_type='softrelu')
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
         return F.mean(loss, axis=self._batch_axis, exclude=True)
 
diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py
index e044df0..9fb3033 100644
--- a/tests/python/unittest/test_loss.py
+++ b/tests/python/unittest/test_loss.py
@@ -97,6 +97,14 @@ def test_bce_loss():
             eval_metric=mx.metric.Loss(), optimizer='adam',
             initializer=mx.init.Xavier(magnitude=2))
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01
+    # Test against npy
+    data = mx.random.uniform(-5, 5, shape=(10,))
+    label = mx.random.uniform(0, 1, shape=(10,))
+    mx_bce_loss = Loss(data, label).asnumpy()
+    prob_npy = 1.0 / (1.0 + np.exp(-data.asnumpy()))
+    label_npy = label.asnumpy()
+    npy_bce_loss = - label_npy * np.log(prob_npy) - (1 - label_npy) * np.log(1 - prob_npy)
+    assert_almost_equal(mx_bce_loss, npy_bce_loss)
 
 def test_bce_equal_ce2():
     N = 100
@@ -107,6 +115,15 @@ def test_bce_equal_ce2():
     label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1)))
     assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy())
 
+def test_logistic_loss_equal_bce():
+    N = 100
+    loss_binary = gluon.loss.LogisticLoss(label_format='binary')
+    loss_signed = gluon.loss.LogisticLoss(label_format='signed')
+    loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False)
+    data = mx.random.uniform(-10, 10, shape=(N, 1))
+    label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1)))
+    assert_almost_equal(loss_binary(data, label).asnumpy(), loss_bce(data, label).asnumpy())
+    assert_almost_equal(loss_signed(data, 2 * label - 1).asnumpy(), loss_bce(data, label).asnumpy())
 
 def test_kl_loss():
     np.random.seed(1234)

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].