You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2017/08/30 10:36:03 UTC

spark git commit: [SPARK-21806][MLLIB] BinaryClassificationMetrics pr(): first point (0.0, 1.0) is misleading

Repository: spark
Updated Branches:
  refs/heads/master 8f0df6bc1 -> 734ed7a7b


[SPARK-21806][MLLIB] BinaryClassificationMetrics pr(): first point (0.0, 1.0) is misleading

## What changes were proposed in this pull request?

Prepend (0,p) to precision-recall curve not (0,1) where p matches lowest recall point

## How was this patch tested?

Updated tests.

Author: Sean Owen <so...@cloudera.com>

Closes #19038 from srowen/SPARK-21806.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/734ed7a7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/734ed7a7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/734ed7a7

Branch: refs/heads/master
Commit: 734ed7a7b397578f16549070f350215bde369b3c
Parents: 8f0df6b
Author: Sean Owen <so...@cloudera.com>
Authored: Wed Aug 30 11:36:00 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Aug 30 11:36:00 2017 +0100

----------------------------------------------------------------------
 .../BinaryClassificationMetrics.scala           |  8 +++----
 .../BinaryClassificationMetricsSuite.scala      | 22 +++++++++-----------
 2 files changed, 14 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/734ed7a7/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index 9b7cd04..2cfcf38 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -98,16 +98,16 @@ class BinaryClassificationMetrics @Since("1.3.0") (
 
   /**
    * Returns the precision-recall curve, which is an RDD of (recall, precision),
-   * NOT (precision, recall), with (0.0, 1.0) prepended to it.
+   * NOT (precision, recall), with (0.0, p) prepended to it, where p is the precision
+   * associated with the lowest recall on the curve.
    * @see <a href="http://en.wikipedia.org/wiki/Precision_and_recall">
    * Precision and recall (Wikipedia)</a>
    */
   @Since("1.0.0")
   def pr(): RDD[(Double, Double)] = {
     val prCurve = createCurve(Recall, Precision)
-    val sc = confusions.context
-    val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
-    first.union(prCurve)
+    val (_, firstPrecision) = prCurve.first()
+    confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/734ed7a7/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index 99d52fa..a08917a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -23,18 +23,16 @@ import org.apache.spark.mllib.util.TestingUtils._
 
 class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
 
-  private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
-
-  private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
-    (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
-
-  private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
-      assert(left.zip(right).forall(areWithinEpsilon))
+  private def assertSequencesMatch(actual: Seq[Double], expected: Seq[Double]): Unit = {
+    actual.zip(expected).foreach { case (a, e) => assert(a ~== e absTol 1.0e-5) }
   }
 
-  private def assertTupleSequencesMatch(left: Seq[(Double, Double)],
-       right: Seq[(Double, Double)]): Unit = {
-    assert(left.zip(right).forall(pairsWithinEpsilon))
+  private def assertTupleSequencesMatch(actual: Seq[(Double, Double)],
+       expected: Seq[(Double, Double)]): Unit = {
+    actual.zip(expected).foreach { case ((ax, ay), (ex, ey)) =>
+      assert(ax ~== ex absTol 1.0e-5)
+      assert(ay ~== ey absTol 1.0e-5)
+    }
   }
 
   private def validateMetrics(metrics: BinaryClassificationMetrics,
@@ -44,7 +42,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
       expectedFMeasures1: Seq[Double],
       expectedFmeasures2: Seq[Double],
       expectedPrecisions: Seq[Double],
-      expectedRecalls: Seq[Double]) = {
+      expectedRecalls: Seq[Double]): Unit = {
 
     assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
     assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
@@ -111,7 +109,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
     val fpr = Seq(1.0)
     val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
     val pr = recalls.zip(precisions)
-    val prCurve = Seq((0.0, 1.0)) ++ pr
+    val prCurve = Seq((0.0, 0.0)) ++ pr
     val f1 = pr.map {
       case (0, 0) => 0.0
       case (r, p) => 2.0 * (p * r) / (p + r)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org