You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/05/11 01:08:50 UTC
[incubator-mxnet] branch master updated: [MXNET-412] Loss update
performance (#10892)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4967feb [MXNET-412] Loss update performance (#10892)
4967feb is described below
commit 4967feb0860b768a5da88ce8c11ea8f41fa458b4
Author: Alexander Zai <az...@gmail.com>
AuthorDate: Thu May 10 18:08:42 2018 -0700
[MXNET-412] Loss update performance (#10892)
* convert pred to list if ndarray is input
* remove unused line
* remove unused line
* test loss.update can accept list and ndarray
* do not store return values when checking shape in accuracy.update
---
python/mxnet/metric.py | 6 +++++-
tests/python/unittest/test_metric.py | 12 +++++++++++-
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 76118cc..aa3ab44 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -421,7 +421,7 @@ class Accuracy(EvalMetric):
label = label.flat
pred_label = pred_label.flat
- labels, preds = check_label_shapes(label, pred_label)
+ check_label_shapes(label, pred_label)
self.sum_metric += (pred_label == label).sum()
self.num_inst += len(pred_label)
@@ -1159,6 +1159,10 @@ class Loss(EvalMetric):
name, output_names=output_names, label_names=label_names)
def update(self, _, preds):
+
+ if isinstance(preds, ndarray.ndarray.NDArray):
+ preds = [preds]
+
for pred in preds:
self.sum_metric += ndarray.sum(pred).asscalar()
self.num_inst += pred.size
diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py
index 1571a0b..7bc9c10 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -32,6 +32,7 @@ def test_metrics():
check_metric('perplexity', -1)
check_metric('pearsonr')
check_metric('nll_loss')
+ check_metric('loss')
composite = mx.metric.create(['acc', 'f1'])
check_metric(composite)
@@ -41,7 +42,6 @@ def test_nll_loss():
label = mx.nd.array([2, 1])
metric.update([label], [pred])
_, loss = metric.get()
- expected_loss = 0.0
expected_loss = -(np.log(pred[0][2].asscalar()) + np.log(pred[1][1].asscalar())) / 2
assert loss == expected_loss
@@ -65,6 +65,16 @@ def test_acc_2d_label():
float(label.asnumpy().ravel().size)
assert acc == expected_acc
+def test_loss_update():
+ pred = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
+ metric1 = mx.metric.create('loss')
+ metric2 = mx.metric.create('loss')
+ metric1.update(None, [pred])
+ metric2.update(None, pred)
+ _, acc1 = metric1.get()
+ _, acc2 = metric2.get()
+ assert acc1 == acc2
+
def test_f1():
microF1 = mx.metric.create("f1", average="micro")
macroF1 = mx.metric.F1(average="macro")
--
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.