You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/10 09:57:23 UTC

spark git commit: [SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib

Repository: spark
Updated Branches:
  refs/heads/master b13162b36 -> bf7e81a51


[SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib

https://issues.apache.org/jira/browse/SPARK-6091

Author: Yanbo Liang <yb...@gmail.com>

Closes #6011 from yanboliang/spark-6091 and squashes the following commits:

bb3e4ba [Yanbo Liang] trigger jenkins
53c045d [Yanbo Liang] keep compatibility for python 2.6
972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bf7e81a5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bf7e81a5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bf7e81a5

Branch: refs/heads/master
Commit: bf7e81a51cd81706570615cd67362c86602dec88
Parents: b13162b
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sun May 10 00:57:14 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun May 10 00:57:14 2015 -0700

----------------------------------------------------------------------
 .../mllib/evaluation/MulticlassMetrics.scala    |   8 ++
 python/pyspark/mllib/evaluation.py              | 129 +++++++++++++++++++
 2 files changed, 137 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bf7e81a5/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index 666362a..4628dc5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg.{Matrices, Matrix}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
 
 /**
  * ::Experimental::
@@ -33,6 +34,13 @@ import org.apache.spark.rdd.RDD
 @Experimental
 class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
 
+  /**
+   * An auxiliary constructor taking a DataFrame.
+   * @param predictionAndLabels a DataFrame with two double columns: prediction and label
+   */
+  private[mllib] def this(predictionAndLabels: DataFrame) =
+    this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
+
   private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
   private lazy val labelCount: Long = labelCountByClass.values.sum
   private lazy val tpByClass: Map[Double, Int] = predictionAndLabels

http://git-wip-us.apache.org/repos/asf/spark/blob/bf7e81a5/python/pyspark/mllib/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 3e11df0..3691459 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -141,6 +141,135 @@ class RegressionMetrics(JavaModelWrapper):
         return self.call("r2")
 
 
+class MulticlassMetrics(JavaModelWrapper):
+    """
+    Evaluator for multiclass classification.
+
+    >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
+    ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
+    >>> metrics = MulticlassMetrics(predictionAndLabels)
+    >>> metrics.falsePositiveRate(0.0)
+    0.2...
+    >>> metrics.precision(1.0)
+    0.75...
+    >>> metrics.recall(2.0)
+    1.0...
+    >>> metrics.fMeasure(0.0, 2.0)
+    0.52...
+    >>> metrics.precision()
+    0.66...
+    >>> metrics.recall()
+    0.66...
+    >>> metrics.weightedFalsePositiveRate
+    0.19...
+    >>> metrics.weightedPrecision
+    0.68...
+    >>> metrics.weightedRecall
+    0.66...
+    >>> metrics.weightedFMeasure()
+    0.66...
+    >>> metrics.weightedFMeasure(2.0)
+    0.65...
+    """
+
+    def __init__(self, predictionAndLabels):
+        """
+        :param predictionAndLabels an RDD of (prediction, label) pairs.
+        """
+        sc = predictionAndLabels.ctx
+        sql_ctx = SQLContext(sc)
+        df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
+            StructField("prediction", DoubleType(), nullable=False),
+            StructField("label", DoubleType(), nullable=False)]))
+        java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
+        java_model = java_class(df._jdf)
+        super(MulticlassMetrics, self).__init__(java_model)
+
+    def truePositiveRate(self, label):
+        """
+        Returns true positive rate for a given label (category).
+        """
+        return self.call("truePositiveRate", label)
+
+    def falsePositiveRate(self, label):
+        """
+        Returns false positive rate for a given label (category).
+        """
+        return self.call("falsePositiveRate", label)
+
+    def precision(self, label=None):
+        """
+        Returns precision or precision for a given label (category) if specified.
+        """
+        if label is None:
+            return self.call("precision")
+        else:
+            return self.call("precision", float(label))
+
+    def recall(self, label=None):
+        """
+        Returns recall or recall for a given label (category) if specified.
+        """
+        if label is None:
+            return self.call("recall")
+        else:
+            return self.call("recall", float(label))
+
+    def fMeasure(self, label=None, beta=None):
+        """
+        Returns f-measure or f-measure for a given label (category) if specified.
+        """
+        if beta is None:
+            if label is None:
+                return self.call("fMeasure")
+            else:
+                return self.call("fMeasure", label)
+        else:
+            if label is None:
+                raise Exception("If the beta parameter is specified, label can not be none")
+            else:
+                return self.call("fMeasure", label, beta)
+
+    @property
+    def weightedTruePositiveRate(self):
+        """
+        Returns weighted true positive rate.
+        (equals to precision, recall and f-measure)
+        """
+        return self.call("weightedTruePositiveRate")
+
+    @property
+    def weightedFalsePositiveRate(self):
+        """
+        Returns weighted false positive rate.
+        """
+        return self.call("weightedFalsePositiveRate")
+
+    @property
+    def weightedRecall(self):
+        """
+        Returns weighted averaged recall.
+        (equals to precision, recall and f-measure)
+        """
+        return self.call("weightedRecall")
+
+    @property
+    def weightedPrecision(self):
+        """
+        Returns weighted averaged precision.
+        """
+        return self.call("weightedPrecision")
+
+    def weightedFMeasure(self, beta=None):
+        """
+        Returns weighted averaged f-measure.
+        """
+        if beta is None:
+            return self.call("weightedFMeasure")
+        else:
+            return self.call("weightedFMeasure", beta)
+
+
 def _test():
     import doctest
     from pyspark import SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org