You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2019/02/08 17:47:43 UTC
[spark] branch master updated: [SPARK-26185][PYTHON] add weightCol
in python MulticlassClassificationEvaluator
This is an automated email from the ASF dual-hosted git repository.
holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 91e64e2 [SPARK-26185][PYTHON] add weightCol in python MulticlassClassificationEvaluator
91e64e2 is described below
commit 91e64e24d54287b1e4564358a2ef2bc8c0e6a22b
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Fri Feb 8 09:46:54 2019 -0800
[SPARK-26185][PYTHON] add weightCol in python MulticlassClassificationEvaluator
## What changes were proposed in this pull request?
add weightCol for python version of MulticlassClassificationEvaluator and MulticlassMetrics
## How was this patch tested?
add doc test
Closes #23157 from huaxingao/spark-26185.
Authored-by: Huaxin Gao <hu...@us.ibm.com>
Signed-off-by: Holden Karau <ho...@pigscanfly.ca>
---
python/pyspark/ml/evaluation.py | 23 +++++++++++++++------
python/pyspark/mllib/evaluation.py | 42 +++++++++++++++++++++++++++++++++-----
2 files changed, 54 insertions(+), 11 deletions(-)
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 8eaf076..f563a2d 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -22,7 +22,7 @@ from pyspark import since, keyword_only
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol, \
- HasFeaturesCol
+ HasFeaturesCol, HasWeightCol
from pyspark.ml.common import inherit_doc
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
@@ -257,7 +257,7 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
@inherit_doc
-class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
+class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -279,6 +279,17 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
>>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)
>>> str(evaluator2.getPredictionCol())
'prediction'
+ >>> scoreAndLabelsAndWeight = [(0.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0),
+ ... (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),
+ ... (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)]
+ >>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["prediction", "label", "weight"])
+ ...
+ >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",
+ ... weightCol="weight")
+ >>> evaluator.evaluate(dataset)
+ 0.66...
+ >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
+ 0.66...
.. versionadded:: 1.5.0
"""
@@ -289,10 +300,10 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
@keyword_only
def __init__(self, predictionCol="prediction", labelCol="label",
- metricName="f1"):
+ metricName="f1", weightCol=None):
"""
__init__(self, predictionCol="prediction", labelCol="label", \
- metricName="f1")
+ metricName="f1", weightCol=None)
"""
super(MulticlassClassificationEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
@@ -318,10 +329,10 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
@keyword_only
@since("1.5.0")
def setParams(self, predictionCol="prediction", labelCol="label",
- metricName="f1"):
+ metricName="f1", weightCol=None):
"""
setParams(self, predictionCol="prediction", labelCol="label", \
- metricName="f1")
+ metricName="f1", weightCol=None)
Sets params for multiclass classification evaluator.
"""
kwargs = self._input_kwargs
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 6ca6df6..b028394 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -162,7 +162,7 @@ class MulticlassMetrics(JavaModelWrapper):
"""
Evaluator for multiclass classification.
- :param predictionAndLabels: an RDD of (prediction, label) pairs.
+ :param predAndLabelsWithOptWeight: an RDD of prediction, label and optional weight.
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
@@ -191,16 +191,48 @@ class MulticlassMetrics(JavaModelWrapper):
0.66...
>>> metrics.weightedFMeasure(2.0)
0.65...
+ >>> predAndLabelsWithOptWeight = sc.parallelize([(0.0, 0.0, 1.0), (0.0, 1.0, 1.0),
+ ... (0.0, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),
+ ... (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)])
+ >>> metrics = MulticlassMetrics(predAndLabelsWithOptWeight)
+ >>> metrics.confusionMatrix().toArray()
+ array([[ 2., 1., 1.],
+ [ 1., 3., 0.],
+ [ 0., 0., 1.]])
+ >>> metrics.falsePositiveRate(0.0)
+ 0.2...
+ >>> metrics.precision(1.0)
+ 0.75...
+ >>> metrics.recall(2.0)
+ 1.0...
+ >>> metrics.fMeasure(0.0, 2.0)
+ 0.52...
+ >>> metrics.accuracy
+ 0.66...
+ >>> metrics.weightedFalsePositiveRate
+ 0.19...
+ >>> metrics.weightedPrecision
+ 0.68...
+ >>> metrics.weightedRecall
+ 0.66...
+ >>> metrics.weightedFMeasure()
+ 0.66...
+ >>> metrics.weightedFMeasure(2.0)
+ 0.65...
.. versionadded:: 1.4.0
"""
- def __init__(self, predictionAndLabels):
- sc = predictionAndLabels.ctx
+ def __init__(self, predAndLabelsWithOptWeight):
+ sc = predAndLabelsWithOptWeight.ctx
sql_ctx = SQLContext.getOrCreate(sc)
- df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
+ numCol = len(predAndLabelsWithOptWeight.first())
+ schema = StructType([
StructField("prediction", DoubleType(), nullable=False),
- StructField("label", DoubleType(), nullable=False)]))
+ StructField("label", DoubleType(), nullable=False)])
+ if (numCol == 3):
+ schema.add("weight", DoubleType(), False)
+ df = sql_ctx.createDataFrame(predAndLabelsWithOptWeight, schema)
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
java_model = java_class(df._jdf)
super(MulticlassMetrics, self).__init__(java_model)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org