You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/24 21:37:08 UTC

git commit: [SPARK-2479 (partial)][MLLIB] fix binary metrics unit tests

Repository: spark
Updated Branches:
  refs/heads/master b352ef175 -> c960b5051


[SPARK-2479 (partial)][MLLIB] fix binary metrics unit tests

Allow small errors in comparison.

@dbtsai , this unit test blocks https://github.com/apache/spark/pull/1562 . I may need to merge this one first. We can change it to use the tools in https://github.com/apache/spark/pull/1425 after that PR gets merged.

Author: Xiangrui Meng <me...@databricks.com>

Closes #1576 from mengxr/fix-binary-metrics-unit-tests and squashes the following commits:

5076a7f [Xiangrui Meng] fix binary metrics unit tests


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c960b505
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c960b505
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c960b505

Branch: refs/heads/master
Commit: c960b5051853f336fb01ea3f16567b9958baa1b6
Parents: b352ef1
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu Jul 24 12:37:02 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Thu Jul 24 12:37:02 2014 -0700

----------------------------------------------------------------------
 .../BinaryClassificationMetricsSuite.scala      | 36 +++++++++++++++-----
 1 file changed, 27 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c960b505/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index 9d16182..94db1dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -20,8 +20,26 @@ package org.apache.spark.mllib.evaluation
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals
 
 class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
+
+  // TODO: move utility functions to TestingUtils.
+
+  def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = {
+    actual.zip(expected).forall { case (x1, x2) =>
+      x1.almostEquals(x2)
+    }
+  }
+
+  def elementsAlmostEqual(
+      actual: Seq[(Double, Double)],
+      expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = {
+    actual.zip(expected).forall { case ((x1, y1), (x2, y2)) =>
+      x1.almostEquals(x2) && y1.almostEquals(y2)
+    }
+  }
+
   test("binary evaluation metrics") {
     val scoreAndLabels = sc.parallelize(
       Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
@@ -41,14 +59,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
     val prCurve = Seq((0.0, 1.0)) ++ pr
     val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
     val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
-    assert(metrics.thresholds().collect().toSeq === threshold)
-    assert(metrics.roc().collect().toSeq === rocCurve)
-    assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
-    assert(metrics.pr().collect().toSeq === prCurve)
-    assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
-    assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1))
-    assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2))
-    assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision))
-    assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall))
+    assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold))
+    assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve))
+    assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve)))
+    assert(elementsAlmostEqual(metrics.pr().collect(), prCurve))
+    assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve)))
+    assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1)))
+    assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2)))
+    assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision)))
+    assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall)))
   }
 }