You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/01/05 06:45:19 UTC
spark git commit: [SPARK-22949][ML] Apply CrossValidator approach to
Driver/Distributed memory tradeoff for TrainValidationSplit
Repository: spark
Updated Branches:
refs/heads/master 52fc5c17d -> cf0aa6557
[SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit
## What changes were proposed in this pull request?
Avoid holding all models in memory for `TrainValidationSplit`.
## How was this patch tested?
Existing tests.
Author: Bago Amirbekian <ba...@databricks.com>
Closes #20143 from MrBago/trainValidMemoryFix.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cf0aa655
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cf0aa655
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cf0aa655
Branch: refs/heads/master
Commit: cf0aa65576acbe0209c67f04c029058fd73555c1
Parents: 52fc5c1
Author: Bago Amirbekian <ba...@databricks.com>
Authored: Thu Jan 4 22:45:15 2018 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Jan 4 22:45:15 2018 -0800
----------------------------------------------------------------------
.../apache/spark/ml/tuning/CrossValidator.scala | 4 +++-
.../spark/ml/tuning/TrainValidationSplit.scala | 18 ++++--------------
2 files changed, 7 insertions(+), 15 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/cf0aa655/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 095b54c..a0b507d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
} (executionContext)
}
- // Wait for metrics to be calculated before unpersisting validation dataset
+ // Wait for metrics to be calculated
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
+
+ // Unpersist training & validation set once all metrics have been produced
trainingDataset.unpersist()
validationDataset.unpersist()
foldMetrics
http://git-wip-us.apache.org/repos/asf/spark/blob/cf0aa655/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index c73bd18..8826ef3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
// Fit models in a Future for training in parallel
logDebug(s"Train split with multiple sets of parameters.")
- val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
- Future[Model[_]] {
+ val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
+ Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
if (collectSubModelsParam) {
subModels.get(paramIndex) = model
}
- model
- } (executionContext)
- }
-
- // Unpersist training data only when all models have trained
- Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
- .onComplete { _ => trainingDataset.unpersist() } (executionContext)
-
- // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
- val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
- modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
@@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
// Wait for all metrics to be calculated
val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
- // Unpersist validation set once all metrics have been produced
+ // Unpersist training & validation set once all metrics have been produced
+ trainingDataset.unpersist()
validationDataset.unpersist()
logInfo(s"Train validation split metrics: ${metrics.toSeq}")
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org