You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2019/02/08 17:47:43 UTC

[spark] branch master updated: [SPARK-26185][PYTHON] add weightCol in python MulticlassClassificationEvaluator

This is an automated email from the ASF dual-hosted git repository.

holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 91e64e2  [SPARK-26185][PYTHON] add weightCol in python MulticlassClassificationEvaluator
91e64e2 is described below

commit 91e64e24d54287b1e4564358a2ef2bc8c0e6a22b
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Fri Feb 8 09:46:54 2019 -0800

    [SPARK-26185][PYTHON] add weightCol in python MulticlassClassificationEvaluator
    
    ## What changes were proposed in this pull request?
    
    add weightCol for python version of MulticlassClassificationEvaluator and MulticlassMetrics
    
    ## How was this patch tested?
    
    add doc test
    
    Closes #23157 from huaxingao/spark-26185.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: Holden Karau <ho...@pigscanfly.ca>
---
 python/pyspark/ml/evaluation.py    | 23 +++++++++++++++------
 python/pyspark/mllib/evaluation.py | 42 +++++++++++++++++++++++++++++++++-----
 2 files changed, 54 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 8eaf076..f563a2d 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -22,7 +22,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.param import Param, Params, TypeConverters
 from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol, \
-    HasFeaturesCol
+    HasFeaturesCol, HasWeightCol
 from pyspark.ml.common import inherit_doc
 from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 
@@ -257,7 +257,7 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
 
 
 @inherit_doc
-class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
+class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
                                         JavaMLReadable, JavaMLWritable):
     """
     .. note:: Experimental
@@ -279,6 +279,17 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
     >>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)
     >>> str(evaluator2.getPredictionCol())
     'prediction'
+    >>> scoreAndLabelsAndWeight = [(0.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0),
+    ...     (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),
+    ...     (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)]
+    >>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["prediction", "label", "weight"])
+    ...
+    >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",
+    ...     weightCol="weight")
+    >>> evaluator.evaluate(dataset)
+    0.66...
+    >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
+    0.66...
 
     .. versionadded:: 1.5.0
     """
@@ -289,10 +300,10 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
 
     @keyword_only
     def __init__(self, predictionCol="prediction", labelCol="label",
-                 metricName="f1"):
+                 metricName="f1", weightCol=None):
         """
         __init__(self, predictionCol="prediction", labelCol="label", \
-                 metricName="f1")
+                 metricName="f1", weightCol=None)
         """
         super(MulticlassClassificationEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -318,10 +329,10 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
     @keyword_only
     @since("1.5.0")
     def setParams(self, predictionCol="prediction", labelCol="label",
-                  metricName="f1"):
+                  metricName="f1", weightCol=None):
         """
         setParams(self, predictionCol="prediction", labelCol="label", \
-                  metricName="f1")
+                  metricName="f1", weightCol=None)
         Sets params for multiclass classification evaluator.
         """
         kwargs = self._input_kwargs
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 6ca6df6..b028394 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -162,7 +162,7 @@ class MulticlassMetrics(JavaModelWrapper):
     """
     Evaluator for multiclass classification.
 
-    :param predictionAndLabels: an RDD of (prediction, label) pairs.
+    :param predAndLabelsWithOptWeight: an RDD of prediction, label and optional weight.
 
     >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
     ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
@@ -191,16 +191,48 @@ class MulticlassMetrics(JavaModelWrapper):
     0.66...
     >>> metrics.weightedFMeasure(2.0)
     0.65...
+    >>> predAndLabelsWithOptWeight = sc.parallelize([(0.0, 0.0, 1.0), (0.0, 1.0, 1.0),
+    ...      (0.0, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),
+    ...      (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)])
+    >>> metrics = MulticlassMetrics(predAndLabelsWithOptWeight)
+    >>> metrics.confusionMatrix().toArray()
+    array([[ 2.,  1.,  1.],
+           [ 1.,  3.,  0.],
+           [ 0.,  0.,  1.]])
+    >>> metrics.falsePositiveRate(0.0)
+    0.2...
+    >>> metrics.precision(1.0)
+    0.75...
+    >>> metrics.recall(2.0)
+    1.0...
+    >>> metrics.fMeasure(0.0, 2.0)
+    0.52...
+    >>> metrics.accuracy
+    0.66...
+    >>> metrics.weightedFalsePositiveRate
+    0.19...
+    >>> metrics.weightedPrecision
+    0.68...
+    >>> metrics.weightedRecall
+    0.66...
+    >>> metrics.weightedFMeasure()
+    0.66...
+    >>> metrics.weightedFMeasure(2.0)
+    0.65...
 
     .. versionadded:: 1.4.0
     """
 
-    def __init__(self, predictionAndLabels):
-        sc = predictionAndLabels.ctx
+    def __init__(self, predAndLabelsWithOptWeight):
+        sc = predAndLabelsWithOptWeight.ctx
         sql_ctx = SQLContext.getOrCreate(sc)
-        df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
+        numCol = len(predAndLabelsWithOptWeight.first())
+        schema = StructType([
             StructField("prediction", DoubleType(), nullable=False),
-            StructField("label", DoubleType(), nullable=False)]))
+            StructField("label", DoubleType(), nullable=False)])
+        if (numCol == 3):
+            schema.add("weight", DoubleType(), False)
+        df = sql_ctx.createDataFrame(predAndLabelsWithOptWeight, schema)
         java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
         java_model = java_class(df._jdf)
         super(MulticlassMetrics, self).__init__(java_model)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org