You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2017/11/22 01:44:18 UTC
[incubator-mxnet] branch master updated: [scala] EvalMetric
sumMetric is now a Double instead of a Float (#8297)
This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8df20a2 [scala] EvalMetric sumMetric is now a Double instead of a Float (#8297)
8df20a2 is described below
commit 8df20a2bd074c4ab55a9b61e0ec04da48bec6426
Author: BenoƮt Quartier <be...@a3.epfl.ch>
AuthorDate: Wed Nov 22 02:44:15 2017 +0100
[scala] EvalMetric sumMetric is now a Double instead of a Float (#8297)
When the difference in magnitude between the total
accuracy and 1 becomes too big and accuracy is not updated anymore due
to the low precision of float numbers.
---
.../core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala | 15 +++++++--------
1 file changed, 7 insertions(+), 8 deletions(-)
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
index 6b993d7..98a09d2 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
abstract class EvalMetric(protected val name: String) {
protected var numInst: Int = 0
- protected var sumMetric: Float = 0.0f
+ protected var sumMetric: Double = 0.0d
/**
* Update the internal evaluation.
@@ -41,7 +41,7 @@ abstract class EvalMetric(protected val name: String) {
*/
def reset(): Unit = {
this.numInst = 0
- this.sumMetric = 0.0f
+ this.sumMetric = 0.0d
}
/**
@@ -50,7 +50,7 @@ abstract class EvalMetric(protected val name: String) {
* value, Value of the evaluation
*/
def get: (Array[String], Array[Float]) = {
- (Array(this.name), Array(this.sumMetric / this.numInst))
+ (Array(this.name), Array((this.sumMetric / this.numInst).toFloat))
}
}
@@ -111,11 +111,10 @@ class Accuracy extends EvalMetric("accuracy") {
require(label.shape == predLabel.shape,
s"label ${label.shape} and prediction ${predLabel.shape}" +
s"should have the same length.")
- for ((labelElem, predElem) <- label.toArray zip predLabel.toArray) {
- if (labelElem == predElem) {
- this.sumMetric += 1
- }
- }
+
+ this.sumMetric += label.toArray.zip(predLabel.toArray)
+ .filter{ case (labelElem: Float, predElem: Float) => labelElem == predElem }
+ .size
this.numInst += predLabel.shape(0)
predLabel.dispose()
}
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].