You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/02/01 23:30:22 UTC
[spark] branch master updated: [SPARK-26754][PYTHON] Add
hasTrainingSummary to replace duplicate code in PySpark
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5bb9647 [SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark
5bb9647 is described below
commit 5bb9647e1019ea7eb17af7d2057fdacb7f4c560b
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Fri Feb 1 17:29:58 2019 -0600
[SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark
## What changes were proposed in this pull request?
Python version of https://github.com/apache/spark/pull/17654
## How was this patch tested?
Existing Python unit test
Closes #23676 from huaxingao/spark26754.
Authored-by: Huaxin Gao <hu...@us.ibm.com>
Signed-off-by: Sean Owen <se...@databricks.com>
---
python/pyspark/ml/classification.py | 19 ++++++-------------
python/pyspark/ml/clustering.py | 37 ++++++-------------------------------
python/pyspark/ml/regression.py | 30 ++++++------------------------
python/pyspark/ml/util.py | 26 ++++++++++++++++++++++++++
4 files changed, 44 insertions(+), 68 deletions(-)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 89b9278..134b9e0 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -483,7 +483,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
return self.getOrDefault(self.upperBoundsOnIntercepts)
-class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
+class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable,
+ HasTrainingSummary):
"""
Model fitted by LogisticRegression.
@@ -532,24 +533,16 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
- java_lrt_summary = self._call_java("summary")
if self.numClasses <= 2:
- return BinaryLogisticRegressionTrainingSummary(java_lrt_summary)
+ return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
+ self).summary)
else:
- return LogisticRegressionTrainingSummary(java_lrt_summary)
+ return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
+ self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
- @property
- @since("2.0.0")
- def hasSummary(self):
- """
- Indicates whether a training summary exists for this model
- instance.
- """
- return self._call_java("hasSummary")
-
@since("2.0.0")
def evaluate(self, dataset):
"""
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index b9c6bdf..864e2a3 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -97,7 +97,7 @@ class ClusteringSummary(JavaWrapper):
return self._call_java("numIter")
-class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by GaussianMixture.
@@ -126,22 +126,13 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
@property
@since("2.1.0")
- def hasSummary(self):
- """
- Indicates whether a training summary exists for this model
- instance.
- """
- return self._call_java("hasSummary")
-
- @property
- @since("2.1.0")
def summary(self):
"""
Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
- return GaussianMixtureSummary(self._call_java("summary"))
+ return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
@@ -323,7 +314,7 @@ class KMeansSummary(ClusteringSummary):
return self._call_java("trainingCost")
-class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
+class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by KMeans.
@@ -337,21 +328,13 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
@property
@since("2.1.0")
- def hasSummary(self):
- """
- Indicates whether a training summary exists for this model instance.
- """
- return self._call_java("hasSummary")
-
- @property
- @since("2.1.0")
def summary(self):
"""
Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
- return KMeansSummary(self._call_java("summary"))
+ return KMeansSummary(super(KMeansModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
@@ -507,7 +490,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
return self.getOrDefault(self.distanceMeasure)
-class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by BisectingKMeans.
@@ -536,21 +519,13 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
@property
@since("2.1.0")
- def hasSummary(self):
- """
- Indicates whether a training summary exists for this model instance.
- """
- return self._call_java("hasSummary")
-
- @property
- @since("2.1.0")
def summary(self):
"""
Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
- return BisectingKMeansSummary(self._call_java("summary"))
+ return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 9e1f8f8..7841de9 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -161,7 +161,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
return self.getOrDefault(self.epsilon)
-class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable):
+class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable,
+ HasTrainingSummary):
"""
Model fitted by :class:`LinearRegression`.
@@ -201,21 +202,11 @@ class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritabl
`trainingSummary is None`.
"""
if self.hasSummary:
- java_lrt_summary = self._call_java("summary")
- return LinearRegressionTrainingSummary(java_lrt_summary)
+ return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
- @property
- @since("2.0.0")
- def hasSummary(self):
- """
- Indicates whether a training summary exists for this model
- instance.
- """
- return self._call_java("hasSummary")
-
@since("2.0.0")
def evaluate(self, dataset):
"""
@@ -1648,7 +1639,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
- JavaMLReadable):
+ JavaMLReadable, HasTrainingSummary):
"""
.. note:: Experimental
@@ -1682,21 +1673,12 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWri
`trainingSummary is None`.
"""
if self.hasSummary:
- java_glrt_summary = self._call_java("summary")
- return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary)
+ return GeneralizedLinearRegressionTrainingSummary(
+ super(GeneralizedLinearRegressionModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
- @property
- @since("2.0.0")
- def hasSummary(self):
- """
- Indicates whether a training summary exists for this model
- instance.
- """
- return self._call_java("hasSummary")
-
@since("2.0.0")
def evaluate(self, dataset):
"""
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index e846834..e184e1a 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -611,3 +611,29 @@ class DefaultParamsReader(MLReader):
py_type = DefaultParamsReader.__get_class(pythonClassName)
instance = py_type.load(path)
return instance
+
+
+@inherit_doc
+class HasTrainingSummary(object):
+ """
+ Base class for models that provides Training summary.
+ .. versionadded:: 3.0.0
+ """
+
+ @property
+ @since("2.1.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @property
+ @since("2.1.0")
+ def summary(self):
+ """
+ Gets summary of the model trained on the training set. An exception is thrown if
+ no summary exists.
+ """
+ return (self._call_java("summary"))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org