You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2017/09/06 12:12:38 UTC

spark git commit: [SPARK-19357][ML] Adding parallel model evaluation in ML tuning

Repository: spark
Updated Branches:
  refs/heads/master 4ee7dfe41 -> 16c4c03c7


[SPARK-19357][ML] Adding parallel model evaluation in ML tuning

## What changes were proposed in this pull request?
Modified `CrossValidator` and `TrainValidationSplit` to be able to evaluate models in parallel for a given parameter grid.  The level of parallelism is controlled by a parameter `numParallelEval` used to schedule a number of models to be trained/evaluated so that the jobs can be run concurrently.  This is a naive approach that does not check the cluster for needed resources, so care must be taken by the user to tune the parameter appropriately.  The default value is `1` which will train/evaluate in serial.

## How was this patch tested?
Added unit tests for CrossValidator and TrainValidationSplit to verify that model selection is the same when run in serial vs parallel.  Manual testing to verify tasks run in parallel when param is > 1. Added parameter usage to relevant examples.

Author: Bryan Cutler <cu...@gmail.com>

Closes #16774 from BryanCutler/parallel-model-eval-SPARK-19357.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/16c4c03c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/16c4c03c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/16c4c03c

Branch: refs/heads/master
Commit: 16c4c03c71394ab30c8edaf4418973e1a2c5ebfe
Parents: 4ee7dfe
Author: Bryan Cutler <cu...@gmail.com>
Authored: Wed Sep 6 14:12:27 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Wed Sep 6 14:12:27 2017 +0200

----------------------------------------------------------------------
 docs/ml-tuning.md                               |  2 +
 ...ModelSelectionViaCrossValidationExample.java |  4 +-
 ...SelectionViaTrainValidationSplitExample.java |  3 +-
 ...odelSelectionViaCrossValidationExample.scala |  1 +
 ...electionViaTrainValidationSplitExample.scala |  2 +
 .../spark/ml/param/shared/HasParallelism.scala  | 59 ++++++++++++++++
 .../apache/spark/ml/tuning/CrossValidator.scala | 71 ++++++++++++++------
 .../spark/ml/tuning/TrainValidationSplit.scala  | 57 ++++++++++++----
 .../spark/ml/tuning/CrossValidatorSuite.scala   | 27 ++++++++
 .../ml/tuning/TrainValidationSplitSuite.scala   | 35 +++++++++-
 10 files changed, 221 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/docs/ml-tuning.md
----------------------------------------------------------------------
diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md
index e9123db..64dc46c 100644
--- a/docs/ml-tuning.md
+++ b/docs/ml-tuning.md
@@ -55,6 +55,8 @@ for multiclass problems. The default metric used to choose the best `ParamMap` c
 method in each of these evaluators.
 
 To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility.
