You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "An De Rijdt (Jira)" <ji...@apache.org> on 2020/06/24 15:26:00 UTC
[jira] [Created] (SPARK-32092) CrossvalidatorModel does not save
all submodels (it saves only 3)
An De Rijdt created SPARK-32092:
-----------------------------------
Summary: CrossvalidatorModel does not save all submodels (it saves only 3)
Key: SPARK-32092
URL: https://issues.apache.org/jira/browse/SPARK-32092
Project: Spark
Issue Type: Bug
Components: ML, PySpark
Affects Versions: 2.4.5, 2.4.0
Environment: Ran on two systems:
* Local pyspark installation (Windows): spark 2.4.5
* Spark 2.4.0 on a cluster
Reporter: An De Rijdt
When saving a CrossValidatorModel with more than 3 subModels and loading again, a different amount of subModels is returned. It seems every time 3 subModels are returned.
With less than two submodels (so 2 folds) writing plainly fails.
Issue seems to be (but I am not so familiar with the scala/java side)
* python object is converted to scala/java
* in scala we save subModels until numFolds:
{code:java}
val subModelsPath = new Path(path, "subModels")
for (splitIndex <- 0 until instance.getNumFolds) {
val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}")
for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
val modelPath = new Path(splitPath, paramIndex.toString).toString
instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
}
{code}
* numFolds is not available on the CrossValidatorModel in pyspark
* default numFolds is 3 so somehow it tries to save 3 subModels.
The first issue can be reproduced by following failing tests, where spark is a SparkSession and tmp_path is a (temporary) directory.
{code:java}
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.linalg import Vectors
def test_save_load_cross_validator(spark, tmp_path):
temp_path = str(tmp_path)
dataset = spark.createDataFrame(
[
(Vectors.dense([0.0]), 0.0),
(Vectors.dense([0.4]), 1.0),
(Vectors.dense([0.5]), 0.0),
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0),
]
* 10,
["features", "label"],
)
lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(
estimator=lr,
estimatorParamMaps=grid,
evaluator=evaluator,
collectSubModels=True,
numFolds=4,
)
cvModel = cv.fit(dataset)
# test save/load of CrossValidatorModel
cvModelPath = temp_path + "/cvModel"
cvModel.write().overwrite().save(cvModelPath)
loadedModel = CrossValidatorModel.load(cvModelPath)
assert len(loadedModel.subModels) == len(cvModel.subModels)
{code}
The second as follows (will fail writing):
{code:java}
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.linalg import Vectors
def test_save_load_cross_validator(spark, tmp_path):
temp_path = str(tmp_path)
dataset = spark.createDataFrame(
[
(Vectors.dense([0.0]), 0.0),
(Vectors.dense([0.4]), 1.0),
(Vectors.dense([0.5]), 0.0),
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0),
]
* 10,
["features", "label"],
)
lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(
estimator=lr,
estimatorParamMaps=grid,
evaluator=evaluator,
collectSubModels=True,
numFolds=2,
)
cvModel = cv.fit(dataset)
# test save/load of CrossValidatorModel
cvModelPath = temp_path + "/cvModel"
cvModel.write().overwrite().save(cvModelPath)
loadedModel = CrossValidatorModel.load(cvModelPath)
assert len(loadedModel.subModels) == len(cvModel.subModels)
{code}
--
This message was sent by Atlassian Jira
(v8.3.4#803005)
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org