You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2022/06/25 19:16:56 UTC
[spark] branch master updated: [SPARK-39446][MLLIB][FOLLOWUP] Modify constructor of RankingMetrics class
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f465a3d943e [SPARK-39446][MLLIB][FOLLOWUP] Modify constructor of RankingMetrics class
f465a3d943e is described below
commit f465a3d943ea692b9ba377fcfcf17012c3bea29f
Author: uchiiii <uc...@gmail.com>
AuthorDate: Sat Jun 25 14:16:45 2022 -0500
[SPARK-39446][MLLIB][FOLLOWUP] Modify constructor of RankingMetrics class
### What changes were proposed in this pull request?
- Merged the two constructor into one using `RDD[_ <: Product]`.
### Why are the changes needed?
- To make code simpler.
- To support even more inputs.
- ~~The previous code treats `rel` as an empty array when `rel` is not provided, which is not that beautiful. This change removes this.~~
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
Closes #36920 from uchiiii/modify_ranking_metrics.
Authored-by: uchiiii <uc...@gmail.com>
Signed-off-by: Sean Owen <sr...@gmail.com>
---
.../spark/mllib/evaluation/RankingMetrics.scala | 22 ++++++++++------------
1 file changed, 10 insertions(+), 12 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index 7fccff9a24e..87a17f57caf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -38,16 +38,14 @@ import org.apache.spark.rdd.RDD
* Since 3.4.0, it supports ndcg evaluation with relevance value.
*/
@Since("1.2.0")
-class RankingMetrics[T: ClassTag] @Since("3.4.0") (
- predictionAndLabels: RDD[(Array[T], Array[T], Array[Double])])
+class RankingMetrics[T: ClassTag] @Since("1.2.0") (predictionAndLabels: RDD[_ <: Product])
extends Logging
with Serializable {
- @Since("1.2.0")
- def this(predictionAndLabelsWithoutRelevance: => RDD[(Array[T], Array[T])]) = {
- this(predictionAndLabelsWithoutRelevance.map {
- case (pred, lab) => (pred, lab, Array.empty[Double])
- })
+ private val rdd = predictionAndLabels.map {
+ case (pred: Array[T], lab: Array[T]) => (pred, lab, Array.empty[Double])
+ case (pred: Array[T], lab: Array[T], rel: Array[Double]) => (pred, lab, rel)
+ case _ => throw new IllegalArgumentException(s"Expected RDD of tuples or triplets")
}
/**
@@ -70,7 +68,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
@Since("1.2.0")
def precisionAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
- predictionAndLabels.map { case (pred, lab, _) =>
+ rdd.map { case (pred, lab, _) =>
countRelevantItemRatio(pred, lab, k, k)
}.mean()
}
@@ -82,7 +80,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
*/
@Since("1.2.0")
lazy val meanAveragePrecision: Double = {
- predictionAndLabels.map { case (pred, lab, _) =>
+ rdd.map { case (pred, lab, _) =>
val labSet = lab.toSet
val k = math.max(pred.length, labSet.size)
averagePrecision(pred, labSet, k)
@@ -99,7 +97,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
@Since("3.0.0")
def meanAveragePrecisionAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
- predictionAndLabels.map { case (pred, lab, _) =>
+ rdd.map { case (pred, lab, _) =>
averagePrecision(pred, lab.toSet, k)
}.mean()
}
@@ -154,7 +152,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
@Since("1.2.0")
def ndcgAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
- predictionAndLabels.map { case (pred, lab, rel) =>
+ rdd.map { case (pred, lab, rel) =>
val useBinary = rel.isEmpty
val labSet = lab.toSet
val relMap = lab.zip(rel).toMap
@@ -224,7 +222,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") (
@Since("3.0.0")
def recallAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
- predictionAndLabels.map { case (pred, lab, _) =>
+ rdd.map { case (pred, lab, _) =>
countRelevantItemRatio(pred, lab, k, lab.toSet.size)
}.mean()
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org