You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2019/05/07 13:50:00 UTC

[GitHub] [spark] srowen commented on a change in pull request #24543: [SPARK-27540][MLlib] Add 'meanAveragePrecision_at_k' metric to RankingMetrics

srowen commented on a change in pull request #24543: [SPARK-27540][MLlib] Add 'meanAveragePrecision_at_k' metric to RankingMetrics
URL: https://github.com/apache/spark/pull/24543#discussion_r281638802
 
 

 ##########
 File path: mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
 ##########
 @@ -71,25 +71,56 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
   lazy val meanAveragePrecision: Double = {
     predictionAndLabels.map { case (pred, lab) =>
       val labSet = lab.toSet
+      val k = math.max(pred.length, labSet.size)
+      averagePrecision(pred, lab, k)
+    }.mean()
+  }
 
-      if (labSet.nonEmpty) {
-        var i = 0
-        var cnt = 0
-        var precSum = 0.0
-        val n = pred.length
-        while (i < n) {
-          if (labSet.contains(pred(i))) {
-            cnt += 1
-            precSum += cnt.toDouble / (i + 1)
-          }
-          i += 1
+  /**
+   * Returns the mean average precision (MAP) at ranking position k of all the queries.
+   * If a query has an empty ground truth set, the average precision will be zero and a log
+   * warning is generated.
+   * @param k the position to compute the truncated precision, must be positive
+   * @return the mean average precision at first k ranking positions
+   */
+  @Since("3.0.0")
+  def meanAveragePrecisionAt(k: Int): Double = {
+    require(k > 0, "ranking position k should be positive")
+    predictionAndLabels.map { case (pred, lab) =>
+      averagePrecision(pred, lab, k)
+    }.mean()
+  }
+
+  /**
+   * Computes the average precision at first k ranking positions of all the queries.
+   * If a query has an empty ground truth set, the value will be zero and a log
+   * warning is generated.
+   *
+   * @param pred predicted ranking
+   * @param lab ground truth
+   * @param k use the top k predicted ranking, must be positive
+   * @return average precision at first k ranking positions
+   */
+  private def averagePrecision(pred: Array[T], lab: Array[T], k: Int): Double = {
 
 Review comment:
   You can pass `labSet` here from callers to avoid computing it twice

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org