You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/11 18:14:23 UTC

spark git commit: [SPARK-6092] [MLLIB] Add RankingMetrics in PySpark/MLlib

Repository: spark
Updated Branches:
  refs/heads/master d70a07689 -> 042dda3c5


[SPARK-6092] [MLLIB] Add RankingMetrics in PySpark/MLlib

Author: Yanbo Liang <yb...@gmail.com>

Closes #6044 from yanboliang/spark-6092 and squashes the following commits:

726a9b1 [Yanbo Liang] add newRankingMetrics
33f649c [Yanbo Liang] Add RankingMetrics in PySpark/MLlib


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/042dda3c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/042dda3c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/042dda3c

Branch: refs/heads/master
Commit: 042dda3c5c25b5ecb6ae4fd37c85b211b01c187b
Parents: d70a076
Author: Yanbo Liang <yb...@gmail.com>
Authored: Mon May 11 09:14:20 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon May 11 09:14:20 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 10 +++
 python/pyspark/mllib/evaluation.py              | 78 +++++++++++++++++++-
 2 files changed, 86 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/042dda3c/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 8c30ad4..f4c4775 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.mllib.classification._
 import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.evaluation.RankingMetrics
 import org.apache.spark.mllib.feature._
 import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
 import org.apache.spark.mllib.linalg._
@@ -50,6 +51,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree
 import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
@@ -923,6 +925,14 @@ private[python] class PythonMLLibAPI extends Serializable {
     RG.gammaVectorRDD(jsc.sc, shape, scale, numRows, numCols, parts, s)
   }
 
+  /**
+   * Java stub for the constructor of Python mllib RankingMetrics
+   */
+  def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = {
+    new RankingMetrics(predictionAndLabels.map(
+      r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any])))
+  }
+
 
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/042dda3c/python/pyspark/mllib/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 3691459..4c777f2 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -15,9 +15,12 @@
 # limitations under the License.
 #
 
-from pyspark.mllib.common import JavaModelWrapper
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
 from pyspark.sql import SQLContext
-from pyspark.sql.types import StructField, StructType, DoubleType
+from pyspark.sql.types import StructField, StructType, DoubleType, IntegerType, ArrayType
+
+__all__ = ['BinaryClassificationMetrics', 'RegressionMetrics',
+           'MulticlassMetrics', 'RankingMetrics']
 
 
 class BinaryClassificationMetrics(JavaModelWrapper):
@@ -270,6 +273,77 @@ class MulticlassMetrics(JavaModelWrapper):
             return self.call("weightedFMeasure", beta)
 
 
+class RankingMetrics(JavaModelWrapper):
+    """
+    Evaluator for ranking algorithms.
+
+    >>> predictionAndLabels = sc.parallelize([
+    ...     ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
+    ...     ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
+    ...     ([1, 2, 3, 4, 5], [])])
+    >>> metrics = RankingMetrics(predictionAndLabels)
+    >>> metrics.precisionAt(1)
+    0.33...
+    >>> metrics.precisionAt(5)
+    0.26...
+    >>> metrics.precisionAt(15)
+    0.17...
+    >>> metrics.meanAveragePrecision
+    0.35...
+    >>> metrics.ndcgAt(3)
+    0.33...
+    >>> metrics.ndcgAt(10)
+    0.48...
+
+    """
+
+    def __init__(self, predictionAndLabels):
+        """
+        :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
+        """
+        sc = predictionAndLabels.ctx
+        sql_ctx = SQLContext(sc)
+        df = sql_ctx.createDataFrame(predictionAndLabels,
+                                     schema=sql_ctx._inferSchema(predictionAndLabels))
+        java_model = callMLlibFunc("newRankingMetrics", df._jdf)
+        super(RankingMetrics, self).__init__(java_model)
+
+    def precisionAt(self, k):
+        """
+        Compute the average precision of all the queries, truncated at ranking position k.
+
+        If for a query, the ranking algorithm returns n (n < k) results, the precision value
+        will be computed as #(relevant items retrieved) / k. This formula also applies when
+        the size of the ground truth set is less than k.
+
+        If a query has an empty ground truth set, zero will be used as precision together
+        with a log warning.
+        """
+        return self.call("precisionAt", int(k))
+
+    @property
+    def meanAveragePrecision(self):
+        """
+        Returns the mean average precision (MAP) of all the queries.
+        If a query has an empty ground truth set, the average precision will be zero and
+        a log warining is generated.
+        """
+        return self.call("meanAveragePrecision")
+
+    def ndcgAt(self, k):
+        """
+        Compute the average NDCG value of all the queries, truncated at ranking position k.
+        The discounted cumulative gain at position k is computed as:
+            sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
+        and the NDCG is obtained by dividing the DCG value on the ground truth set.
+        In the current implementation, the relevance value is binary.
+
+        If a query has an empty ground truth set, zero will be used as ndcg together with
+        a log warning.
+        """
+        return self.call("ndcgAt", int(k))
+
+
 def _test():
     import doctest
     from pyspark import SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org