You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/05/11 01:08:50 UTC

[incubator-mxnet] branch master updated: [MXNET-412] Loss update performance (#10892)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4967feb  [MXNET-412] Loss update performance (#10892)
4967feb is described below

commit 4967feb0860b768a5da88ce8c11ea8f41fa458b4
Author: Alexander Zai <az...@gmail.com>
AuthorDate: Thu May 10 18:08:42 2018 -0700

    [MXNET-412] Loss update performance (#10892)
    
    * convert pred to list if ndarray is input
    
    * remove unused line
    
    * remove unused line
    
    * test loss.update can accept list and ndarray
    
    * do not store return values when checking shape in accuracy.update
---
 python/mxnet/metric.py               |  6 +++++-
 tests/python/unittest/test_metric.py | 12 +++++++++++-
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 76118cc..aa3ab44 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -421,7 +421,7 @@ class Accuracy(EvalMetric):
             label = label.flat
             pred_label = pred_label.flat
 
-            labels, preds = check_label_shapes(label, pred_label)
+            check_label_shapes(label, pred_label)
 
             self.sum_metric += (pred_label == label).sum()
             self.num_inst += len(pred_label)
@@ -1159,6 +1159,10 @@ class Loss(EvalMetric):
             name, output_names=output_names, label_names=label_names)
 
     def update(self, _, preds):
+
+        if isinstance(preds, ndarray.ndarray.NDArray):
+            preds = [preds]
+
         for pred in preds:
             self.sum_metric += ndarray.sum(pred).asscalar()
             self.num_inst += pred.size
diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py
index 1571a0b..7bc9c10 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -32,6 +32,7 @@ def test_metrics():
     check_metric('perplexity', -1)
     check_metric('pearsonr')
     check_metric('nll_loss')
+    check_metric('loss')
     composite = mx.metric.create(['acc', 'f1'])
     check_metric(composite)
 
@@ -41,7 +42,6 @@ def test_nll_loss():
     label = mx.nd.array([2, 1])
     metric.update([label], [pred])
     _, loss = metric.get()
-    expected_loss = 0.0
     expected_loss = -(np.log(pred[0][2].asscalar()) + np.log(pred[1][1].asscalar())) / 2
     assert loss == expected_loss
 
@@ -65,6 +65,16 @@ def test_acc_2d_label():
                    float(label.asnumpy().ravel().size)
     assert acc == expected_acc
 
+def test_loss_update():
+    pred = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
+    metric1 = mx.metric.create('loss')
+    metric2 = mx.metric.create('loss')
+    metric1.update(None, [pred])
+    metric2.update(None, pred)
+    _, acc1 = metric1.get()
+    _, acc2 = metric2.get()
+    assert acc1 == acc2
+
 def test_f1():
     microF1 = mx.metric.create("f1", average="micro")
     macroF1 = mx.metric.F1(average="macro")

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.