You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by MrBago <gi...@git.apache.org> on 2017/12/08 02:35:12 UTC
[GitHub] spark pull request #19904: [SPARK-22707][ML] Optimize CrossValidator memory ...
Github user MrBago commented on a diff in the pull request:
https://github.com/apache/spark/pull/19904#discussion_r155693144
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
@@ -146,31 +146,34 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
+ var completeFitCount = 0
+ val signal = new Object
// Fit models in a Future for training in parallel
- val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
- Future[Model[_]] {
+ val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
+ Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
+ signal.synchronized {
+ completeFitCount += 1
+ signal.notify()
+ }
if (collectSubModelsParam) {
subModels.get(splitIndex)(paramIndex) = model
}
- model
- } (executionContext)
- }
-
- // Unpersist training data only when all models have trained
- Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
- .onComplete { _ => trainingDataset.unpersist() } (executionContext)
-
- // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
- val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
- modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
}
+ Future {
+ signal.synchronized {
+ while (completeFitCount < epm.length) {
--- End diff --
Sorry I'm not too familiar with Futures in Scala. Is it save to create a blocking future like this, do you risk starving the thread pool? Can we just just an if statement in the `synchronized` block above? something like:
```
completeFitCount += 1
if (completeFitCount == epm.length) {
trainingDataset.unpersist()
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org