You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/01/05 06:45:19 UTC

spark git commit: [SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit

Repository: spark
Updated Branches:
  refs/heads/master 52fc5c17d -> cf0aa6557


[SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit

## What changes were proposed in this pull request?

Avoid holding all models in memory for `TrainValidationSplit`.

## How was this patch tested?

Existing tests.

Author: Bago Amirbekian <ba...@databricks.com>

Closes #20143 from MrBago/trainValidMemoryFix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cf0aa655
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cf0aa655
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cf0aa655

Branch: refs/heads/master
Commit: cf0aa65576acbe0209c67f04c029058fd73555c1
Parents: 52fc5c1
Author: Bago Amirbekian <ba...@databricks.com>
Authored: Thu Jan 4 22:45:15 2018 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Jan 4 22:45:15 2018 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/tuning/CrossValidator.scala   |  4 +++-
 .../spark/ml/tuning/TrainValidationSplit.scala    | 18 ++++--------------
 2 files changed, 7 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cf0aa655/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 095b54c..a0b507d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
         } (executionContext)
       }
 
-      // Wait for metrics to be calculated before unpersisting validation dataset
+      // Wait for metrics to be calculated
       val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
+
+      // Unpersist training & validation set once all metrics have been produced
       trainingDataset.unpersist()
       validationDataset.unpersist()
       foldMetrics

http://git-wip-us.apache.org/repos/asf/spark/blob/cf0aa655/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index c73bd18..8826ef3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
 
     // Fit models in a Future for training in parallel
     logDebug(s"Train split with multiple sets of parameters.")
-    val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
-      Future[Model[_]] {
+    val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
+      Future[Double] {
         val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
 
         if (collectSubModelsParam) {
           subModels.get(paramIndex) = model
         }
-        model
-      } (executionContext)
-    }
-
-    // Unpersist training data only when all models have trained
-    Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
-      .onComplete { _ => trainingDataset.unpersist() } (executionContext)
-
-    // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
-    val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
-      modelFuture.map { model =>
         // TODO: duplicate evaluator to take extra params from input
         val metric = eval.evaluate(model.transform(validationDataset, paramMap))
         logDebug(s"Got metric $metric for model trained with $paramMap.")
@@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     // Wait for all metrics to be calculated
     val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
 
-    // Unpersist validation set once all metrics have been produced
+    // Unpersist training & validation set once all metrics have been produced
+    trainingDataset.unpersist()
     validationDataset.unpersist()
 
     logInfo(s"Train validation split metrics: ${metrics.toSeq}")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org