You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/10 09:57:23 UTC
spark git commit: [SPARK-6091] [MLLIB] Add MulticlassMetrics in
PySpark/MLlib
Repository: spark
Updated Branches:
refs/heads/master b13162b36 -> bf7e81a51
[SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib
https://issues.apache.org/jira/browse/SPARK-6091
Author: Yanbo Liang <yb...@gmail.com>
Closes #6011 from yanboliang/spark-6091 and squashes the following commits:
bb3e4ba [Yanbo Liang] trigger jenkins
53c045d [Yanbo Liang] keep compatibility for python 2.6
972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bf7e81a5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bf7e81a5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bf7e81a5
Branch: refs/heads/master
Commit: bf7e81a51cd81706570615cd67362c86602dec88
Parents: b13162b
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sun May 10 00:57:14 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun May 10 00:57:14 2015 -0700
----------------------------------------------------------------------
.../mllib/evaluation/MulticlassMetrics.scala | 8 ++
python/pyspark/mllib/evaluation.py | 129 +++++++++++++++++++
2 files changed, 137 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/bf7e81a5/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index 666362a..4628dc5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
/**
* ::Experimental::
@@ -33,6 +34,13 @@ import org.apache.spark.rdd.RDD
@Experimental
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
+ /**
+ * An auxiliary constructor taking a DataFrame.
+ * @param predictionAndLabels a DataFrame with two double columns: prediction and label
+ */
+ private[mllib] def this(predictionAndLabels: DataFrame) =
+ this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
+
private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
http://git-wip-us.apache.org/repos/asf/spark/blob/bf7e81a5/python/pyspark/mllib/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 3e11df0..3691459 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -141,6 +141,135 @@ class RegressionMetrics(JavaModelWrapper):
return self.call("r2")
+class MulticlassMetrics(JavaModelWrapper):
+ """
+ Evaluator for multiclass classification.
+
+ >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
+ ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
+ >>> metrics = MulticlassMetrics(predictionAndLabels)
+ >>> metrics.falsePositiveRate(0.0)
+ 0.2...
+ >>> metrics.precision(1.0)
+ 0.75...
+ >>> metrics.recall(2.0)
+ 1.0...
+ >>> metrics.fMeasure(0.0, 2.0)
+ 0.52...
+ >>> metrics.precision()
+ 0.66...
+ >>> metrics.recall()
+ 0.66...
+ >>> metrics.weightedFalsePositiveRate
+ 0.19...
+ >>> metrics.weightedPrecision
+ 0.68...
+ >>> metrics.weightedRecall
+ 0.66...
+ >>> metrics.weightedFMeasure()
+ 0.66...
+ >>> metrics.weightedFMeasure(2.0)
+ 0.65...
+ """
+
+ def __init__(self, predictionAndLabels):
+ """
+ :param predictionAndLabels an RDD of (prediction, label) pairs.
+ """
+ sc = predictionAndLabels.ctx
+ sql_ctx = SQLContext(sc)
+ df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
+ StructField("prediction", DoubleType(), nullable=False),
+ StructField("label", DoubleType(), nullable=False)]))
+ java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
+ java_model = java_class(df._jdf)
+ super(MulticlassMetrics, self).__init__(java_model)
+
+ def truePositiveRate(self, label):
+ """
+ Returns true positive rate for a given label (category).
+ """
+ return self.call("truePositiveRate", label)
+
+ def falsePositiveRate(self, label):
+ """
+ Returns false positive rate for a given label (category).
+ """
+ return self.call("falsePositiveRate", label)
+
+ def precision(self, label=None):
+ """
+ Returns precision or precision for a given label (category) if specified.
+ """
+ if label is None:
+ return self.call("precision")
+ else:
+ return self.call("precision", float(label))
+
+ def recall(self, label=None):
+ """
+ Returns recall or recall for a given label (category) if specified.
+ """
+ if label is None:
+ return self.call("recall")
+ else:
+ return self.call("recall", float(label))
+
+ def fMeasure(self, label=None, beta=None):
+ """
+ Returns f-measure or f-measure for a given label (category) if specified.
+ """
+ if beta is None:
+ if label is None:
+ return self.call("fMeasure")
+ else:
+ return self.call("fMeasure", label)
+ else:
+ if label is None:
+ raise Exception("If the beta parameter is specified, label can not be none")
+ else:
+ return self.call("fMeasure", label, beta)
+
+ @property
+ def weightedTruePositiveRate(self):
+ """
+ Returns weighted true positive rate.
+ (equals to precision, recall and f-measure)
+ """
+ return self.call("weightedTruePositiveRate")
+
+ @property
+ def weightedFalsePositiveRate(self):
+ """
+ Returns weighted false positive rate.
+ """
+ return self.call("weightedFalsePositiveRate")
+
+ @property
+ def weightedRecall(self):
+ """
+ Returns weighted averaged recall.
+ (equals to precision, recall and f-measure)
+ """
+ return self.call("weightedRecall")
+
+ @property
+ def weightedPrecision(self):
+ """
+ Returns weighted averaged precision.
+ """
+ return self.call("weightedPrecision")
+
+ def weightedFMeasure(self, beta=None):
+ """
+ Returns weighted averaged f-measure.
+ """
+ if beta is None:
+ return self.call("weightedFMeasure")
+ else:
+ return self.call("weightedFMeasure", beta)
+
+
def _test():
import doctest
from pyspark import SparkContext
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org