You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/06/29 09:37:19 UTC

[spark] branch master updated: [SPARK-39446][MLLIB][FOLLOWUP] Modify ranking metrics for java and python

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new feae21c6445 [SPARK-39446][MLLIB][FOLLOWUP] Modify ranking metrics for java and python
feae21c6445 is described below

commit feae21c6445f8767bf5f62bb54f6c61a8df4e0c1
Author: uchiiii <uc...@gmail.com>
AuthorDate: Wed Jun 29 17:36:48 2022 +0800

    [SPARK-39446][MLLIB][FOLLOWUP] Modify ranking metrics for java and python
    
    ### What changes were proposed in this pull request?
    - Updated `RankingMetrics` for Java and Python
    - Modified the interface for Java and Python
    - Added test for Java
    
    ### Why are the changes needed?
    
    - To expose the change in https://github.com/apache/spark/pull/36843 to Java and Python.
    - To update the document for Java and Python.
    
    ### Does this PR introduce _any_ user-facing change?
    
    - Java users can use a JavaRDD of (predicted ranking, ground truth set, relevance value of ground truth set) for `RankingMetrics`
    
    ### How was this patch tested?
    - Added test for Java
    
    Closes #37019 from uchiiii/modify_ranking_metrics_for_java_and_python.
    
    Authored-by: uchiiii <uc...@gmail.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../spark/mllib/evaluation/RankingMetrics.scala    | 16 ++++++++++---
 .../mllib/evaluation/JavaRankingMetricsSuite.java  | 27 ++++++++++++++++++++++
 python/pyspark/mllib/evaluation.py                 | 14 ++++++++---
 3 files changed, 51 insertions(+), 6 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index 87a17f57caf..6ff8262c498 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -267,12 +267,22 @@ object RankingMetrics {
   /**
    * Creates a [[RankingMetrics]] instance (for Java users).
    * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs
+   *                            or (predicted ranking, ground truth set,
+   *                            relevance value of ground truth set).
+   *                            Since 3.4.0, it supports ndcg evaluation with relevance value.
    */
   @Since("1.4.0")
-  def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = {
+  def of[E, T <: jl.Iterable[E], A <: jl.Iterable[Double]](
+      predictionAndLabels: JavaRDD[_ <: Product]): RankingMetrics[E] = {
     implicit val tag = JavaSparkContext.fakeClassTag[E]
-    val rdd = predictionAndLabels.rdd.map { case (predictions, labels) =>
-      (predictions.asScala.toArray, labels.asScala.toArray)
+    val rdd = predictionAndLabels.rdd.map {
+      case (predictions, labels) =>
+        (predictions.asInstanceOf[T].asScala.toArray, labels.asInstanceOf[T].asScala.toArray)
+      case (predictions, labels, rels) =>
+        (
+          predictions.asInstanceOf[T].asScala.toArray,
+          labels.asInstanceOf[T].asScala.toArray,
+          rels.asInstanceOf[A].asScala.toArray)
     }
     new RankingMetrics(rdd)
   }
diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
index 50822c61fdc..4dcb2920610 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -22,7 +22,9 @@ import java.util.Arrays;
 import java.util.List;
 
 import scala.Tuple2;
+import scala.Tuple3;
 import scala.Tuple2$;
+import scala.Tuple3$;
 
 import org.junit.Assert;
 import org.junit.Test;
@@ -32,6 +34,8 @@ import org.apache.spark.api.java.JavaRDD;
 
 public class JavaRankingMetricsSuite extends SharedSparkSession {
   private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
+  private transient JavaRDD<Tuple3<List<Integer>, List<Integer>, List<Double>>>
+    predictionLabelsAndRelevance;
 
   @Override
   public void setUp() throws IOException {
@@ -43,6 +47,22 @@ public class JavaRankingMetricsSuite extends SharedSparkSession {
         Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
       Tuple2$.MODULE$.apply(
         Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
+    predictionLabelsAndRelevance = jsc.parallelize(Arrays.asList(
+      Tuple3$.MODULE$.apply(
+        Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5),
+        Arrays.asList(1, 2, 3, 4, 5),
+        Arrays.asList(3.0, 2.0, 1.0, 1.0, 1.0)
+      ),
+      Tuple3$.MODULE$.apply(
+        Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10),
+        Arrays.asList(1, 2, 3),
+        Arrays.asList(2.0, 0.0, 0.0)
+      ),
+      Tuple3$.MODULE$.apply(
+        Arrays.asList(1, 2, 3, 4, 5),
+        Arrays.<Integer>asList(),
+        Arrays.<Double>asList()
+      )), 3);
   }
 
   @Test
@@ -51,4 +71,11 @@ public class JavaRankingMetricsSuite extends SharedSparkSession {
     Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5);
     Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5);
   }
+
+  @Test
+  public void rankingMetricsWithRelevance() {
+    RankingMetrics<?> metrics = RankingMetrics.of(predictionLabelsAndRelevance);
+    Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5);
+    Assert.assertEquals(0.511959, metrics.ndcgAt(3), 1e-5);
+  }
 }
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 1003ba68c5f..cee61a1b241 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 
-from typing import Generic, List, Optional, Tuple, TypeVar
+from typing import Generic, List, Optional, Tuple, TypeVar, Union
 
 import sys
 
@@ -418,7 +418,10 @@ class RankingMetrics(JavaModelWrapper, Generic[T]):
     Parameters
     ----------
     predictionAndLabels : :py:class:`pyspark.RDD`
-        an RDD of (predicted ranking, ground truth set) pairs.
+        an RDD of (predicted ranking, ground truth set) pairs
+        or (predicted ranking, ground truth set,
+        relevance value of ground truth set).
+        Since 3.4.0, it supports ndcg evaluation with relevance value.
 
     Examples
     --------
@@ -451,7 +454,12 @@ class RankingMetrics(JavaModelWrapper, Generic[T]):
     0.66...
     """
 
-    def __init__(self, predictionAndLabels: RDD[Tuple[List[T], List[T]]]):
+    def __init__(
+        self,
+        predictionAndLabels: Union[
+            RDD[Tuple[List[T], List[T]]], RDD[Tuple[List[T], List[T], List[float]]]
+        ],
+    ):
         sc = predictionAndLabels.ctx
         sql_ctx = SQLContext.getOrCreate(sc)
         df = sql_ctx.createDataFrame(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org