You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2020/03/04 04:21:00 UTC
[spark] branch master updated: [SPARK-29212][ML][PYSPARK] Add
common classes without using JVM backend
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e1b3e9a [SPARK-29212][ML][PYSPARK] Add common classes without using JVM backend
e1b3e9a is described below
commit e1b3e9a3d25978dc0ad4609ecbc157ea1eebe2dd
Author: zero323 <ms...@gmail.com>
AuthorDate: Wed Mar 4 12:20:02 2020 +0800
[SPARK-29212][ML][PYSPARK] Add common classes without using JVM backend
### What changes were proposed in this pull request?
Implement common base ML classes (`Predictor`, `PredictionModel`, `Classifier`, `ClasssificationModel` `ProbabilisticClassifier`, `ProbabilisticClasssificationModel`, `Regressor`, `RegrssionModel`) for non-Java backends.
Note
- `Predictor` and `JavaClassifier` should be abstract as `_fit` method is not implemented.
- `PredictionModel` should be abstract as `_transform` is not implemented.
### Why are the changes needed?
To provide extensions points for non-JVM algorithms, as well as a public (as opposed to `Java*` variants, which are commonly described in docstrings as private) hierarchy which can be used to distinguish between different classes of predictors.
For longer discussion see [SPARK-29212](https://issues.apache.org/jira/browse/SPARK-29212) and / or https://github.com/apache/spark/pull/25776.
### Does this PR introduce any user-facing change?
It adds new base classes as listed above, but effective interfaces (method resolution order notwithstanding) stay the same.
Additionally "private" `Java*` classes in`ml.regression` and `ml.classification` have been renamed to follow PEP-8 conventions (added leading underscore).
It is for discussion if the same should be done to equivalent classes from `ml.wrapper`.
If we take `JavaClassifier` as an example, type hierarchy will change from
![old pyspark ml classification JavaClassifier](https://user-images.githubusercontent.com/1554276/72657093-5c0b0c80-39a0-11ea-9069-a897d75de483.png)
to
![new pyspark ml classification _JavaClassifier](https://user-images.githubusercontent.com/1554276/72657098-64fbde00-39a0-11ea-8f80-01187a5ea5a6.png)
Similarly the old model
![old pyspark ml classification JavaClassificationModel](https://user-images.githubusercontent.com/1554276/72657103-7513bd80-39a0-11ea-9ffc-59eb6ab61fde.png)
will become
![new pyspark ml classification _JavaClassificationModel](https://user-images.githubusercontent.com/1554276/72657110-80ff7f80-39a0-11ea-9f5c-fe408664e827.png)
### How was this patch tested?
Existing unit tests.
Closes #27245 from zero323/SPARK-29212.
Authored-by: zero323 <ms...@gmail.com>
Signed-off-by: zhengruifeng <ru...@foxmail.com>
---
python/pyspark/ml/__init__.py | 6 +-
python/pyspark/ml/base.py | 81 ++++++++++++++++-
python/pyspark/ml/classification.py | 158 +++++++++++++++++++++++++---------
python/pyspark/ml/regression.py | 71 ++++++++++-----
python/pyspark/ml/tests/test_param.py | 6 +-
python/pyspark/ml/wrapper.py | 52 ++---------
6 files changed, 258 insertions(+), 116 deletions(-)
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
index d99a253..47fc78e 100644
--- a/python/pyspark/ml/__init__.py
+++ b/python/pyspark/ml/__init__.py
@@ -19,13 +19,15 @@
DataFrame-based machine learning APIs to let users quickly assemble and configure practical
machine learning pipelines.
"""
-from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer
+from pyspark.ml.base import Estimator, Model, Predictor, PredictionModel, \
+ Transformer, UnaryTransformer
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml import classification, clustering, evaluation, feature, fpm, \
image, pipeline, recommendation, regression, stat, tuning, util, linalg, param
__all__ = [
- "Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel",
+ "Transformer", "UnaryTransformer", "Estimator", "Model",
+ "Predictor", "PredictionModel", "Pipeline", "PipelineModel",
"classification", "clustering", "evaluation", "feature", "fpm", "image",
"recommendation", "regression", "stat", "tuning", "util", "linalg", "param",
]
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 542cb25..b8df5a3 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from abc import ABCMeta, abstractmethod
+from abc import ABCMeta, abstractmethod, abstractproperty
import copy
import threading
@@ -246,3 +246,82 @@ class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
transformedDataset = dataset.withColumn(self.getOutputCol(),
transformUDF(dataset[self.getInputCol()]))
return transformedDataset
+
+
+@inherit_doc
+class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
+ """
+ Params for :py:class:`Predictor` and :py:class:`PredictorModel`.
+
+ .. versionadded:: 3.0.0
+ """
+ pass
+
+
+@inherit_doc
+class Predictor(Estimator, _PredictorParams):
+ """
+ Estimator for prediction tasks (regression and classification).
+ """
+
+ __metaclass__ = ABCMeta
+
+ @since("3.0.0")
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+
+@inherit_doc
+class PredictionModel(Model, _PredictorParams):
+ """
+ Model for prediction tasks (regression and classification).
+ """
+
+ __metaclass__ = ABCMeta
+
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @abstractproperty
+ @since("2.1.0")
+ def numFeatures(self):
+ """
+ Returns the number of features the model was trained on. If unknown, returns -1
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ @since("3.0.0")
+ def predict(self, value):
+ """
+ Predict label for the given features.
+ """
+ raise NotImplementedError()
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 1436b78..0d88aa8 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -17,18 +17,20 @@
import operator
import sys
+from abc import ABCMeta, abstractmethod, abstractproperty
from multiprocessing.pool import ThreadPool
from pyspark import since, keyword_only
-from pyspark.ml import Estimator, Model
+from pyspark.ml import Estimator, Predictor, PredictionModel, Model
from pyspark.ml.param.shared import *
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
_TreeEnsembleModel, _RandomForestParams, _GBTParams, \
_HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
from pyspark.ml.util import *
+from pyspark.ml.base import _PredictorParams
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
- JavaPredictor, _JavaPredictorParams, JavaPredictionModel, JavaWrapper
+ JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc, _java2py, _py2java
from pyspark.ml.linalg import Vectors
from pyspark.sql import DataFrame
@@ -49,9 +51,9 @@ __all__ = ['LinearSVC', 'LinearSVCModel',
'FMClassifier', 'FMClassificationModel']
-class _JavaClassifierParams(HasRawPredictionCol, _JavaPredictorParams):
+class _ClassifierParams(HasRawPredictionCol, _PredictorParams):
"""
- Java Classifier Params for classification tasks.
+ Classifier Params for classification tasks.
.. versionadded:: 3.0.0
"""
@@ -59,12 +61,14 @@ class _JavaClassifierParams(HasRawPredictionCol, _JavaPredictorParams):
@inherit_doc
-class JavaClassifier(JavaPredictor, _JavaClassifierParams):
+class Classifier(Predictor, _ClassifierParams):
"""
- Java Classifier for classification tasks.
+ Classifier for classification tasks.
Classes are indexed {0, 1, ..., numClasses - 1}.
"""
+ __metaclass__ = ABCMeta
+
@since("3.0.0")
def setRawPredictionCol(self, value):
"""
@@ -74,13 +78,14 @@ class JavaClassifier(JavaPredictor, _JavaClassifierParams):
@inherit_doc
-class JavaClassificationModel(JavaPredictionModel, _JavaClassifierParams):
+class ClassificationModel(PredictionModel, _ClassifierParams):
"""
- Java Model produced by a ``Classifier``.
+ Model produced by a ``Classifier``.
Classes are indexed {0, 1, ..., numClasses - 1}.
- To be mixed in with class:`pyspark.ml.JavaModel`
"""
+ __metaclass__ = ABCMeta
+
@since("3.0.0")
def setRawPredictionCol(self, value):
"""
@@ -88,26 +93,27 @@ class JavaClassificationModel(JavaPredictionModel, _JavaClassifierParams):
"""
return self._set(rawPredictionCol=value)
- @property
+ @abstractproperty
@since("2.1.0")
def numClasses(self):
"""
Number of classes (values which the label can take).
"""
- return self._call_java("numClasses")
+ raise NotImplementedError()
+ @abstractmethod
@since("3.0.0")
def predictRaw(self, value):
"""
Raw prediction for each possible label.
"""
- return self._call_java("predictRaw", value)
+ raise NotImplementedError()
-class _JavaProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _JavaClassifierParams):
+class _ProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _ClassifierParams):
"""
- Params for :py:class:`JavaProbabilisticClassifier` and
- :py:class:`JavaProbabilisticClassificationModel`.
+ Params for :py:class:`ProbabilisticClassifier` and
+ :py:class:`ProbabilisticClassificationModel`.
.. versionadded:: 3.0.0
"""
@@ -115,11 +121,13 @@ class _JavaProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _Java
@inherit_doc
-class JavaProbabilisticClassifier(JavaClassifier, _JavaProbabilisticClassifierParams):
+class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams):
"""
- Java Probabilistic Classifier for classification tasks.
+ Probabilistic Classifier for classification tasks.
"""
+ __metaclass__ = ABCMeta
+
@since("3.0.0")
def setProbabilityCol(self, value):
"""
@@ -136,12 +144,14 @@ class JavaProbabilisticClassifier(JavaClassifier, _JavaProbabilisticClassifierPa
@inherit_doc
-class JavaProbabilisticClassificationModel(JavaClassificationModel,
- _JavaProbabilisticClassifierParams):
+class ProbabilisticClassificationModel(ClassificationModel,
+ _ProbabilisticClassifierParams):
"""
- Java Model produced by a ``ProbabilisticClassifier``.
+ Model produced by a ``ProbabilisticClassifier``.
"""
+ __metaclass__ = ABCMeta
+
@since("3.0.0")
def setProbabilityCol(self, value):
"""
@@ -156,6 +166,72 @@ class JavaProbabilisticClassificationModel(JavaClassificationModel,
"""
return self._set(thresholds=value)
+ @abstractmethod
+ @since("3.0.0")
+ def predictProbability(self, value):
+ """
+ Predict the probability of each class given the features.
+ """
+ raise NotImplementedError()
+
+
+@inherit_doc
+class _JavaClassifier(Classifier, JavaPredictor):
+ """
+ Java Classifier for classification tasks.
+ Classes are indexed {0, 1, ..., numClasses - 1}.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @since("3.0.0")
+ def setRawPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`rawPredictionCol`.
+ """
+ return self._set(rawPredictionCol=value)
+
+
+@inherit_doc
+class _JavaClassificationModel(ClassificationModel, JavaPredictionModel):
+ """
+ Java Model produced by a ``Classifier``.
+ Classes are indexed {0, 1, ..., numClasses - 1}.
+ To be mixed in with class:`pyspark.ml.JavaModel`
+ """
+
+ @property
+ @since("2.1.0")
+ def numClasses(self):
+ """
+ Number of classes (values which the label can take).
+ """
+ return self._call_java("numClasses")
+
+ @since("3.0.0")
+ def predictRaw(self, value):
+ """
+ Raw prediction for each possible label.
+ """
+ return self._call_java("predictRaw", value)
+
+
+@inherit_doc
+class _JavaProbabilisticClassifier(ProbabilisticClassifier, _JavaClassifier):
+ """
+ Java Probabilistic Classifier for classification tasks.
+ """
+
+ __metaclass__ = ABCMeta
+
+
+@inherit_doc
+class _JavaProbabilisticClassificationModel(ProbabilisticClassificationModel,
+ _JavaClassificationModel):
+ """
+ Java Model produced by a ``ProbabilisticClassifier``.
+ """
+
@since("3.0.0")
def predictProbability(self, value):
"""
@@ -164,7 +240,7 @@ class JavaProbabilisticClassificationModel(JavaClassificationModel,
return self._call_java("predictProbability", value)
-class _LinearSVCParams(_JavaClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol,
+class _LinearSVCParams(_ClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol,
HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold):
"""
Params for :py:class:`LinearSVC` and :py:class:`LinearSVCModel`.
@@ -180,7 +256,7 @@ class _LinearSVCParams(_JavaClassifierParams, HasRegParam, HasMaxIter, HasFitInt
@inherit_doc
-class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
+class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
"""
`Linear SVM Classifier <https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM>`_
@@ -343,7 +419,7 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable
return self._set(aggregationDepth=value)
-class LinearSVCModel(JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
+class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LinearSVC.
@@ -374,7 +450,7 @@ class LinearSVCModel(JavaClassificationModel, _LinearSVCParams, JavaMLWritable,
return self._call_java("intercept")
-class _LogisticRegressionParams(_JavaProbabilisticClassifierParams, HasRegParam,
+class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
HasElasticNetParam, HasMaxIter, HasFitIntercept, HasTol,
HasStandardization, HasWeightCol, HasAggregationDepth,
HasThreshold):
@@ -533,7 +609,7 @@ class _LogisticRegressionParams(_JavaProbabilisticClassifierParams, HasRegParam,
@inherit_doc
-class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams, JavaMLWritable,
+class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams, JavaMLWritable,
JavaMLReadable):
"""
Logistic regression.
@@ -759,7 +835,7 @@ class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams,
return self._set(aggregationDepth=value)
-class LogisticRegressionModel(JavaProbabilisticClassificationModel, _LogisticRegressionParams,
+class LogisticRegressionModel(_JavaProbabilisticClassificationModel, _LogisticRegressionParams,
JavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by LogisticRegression.
@@ -1131,7 +1207,7 @@ class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
@inherit_doc
-class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
+class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
@@ -1326,7 +1402,7 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifie
@inherit_doc
-class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel,
+class DecisionTreeClassificationModel(_DecisionTreeModel, _JavaProbabilisticClassificationModel,
_DecisionTreeClassifierParams, JavaMLWritable,
JavaMLReadable):
"""
@@ -1366,7 +1442,7 @@ class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
@inherit_doc
-class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifierParams,
+class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
@@ -1585,7 +1661,7 @@ class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifie
return self._set(minWeightFractionPerNode=value)
-class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
+class RandomForestClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClassificationModel,
_RandomForestClassifierParams, JavaMLWritable,
JavaMLReadable):
"""
@@ -1639,7 +1715,7 @@ class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
@inherit_doc
-class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
+class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
@@ -1904,7 +1980,7 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
return self._set(minWeightFractionPerNode=value)
-class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
+class GBTClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClassificationModel,
_GBTClassifierParams, JavaMLWritable, JavaMLReadable):
"""
Model fitted by GBTClassifier.
@@ -1945,7 +2021,7 @@ class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassification
return self._call_java("evaluateEachIteration", dataset)
-class _NaiveBayesParams(_JavaPredictorParams, HasWeightCol):
+class _NaiveBayesParams(_PredictorParams, HasWeightCol):
"""
Params for :py:class:`NaiveBayes` and :py:class:`NaiveBayesModel`.
@@ -1975,7 +2051,7 @@ class _NaiveBayesParams(_JavaPredictorParams, HasWeightCol):
@inherit_doc
-class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds, HasWeightCol,
+class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds, HasWeightCol,
JavaMLWritable, JavaMLReadable):
"""
Naive Bayes Classifiers.
@@ -2119,7 +2195,7 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
return self._set(weightCol=value)
-class NaiveBayesModel(JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable,
+class NaiveBayesModel(_JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by NaiveBayes.
@@ -2152,7 +2228,7 @@ class NaiveBayesModel(JavaProbabilisticClassificationModel, _NaiveBayesParams, J
return self._call_java("sigma")
-class _MultilayerPerceptronParams(_JavaProbabilisticClassifierParams, HasSeed, HasMaxIter,
+class _MultilayerPerceptronParams(_ProbabilisticClassifierParams, HasSeed, HasMaxIter,
HasTol, HasStepSize, HasSolver, HasBlockSize):
"""
Params for :py:class:`MultilayerPerceptronClassifier`.
@@ -2185,7 +2261,7 @@ class _MultilayerPerceptronParams(_JavaProbabilisticClassifierParams, HasSeed, H
@inherit_doc
-class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPerceptronParams,
+class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPerceptronParams,
JavaMLWritable, JavaMLReadable):
"""
Classifier trainer based on the Multilayer Perceptron.
@@ -2348,7 +2424,7 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
return self._set(solver=value)
-class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel,
+class MultilayerPerceptronClassificationModel(_JavaProbabilisticClassificationModel,
_MultilayerPerceptronParams, JavaMLWritable,
JavaMLReadable):
"""
@@ -2366,7 +2442,7 @@ class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationMod
return self._call_java("weights")
-class _OneVsRestParams(_JavaClassifierParams, HasWeightCol):
+class _OneVsRestParams(_ClassifierParams, HasWeightCol):
"""
Params for :py:class:`OneVsRest` and :py:class:`OneVsRestModelModel`.
"""
@@ -2802,7 +2878,7 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
@inherit_doc
-class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,
+class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,
JavaMLReadable):
"""
Factorization Machines learning algorithm for classification.
@@ -2973,7 +3049,7 @@ class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, Ja
return self._set(regParam=value)
-class FMClassificationModel(JavaProbabilisticClassificationModel, _FactorizationMachinesParams,
+class FMClassificationModel(_JavaProbabilisticClassificationModel, _FactorizationMachinesParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`FMClassifier`.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index a4c9782..f227fe0 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -16,15 +16,18 @@
#
import sys
+from abc import ABCMeta
from pyspark import since, keyword_only
+from pyspark.ml import Predictor, PredictionModel
+from pyspark.ml.base import _PredictorParams
from pyspark.ml.param.shared import *
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
_TreeEnsembleModel, _TreeEnsembleParams, _RandomForestParams, _GBTParams, \
_HasVarianceImpurity, _TreeRegressorParams
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
- JavaPredictor, JavaPredictionModel, _JavaPredictorParams, JavaWrapper
+ JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
from pyspark.sql import DataFrame
@@ -41,26 +44,48 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'FMRegressor', 'FMRegressionModel']
-class JavaRegressor(JavaPredictor, _JavaPredictorParams):
+class Regressor(Predictor, _PredictorParams):
+ """
+ Regressor for regression tasks.
+
+ .. versionadded:: 3.0.0
+ """
+
+ __metaclass__ = ABCMeta
+
+
+class RegressionModel(PredictionModel, _PredictorParams):
+ """
+ Model produced by a ``Regressor``.
+
+ .. versionadded:: 3.0.0
+ """
+
+ __metaclass__ = ABCMeta
+
+
+class _JavaRegressor(Regressor, JavaPredictor):
"""
Java Regressor for regression tasks.
.. versionadded:: 3.0.0
"""
- pass
+
+ __metaclass__ = ABCMeta
-class JavaRegressionModel(JavaPredictionModel, _JavaPredictorParams):
+class _JavaRegressionModel(RegressionModel, JavaPredictionModel):
"""
Java Model produced by a ``_JavaRegressor``.
To be mixed in with class:`pyspark.ml.JavaModel`
.. versionadded:: 3.0.0
"""
- pass
+
+ __metaclass__ = ABCMeta
-class _LinearRegressionParams(_JavaPredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter,
+class _LinearRegressionParams(_PredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter,
HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver,
HasAggregationDepth, HasLoss):
"""
@@ -88,7 +113,7 @@ class _LinearRegressionParams(_JavaPredictorParams, HasRegParam, HasElasticNetPa
@inherit_doc
-class LinearRegression(JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable):
+class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable):
"""
Linear regression.
@@ -270,7 +295,7 @@ class LinearRegression(JavaRegressor, _LinearRegressionParams, JavaMLWritable, J
return self._set(lossType=value)
-class LinearRegressionModel(JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable,
+class LinearRegressionModel(_JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable,
JavaMLReadable, HasTrainingSummary):
"""
Model fitted by :class:`LinearRegression`.
@@ -777,7 +802,7 @@ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, Ha
@inherit_doc
-class DecisionTreeRegressor(JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable,
+class DecisionTreeRegressor(_JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable,
JavaMLReadable):
"""
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
@@ -973,7 +998,7 @@ class DecisionTreeRegressor(JavaRegressor, _DecisionTreeRegressorParams, JavaMLW
@inherit_doc
class DecisionTreeRegressionModel(
- JavaRegressionModel, _DecisionTreeModel, _DecisionTreeRegressorParams,
+ _JavaRegressionModel, _DecisionTreeModel, _DecisionTreeRegressorParams,
JavaMLWritable, JavaMLReadable
):
"""
@@ -1021,7 +1046,7 @@ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
@inherit_doc
-class RandomForestRegressor(JavaRegressor, _RandomForestRegressorParams, JavaMLWritable,
+class RandomForestRegressor(_JavaRegressor, _RandomForestRegressorParams, JavaMLWritable,
JavaMLReadable):
"""
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
@@ -1230,7 +1255,7 @@ class RandomForestRegressor(JavaRegressor, _RandomForestRegressorParams, JavaMLW
class RandomForestRegressionModel(
- JavaRegressionModel, _TreeEnsembleModel, _RandomForestRegressorParams,
+ _JavaRegressionModel, _TreeEnsembleModel, _RandomForestRegressorParams,
JavaMLWritable, JavaMLReadable
):
"""
@@ -1284,7 +1309,7 @@ class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
@inherit_doc
-class GBTRegressor(JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
+class GBTRegressor(_JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
learning algorithm for regression.
@@ -1526,7 +1551,7 @@ class GBTRegressor(JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLRea
class GBTRegressionModel(
- JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams,
+ _JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams,
JavaMLWritable, JavaMLReadable
):
"""
@@ -1571,7 +1596,7 @@ class GBTRegressionModel(
return self._call_java("evaluateEachIteration", dataset, loss)
-class _AFTSurvivalRegressionParams(_JavaPredictorParams, HasMaxIter, HasTol, HasFitIntercept,
+class _AFTSurvivalRegressionParams(_PredictorParams, HasMaxIter, HasTol, HasFitIntercept,
HasAggregationDepth):
"""
Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
@@ -1618,7 +1643,7 @@ class _AFTSurvivalRegressionParams(_JavaPredictorParams, HasMaxIter, HasTol, Has
@inherit_doc
-class AFTSurvivalRegression(JavaRegressor, _AFTSurvivalRegressionParams,
+class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
@@ -1759,7 +1784,7 @@ class AFTSurvivalRegression(JavaRegressor, _AFTSurvivalRegressionParams,
return self._set(aggregationDepth=value)
-class AFTSurvivalRegressionModel(JavaRegressionModel, _AFTSurvivalRegressionParams,
+class AFTSurvivalRegressionModel(_JavaRegressionModel, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`AFTSurvivalRegression`.
@@ -1813,7 +1838,7 @@ class AFTSurvivalRegressionModel(JavaRegressionModel, _AFTSurvivalRegressionPara
return self._call_java("predictQuantiles", features)
-class _GeneralizedLinearRegressionParams(_JavaPredictorParams, HasFitIntercept, HasMaxIter,
+class _GeneralizedLinearRegressionParams(_PredictorParams, HasFitIntercept, HasMaxIter,
HasTol, HasRegParam, HasWeightCol, HasSolver,
HasAggregationDepth):
"""
@@ -1891,7 +1916,7 @@ class _GeneralizedLinearRegressionParams(_JavaPredictorParams, HasFitIntercept,
@inherit_doc
-class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionParams,
+class GeneralizedLinearRegression(_JavaRegressor, _GeneralizedLinearRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Generalized Linear Regression.
@@ -2096,7 +2121,7 @@ class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionPar
return self._set(aggregationDepth=value)
-class GeneralizedLinearRegressionModel(JavaRegressionModel, _GeneralizedLinearRegressionParams,
+class GeneralizedLinearRegressionModel(_JavaRegressionModel, _GeneralizedLinearRegressionParams,
JavaMLWritable, JavaMLReadable, HasTrainingSummary):
"""
Model fitted by :class:`GeneralizedLinearRegression`.
@@ -2328,7 +2353,7 @@ class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSumm
return self._call_java("toString")
-class _FactorizationMachinesParams(_JavaPredictorParams, HasMaxIter, HasStepSize, HasTol,
+class _FactorizationMachinesParams(_PredictorParams, HasMaxIter, HasStepSize, HasTol,
HasSolver, HasSeed, HasFitIntercept, HasRegParam):
"""
Params for :py:class:`FMRegressor`, :py:class:`FMRegressionModel`, :py:class:`FMClassifier`
@@ -2384,7 +2409,7 @@ class _FactorizationMachinesParams(_JavaPredictorParams, HasMaxIter, HasStepSize
@inherit_doc
-class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
+class FMRegressor(_JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
"""
Factorization Machines learning algorithm for regression.
@@ -2548,7 +2573,7 @@ class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, J
return self._set(regParam=value)
-class FMRegressionModel(JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable,
+class FMRegressionModel(_JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by :class:`FMRegressor`.
diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py
index 777b493..61f9f18 100644
--- a/python/pyspark/ml/tests/test_param.py
+++ b/python/pyspark/ml/tests/test_param.py
@@ -348,8 +348,9 @@ class DefaultValuesTests(PySparkTestCase):
Test :py:class:`JavaParams` classes to see if their default Param values match
those in their Scala counterparts.
"""
-
def test_java_params(self):
+ import re
+
import pyspark.ml.feature
import pyspark.ml.classification
import pyspark.ml.clustering
@@ -365,8 +366,9 @@ class DefaultValuesTests(PySparkTestCase):
for name, cls in inspect.getmembers(module, inspect.isclass):
if not name.endswith('Model') and not name.endswith('Params') \
and issubclass(cls, JavaParams) and not inspect.isabstract(cls) \
- and not name.startswith('Java') and name != '_LSH':
+ and not re.match("_?Java", name) and name != '_LSH':
# NOTE: disable check_params_exist until there is parity with Scala API
+
check_params(self, cls(), check_params_exist=False)
# Additional classes that need explicit construction
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index ae3a6ba..e59c6c7b 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -23,7 +23,8 @@ if sys.version >= '3':
from pyspark import since
from pyspark import SparkContext
from pyspark.sql import DataFrame
-from pyspark.ml import Estimator, Transformer, Model
+from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, Model
+from pyspark.ml.base import _PredictorParams
from pyspark.ml.param import Params
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol
from pyspark.ml.util import _jvm
@@ -377,63 +378,20 @@ class JavaModel(JavaTransformer, Model):
@inherit_doc
-class _JavaPredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
- """
- Params for :py:class:`JavaPredictor` and :py:class:`JavaPredictorModel`.
-
- .. versionadded:: 3.0.0
- """
- pass
-
-
-@inherit_doc
-class JavaPredictor(JavaEstimator, _JavaPredictorParams):
+class JavaPredictor(Predictor, JavaEstimator, _PredictorParams):
"""
(Private) Java Estimator for prediction tasks (regression and classification).
"""
- @since("3.0.0")
- def setLabelCol(self, value):
- """
- Sets the value of :py:attr:`labelCol`.
- """
- return self._set(labelCol=value)
-
- @since("3.0.0")
- def setFeaturesCol(self, value):
- """
- Sets the value of :py:attr:`featuresCol`.
- """
- return self._set(featuresCol=value)
-
- @since("3.0.0")
- def setPredictionCol(self, value):
- """
- Sets the value of :py:attr:`predictionCol`.
- """
- return self._set(predictionCol=value)
+ __metaclass__ = ABCMeta
@inherit_doc
-class JavaPredictionModel(JavaModel, _JavaPredictorParams):
+class JavaPredictionModel(PredictionModel, JavaModel, _PredictorParams):
"""
(Private) Java Model for prediction tasks (regression and classification).
"""
- @since("3.0.0")
- def setFeaturesCol(self, value):
- """
- Sets the value of :py:attr:`featuresCol`.
- """
- return self._set(featuresCol=value)
-
- @since("3.0.0")
- def setPredictionCol(self, value):
- """
- Sets the value of :py:attr:`predictionCol`.
- """
- return self._set(predictionCol=value)
-
@property
@since("2.1.0")
def numFeatures(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org