You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/06/29 09:37:19 UTC
[spark] branch master updated: [SPARK-39446][MLLIB][FOLLOWUP] Modify ranking metrics for java and python
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new feae21c6445 [SPARK-39446][MLLIB][FOLLOWUP] Modify ranking metrics for java and python
feae21c6445 is described below
commit feae21c6445f8767bf5f62bb54f6c61a8df4e0c1
Author: uchiiii <uc...@gmail.com>
AuthorDate: Wed Jun 29 17:36:48 2022 +0800
[SPARK-39446][MLLIB][FOLLOWUP] Modify ranking metrics for java and python
### What changes were proposed in this pull request?
- Updated `RankingMetrics` for Java and Python
- Modified the interface for Java and Python
- Added test for Java
### Why are the changes needed?
- To expose the change in https://github.com/apache/spark/pull/36843 to Java and Python.
- To update the document for Java and Python.
### Does this PR introduce _any_ user-facing change?
- Java users can use a JavaRDD of (predicted ranking, ground truth set, relevance value of ground truth set) for `RankingMetrics`
### How was this patch tested?
- Added test for Java
Closes #37019 from uchiiii/modify_ranking_metrics_for_java_and_python.
Authored-by: uchiiii <uc...@gmail.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../spark/mllib/evaluation/RankingMetrics.scala | 16 ++++++++++---
.../mllib/evaluation/JavaRankingMetricsSuite.java | 27 ++++++++++++++++++++++
python/pyspark/mllib/evaluation.py | 14 ++++++++---
3 files changed, 51 insertions(+), 6 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index 87a17f57caf..6ff8262c498 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -267,12 +267,22 @@ object RankingMetrics {
/**
* Creates a [[RankingMetrics]] instance (for Java users).
* @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs
+ * or (predicted ranking, ground truth set,
+ * relevance value of ground truth set).
+ * Since 3.4.0, it supports ndcg evaluation with relevance value.
*/
@Since("1.4.0")
- def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = {
+ def of[E, T <: jl.Iterable[E], A <: jl.Iterable[Double]](
+ predictionAndLabels: JavaRDD[_ <: Product]): RankingMetrics[E] = {
implicit val tag = JavaSparkContext.fakeClassTag[E]
- val rdd = predictionAndLabels.rdd.map { case (predictions, labels) =>
- (predictions.asScala.toArray, labels.asScala.toArray)
+ val rdd = predictionAndLabels.rdd.map {
+ case (predictions, labels) =>
+ (predictions.asInstanceOf[T].asScala.toArray, labels.asInstanceOf[T].asScala.toArray)
+ case (predictions, labels, rels) =>
+ (
+ predictions.asInstanceOf[T].asScala.toArray,
+ labels.asInstanceOf[T].asScala.toArray,
+ rels.asInstanceOf[A].asScala.toArray)
}
new RankingMetrics(rdd)
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
index 50822c61fdc..4dcb2920610 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -22,7 +22,9 @@ import java.util.Arrays;
import java.util.List;
import scala.Tuple2;
+import scala.Tuple3;
import scala.Tuple2$;
+import scala.Tuple3$;
import org.junit.Assert;
import org.junit.Test;
@@ -32,6 +34,8 @@ import org.apache.spark.api.java.JavaRDD;
public class JavaRankingMetricsSuite extends SharedSparkSession {
private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
+ private transient JavaRDD<Tuple3<List<Integer>, List<Integer>, List<Double>>>
+ predictionLabelsAndRelevance;
@Override
public void setUp() throws IOException {
@@ -43,6 +47,22 @@ public class JavaRankingMetricsSuite extends SharedSparkSession {
Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
Tuple2$.MODULE$.apply(
Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
+ predictionLabelsAndRelevance = jsc.parallelize(Arrays.asList(
+ Tuple3$.MODULE$.apply(
+ Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5),
+ Arrays.asList(1, 2, 3, 4, 5),
+ Arrays.asList(3.0, 2.0, 1.0, 1.0, 1.0)
+ ),
+ Tuple3$.MODULE$.apply(
+ Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10),
+ Arrays.asList(1, 2, 3),
+ Arrays.asList(2.0, 0.0, 0.0)
+ ),
+ Tuple3$.MODULE$.apply(
+ Arrays.asList(1, 2, 3, 4, 5),
+ Arrays.<Integer>asList(),
+ Arrays.<Double>asList()
+ )), 3);
}
@Test
@@ -51,4 +71,11 @@ public class JavaRankingMetricsSuite extends SharedSparkSession {
Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5);
Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5);
}
+
+ @Test
+ public void rankingMetricsWithRelevance() {
+ RankingMetrics<?> metrics = RankingMetrics.of(predictionLabelsAndRelevance);
+ Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5);
+ Assert.assertEquals(0.511959, metrics.ndcgAt(3), 1e-5);
+ }
}
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 1003ba68c5f..cee61a1b241 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from typing import Generic, List, Optional, Tuple, TypeVar
+from typing import Generic, List, Optional, Tuple, TypeVar, Union
import sys
@@ -418,7 +418,10 @@ class RankingMetrics(JavaModelWrapper, Generic[T]):
Parameters
----------
predictionAndLabels : :py:class:`pyspark.RDD`
- an RDD of (predicted ranking, ground truth set) pairs.
+ an RDD of (predicted ranking, ground truth set) pairs
+ or (predicted ranking, ground truth set,
+ relevance value of ground truth set).
+ Since 3.4.0, it supports ndcg evaluation with relevance value.
Examples
--------
@@ -451,7 +454,12 @@ class RankingMetrics(JavaModelWrapper, Generic[T]):
0.66...
"""
- def __init__(self, predictionAndLabels: RDD[Tuple[List[T], List[T]]]):
+ def __init__(
+ self,
+ predictionAndLabels: Union[
+ RDD[Tuple[List[T], List[T]]], RDD[Tuple[List[T], List[T], List[float]]]
+ ],
+ ):
sc = predictionAndLabels.ctx
sql_ctx = SQLContext.getOrCreate(sc)
df = sql_ctx.createDataFrame(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org