You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/01/31 17:10:25 UTC

[incubator-mxnet] branch master updated: avoid per-batch blocking in metric (#9636)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3fe694e  avoid per-batch blocking in metric (#9636)
3fe694e is described below

commit 3fe694e7b1ed7fa6a2dcfeddeac44c14ab77b015
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Wed Jan 31 09:10:21 2018 -0800

    avoid per-batch blocking in metric (#9636)
---
 python/mxnet/metric.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index fc2b901..e91fd3b 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -28,6 +28,7 @@ import numpy
 from .base import numeric_types, string_types
 from . import ndarray
 from . import registry
+from .context import cpu
 
 
 def check_label_shapes(labels, preds, shape=0):
@@ -388,6 +389,7 @@ class Accuracy(EvalMetric):
         """
         check_label_shapes(labels, preds)
 
+        results = []
         for label, pred_label in zip(labels, preds):
             if pred_label.shape != label.shape:
                 pred_label = ndarray.argmax(pred_label, axis=self.axis)
@@ -399,8 +401,10 @@ class Accuracy(EvalMetric):
             if pred_label.context != label.context:
                 pred_label = pred_label.as_in_context(label.context)
 
-            self.sum_metric += (pred_label.reshape((-1,)) == label.reshape((-1,))).sum().asscalar()
-            self.num_inst += numpy.prod(pred_label.shape)
+            self.num_inst += pred_label.size
+            results.append((pred_label.reshape((-1,)) == label.reshape((-1,)))
+                           .sum().as_in_context(cpu()))
+        self.sum_metric += ndarray.add_n(*results).asscalar()
 
 
 @register

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.