+By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit` (NOTE: this is not yet supported in Python).
+The value of `parallelism` should be chosen carefully to maximize parallelism without exceeding cluster resources, and larger values may not always lead to improved performance.  Generally speaking, a value up to 10 should be sufficient for most clusters.
 
 # Cross-Validation
 

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
index 975c65e..d973279 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
@@ -94,7 +94,9 @@ public class JavaModelSelectionViaCrossValidationExample {
     CrossValidator cv = new CrossValidator()
       .setEstimator(pipeline)
       .setEvaluator(new BinaryClassificationEvaluator())
-      .setEstimatorParamMaps(paramGrid).setNumFolds(2);  // Use 3+ in practice
+      .setEstimatorParamMaps(paramGrid)
+      .setNumFolds(2)  // Use 3+ in practice
+      .setParallelism(2);  // Evaluate up to 2 parameter settings in parallel
 
     // Run cross-validation, and choose the best set of parameters.
     CrossValidatorModel cvModel = cv.fit(training);

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
index 9a4722b..2ef8bea 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
@@ -70,7 +70,8 @@ public class JavaModelSelectionViaTrainValidationSplitExample {
       .setEstimator(lr)
       .setEvaluator(new RegressionEvaluator())
       .setEstimatorParamMaps(paramGrid)
-      .setTrainRatio(0.8);  // 80% for training and the remaining 20% for validation
+      .setTrainRatio(0.8)  // 80% for training and the remaining 20% for validation
+      .setParallelism(2);  // Evaluate up to 2 parameter settings in parallel
 
     // Run train validation split, and choose the best set of parameters.
     TrainValidationSplitModel model = trainValidationSplit.fit(training);

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
index c1ff9ef..87d96dd 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
@@ -93,6 +93,7 @@ object ModelSelectionViaCrossValidationExample {
       .setEvaluator(new BinaryClassificationEvaluator)
       .setEstimatorParamMaps(paramGrid)
       .setNumFolds(2)  // Use 3+ in practice
+      .setParallelism(2)  // Evaluate up to 2 parameter settings in parallel
 
     // Run cross-validation, and choose the best set of parameters.
     val cvModel = cv.fit(training)

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
index 1cd2641..71e41e7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
@@ -65,6 +65,8 @@ object ModelSelectionViaTrainValidationSplitExample {
       .setEstimatorParamMaps(paramGrid)
       // 80% of the data will be used for training and the remaining 20% for validation.
       .setTrainRatio(0.8)
+      // Evaluate up to 2 parameter settings in parallel
+      .setParallelism(2)
 
     // Run train validation split, and choose the best set of parameters.
     val model = trainValidationSplit.fit(training)

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/mllib/src/main/scala/org/apache/spark/ml/param/shared/HasParallelism.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/HasParallelism.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/HasParallelism.scala
new file mode 100644
index 0000000..021d0b3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/HasParallelism.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param.shared
+
+import scala.concurrent.ExecutionContext
+
+import org.apache.spark.ml.param.{IntParam, Params, ParamValidators}
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Trait to define a level of parallelism for algorithms that are able to use
+ * multithreaded execution, and provide a thread-pool based execution context.
+ */
+private[ml] trait HasParallelism extends Params {
+
+  /**
+   * The number of threads to use when running parallel algorithms.
+   * Default is 1 for serial execution
+   *
+   * @group expertParam
+   */
+  val parallelism = new IntParam(this, "parallelism",
+    "the number of threads to use when running parallel algorithms", ParamValidators.gtEq(1))
+
+  setDefault(parallelism -> 1)
+
+  /** @group expertGetParam */
+  def getParallelism: Int = $(parallelism)
+
+  /**
+   * Create a new execution context with a thread-pool that has a maximum number of threads
+   * set to the value of [[parallelism]]. If this param is set to 1, a same-thread executor
+   * will be used to run in serial.
+   */
+  private[ml] def getExecutionContext: ExecutionContext = {
+    getParallelism match {
+      case 1 =>
+        ThreadUtils.sameThread
+      case n =>
+        ExecutionContext.fromExecutorService(ThreadUtils
+          .newDaemonCachedThreadPool(s"${this.getClass.getSimpleName}-thread-pool", n))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 2012d6c..ce2a3a2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -20,20 +20,23 @@ package org.apache.spark.ml.tuning
 import java.util.{List => JList}
 
 import scala.collection.JavaConverters._
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
 
-import com.github.fommil.netlib.F2jBLAS
 import org.apache.hadoop.fs.Path
 import org.json4s.DefaultFormats
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml._
+import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.HasParallelism
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
 
 /**
  * Params for [[CrossValidator]] and [[CrossValidatorModel]].
@@ -64,13 +67,11 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
 @Since("1.2.0")
 class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
   extends Estimator[CrossValidatorModel]
-  with CrossValidatorParams with MLWritable with Logging {
+  with CrossValidatorParams with HasParallelism with MLWritable with Logging {
 
   @Since("1.2.0")
   def this() = this(Identifiable.randomUID("cv"))
 
-  private val f2jBLAS = new F2jBLAS
-
   /** @group setParam */
   @Since("1.2.0")
   def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
@@ -91,6 +92,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
   @Since("2.0.0")
   def setSeed(value: Long): this.type = set(seed, value)
 
+  /**
+   * Set the mamixum level of parallelism to evaluate models in parallel.
+   * Default is 1 for serial evaluation
+   *
+   * @group expertSetParam
+   */
+  @Since("2.3.0")
+  def setParallelism(value: Int): this.type = set(parallelism, value)
+
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): CrossValidatorModel = {
     val schema = dataset.schema
@@ -99,32 +109,49 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
     val est = $(estimator)
     val eval = $(evaluator)
     val epm = $(estimatorParamMaps)
-    val numModels = epm.length
-    val metrics = new Array[Double](epm.length)
+
+    // Create execution context based on $(parallelism)
+    val executionContext = getExecutionContext
 
     val instr = Instrumentation.create(this, dataset)
-    instr.logParams(numFolds, seed)
+    instr.logParams(numFolds, seed, parallelism)
     logTuningParams(instr)
 
+    // Compute metrics for each model over each split
     val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
-    splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
+    val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
       val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
       val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
-      // multi-model training
       logDebug(s"Train split $splitIndex with multiple sets of parameters.")
-      val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
-      trainingDataset.unpersist()
-      var i = 0
-      while (i < numModels) {
-        // TODO: duplicate evaluator to take extra params from input
-        val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
-        logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
-        metrics(i) += metric
-        i += 1
+
+      // Fit models in a Future for training in parallel
+      val modelFutures = epm.map { paramMap =>
+        Future[Model[_]] {
+          val model = est.fit(trainingDataset, paramMap)
+          model.asInstanceOf[Model[_]]
+        } (executionContext)
+      }
+
+      // Unpersist training data only when all models have trained
+      Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
+        .onComplete { _ => trainingDataset.unpersist() } (executionContext)
+
+      // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
+      val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
+        modelFuture.map { model =>
+          // TODO: duplicate evaluator to take extra params from input
+          val metric = eval.evaluate(model.transform(validationDataset, paramMap))
+          logDebug(s"Got metric $metric for model trained with $paramMap.")
+          metric
+        } (executionContext)
       }
+
+      // Wait for metrics to be calculated before unpersisting validation dataset
+      val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
       validationDataset.unpersist()
-    }
-    f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
+      foldMetrics
+    }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits
+
     logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
     val (bestMetric, bestIndex) =
       if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index db7c9d1..16db0f5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -20,6 +20,8 @@ package org.apache.spark.ml.tuning
 import java.util.{List => JList}
 
 import scala.collection.JavaConverters._
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
 import scala.language.existentials
 
 import org.apache.hadoop.fs.Path
@@ -30,9 +32,11 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.HasParallelism
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
 
 /**
  * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
@@ -62,7 +66,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
 @Since("1.5.0")
 class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
   extends Estimator[TrainValidationSplitModel]
-  with TrainValidationSplitParams with MLWritable with Logging {
+  with TrainValidationSplitParams with HasParallelism with MLWritable with Logging {
 
   @Since("1.5.0")
   def this() = this(Identifiable.randomUID("tvs"))
@@ -87,6 +91,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
   @Since("2.0.0")
   def setSeed(value: Long): this.type = set(seed, value)
 
+  /**
+   * Set the mamixum level of parallelism to evaluate models in parallel.
+   * Default is 1 for serial evaluation
+   *
+   * @group expertSetParam
+   */
+  @Since("2.3.0")
+  def setParallelism(value: Int): this.type = set(parallelism, value)
+
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
     val schema = dataset.schema
@@ -94,11 +107,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     val est = $(estimator)
     val eval = $(evaluator)
     val epm = $(estimatorParamMaps)
-    val numModels = epm.length
-    val metrics = new Array[Double](epm.length)
+
+    // Create execution context based on $(parallelism)
+    val executionContext = getExecutionContext
 
     val instr = Instrumentation.create(this, dataset)
-    instr.logParams(trainRatio, seed)
+    instr.logParams(trainRatio, seed, parallelism)
     logTuningParams(instr)
 
     val Array(trainingDataset, validationDataset) =
@@ -106,18 +120,33 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     trainingDataset.cache()
     validationDataset.cache()
 
-    // multi-model training
+    // Fit models in a Future for training in parallel
     logDebug(s"Train split with multiple sets of parameters.")
-    val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
-    trainingDataset.unpersist()
-    var i = 0
-    while (i < numModels) {
-      // TODO: duplicate evaluator to take extra params from input
-      val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
-      logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
-      metrics(i) += metric
-      i += 1
+    val modelFutures = epm.map { paramMap =>
+      Future[Model[_]] {
+        val model = est.fit(trainingDataset, paramMap)
+        model.asInstanceOf[Model[_]]
+      } (executionContext)
     }
+
+    // Unpersist training data only when all models have trained
+    Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
+      .onComplete { _ => trainingDataset.unpersist() } (executionContext)
+
+    // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
+    val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
+      modelFuture.map { model =>
+        // TODO: duplicate evaluator to take extra params from input
+        val metric = eval.evaluate(model.transform(validationDataset, paramMap))
+        logDebug(s"Got metric $metric for model trained with $paramMap.")
+        metric
+      } (executionContext)
+    }
+
+    // Wait for all metrics to be calculated
+    val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
+
+    // Unpersist validation set once all metrics have been produced
     validationDataset.unpersist()
 
     logInfo(s"Train validation split metrics: ${metrics.toSeq}")

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 90778d7..a8d4377 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -120,6 +120,33 @@ class CrossValidatorSuite
     }
   }
 
+  test("cross validation with parallel evaluation") {
+    val lr = new LogisticRegression
+    val lrParamMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.001, 1000.0))
+      .addGrid(lr.maxIter, Array(0, 3))
+      .build()
+    val eval = new BinaryClassificationEvaluator
+    val cv = new CrossValidator()
+      .setEstimator(lr)
+      .setEstimatorParamMaps(lrParamMaps)
+      .setEvaluator(eval)
+      .setNumFolds(2)
+      .setParallelism(1)
+    val cvSerialModel = cv.fit(dataset)
+    cv.setParallelism(2)
+    val cvParallelModel = cv.fit(dataset)
+
+    val serialMetrics = cvSerialModel.avgMetrics.sorted
+    val parallelMetrics = cvParallelModel.avgMetrics.sorted
+    assert(serialMetrics === parallelMetrics)
+
+    val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression]
+    val parentParallel = cvParallelModel.bestModel.parent.asInstanceOf[LogisticRegression]
+    assert(parentSerial.getRegParam === parentParallel.getRegParam)
+    assert(parentSerial.getMaxIter === parentParallel.getMaxIter)
+  }
+
   test("read/write: CrossValidator with simple estimator") {
     val lr = new LogisticRegression().setMaxIter(3)
     val evaluator = new BinaryClassificationEvaluator()

http://git-wip-us.apache.org/repos/asf/spark/blob/16c4c03c/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index aa8b4cf..7480173 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -36,9 +36,14 @@ class TrainValidationSplitSuite
 
   import testImplicits._
 
-  test("train validation with logistic regression") {
-    val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
+  @transient var dataset: Dataset[_] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
+  }
 
+  test("train validation with logistic regression") {
     val lr = new LogisticRegression
     val lrParamMaps = new ParamGridBuilder()
       .addGrid(lr.regParam, Array(0.001, 1000.0))
@@ -117,6 +122,32 @@ class TrainValidationSplitSuite
     }
   }
 
+  test("train validation with parallel evaluation") {
+    val lr = new LogisticRegression
+    val lrParamMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.001, 1000.0))
+      .addGrid(lr.maxIter, Array(0, 3))
+      .build()
+    val eval = new BinaryClassificationEvaluator
+    val cv = new TrainValidationSplit()
+      .setEstimator(lr)
+      .setEstimatorParamMaps(lrParamMaps)
+      .setEvaluator(eval)
+      .setParallelism(1)
+    val cvSerialModel = cv.fit(dataset)
+    cv.setParallelism(2)
+    val cvParallelModel = cv.fit(dataset)
+
+    val serialMetrics = cvSerialModel.validationMetrics.sorted
+    val parallelMetrics = cvParallelModel.validationMetrics.sorted
+    assert(serialMetrics === parallelMetrics)
+
+    val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression]
+    val parentParallel = cvParallelModel.bestModel.parent.asInstanceOf[LogisticRegression]
+    assert(parentSerial.getRegParam === parentParallel.getRegParam)
+    assert(parentSerial.getMaxIter === parentParallel.getMaxIter)
+  }
+
   test("read/write: TrainValidationSplit") {
     val lr = new LogisticRegression().setMaxIter(3)
     val evaluator = new BinaryClassificationEvaluator()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org