You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2017/11/22 01:44:18 UTC

[incubator-mxnet] branch master updated: [scala] EvalMetric sumMetric is now a Double instead of a Float (#8297)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8df20a2  [scala] EvalMetric sumMetric is now a Double instead of a Float (#8297)
8df20a2 is described below

commit 8df20a2bd074c4ab55a9b61e0ec04da48bec6426
Author: BenoƮt Quartier <be...@a3.epfl.ch>
AuthorDate: Wed Nov 22 02:44:15 2017 +0100

    [scala] EvalMetric sumMetric is now a Double instead of a Float (#8297)
    
    When the difference in magnitude between the total
    accuracy and 1 becomes too big and accuracy is not updated anymore due
    to the low precision of float numbers.
---
 .../core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala    | 15 +++++++--------
 1 file changed, 7 insertions(+), 8 deletions(-)

diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
index 6b993d7..98a09d2 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
 abstract class EvalMetric(protected val name: String) {
 
   protected var numInst: Int = 0
-  protected var sumMetric: Float = 0.0f
+  protected var sumMetric: Double = 0.0d
 
   /**
    * Update the internal evaluation.
@@ -41,7 +41,7 @@ abstract class EvalMetric(protected val name: String) {
    */
   def reset(): Unit = {
     this.numInst = 0
-    this.sumMetric = 0.0f
+    this.sumMetric = 0.0d
   }
 
   /**
@@ -50,7 +50,7 @@ abstract class EvalMetric(protected val name: String) {
    *         value, Value of the evaluation
    */
   def get: (Array[String], Array[Float]) = {
-    (Array(this.name), Array(this.sumMetric / this.numInst))
+    (Array(this.name), Array((this.sumMetric / this.numInst).toFloat))
   }
 }
 
@@ -111,11 +111,10 @@ class Accuracy extends EvalMetric("accuracy") {
       require(label.shape == predLabel.shape,
         s"label ${label.shape} and prediction ${predLabel.shape}" +
         s"should have the same length.")
-      for ((labelElem, predElem) <- label.toArray zip predLabel.toArray) {
-        if (labelElem == predElem) {
-          this.sumMetric += 1
-        }
-      }
+
+      this.sumMetric += label.toArray.zip(predLabel.toArray)
+        .filter{ case (labelElem: Float, predElem: Float) => labelElem == predElem }
+        .size
       this.numInst += predLabel.shape(0)
       predLabel.dispose()
     }

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].