You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by MrBago <gi...@git.apache.org> on 2017/12/08 02:35:12 UTC

[GitHub] spark pull request #19904: [SPARK-22707][ML] Optimize CrossValidator memory ...

Github user MrBago commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19904#discussion_r155693144
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
    @@ -146,31 +146,34 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
           val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
           logDebug(s"Train split $splitIndex with multiple sets of parameters.")
     
    +      var completeFitCount = 0
    +      val signal = new Object
           // Fit models in a Future for training in parallel
    -      val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
    -        Future[Model[_]] {
    +      val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
    +        Future[Double] {
               val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
    +          signal.synchronized {
    +            completeFitCount += 1
    +            signal.notify()
    +          }
     
               if (collectSubModelsParam) {
                 subModels.get(splitIndex)(paramIndex) = model
               }
    -          model
    -        } (executionContext)
    -      }
    -
    -      // Unpersist training data only when all models have trained
    -      Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
    -        .onComplete { _ => trainingDataset.unpersist() } (executionContext)
    -
    -      // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
    -      val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
    -        modelFuture.map { model =>
               // TODO: duplicate evaluator to take extra params from input
               val metric = eval.evaluate(model.transform(validationDataset, paramMap))
               logDebug(s"Got metric $metric for model trained with $paramMap.")
               metric
             } (executionContext)
           }
    +      Future {
    +        signal.synchronized {
    +          while (completeFitCount < epm.length) {
    --- End diff --
    
    Sorry I'm not too familiar with Futures in Scala. Is it save to create a blocking future like this, do you risk starving the thread pool? Can we just just an if statement in the `synchronized` block above? something like:
    
    ```
    completeFitCount += 1
    if (completeFitCount == epm.length) {
        trainingDataset.unpersist()
    }
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org