You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/05/14 15:57:27 UTC

[spark] branch branch-3.0 updated: [SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 6834f46  [SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary
6834f46 is described below

commit 6834f4691b3e2603d410bfe24f0db0b6e3a36446
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Thu May 14 10:54:35 2020 -0500

    [SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary
    
    ### What changes were proposed in this pull request?
    Return LogisticRegressionSummary for multiclass logistic regression evaluate in PySpark
    
    ### Why are the changes needed?
    Currently we have
    ```
        since("2.0.0")
        def evaluate(self, dataset):
            if not isinstance(dataset, DataFrame):
                raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
            java_blr_summary = self._call_java("evaluate", dataset)
            return BinaryLogisticRegressionSummary(java_blr_summary)
    ```
    we should return LogisticRegressionSummary for multiclass logistic regression
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    return LogisticRegressionSummary instead of BinaryLogisticRegressionSummary for multiclass logistic regression in Python
    
    ### How was this patch tested?
    unit test
    
    Closes #28503 from huaxingao/lr_summary.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
    (cherry picked from commit e10516ae63cfc58f2d493e4d3f19940d45c8f033)
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 python/pyspark/ml/classification.py              | 5 ++++-
 python/pyspark/ml/tests/test_training_summary.py | 6 +++++-
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 1436b78..424e16a 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -831,7 +831,10 @@ class LogisticRegressionModel(JavaProbabilisticClassificationModel, _LogisticReg
         if not isinstance(dataset, DataFrame):
             raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
         java_blr_summary = self._call_java("evaluate", dataset)
-        return BinaryLogisticRegressionSummary(java_blr_summary)
+        if self.numClasses <= 2:
+            return BinaryLogisticRegressionSummary(java_blr_summary)
+        else:
+            return LogisticRegressionSummary(java_blr_summary)
 
 
 class LogisticRegressionSummary(JavaWrapper):
diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py
index 1d19ebf..b505409 100644
--- a/python/pyspark/ml/tests/test_training_summary.py
+++ b/python/pyspark/ml/tests/test_training_summary.py
@@ -21,7 +21,8 @@ import unittest
 if sys.version > '3':
     basestring = str
 
-from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \
+    LogisticRegressionSummary
 from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
 from pyspark.ml.linalg import Vectors
 from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
@@ -149,6 +150,7 @@ class TrainingSummaryTest(SparkSessionTestCase):
         # test evaluation (with training dataset) produces a summary with same values
         # one check is enough to verify a summary is returned, Scala version runs full test
         sameSummary = model.evaluate(df)
+        self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary))
         self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
 
     def test_multiclass_logistic_regression_summary(self):
@@ -187,6 +189,8 @@ class TrainingSummaryTest(SparkSessionTestCase):
         # test evaluation (with training dataset) produces a summary with same values
         # one check is enough to verify a summary is returned, Scala version runs full test
         sameSummary = model.evaluate(df)
+        self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary))
+        self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary))
         self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
 
     def test_gaussian_mixture_summary(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org