You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2017/08/30 10:36:03 UTC
spark git commit: [SPARK-21806][MLLIB] BinaryClassificationMetrics
pr(): first point (0.0, 1.0) is misleading
Repository: spark
Updated Branches:
refs/heads/master 8f0df6bc1 -> 734ed7a7b
[SPARK-21806][MLLIB] BinaryClassificationMetrics pr(): first point (0.0, 1.0) is misleading
## What changes were proposed in this pull request?
Prepend (0,p) to precision-recall curve not (0,1) where p matches lowest recall point
## How was this patch tested?
Updated tests.
Author: Sean Owen <so...@cloudera.com>
Closes #19038 from srowen/SPARK-21806.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/734ed7a7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/734ed7a7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/734ed7a7
Branch: refs/heads/master
Commit: 734ed7a7b397578f16549070f350215bde369b3c
Parents: 8f0df6b
Author: Sean Owen <so...@cloudera.com>
Authored: Wed Aug 30 11:36:00 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Aug 30 11:36:00 2017 +0100
----------------------------------------------------------------------
.../BinaryClassificationMetrics.scala | 8 +++----
.../BinaryClassificationMetricsSuite.scala | 22 +++++++++-----------
2 files changed, 14 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/734ed7a7/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index 9b7cd04..2cfcf38 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -98,16 +98,16 @@ class BinaryClassificationMetrics @Since("1.3.0") (
/**
* Returns the precision-recall curve, which is an RDD of (recall, precision),
- * NOT (precision, recall), with (0.0, 1.0) prepended to it.
+ * NOT (precision, recall), with (0.0, p) prepended to it, where p is the precision
+ * associated with the lowest recall on the curve.
* @see <a href="http://en.wikipedia.org/wiki/Precision_and_recall">
* Precision and recall (Wikipedia)</a>
*/
@Since("1.0.0")
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
- val sc = confusions.context
- val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
- first.union(prCurve)
+ val (_, firstPrecision) = prCurve.first()
+ confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/734ed7a7/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index 99d52fa..a08917a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -23,18 +23,16 @@ import org.apache.spark.mllib.util.TestingUtils._
class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
- private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
-
- private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
- (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
-
- private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
- assert(left.zip(right).forall(areWithinEpsilon))
+ private def assertSequencesMatch(actual: Seq[Double], expected: Seq[Double]): Unit = {
+ actual.zip(expected).foreach { case (a, e) => assert(a ~== e absTol 1.0e-5) }
}
- private def assertTupleSequencesMatch(left: Seq[(Double, Double)],
- right: Seq[(Double, Double)]): Unit = {
- assert(left.zip(right).forall(pairsWithinEpsilon))
+ private def assertTupleSequencesMatch(actual: Seq[(Double, Double)],
+ expected: Seq[(Double, Double)]): Unit = {
+ actual.zip(expected).foreach { case ((ax, ay), (ex, ey)) =>
+ assert(ax ~== ex absTol 1.0e-5)
+ assert(ay ~== ey absTol 1.0e-5)
+ }
}
private def validateMetrics(metrics: BinaryClassificationMetrics,
@@ -44,7 +42,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
expectedFMeasures1: Seq[Double],
expectedFmeasures2: Seq[Double],
expectedPrecisions: Seq[Double],
- expectedRecalls: Seq[Double]) = {
+ expectedRecalls: Seq[Double]): Unit = {
assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
@@ -111,7 +109,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
val fpr = Seq(1.0)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
- val prCurve = Seq((0.0, 1.0)) ++ pr
+ val prCurve = Seq((0.0, 0.0)) ++ pr
val f1 = pr.map {
case (0, 0) => 0.0
case (r, p) => 2.0 * (p * r) / (p + r)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org