You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/02/12 00:05:37 UTC

spark git commit: [SPARK-11515][ML] QuantileDiscretizer should take random seed

Repository: spark
Updated Branches:
  refs/heads/master efb65e09b -> 574571c87


[SPARK-11515][ML] QuantileDiscretizer should take random seed

cc jkbradley

Author: Yu ISHIKAWA <yu...@gmail.com>

Closes #9535 from yu-iskw/SPARK-11515.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/574571c8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/574571c8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/574571c8

Branch: refs/heads/master
Commit: 574571c87098795a2206a113ee9ed4bafba8f00f
Parents: efb65e0
Author: Yu ISHIKAWA <yu...@gmail.com>
Authored: Thu Feb 11 15:05:34 2016 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Feb 11 15:05:34 2016 -0800

----------------------------------------------------------------------
 .../spark/ml/feature/QuantileDiscretizer.scala       | 15 ++++++++++-----
 .../spark/ml/feature/QuantileDiscretizerSuite.scala  |  2 +-
 2 files changed, 11 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/574571c8/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 8fd0ce2..2a294d3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml._
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.param.{IntParam, _}
-import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.types.{DoubleType, StructType}
@@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom
 /**
  * Params for [[QuantileDiscretizer]].
  */
-private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol {
+private[feature] trait QuantileDiscretizerBase extends Params
+  with HasInputCol with HasOutputCol with HasSeed {
 
   /**
    * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
@@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String)
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
+  /** @group setParam */
+  def setSeed(value: Long): this.type = set(seed, value)
+
   override def transformSchema(schema: StructType): StructType = {
     validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
@@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String)
   }
 
   override def fit(dataset: DataFrame): Bucketizer = {
-    val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets))
+    val samples = QuantileDiscretizer
+      .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
       .map { case Row(feature: Double) => feature }
     val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
     val splits = QuantileDiscretizer.getSplits(candidates)
@@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
   /**
    * Sampling from the given dataset to collect quantile statistics.
    */
-  private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = {
+  private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
     val totalSamples = dataset.count()
     require(totalSamples > 0,
       "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
     val requiredSamples = math.max(numBins * numBins, 10000)
     val fraction = math.min(requiredSamples / dataset.count(), 1.0)
-    dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
+    dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/574571c8/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 722f1ab..4fde429 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {
 
     val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
     val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
-      .setNumBuckets(numBucket)
+      .setNumBuckets(numBucket).setSeed(1)
     val result = discretizer.fit(df).transform(df)
 
     val transformedFeatures = result.select("result").collect()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org