You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/12/16 20:05:41 UTC

spark git commit: [SPARK-9694][ML] Add random seed Param to Scala CrossValidator

Repository: spark
Updated Branches:
  refs/heads/master 7b6dc29d0 -> 860dc7f2f


[SPARK-9694][ML] Add random seed Param to Scala CrossValidator

Add random seed Param to Scala CrossValidator

Author: Yanbo Liang <yb...@gmail.com>

Closes #9108 from yanboliang/spark-9694.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/860dc7f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/860dc7f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/860dc7f2

Branch: refs/heads/master
Commit: 860dc7f2f8dd01f2562ba83b7af27ba29d91cb62
Parents: 7b6dc29
Author: Yanbo Liang <yb...@gmail.com>
Authored: Wed Dec 16 11:05:37 2015 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Dec 16 11:05:37 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/tuning/CrossValidator.scala      | 11 ++++++++---
 .../main/scala/org/apache/spark/mllib/util/MLUtils.scala |  8 ++++++++
 2 files changed, 16 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/860dc7f2/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5c09f1a..40f8857 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -29,8 +29,9 @@ import org.apache.spark.ml.classification.OneVsRestParams
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.feature.RFormulaModel
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.ml.param.shared.HasSeed
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
@@ -39,7 +40,7 @@ import org.apache.spark.sql.types.StructType
 /**
  * Params for [[CrossValidator]] and [[CrossValidatorModel]].
  */
-private[ml] trait CrossValidatorParams extends ValidatorParams {
+private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
   /**
    * Param for number of folds for cross validation.  Must be >= 2.
    * Default: 3
@@ -85,6 +86,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
   @Since("1.2.0")
   def setNumFolds(value: Int): this.type = set(numFolds, value)
 
+  /** @group setParam */
+  @Since("2.0.0")
+  def setSeed(value: Long): this.type = set(seed, value)
+
   @Since("1.4.0")
   override def fit(dataset: DataFrame): CrossValidatorModel = {
     val schema = dataset.schema
@@ -95,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
     val epm = $(estimatorParamMaps)
     val numModels = epm.length
     val metrics = new Array[Double](epm.length)
-    val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0)
+    val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
     splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
       val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
       val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()

http://git-wip-us.apache.org/repos/asf/spark/blob/860dc7f2/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 414ea99..4c9151f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -265,6 +265,14 @@ object MLUtils {
    */
   @Since("1.0.0")
   def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
+    kFold(rdd, numFolds, seed.toLong)
+  }
+
+  /**
+   * Version of [[kFold()]] taking a Long seed.
+   */
+  @Since("2.0.0")
+  def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = {
     val numFoldsF = numFolds.toFloat
     (1 to numFolds).map { fold =>
       val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org