You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2022/06/25 19:16:56 UTC

[spark] branch master updated: [SPARK-39446][MLLIB][FOLLOWUP] Modify constructor of RankingMetrics class

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f465a3d943e [SPARK-39446][MLLIB][FOLLOWUP] Modify constructor of RankingMetrics class
f465a3d943e is described below

commit f465a3d943ea692b9ba377fcfcf17012c3bea29f
Author: uchiiii <uc...@gmail.com>
AuthorDate: Sat Jun 25 14:16:45 2022 -0500

    [SPARK-39446][MLLIB][FOLLOWUP] Modify constructor of RankingMetrics class
    
    ### What changes were proposed in this pull request?
    - Merged the two constructor into one using `RDD[_ <: Product]`.
    
    ### Why are the changes needed?
    - To make code simpler.
    - To support even more inputs.
    - ~~The previous code treats `rel` as an empty array when `rel` is not provided, which is not that beautiful. This change removes this.~~
    
    ### Does this PR introduce _any_ user-facing change?
    NO
    
    ### How was this patch tested?
    
    Closes #36920 from uchiiii/modify_ranking_metrics.
    
    Authored-by: uchiiii <uc...@gmail.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../spark/mllib/evaluation/RankingMetrics.scala    | 22 ++++++++++------------
 1 file changed, 10 insertions(+), 12 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index 7fccff9a24e..87a17f57caf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -38,16 +38,14 @@ import org.apache.spark.rdd.RDD
  *                            Since 3.4.0, it supports ndcg evaluation with relevance value.
  */
 @Since("1.2.0")
-class RankingMetrics[T: ClassTag] @Since("3.4.0") (
-    predictionAndLabels: RDD[(Array[T], Array[T], Array[Double])])
+class RankingMetrics[T: ClassTag] @Since("1.2.0") (predictionAndLabels: RDD[_ <: Product])
     extends Logging
     with Serializable {
 
-  @Since("1.2.0")
-  def this(predictionAndLabelsWithoutRelevance: => RDD[(Array[T], Array[T])]) = {
-    this(predictionAndLabelsWithoutRelevance.map {
-      case (pred, lab) => (pred, lab, Array.empty[Double])
-    })
+  private val rdd = predictionAndLabels.map {
+    case (pred: Array[T], lab: Array[T]) => (pred, lab, Array.empty[Double])
+    case (pred: Array[T], lab: Array[T], rel: Array[Double]) => (pred, lab, rel)
+    case _ => throw new IllegalArgumentException(s"Expected RDD of tuples or triplets")
   }
 
   /**
@@ -70,7 +68,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
   @Since("1.2.0")
   def precisionAt(k: Int): Double = {
     require(k > 0, "ranking position k should be positive")
-    predictionAndLabels.map { case (pred, lab, _) =>
+    rdd.map { case (pred, lab, _) =>
       countRelevantItemRatio(pred, lab, k, k)
     }.mean()
   }
@@ -82,7 +80,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
    */
   @Since("1.2.0")
   lazy val meanAveragePrecision: Double = {
-    predictionAndLabels.map { case (pred, lab, _) =>
+    rdd.map { case (pred, lab, _) =>
       val labSet = lab.toSet
       val k = math.max(pred.length, labSet.size)
       averagePrecision(pred, labSet, k)
@@ -99,7 +97,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
   @Since("3.0.0")
   def meanAveragePrecisionAt(k: Int): Double = {
     require(k > 0, "ranking position k should be positive")
-    predictionAndLabels.map { case (pred, lab, _) =>
+    rdd.map { case (pred, lab, _) =>
       averagePrecision(pred, lab.toSet, k)
     }.mean()
   }
@@ -154,7 +152,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
   @Since("1.2.0")
   def ndcgAt(k: Int): Double = {
     require(k > 0, "ranking position k should be positive")
-    predictionAndLabels.map { case (pred, lab, rel) =>
+    rdd.map { case (pred, lab, rel) =>
       val useBinary = rel.isEmpty
       val labSet = lab.toSet
       val relMap = lab.zip(rel).toMap
@@ -224,7 +222,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
   @Since("3.0.0")
   def recallAt(k: Int): Double = {
     require(k > 0, "ranking position k should be positive")
-    predictionAndLabels.map { case (pred, lab, _) =>
+    rdd.map { case (pred, lab, _) =>
       countRelevantItemRatio(pred, lab, k, lab.toSet.size)
     }.mean()
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org