You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/02/01 23:30:22 UTC

[spark] branch master updated: [SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5bb9647  [SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark
5bb9647 is described below

commit 5bb9647e1019ea7eb17af7d2057fdacb7f4c560b
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Fri Feb 1 17:29:58 2019 -0600

    [SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark
    
    ## What changes were proposed in this pull request?
    
    Python version of https://github.com/apache/spark/pull/17654
    
    ## How was this patch tested?
    
    Existing Python unit test
    
    Closes #23676 from huaxingao/spark26754.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: Sean Owen <se...@databricks.com>
---
 python/pyspark/ml/classification.py | 19 ++++++-------------
 python/pyspark/ml/clustering.py     | 37 ++++++-------------------------------
 python/pyspark/ml/regression.py     | 30 ++++++------------------------
 python/pyspark/ml/util.py           | 26 ++++++++++++++++++++++++++
 4 files changed, 44 insertions(+), 68 deletions(-)

diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 89b9278..134b9e0 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -483,7 +483,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
         return self.getOrDefault(self.upperBoundsOnIntercepts)
 
 
-class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
+class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable,
+                              HasTrainingSummary):
     """
     Model fitted by LogisticRegression.
 
@@ -532,24 +533,16 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
         trained on the training set. An exception is thrown if `trainingSummary is None`.
         """
         if self.hasSummary:
-            java_lrt_summary = self._call_java("summary")
             if self.numClasses <= 2:
-                return BinaryLogisticRegressionTrainingSummary(java_lrt_summary)
+                return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
+                                                                     self).summary)
             else:
-                return LogisticRegressionTrainingSummary(java_lrt_summary)
+                return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
+                                                               self).summary)
         else:
             raise RuntimeError("No training summary available for this %s" %
                                self.__class__.__name__)
 
-    @property
-    @since("2.0.0")
-    def hasSummary(self):
-        """
-        Indicates whether a training summary exists for this model
-        instance.
-        """
-        return self._call_java("hasSummary")
-
     @since("2.0.0")
     def evaluate(self, dataset):
         """
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index b9c6bdf..864e2a3 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -97,7 +97,7 @@ class ClusteringSummary(JavaWrapper):
         return self._call_java("numIter")
 
 
-class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary):
     """
     Model fitted by GaussianMixture.
 
@@ -126,22 +126,13 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
 
     @property
     @since("2.1.0")
-    def hasSummary(self):
-        """
-        Indicates whether a training summary exists for this model
-        instance.
-        """
-        return self._call_java("hasSummary")
-
-    @property
-    @since("2.1.0")
     def summary(self):
         """
         Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
         training set. An exception is thrown if no summary exists.
         """
         if self.hasSummary:
-            return GaussianMixtureSummary(self._call_java("summary"))
+            return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
         else:
             raise RuntimeError("No training summary available for this %s" %
                                self.__class__.__name__)
@@ -323,7 +314,7 @@ class KMeansSummary(ClusteringSummary):
         return self._call_java("trainingCost")
 
 
-class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
+class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary):
     """
     Model fitted by KMeans.
 
@@ -337,21 +328,13 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
 
     @property
     @since("2.1.0")
-    def hasSummary(self):
-        """
-        Indicates whether a training summary exists for this model instance.
-        """
-        return self._call_java("hasSummary")
-
-    @property
-    @since("2.1.0")
     def summary(self):
         """
         Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
         training set. An exception is thrown if no summary exists.
         """
         if self.hasSummary:
-            return KMeansSummary(self._call_java("summary"))
+            return KMeansSummary(super(KMeansModel, self).summary)
         else:
             raise RuntimeError("No training summary available for this %s" %
                                self.__class__.__name__)
@@ -507,7 +490,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
         return self.getOrDefault(self.distanceMeasure)
 
 
-class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary):
     """
     Model fitted by BisectingKMeans.
 
@@ -536,21 +519,13 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
 
     @property
     @since("2.1.0")
-    def hasSummary(self):
-        """
-        Indicates whether a training summary exists for this model instance.
-        """
-        return self._call_java("hasSummary")
-
-    @property
-    @since("2.1.0")
     def summary(self):
         """
         Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
         training set. An exception is thrown if no summary exists.
         """
         if self.hasSummary:
-            return BisectingKMeansSummary(self._call_java("summary"))
+            return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
         else:
             raise RuntimeError("No training summary available for this %s" %
                                self.__class__.__name__)
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 9e1f8f8..7841de9 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -161,7 +161,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
         return self.getOrDefault(self.epsilon)
 
 
-class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable):
+class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable,
+                            HasTrainingSummary):
     """
     Model fitted by :class:`LinearRegression`.
 
@@ -201,21 +202,11 @@ class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritabl
         `trainingSummary is None`.
         """
         if self.hasSummary:
-            java_lrt_summary = self._call_java("summary")
-            return LinearRegressionTrainingSummary(java_lrt_summary)
+            return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
         else:
             raise RuntimeError("No training summary available for this %s" %
                                self.__class__.__name__)
 
-    @property
-    @since("2.0.0")
-    def hasSummary(self):
-        """
-        Indicates whether a training summary exists for this model
-        instance.
-        """
-        return self._call_java("hasSummary")
-
     @since("2.0.0")
     def evaluate(self, dataset):
         """
@@ -1648,7 +1639,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
 
 
 class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
-                                       JavaMLReadable):
+                                       JavaMLReadable, HasTrainingSummary):
     """
     .. note:: Experimental
 
@@ -1682,21 +1673,12 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWri
         `trainingSummary is None`.
         """
         if self.hasSummary:
-            java_glrt_summary = self._call_java("summary")
-            return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary)
+            return GeneralizedLinearRegressionTrainingSummary(
+                super(GeneralizedLinearRegressionModel, self).summary)
         else:
             raise RuntimeError("No training summary available for this %s" %
                                self.__class__.__name__)
 
-    @property
-    @since("2.0.0")
-    def hasSummary(self):
-        """
-        Indicates whether a training summary exists for this model
-        instance.
-        """
-        return self._call_java("hasSummary")
-
     @since("2.0.0")
     def evaluate(self, dataset):
         """
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index e846834..e184e1a 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -611,3 +611,29 @@ class DefaultParamsReader(MLReader):
         py_type = DefaultParamsReader.__get_class(pythonClassName)
         instance = py_type.load(path)
         return instance
+
+
+@inherit_doc
+class HasTrainingSummary(object):
+    """
+    Base class for models that provides Training summary.
+    .. versionadded:: 3.0.0
+    """
+
+    @property
+    @since("2.1.0")
+    def hasSummary(self):
+        """
+        Indicates whether a training summary exists for this model
+        instance.
+        """
+        return self._call_java("hasSummary")
+
+    @property
+    @since("2.1.0")
+    def summary(self):
+        """
+        Gets summary of the model trained on the training set. An exception is thrown if
+        no summary exists.
+        """
+        return (self._call_java("summary"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org