You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2020/08/08 02:57:14 UTC

[incubator-mxnet] branch master updated: [Numpy][Bugfix] Add hybridization test to loss layers (#18876)

This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cf908fd  [Numpy][Bugfix] Add hybridization test to loss layers (#18876)
cf908fd is described below

commit cf908fd48384b8b0d3f2ecea1495a007b9aedfdb
Author: Xingjian Shi <xs...@connect.ust.hk>
AuthorDate: Fri Aug 7 19:55:36 2020 -0700

    [Numpy][Bugfix] Add hybridization test to loss layers (#18876)
    
    * Test for hybridization
    
    * fix typo
    
    * fix
    
    * fix test
    
    * update
    
    * Update loss.py
    
    * fix bug of sum
---
 python/mxnet/gluon/loss.py               | 45 ++++++++-------
 python/mxnet/symbol/numpy/_symbol.py     |  2 +-
 tests/python/unittest/test_numpy_loss.py | 94 +++++++++++++++++++++++++++++---
 3 files changed, 114 insertions(+), 27 deletions(-)

diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index bc447b0..75d8981 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -77,18 +77,28 @@ def _reshape_like(F, x, y):
 def _batch_mean(F, loss, batch_axis):
     """Return mean on the specified batch axis, not keeping the axis"""
     if is_np_array():
-        axes = list(range(loss.ndim))
-        del axes[batch_axis]
-        return F.np.mean(loss, axis=axes)
+        if F is ndarray:
+            axes = list(range(loss.ndim))
+            del axes[batch_axis]
+            return F.np.mean(loss, axis=axes)
+        else:
+            assert batch_axis == 0, 'Currently, we have not supported the "exclude" ' \
+                                    'flag in mean. So we only support batch_axis=0.'
+            return F.npx.batch_flatten(loss).mean(axis=1)
     else:
         return F.mean(loss, axis=batch_axis, exclude=True)
 
 def _batch_sum(F, loss, batch_axis):
     """Return sum on the specified batch axis, not keeping the axis"""
     if is_np_array():
-        axes = list(range(loss.ndim))
-        del axes[batch_axis]
-        return F.np.sum(loss, axis=axes)
+        if F is ndarray:
+            axes = list(range(loss.ndim))
+            del axes[batch_axis]
+            return F.np.sum(loss, axis=axes)
+        else:
+            assert batch_axis == 0, 'Currently, we have not supported the "exclude" ' \
+                                    'flag in mean. So we only support batch_axis=0.'
+            return F.npx.batch_flatten(loss).sum(axis=1)
     else:
         return F.sum(loss, axis=batch_axis, exclude=True)
 
@@ -899,8 +909,8 @@ class PoissonNLLLoss(Loss):
             stirling_factor = target * \
                 log_fn(target) - target + 0.5 * log_fn(2 * target * np.pi)
             target_gt_1 = target > 1
-            stirling_factor *= target_gt_1
-            loss += stirling_factor
+            stirling_factor = stirling_factor * target_gt_1
+            loss = loss + stirling_factor
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
         return _batch_mean(F, loss, self._batch_axis)
 
@@ -1023,7 +1033,8 @@ class SDMLLoss(Loss):
     def __init__(self, smoothing_parameter=0.3, weight=1., batch_axis=0, **kwargs):
         super(SDMLLoss, self).__init__(weight, batch_axis, **kwargs)
         self.kl_loss = KLDivLoss(from_logits=True)
-        self.smoothing_parameter = smoothing_parameter # Smoothing probability mass
+        # Smoothing probability mass
+        self.smoothing_parameter = smoothing_parameter
 
     def _compute_distances(self, F, x1, x2):
         """
@@ -1032,17 +1043,13 @@ class SDMLLoss(Loss):
         """
         if is_np_array():
             expand_dims_fn = F.np.expand_dims
-            broadcast_to_fn = F.np.broadcast_to
         else:
             expand_dims_fn = F.expand_dims
-            broadcast_to_fn = F.broadcast_to
-
-        # extracting sizes expecting [batch_size, dim]
-        assert x1.shape == x2.shape
-        batch_size, dim = x1.shape
-        # expanding both tensor form [batch_size, dim] to [batch_size, batch_size, dim]
-        x1_ = broadcast_to_fn(expand_dims_fn(x1, 1), [batch_size, batch_size, dim])
-        x2_ = broadcast_to_fn(expand_dims_fn(x2, 0), [batch_size, batch_size, dim])
+
+        # expanding x1 form [batch_size, dim] to [batch_size, 1, dim]
+        # and x2 to [1, batch_size, dim]
+        x1_ = expand_dims_fn(x1, 1)
+        x2_ = expand_dims_fn(x2, 0)
         # pointwise squared differences
         squared_diffs = (x1_ - x2_)**2
         # sum of squared differences distance
@@ -1073,7 +1080,6 @@ class SDMLLoss(Loss):
         labels = gold * (1 - self.smoothing_parameter) + (1 - gold) * self.smoothing_parameter / (batch_size - 1)
         return labels
 
-
     def hybrid_forward(self, F, x1, x2):
         """
         the function computes the kl divergence between the negative distances
@@ -1092,6 +1098,7 @@ class SDMLLoss(Loss):
         learn to predict french president comparing it with all the other
         vectors in batch 2
         """
+        assert F is ndarray, 'SDMLLoss does not support symbolic '
         if is_np_array():
             log_softmax_fn = F.npx.log_softmax
         else:
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 2df7357..ee46544 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -8012,7 +8012,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
 
 # pylint:disable=redefined-outer-name, too-many-arguments
 @set_module('mxnet.symbol.numpy')
-def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None):
+def sum(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None):
     r"""
     Sum of array elements over a given axis.
 
diff --git a/tests/python/unittest/test_numpy_loss.py b/tests/python/unittest/test_numpy_loss.py
index 6c63546..14f46f0 100644
--- a/tests/python/unittest/test_numpy_loss.py
+++ b/tests/python/unittest/test_numpy_loss.py
@@ -20,57 +20,94 @@ import numpy as np
 from mxnet import gluon, autograd
 from mxnet.test_utils import assert_almost_equal, default_context, use_np
 from common import setup_module, with_seed, teardown_module, xfail_when_nonstandard_decimal_separator
-import unittest
+import pytest
 
 
 @xfail_when_nonstandard_decimal_separator
 @with_seed()
 @use_np
-def test_loss_np_ndarray():
+@pytest.mark.parametrize("hybridize", [False, True])
+def test_loss_np_ndarray(hybridize):
     output = mx.np.array([1, 2, 3, 4])
     label = mx.np.array([1, 3, 5, 7])
     weighting = mx.np.array([0.5, 1, 0.5, 1])
 
     loss = gluon.loss.L1Loss()
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 6.
     loss = gluon.loss.L1Loss(weight=0.5)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 3.
     loss = gluon.loss.L1Loss()
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label, weighting)) == 5.
 
     loss = gluon.loss.L2Loss()
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 7.
     loss = gluon.loss.L2Loss(weight=0.25)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 1.75
     loss = gluon.loss.L2Loss()
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label, weighting)) == 6
 
     loss = gluon.loss.HuberLoss()
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 4.5
     loss = gluon.loss.HuberLoss(weight=0.25)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 1.125
     loss = gluon.loss.HuberLoss()
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label, weighting)) == 3.75
 
     loss = gluon.loss.HingeLoss(margin=10)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 13.
     loss = gluon.loss.HingeLoss(margin=8, weight=0.25)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 2.25
     loss = gluon.loss.HingeLoss(margin=7)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label, weighting)) == 4.
 
     loss = gluon.loss.SquaredHingeLoss(margin=10)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 97.
     loss = gluon.loss.SquaredHingeLoss(margin=8, weight=0.25)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label)) == 13.25
     loss = gluon.loss.SquaredHingeLoss(margin=7)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label, weighting)) == 19.
 
     loss = gluon.loss.TripletLoss(margin=10)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label, -label)) == 6.
     loss = gluon.loss.TripletLoss(margin=8, weight=0.25)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label, -label)) == 1.
     loss = gluon.loss.TripletLoss(margin=7)
+    if hybridize:
+        loss.hybridize()
     assert mx.np.sum(loss(output, label, -label, weighting)) == 1.5
 
     output = mx.np.array([[0, 2], [1, 4]])
@@ -78,30 +115,48 @@ def test_loss_np_ndarray():
     weighting = mx.np.array([[0.5], [1.0]])
 
     loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    if hybridize:
+        loss.hybridize()
     L = loss(output, label).asnumpy()
     assert_almost_equal(L, np.array([ 2.12692809,  0.04858733]), rtol=1e-3, atol=1e-4)
 
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    if hybridize:
+        loss.hybridize()
     L = loss(output, label, weighting).asnumpy()
     assert_almost_equal(L, np.array([ 1.06346405,  0.04858733]), rtol=1e-3, atol=1e-4)
 
 
 @with_seed()
 @use_np
-def test_bce_equal_ce2():
+@pytest.mark.parametrize("hybridize", [False, True])
+def test_bce_equal_ce2(hybridize):
     N = 100
     loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True)
+    if hybridize:
+        loss1.hybridize()
     loss2 = gluon.loss.SoftmaxCELoss(from_logits=True)
+    if hybridize:
+        loss2.hybridize()
     out1 = mx.np.random.uniform(0.1, 0.9, size=(N, 1))
     out2 = mx.np.log(mx.np.concatenate((1-out1, out1), axis=1) + 1e-8)
     label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1)))
     assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy())
 
+
 @use_np
-def test_logistic_loss_equal_bce():
+@pytest.mark.parametrize("hybridize", [False, True])
+def test_logistic_loss_equal_bce(hybridize):
     N = 100
     loss_binary = gluon.loss.LogisticLoss(label_format='binary')
+    if hybridize:
+        loss_binary.hybridize()
     loss_signed = gluon.loss.LogisticLoss(label_format='signed')
+    if hybridize:
+        loss_signed.hybridize()
     loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False)
+    if hybridize:
+        loss_bce.hybridize()
     data = mx.np.random.uniform(-10, 10, size=(N, 1))
     label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1)))
     assert_almost_equal(loss_binary(data, label), loss_bce(data, label), atol=1e-6)
@@ -110,28 +165,41 @@ def test_logistic_loss_equal_bce():
 
 @with_seed()
 @use_np
-def test_ctc_loss():
+@pytest.mark.parametrize("hybridize", [False, True])
+def test_ctc_loss(hybridize):
     loss = gluon.loss.CTCLoss()
+    if hybridize:
+        loss.hybridize()
     l = loss(mx.np.ones((2,20,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]))
     assert_almost_equal(l, np.array([18.82820702, 16.50581741]))
 
     loss = gluon.loss.CTCLoss(layout='TNC')
+    if hybridize:
+        loss.hybridize()
     l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]))
     assert_almost_equal(l, np.array([18.82820702, 16.50581741]))
 
     loss = gluon.loss.CTCLoss(layout='TNC', label_layout='TN')
+    if hybridize:
+        loss.hybridize()
     l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]).T)
     assert_almost_equal(l, np.array([18.82820702, 16.50581741]))
 
     loss = gluon.loss.CTCLoss()
+    if hybridize:
+        loss.hybridize()
     l = loss(mx.np.ones((2,20,4)), mx.np.array([[2,1,2,2],[3,2,2,2]]), None, mx.np.array([2,3]))
     assert_almost_equal(l, np.array([18.82820702, 16.50581741]))
 
     loss = gluon.loss.CTCLoss()
+    if hybridize:
+        loss.hybridize()
     l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,-1,-1],[3,2,2,-1]]), mx.np.array([20,20]))
     assert_almost_equal(l, np.array([18.82820702, 16.50581741]))
 
     loss = gluon.loss.CTCLoss()
+    if hybridize:
+        loss.hybridize()
     l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,3,3],[3,2,2,3]]), mx.np.array([20,20]), mx.np.array([2,3]))
     assert_almost_equal(l, np.array([18.82820702, 16.50581741]))
 
@@ -171,15 +239,19 @@ def test_sdml_loss():
     avg_loss = loss.sum()/len(loss)
     assert(avg_loss < 0.05)
 
+
 @with_seed()
 @use_np
-def test_cosine_loss():
+@pytest.mark.parametrize("hybridize", [False, True])
+def test_cosine_loss(hybridize):
     #Generating samples
     input1 = mx.np.random.randn(3, 2)
     input2 = mx.np.random.randn(3, 2)
     label = mx.np.sign(mx.np.random.randn(input1.shape[0]))
     #Calculating loss from cosine embedding loss function in Gluon
     Loss = gluon.loss.CosineEmbeddingLoss()
+    if hybridize:
+        Loss.hybridize()
     loss = Loss(input1, input2, label)
 
     # Calculating the loss Numpy way
@@ -192,9 +264,11 @@ def test_cosine_loss():
         mx.np.where(label == 1, 1-x, mx.npx.relu(x)), (-1,))
     assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5)
 
+
 @xfail_when_nonstandard_decimal_separator
 @use_np
-def test_poisson_nllloss():
+@pytest.mark.parametrize("hybridize", [False, True])
+def test_poisson_nllloss(hybridize):
     shape=(3, 4)
     not_axis0 = tuple(range(1, len(shape)))
     pred = mx.np.random.normal(size=shape)
@@ -209,7 +283,11 @@ def test_poisson_nllloss():
     target[:] += mx.np.abs(min_target)
 
     Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
+    if hybridize:
+        Loss.hybridize()
     Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
+    if hybridize:
+        Loss_no_logits.hybridize()
     #Calculating by brute formula for default value of from_logits = True
 
     # 1) Testing for flag logits = True
@@ -230,6 +308,8 @@ def test_poisson_nllloss():
     np_compute_full = mx.np.mean((np_pred - np_target * mx.np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
      np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1)
     Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
+    if hybridize:
+        Loss_compute_full.hybridize()
     loss_compute_full = Loss_compute_full(np_pred, np_target)
     assert_almost_equal(np_compute_full, loss_compute_full)