You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/04/16 16:31:31 UTC
spark git commit: [SPARK-21088][ML] CrossValidator,
TrainValidationSplit support collect all models when fitting: Python
API
Repository: spark
Updated Branches:
refs/heads/master 5003736ad -> 04614820e
[SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all models when fitting: Python API
## What changes were proposed in this pull request?
Add python API for collecting sub-models during CrossValidator/TrainValidationSplit fitting.
## How was this patch tested?
UT added.
Author: WeichenXu <we...@databricks.com>
Closes #19627 from WeichenXu123/expose-model-list-py.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/04614820
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/04614820
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/04614820
Branch: refs/heads/master
Commit: 04614820e103feeae91299dc90dba1dd628fd485
Parents: 5003736
Author: WeichenXu <we...@databricks.com>
Authored: Mon Apr 16 11:31:24 2018 -0500
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Apr 16 11:31:24 2018 -0500
----------------------------------------------------------------------
.../apache/spark/ml/tuning/CrossValidator.scala | 11 ++
.../spark/ml/tuning/TrainValidationSplit.scala | 11 ++
.../pyspark/ml/param/_shared_params_code_gen.py | 5 +
python/pyspark/ml/param/shared.py | 24 +++++
python/pyspark/ml/tests.py | 78 ++++++++++++++
python/pyspark/ml/tuning.py | 107 ++++++++++++++-----
python/pyspark/ml/util.py | 4 +
7 files changed, 211 insertions(+), 29 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index a0b507d..c2826dc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -270,6 +270,17 @@ class CrossValidatorModel private[ml] (
this
}
+ // A Python-friendly auxiliary method
+ private[tuning] def setSubModels(subModels: JList[JList[Model[_]]])
+ : CrossValidatorModel = {
+ _subModels = if (subModels != null) {
+ Some(subModels.asScala.toArray.map(_.asScala.toArray))
+ } else {
+ None
+ }
+ this
+ }
+
/**
* @return submodels represented in two dimension array. The index of outer array is the
* fold index, and the index of inner array corresponds to the ordering of
http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 88ff0df..8d1b9a8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -262,6 +262,17 @@ class TrainValidationSplitModel private[ml] (
this
}
+ // A Python-friendly auxiliary method
+ private[tuning] def setSubModels(subModels: JList[Model[_]])
+ : TrainValidationSplitModel = {
+ _subModels = if (subModels != null) {
+ Some(subModels.asScala.toArray)
+ } else {
+ None
+ }
+ this
+ }
+
/**
* @return submodels represented in array. The index of array corresponds to the ordering of
* estimatorParamMaps
http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/param/_shared_params_code_gen.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index db951d8..6e9e0a3 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -157,6 +157,11 @@ if __name__ == "__main__":
"TypeConverters.toInt"),
("parallelism", "the number of threads to use when running parallel algorithms (>= 1).",
"1", "TypeConverters.toInt"),
+ ("collectSubModels", "Param for whether to collect a list of sub-models trained during " +
+ "tuning. If set to false, then only the single best sub-model will be available after " +
+ "fitting. If set to true, then all sub-models will be available. Warning: For large " +
+ "models, collecting all sub-models can cause OOMs on the Spark driver.",
+ "False", "TypeConverters.toBoolean"),
("loss", "the loss function to be optimized.", None, "TypeConverters.toString")]
code = []
http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/param/shared.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 474c387..08408ee 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -655,6 +655,30 @@ class HasParallelism(Params):
return self.getOrDefault(self.parallelism)
+class HasCollectSubModels(Params):
+ """
+ Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.
+ """
+
+ collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean)
+
+ def __init__(self):
+ super(HasCollectSubModels, self).__init__()
+ self._setDefault(collectSubModels=False)
+
+ def setCollectSubModels(self, value):
+ """
+ Sets the value of :py:attr:`collectSubModels`.
+ """
+ return self._set(collectSubModels=value)
+
+ def getCollectSubModels(self):
+ """
+ Gets the value of collectSubModels or its default value.
+ """
+ return self.getOrDefault(self.collectSubModels)
+
+
class HasLoss(Params):
"""
Mixin for param loss: the loss function to be optimized.
http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4ce5454..2ec0be6 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1018,6 +1018,50 @@ class CrossValidatorTests(SparkSessionTestCase):
cvParallelModel = cv.fit(dataset)
self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics)
+ def test_expose_sub_models(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+
+ numFolds = 3
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
+ numFolds=numFolds, collectSubModels=True)
+
+ def checkSubModels(subModels):
+ self.assertEqual(len(subModels), numFolds)
+ for i in range(numFolds):
+ self.assertEqual(len(subModels[i]), len(grid))
+
+ cvModel = cv.fit(dataset)
+ checkSubModels(cvModel.subModels)
+
+ # Test the default value for option "persistSubModel" to be "true"
+ testSubPath = temp_path + "/testCrossValidatorSubModels"
+ savingPathWithSubModels = testSubPath + "cvModel3"
+ cvModel.save(savingPathWithSubModels)
+ cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
+ checkSubModels(cvModel3.subModels)
+ cvModel4 = cvModel3.copy()
+ checkSubModels(cvModel4.subModels)
+
+ savingPathWithoutSubModels = testSubPath + "cvModel2"
+ cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
+ cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
+ self.assertEqual(cvModel2.subModels, None)
+
+ for i in range(numFolds):
+ for j in range(len(grid)):
+ self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid)
+
def test_save_load_nested_estimator(self):
temp_path = tempfile.mkdtemp()
dataset = self.spark.createDataFrame(
@@ -1186,6 +1230,40 @@ class TrainValidationSplitTests(SparkSessionTestCase):
tvsParallelModel = tvs.fit(dataset)
self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics)
+ def test_expose_sub_models(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
+ collectSubModels=True)
+ tvsModel = tvs.fit(dataset)
+ self.assertEqual(len(tvsModel.subModels), len(grid))
+
+ # Test the default value for option "persistSubModel" to be "true"
+ testSubPath = temp_path + "/testTrainValidationSplitSubModels"
+ savingPathWithSubModels = testSubPath + "cvModel3"
+ tvsModel.save(savingPathWithSubModels)
+ tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
+ self.assertEqual(len(tvsModel3.subModels), len(grid))
+ tvsModel4 = tvsModel3.copy()
+ self.assertEqual(len(tvsModel4.subModels), len(grid))
+
+ savingPathWithoutSubModels = testSubPath + "cvModel2"
+ tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
+ tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
+ self.assertEqual(tvsModel2.subModels, None)
+
+ for i in range(len(grid)):
+ self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid)
+
def test_save_load_nested_estimator(self):
# This tests saving and loading the trained model only.
# Save/load for TrainValidationSplit will be added later: SPARK-13786
http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 545e24c..0c8029f 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -24,7 +24,7 @@ from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.common import _py2java
from pyspark.ml.param import Params, Param, TypeConverters
-from pyspark.ml.param.shared import HasParallelism, HasSeed
+from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
@@ -33,7 +33,7 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainVa
'TrainValidationSplitModel']
-def _parallelFitTasks(est, train, eva, validation, epm):
+def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
"""
Creates a list of callables which can be called from different threads to fit and evaluate
an estimator in parallel. Each callable returns an `(index, metric)` pair.
@@ -43,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm):
:param eva: Evaluator, used to compute `metric`
:param validation: DataFrame, validation data set, used for evaluation.
:param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
- :return: (int, float), an index into `epm` and the associated metric value.
+ :param collectSubModel: Whether to collect sub model.
+ :return: (int, float, subModel), an index into `epm` and the associated metric value.
"""
modelIter = est.fitMultiple(train, epm)
def singleTask():
index, model = next(modelIter)
metric = eva.evaluate(model.transform(validation, epm[index]))
- return index, metric
+ return index, metric, model if collectSubModel else None
return [singleTask] * len(epm)
@@ -194,7 +195,8 @@ class ValidatorParams(HasSeed):
return java_estimator, java_epms, java_evaluator
-class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable):
+class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels,
+ MLReadable, MLWritable):
"""
K-fold cross validation performs model selection by splitting the dataset into a set of
@@ -233,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
- seed=None, parallelism=1):
+ seed=None, parallelism=1, collectSubModels=False):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
- seed=None, parallelism=1)
+ seed=None, parallelism=1, collectSubModels=False)
"""
super(CrossValidator, self).__init__()
self._setDefault(numFolds=3, parallelism=1)
@@ -246,10 +248,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW
@keyword_only
@since("1.4.0")
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
- seed=None, parallelism=1):
+ seed=None, parallelism=1, collectSubModels=False):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
- seed=None, parallelism=1):
+ seed=None, parallelism=1, collectSubModels=False):
Sets params for cross validator.
"""
kwargs = self._input_kwargs
@@ -282,6 +284,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW
metrics = [0.0] * numModels
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
+ subModels = None
+ collectSubModelsParam = self.getCollectSubModels()
+ if collectSubModelsParam:
+ subModels = [[None for j in range(numModels)] for i in range(nFolds)]
for i in range(nFolds):
validateLB = i * h
@@ -290,9 +296,12 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()
- tasks = _parallelFitTasks(est, train, eva, validation, epm)
- for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+ tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
+ for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] += (metric / nFolds)
+ if collectSubModelsParam:
+ subModels[i][j] = subModel
+
validation.unpersist()
train.unpersist()
@@ -301,7 +310,7 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return self._copyValues(CrossValidatorModel(bestModel, metrics))
+ return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels))
@since("1.4.0")
def copy(self, extra=None):
@@ -345,9 +354,11 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW
numFolds = java_stage.getNumFolds()
seed = java_stage.getSeed()
parallelism = java_stage.getParallelism()
+ collectSubModels = java_stage.getCollectSubModels()
# Create a new instance of this stage.
py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
- numFolds=numFolds, seed=seed, parallelism=parallelism)
+ numFolds=numFolds, seed=seed, parallelism=parallelism,
+ collectSubModels=collectSubModels)
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -367,6 +378,7 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW
_java_obj.setSeed(self.getSeed())
_java_obj.setNumFolds(self.getNumFolds())
_java_obj.setParallelism(self.getParallelism())
+ _java_obj.setCollectSubModels(self.getCollectSubModels())
return _java_obj
@@ -381,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
.. versionadded:: 1.4.0
"""
- def __init__(self, bestModel, avgMetrics=[]):
+ def __init__(self, bestModel, avgMetrics=[], subModels=None):
super(CrossValidatorModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
#: Average cross-validation metrics for each paramMap in
#: CrossValidator.estimatorParamMaps, in the corresponding order.
self.avgMetrics = avgMetrics
+ #: sub model list from cross validation
+ self.subModels = subModels
def _transform(self, dataset):
return self.bestModel.transform(dataset)
@@ -399,6 +413,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
and some extra params. This copies the underlying bestModel,
creates a deep copy of the embedded paramMap, and
copies the embedded and extra parameters over.
+ It does not copy the extra Params into the subModels.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
@@ -407,7 +422,8 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
extra = dict()
bestModel = self.bestModel.copy(extra)
avgMetrics = self.avgMetrics
- return CrossValidatorModel(bestModel, avgMetrics)
+ subModels = self.subModels
+ return CrossValidatorModel(bestModel, avgMetrics, subModels)
@since("2.3.0")
def write(self):
@@ -426,13 +442,17 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
Given a Java CrossValidatorModel, create and return a Python wrapper of it.
Used for ML persistence.
"""
-
bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ if java_stage.hasSubModels():
+ py_stage.subModels = [[JavaParams._from_java(sub_model)
+ for sub_model in fold_sub_models]
+ for fold_sub_models in java_stage.subModels()]
+
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -454,10 +474,16 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
_java_obj.set("evaluator", evaluator)
_java_obj.set("estimator", estimator)
_java_obj.set("estimatorParamMaps", epms)
+
+ if self.subModels is not None:
+ java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models]
+ for fold_sub_models in self.subModels]
+ _java_obj.setSubModels(java_sub_models)
return _java_obj
-class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable):
+class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels,
+ MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -492,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
- parallelism=1, seed=None):
+ parallelism=1, collectSubModels=False, seed=None):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
- parallelism=1, seed=None)
+ parallelism=1, collectSubModels=False, seed=None)
"""
super(TrainValidationSplit, self).__init__()
self._setDefault(trainRatio=0.75, parallelism=1)
@@ -505,10 +531,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl
@since("2.0.0")
@keyword_only
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
- parallelism=1, seed=None):
+ parallelism=1, collectSubModels=False, seed=None):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
- parallelism=1, seed=None):
+ parallelism=1, collectSubModels=False, seed=None):
Sets params for the train validation split.
"""
kwargs = self._input_kwargs
@@ -541,11 +567,19 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()
- tasks = _parallelFitTasks(est, train, eva, validation, epm)
+ subModels = None
+ collectSubModelsParam = self.getCollectSubModels()
+ if collectSubModelsParam:
+ subModels = [None for i in range(numModels)]
+
+ tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
metrics = [None] * numModels
- for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+ for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] = metric
+ if collectSubModelsParam:
+ subModels[j] = subModel
+
train.unpersist()
validation.unpersist()
@@ -554,7 +588,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return self._copyValues(TrainValidationSplitModel(bestModel, metrics))
+ return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels))
@since("2.0.0")
def copy(self, extra=None):
@@ -598,9 +632,11 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl
trainRatio = java_stage.getTrainRatio()
seed = java_stage.getSeed()
parallelism = java_stage.getParallelism()
+ collectSubModels = java_stage.getCollectSubModels()
# Create a new instance of this stage.
py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
- trainRatio=trainRatio, seed=seed, parallelism=parallelism)
+ trainRatio=trainRatio, seed=seed, parallelism=parallelism,
+ collectSubModels=collectSubModels)
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -620,7 +656,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl
_java_obj.setTrainRatio(self.getTrainRatio())
_java_obj.setSeed(self.getSeed())
_java_obj.setParallelism(self.getParallelism())
-
+ _java_obj.setCollectSubModels(self.getCollectSubModels())
return _java_obj
@@ -633,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
.. versionadded:: 2.0.0
"""
- def __init__(self, bestModel, validationMetrics=[]):
+ def __init__(self, bestModel, validationMetrics=[], subModels=None):
super(TrainValidationSplitModel, self).__init__()
- #: best model from cross validation
+ #: best model from train validation split
self.bestModel = bestModel
#: evaluated validation metrics
self.validationMetrics = validationMetrics
+ #: sub models from train validation split
+ self.subModels = subModels
def _transform(self, dataset):
return self.bestModel.transform(dataset)
@@ -651,6 +689,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
creates a deep copy of the embedded paramMap, and
copies the embedded and extra parameters over.
And, this creates a shallow copy of the validationMetrics.
+ It does not copy the extra Params into the subModels.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
@@ -659,7 +698,8 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
extra = dict()
bestModel = self.bestModel.copy(extra)
validationMetrics = list(self.validationMetrics)
- return TrainValidationSplitModel(bestModel, validationMetrics)
+ subModels = self.subModels
+ return TrainValidationSplitModel(bestModel, validationMetrics, subModels)
@since("2.3.0")
def write(self):
@@ -687,6 +727,10 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ if java_stage.hasSubModels():
+ py_stage.subModels = [JavaParams._from_java(sub_model)
+ for sub_model in java_stage.subModels()]
+
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -708,6 +752,11 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
_java_obj.set("evaluator", evaluator)
_java_obj.set("estimator", estimator)
_java_obj.set("estimatorParamMaps", epms)
+
+ if self.subModels is not None:
+ java_sub_models = [sub_model._to_java() for sub_model in self.subModels]
+ _java_obj.setSubModels(java_sub_models)
+
return _java_obj
http://git-wip-us.apache.org/repos/asf/spark/blob/04614820/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index c3c47bd..a486c6a 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -169,6 +169,10 @@ class JavaMLWriter(MLWriter):
self._jwrite.overwrite()
return self
+ def option(self, key, value):
+ self._jwrite.option(key, value)
+ return self
+
def context(self, sqlContext):
"""
Sets the SQL context to use for saving.
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org