You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/01/31 17:10:25 UTC
[incubator-mxnet] branch master updated: avoid per-batch blocking
in metric (#9636)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3fe694e avoid per-batch blocking in metric (#9636)
3fe694e is described below
commit 3fe694e7b1ed7fa6a2dcfeddeac44c14ab77b015
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Wed Jan 31 09:10:21 2018 -0800
avoid per-batch blocking in metric (#9636)
---
python/mxnet/metric.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index fc2b901..e91fd3b 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -28,6 +28,7 @@ import numpy
from .base import numeric_types, string_types
from . import ndarray
from . import registry
+from .context import cpu
def check_label_shapes(labels, preds, shape=0):
@@ -388,6 +389,7 @@ class Accuracy(EvalMetric):
"""
check_label_shapes(labels, preds)
+ results = []
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = ndarray.argmax(pred_label, axis=self.axis)
@@ -399,8 +401,10 @@ class Accuracy(EvalMetric):
if pred_label.context != label.context:
pred_label = pred_label.as_in_context(label.context)
- self.sum_metric += (pred_label.reshape((-1,)) == label.reshape((-1,))).sum().asscalar()
- self.num_inst += numpy.prod(pred_label.shape)
+ self.num_inst += pred_label.size
+ results.append((pred_label.reshape((-1,)) == label.reshape((-1,)))
+ .sum().as_in_context(cpu()))
+ self.sum_metric += ndarray.add_n(*results).asscalar()
@register
--
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.