You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/05/07 13:45:17 UTC

[spark] branch branch-2.3 updated: [SPARK-27577][MLLIB] Correct thresholds downsampled in BinaryClassificationMetrics

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-2.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.3 by this push:
     new 3c5fca1  [SPARK-27577][MLLIB] Correct thresholds downsampled in BinaryClassificationMetrics
3c5fca1 is described below

commit 3c5fca1716130248b54389e2555b47f35c5a6901
Author: Shaochen Shi <sh...@bytedance.com>
AuthorDate: Tue May 7 08:41:58 2019 -0500

    [SPARK-27577][MLLIB] Correct thresholds downsampled in BinaryClassificationMetrics
    
    ## What changes were proposed in this pull request?
    
    Choose the last record in chunks when calculating metrics with downsampling in `BinaryClassificationMetrics`.
    
    ## How was this patch tested?
    
    A new unit test is added to verify thresholds from downsampled records.
    
    Closes #24470 from shishaochen/spark-mllib-binary-metrics.
    
    Authored-by: Shaochen Shi <sh...@bytedance.com>
    Signed-off-by: Sean Owen <se...@databricks.com>
    (cherry picked from commit d5308cd86fff1e4bf9c24e0dd73d8d2c92737c4f)
    Signed-off-by: Sean Owen <se...@databricks.com>
---
 .../spark/mllib/evaluation/BinaryClassificationMetrics.scala  | 11 +++++++----
 .../mllib/evaluation/BinaryClassificationMetricsSuite.scala   | 11 +++++++++++
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index 2cfcf38..764806b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -175,12 +175,15 @@ class BinaryClassificationMetrics @Since("1.3.0") (
             grouping = Int.MaxValue
           }
           counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
-            // The score of the combined point will be just the first one's score
-            val firstScore = pairs.head._1
-            // The point will contain all counts in this chunk
+            // The score of the combined point will be just the last one's score, which is also
+            // the minimal in each chunk since all scores are already sorted in descending.
+            val lastScore = pairs.last._1
+            // The combined point will contain all counts in this chunk. Thus, calculated
+            // metrics (like precision, recall, etc.) on its score (or so-called threshold) are
+            // the same as those without sampling.
             val agg = new BinaryLabelCounter()
             pairs.foreach(pair => agg += pair._2)
-            (firstScore, agg)
+            (lastScore, agg)
           })
         }
       }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index a08917a..4cc9ee5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -155,6 +155,17 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
         (1.0, 1.0), (1.0, 1.0)
       ) ==
       downsampledROC)
+
+    val downsampledRecall = downsampled.recallByThreshold().collect().sorted.toList
+    assert(
+      // May have to add 1 if the sample factor didn't divide evenly
+      numBins + (if (scoreAndLabels.size % numBins == 0) 0 else 1) ==
+      downsampledRecall.size)
+    assert(
+      List(
+        (0.1, 1.0), (0.2, 1.0), (0.4, 0.75), (0.6, 0.75), (0.8, 0.25)
+      ) ==
+      downsampledRecall)
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org