You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/05/14 15:57:27 UTC
[spark] branch branch-3.0 updated: [SPARK-31681][ML][PYSPARK]
Python multiclass logistic regression evaluate should return
LogisticRegressionSummary
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 6834f46 [SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary
6834f46 is described below
commit 6834f4691b3e2603d410bfe24f0db0b6e3a36446
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Thu May 14 10:54:35 2020 -0500
[SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary
### What changes were proposed in this pull request?
Return LogisticRegressionSummary for multiclass logistic regression evaluate in PySpark
### Why are the changes needed?
Currently we have
```
since("2.0.0")
def evaluate(self, dataset):
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)
```
we should return LogisticRegressionSummary for multiclass logistic regression
### Does this PR introduce _any_ user-facing change?
Yes
return LogisticRegressionSummary instead of BinaryLogisticRegressionSummary for multiclass logistic regression in Python
### How was this patch tested?
unit test
Closes #28503 from huaxingao/lr_summary.
Authored-by: Huaxin Gao <hu...@us.ibm.com>
Signed-off-by: Sean Owen <sr...@gmail.com>
(cherry picked from commit e10516ae63cfc58f2d493e4d3f19940d45c8f033)
Signed-off-by: Sean Owen <sr...@gmail.com>
---
python/pyspark/ml/classification.py | 5 ++++-
python/pyspark/ml/tests/test_training_summary.py | 6 +++++-
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 1436b78..424e16a 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -831,7 +831,10 @@ class LogisticRegressionModel(JavaProbabilisticClassificationModel, _LogisticReg
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_blr_summary = self._call_java("evaluate", dataset)
- return BinaryLogisticRegressionSummary(java_blr_summary)
+ if self.numClasses <= 2:
+ return BinaryLogisticRegressionSummary(java_blr_summary)
+ else:
+ return LogisticRegressionSummary(java_blr_summary)
class LogisticRegressionSummary(JavaWrapper):
diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py
index 1d19ebf..b505409 100644
--- a/python/pyspark/ml/tests/test_training_summary.py
+++ b/python/pyspark/ml/tests/test_training_summary.py
@@ -21,7 +21,8 @@ import unittest
if sys.version > '3':
basestring = str
-from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \
+ LogisticRegressionSummary
from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
@@ -149,6 +150,7 @@ class TrainingSummaryTest(SparkSessionTestCase):
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
+ self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
def test_multiclass_logistic_regression_summary(self):
@@ -187,6 +189,8 @@ class TrainingSummaryTest(SparkSessionTestCase):
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
+ self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary))
+ self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
def test_gaussian_mixture_summary(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org