You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/11/02 00:00:07 UTC

spark git commit: [SPARK-18088][ML] Various ChiSqSelector cleanups

Repository: spark
Updated Branches:
  refs/heads/master b929537b6 -> 91c33a0ca


[SPARK-18088][ML] Various ChiSqSelector cleanups

## What changes were proposed in this pull request?
- Renamed kbest to numTopFeatures
- Renamed alpha to fpr
- Added missing Since annotations
- Doc cleanups
## How was this patch tested?

Added new standardized unit tests for spark.ml.
Improved existing unit test coverage a bit.

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #15647 from jkbradley/chisqselector-follow-ups.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/91c33a0c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/91c33a0c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/91c33a0c

Branch: refs/heads/master
Commit: 91c33a0ca5c8287f710076ed7681e5aa13ca068f
Parents: b929537
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Tue Nov 1 17:00:00 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Nov 1 17:00:00 2016 -0700

----------------------------------------------------------------------
 docs/ml-features.md                             |  12 +-
 docs/mllib-feature-extraction.md                |  15 +--
 .../apache/spark/ml/feature/ChiSqSelector.scala |  59 ++++----
 .../spark/mllib/api/python/PythonMLLibAPI.scala |   4 +-
 .../spark/mllib/feature/ChiSqSelector.scala     |  45 +++----
 .../spark/ml/feature/ChiSqSelectorSuite.scala   | 135 ++++++++++---------
 .../mllib/feature/ChiSqSelectorSuite.scala      |  17 +--
 python/pyspark/ml/feature.py                    |  37 ++---
 python/pyspark/mllib/feature.py                 |  58 ++++----
 9 files changed, 197 insertions(+), 185 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/docs/ml-features.md
----------------------------------------------------------------------
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 64c6a16..352887d 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -1338,14 +1338,14 @@ for more details on the API.
 `ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with
 categorical features. ChiSqSelector uses the
 [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which
-features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`:
+features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`:
 
-* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power.
-* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number.
-* `FPR` chooses all features whose false positive rate meets some threshold.
+* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power.
+* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number.
+* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection.
 
-By default, the selection method is `KBest`, the default number of top features is 50. User can use
-`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods.
+By default, the selection method is `numTopFeatures`, with the default number of top features set to 50.
+The user can choose a selection method using `setSelectorType`.
 
 **Examples**
 

http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/docs/mllib-feature-extraction.md
----------------------------------------------------------------------
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 87e1e02..42568c3 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -227,22 +227,19 @@ both speed and statistical learning behavior.
 [`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements
 Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the
 [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which
-features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`:
+features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`:
 
-* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power.
-* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number.
-* `FPR` chooses all features whose false positive rate meets some threshold.
+* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power.
+* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number.
+* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection.
 
-By default, the selection method is `KBest`, the default number of top features is 50. User can use
-`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods.
+By default, the selection method is `numTopFeatures`, with the default number of top features set to 50.
+The user can choose a selection method using `setSelectorType`.
 
 The number of features to select can be tuned using a held-out validation set.
 
 ### Model Fitting
 
-`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that
-the selector will select.
-
 The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes
 an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then
 returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space.

http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index d0385e2..653fa41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -42,69 +42,80 @@ private[feature] trait ChiSqSelectorParams extends Params
   with HasFeaturesCol with HasOutputCol with HasLabelCol {
 
   /**
-   * Number of features that selector will select (ordered by statistic value descending). If the
+   * Number of features that selector will select, ordered by ascending p-value. If the
    * number of features is less than numTopFeatures, then this will select all features.
-   * Only applicable when selectorType = "kbest".
+   * Only applicable when selectorType = "numTopFeatures".
    * The default value of numTopFeatures is 50.
    *
    * @group param
    */
+  @Since("1.6.0")
   final val numTopFeatures = new IntParam(this, "numTopFeatures",
-    "Number of features that selector will select, ordered by statistics value descending. If the" +
+    "Number of features that selector will select, ordered by ascending p-value. If the" +
       " number of features is < numTopFeatures, then this will select all features.",
     ParamValidators.gtEq(1))
   setDefault(numTopFeatures -> 50)
 
   /** @group getParam */
+  @Since("1.6.0")
   def getNumTopFeatures: Int = $(numTopFeatures)
 
   /**
    * Percentile of features that selector will select, ordered by statistics value descending.
    * Only applicable when selectorType = "percentile".
    * Default value is 0.1.
+   * @group param
    */
+  @Since("2.1.0")
   final val percentile = new DoubleParam(this, "percentile",
-    "Percentile of features that selector will select, ordered by statistics value descending.",
+    "Percentile of features that selector will select, ordered by ascending p-value.",
     ParamValidators.inRange(0, 1))
   setDefault(percentile -> 0.1)
 
   /** @group getParam */
+  @Since("2.1.0")
   def getPercentile: Double = $(percentile)
 
   /**
    * The highest p-value for features to be kept.
    * Only applicable when selectorType = "fpr".
    * Default value is 0.05.
+   * @group param
    */
-  final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.",
+  final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.",
     ParamValidators.inRange(0, 1))
-  setDefault(alpha -> 0.05)
+  setDefault(fpr -> 0.05)
 
   /** @group getParam */
-  def getAlpha: Double = $(alpha)
+  def getFpr: Double = $(fpr)
 
   /**
    * The selector type of the ChisqSelector.
-   * Supported options: "kbest" (default), "percentile" and "fpr".
+   * Supported options: "numTopFeatures" (default), "percentile", "fpr".
+   * @group param
    */
+  @Since("2.1.0")
   final val selectorType = new Param[String](this, "selectorType",
     "The selector type of the ChisqSelector. " +
-      "Supported options: kbest (default), percentile and fpr.",
-    ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray))
-  setDefault(selectorType -> OldChiSqSelector.KBest)
+      "Supported options: " + OldChiSqSelector.supportedSelectorTypes.mkString(", "),
+    ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes))
+  setDefault(selectorType -> OldChiSqSelector.NumTopFeatures)
 
   /** @group getParam */
+  @Since("2.1.0")
   def getSelectorType: String = $(selectorType)
 }
 
 /**
  * Chi-Squared feature selection, which selects categorical features to use for predicting a
  * categorical label.
- * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
- * `kbest` chooses the `k` top features according to a chi-squared test.
- * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `fpr` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `kbest`, the default number of top features is 50.
+ * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`.
+ *  - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
+ *  - `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ *  - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false
+ *    positive rate of selection.
+ * By default, the selection method is `numTopFeatures`, with the default number of top features
+ * set to 50.
  */
 @Since("1.6.0")
 final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@@ -114,10 +125,6 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
   def this() = this(Identifiable.randomUID("chiSqSelector"))
 
   /** @group setParam */
-  @Since("2.1.0")
-  def setSelectorType(value: String): this.type = set(selectorType, value)
-
-  /** @group setParam */
   @Since("1.6.0")
   def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
 
@@ -127,7 +134,11 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
 
   /** @group setParam */
   @Since("2.1.0")
-  def setAlpha(value: Double): this.type = set(alpha, value)
+  def setFpr(value: Double): this.type = set(fpr, value)
+
+  /** @group setParam */
+  @Since("2.1.0")
+  def setSelectorType(value: String): this.type = set(selectorType, value)
 
   /** @group setParam */
   @Since("1.6.0")
@@ -153,15 +164,15 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
       .setSelectorType($(selectorType))
       .setNumTopFeatures($(numTopFeatures))
       .setPercentile($(percentile))
-      .setAlpha($(alpha))
+      .setFpr($(fpr))
     val model = selector.fit(input)
     copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
   }
 
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
-    val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType))
-    otherPairs.foreach { case (_, paramName: String) =>
+    val otherPairs = OldChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType))
+    otherPairs.foreach { paramName: String =>
       if (isSet(getParam(paramName))) {
         logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 904000f..034e362 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -638,13 +638,13 @@ private[python] class PythonMLLibAPI extends Serializable {
       selectorType: String,
       numTopFeatures: Int,
       percentile: Double,
-      alpha: Double,
+      fpr: Double,
       data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
     new ChiSqSelector()
       .setSelectorType(selectorType)
       .setNumTopFeatures(numTopFeatures)
       .setPercentile(percentile)
-      .setAlpha(alpha)
+      .setFpr(fpr)
       .fit(data.rdd)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index f8276de..f9156b6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
       Loader.checkSchema[Data](dataFrame.schema)
 
       val features = dataArray.rdd.map {
-        case Row(feature: Int) => (feature)
+        case Row(feature: Int) => feature
       }.collect()
 
       new ChiSqSelectorModel(features)
@@ -171,18 +171,20 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
 
 /**
  * Creates a ChiSquared feature selector.
- * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
- * `kbest` chooses the `k` top features according to a chi-squared test.
- * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `fpr` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `kbest`, the default number of top features is 50.
+ * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`.
+ *  - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
+ *  - `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ *  - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false
+ *    positive rate of selection.
+ * By default, the selection method is `numTopFeatures`, with the default number of top features
+ * set to 50.
  */
 @Since("1.3.0")
 class ChiSqSelector @Since("2.1.0") () extends Serializable {
   var numTopFeatures: Int = 50
   var percentile: Double = 0.1
-  var alpha: Double = 0.05
-  var selectorType = ChiSqSelector.KBest
+  var fpr: Double = 0.05
+  var selectorType = ChiSqSelector.NumTopFeatures
 
   /**
    * The is the same to call this() and setNumTopFeatures(numTopFeatures)
@@ -207,15 +209,15 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   }
 
   @Since("2.1.0")
-  def setAlpha(value: Double): this.type = {
-    require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]")
-    alpha = value
+  def setFpr(value: Double): this.type = {
+    require(0.0 <= value && value <= 1.0, "FPR must be in [0,1]")
+    fpr = value
     this
   }
 
   @Since("2.1.0")
   def setSelectorType(value: String): this.type = {
-    require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value),
+    require(ChiSqSelector.supportedSelectorTypes.contains(value),
       s"ChiSqSelector Type: $value was not supported.")
     selectorType = value
     this
@@ -232,7 +234,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
     val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
     val features = selectorType match {
-      case ChiSqSelector.KBest =>
+      case ChiSqSelector.NumTopFeatures =>
         chiSqTestResult
           .sortBy { case (res, _) => res.pValue }
           .take(numTopFeatures)
@@ -242,7 +244,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
           .take((chiSqTestResult.length * percentile).toInt)
       case ChiSqSelector.FPR =>
         chiSqTestResult
-          .filter { case (res, _) => res.pValue < alpha }
+          .filter { case (res, _) => res.pValue < fpr }
       case errorType =>
         throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
     }
@@ -251,22 +253,17 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   }
 }
 
-@Since("2.1.0")
-object ChiSqSelector {
+private[spark] object ChiSqSelector {
 
-  /** String name for `kbest` selector type. */
-  private[spark] val KBest: String = "kbest"
+  /** String name for `numTopFeatures` selector type. */
+  val NumTopFeatures: String = "numTopFeatures"
 
   /** String name for `percentile` selector type. */
-  private[spark] val Percentile: String = "percentile"
+  val Percentile: String = "percentile"
 
   /** String name for `fpr` selector type. */
   private[spark] val FPR: String = "fpr"
 
-  /** Set of selector type and param pairs that ChiSqSelector supports. */
-  private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures",
-    Percentile -> "percentile", FPR -> "alpha")
-
   /** Set of selector types that ChiSqSelector supports. */
-  private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1)
+  val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 6af06d8..80970fd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -19,85 +19,72 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Dataset, Row}
 
 class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
   with DefaultReadWriteTest {
 
-  test("Test Chi-Square selector") {
-    import testImplicits._
-    val data = Seq(
-      LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
-      LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
-      LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
-      LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
-    )
+  @transient var dataset: Dataset[_] = _
 
-    val preFilteredData = Seq(
-      Vectors.dense(8.0),
-      Vectors.dense(0.0),
-      Vectors.dense(0.0),
-      Vectors.dense(8.0)
-    )
+  override def beforeAll(): Unit = {
+    super.beforeAll()
 
-    val df = sc.parallelize(data.zip(preFilteredData))
-      .map(x => (x._1.label, x._1.features, x._2))
-      .toDF("label", "data", "preFilteredData")
-
-    val selector = new ChiSqSelector()
-      .setSelectorType("kbest")
-      .setNumTopFeatures(1)
-      .setFeaturesCol("data")
-      .setLabelCol("label")
-      .setOutputCol("filtered")
-
-    selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
-      case Row(vec1: Vector, vec2: Vector) =>
-        assert(vec1 ~== vec2 absTol 1e-1)
-    }
-
-    selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df)
-      .select("filtered", "preFilteredData").collect().foreach {
-        case Row(vec1: Vector, vec2: Vector) =>
-          assert(vec1 ~== vec2 absTol 1e-1)
-      }
+    // Toy dataset, including the top feature for a chi-squared test.
+    // These data are chosen such that each feature's test has a distinct p-value.
+    /*  To verify the results with R, run:
+      library(stats)
+      x1 <- c(8.0, 0.0, 0.0, 7.0, 8.0)
+      x2 <- c(7.0, 9.0, 9.0, 9.0, 7.0)
+      x3 <- c(0.0, 6.0, 8.0, 5.0, 3.0)
+      y <- c(0.0, 1.0, 1.0, 2.0, 2.0)
+      chisq.test(x1,y)
+      chisq.test(x2,y)
+      chisq.test(x3,y)
+     */
+    dataset = spark.createDataFrame(Seq(
+      (0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0))), Vectors.dense(8.0)),
+      (1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0))), Vectors.dense(0.0)),
+      (1.0, Vectors.dense(Array(0.0, 9.0, 8.0)), Vectors.dense(0.0)),
+      (2.0, Vectors.dense(Array(7.0, 9.0, 5.0)), Vectors.dense(7.0)),
+      (2.0, Vectors.dense(Array(8.0, 7.0, 3.0)), Vectors.dense(8.0))
+    )).toDF("label", "features", "topFeature")
+  }
 
-    val preFilteredData2 = Seq(
-      Vectors.dense(8.0, 7.0),
-      Vectors.dense(0.0, 9.0),
-      Vectors.dense(0.0, 9.0),
-      Vectors.dense(8.0, 9.0)
-    )
+  test("params") {
+    ParamsSuite.checkParams(new ChiSqSelector)
+    val model = new ChiSqSelectorModel("myModel",
+      new org.apache.spark.mllib.feature.ChiSqSelectorModel(Array(1, 3, 4)))
+    ParamsSuite.checkParams(model)
+  }
 
-    val df2 = sc.parallelize(data.zip(preFilteredData2))
-      .map(x => (x._1.label, x._1.features, x._2))
-      .toDF("label", "data", "preFilteredData")
+  test("Test Chi-Square selector: numTopFeatures") {
+    val selector = new ChiSqSelector()
+      .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1)
+    ChiSqSelectorSuite.testSelector(selector, dataset)
+  }
 
-    selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2)
-      .select("filtered", "preFilteredData").collect().foreach {
-        case Row(vec1: Vector, vec2: Vector) =>
-          assert(vec1 ~== vec2 absTol 1e-1)
-      }
+  test("Test Chi-Square selector: percentile") {
+    val selector = new ChiSqSelector()
+      .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.34)
+    ChiSqSelectorSuite.testSelector(selector, dataset)
   }
 
-  test("ChiSqSelector read/write") {
-    val t = new ChiSqSelector()
-      .setFeaturesCol("myFeaturesCol")
-      .setLabelCol("myLabelCol")
-      .setOutputCol("myOutputCol")
-      .setNumTopFeatures(2)
-    testDefaultReadWrite(t)
+  test("Test Chi-Square selector: fpr") {
+    val selector = new ChiSqSelector()
+      .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.2)
+    ChiSqSelectorSuite.testSelector(selector, dataset)
   }
 
-  test("ChiSqSelectorModel read/write") {
-    val oldModel = new feature.ChiSqSelectorModel(Array(1, 3))
-    val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel)
-    val newInstance = testDefaultReadWrite(instance)
-    assert(newInstance.selectedFeatures === instance.selectedFeatures)
+  test("read/write") {
+    def checkModelData(model: ChiSqSelectorModel, model2: ChiSqSelectorModel): Unit = {
+      assert(model.selectedFeatures === model2.selectedFeatures)
+    }
+    val nb = new ChiSqSelector
+    testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and not support other types") {
@@ -108,3 +95,25 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
       }
   }
 }
+
+object ChiSqSelectorSuite {
+
+  private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = {
+    selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect()
+      .foreach { case Row(vec1: Vector, vec2: Vector) =>
+        assert(vec1 ~== vec2 absTol 1e-1)
+      }
+  }
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "selectorType" -> "percentile",
+    "numTopFeatures" -> 1,
+    "percentile" -> 0.12,
+    "outputCol" -> "myOutput"
+  )
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index ac702b4..77219e5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -54,33 +54,34 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
         LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
         LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2)
     val preFilteredData =
-      Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))),
+      Seq(LabeledPoint(0.0, Vectors.dense(Array(8.0))),
         LabeledPoint(1.0, Vectors.dense(Array(0.0))),
         LabeledPoint(1.0, Vectors.dense(Array(0.0))),
         LabeledPoint(2.0, Vectors.dense(Array(8.0))))
     val model = new ChiSqSelector(1).fit(labeledDiscreteData)
     val filteredData = labeledDiscreteData.map { lp =>
       LabeledPoint(lp.label, model.transform(lp.features))
-    }.collect().toSet
-    assert(filteredData == preFilteredData)
+    }.collect().toSeq
+    assert(filteredData === preFilteredData)
   }
 
-  test("ChiSqSelector by FPR transform test (sparse & dense vector)") {
+  test("ChiSqSelector by fpr transform test (sparse & dense vector)") {
     val labeledDiscreteData = sc.parallelize(
       Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))),
         LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))),
         LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))),
         LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2)
     val preFilteredData =
-      Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
+      Seq(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
         LabeledPoint(1.0, Vectors.dense(Array(4.0))),
         LabeledPoint(1.0, Vectors.dense(Array(4.0))),
         LabeledPoint(2.0, Vectors.dense(Array(9.0))))
-    val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData)
+    val model: ChiSqSelectorModel = new ChiSqSelector().setSelectorType("fpr")
+      .setFpr(0.1).fit(labeledDiscreteData)
     val filteredData = labeledDiscreteData.map { lp =>
       LabeledPoint(lp.label, model.transform(lp.features))
-    }.collect().toSet
-    assert(filteredData == preFilteredData)
+    }.collect().toSeq
+    assert(filteredData === preFilteredData)
   }
 
   test("model load / save") {

http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 94afe82..635cf13 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2606,42 +2606,43 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
 
     selectorType = Param(Params._dummy(), "selectorType",
                          "The selector type of the ChisqSelector. " +
-                         "Supported options: kbest (default), percentile and fpr.",
+                         "Supported options: numTopFeatures (default), percentile and fpr.",
                          typeConverter=TypeConverters.toString)
 
     numTopFeatures = \
         Param(Params._dummy(), "numTopFeatures",
-              "Number of features that selector will select, ordered by statistics value " +
-              "descending. If the number of features is < numTopFeatures, then this will select " +
+              "Number of features that selector will select, ordered by ascending p-value. " +
+              "If the number of features is < numTopFeatures, then this will select " +
               "all features.", typeConverter=TypeConverters.toInt)
 
     percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " +
-                       "will select, ordered by statistics value descending.",
+                       "will select, ordered by ascending p-value.",
                        typeConverter=TypeConverters.toFloat)
 
-    alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.",
-                  typeConverter=TypeConverters.toFloat)
+    fpr = Param(Params._dummy(), "fpr", "The highest p-value for features to be kept.",
+                typeConverter=TypeConverters.toFloat)
 
     @keyword_only
     def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None,
-                 labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05):
+                 labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05):
         """
         __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
-                 labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05)
+                 labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05)
         """
         super(ChiSqSelector, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
-        self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05)
+        self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1,
+                         fpr=0.05)
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.0.0")
     def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
-                  labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05):
+                  labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05):
         """
         setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
-                  labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05)
+                  labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05)
         Sets params for this ChiSqSelector.
         """
         kwargs = self.setParams._input_kwargs
@@ -2665,7 +2666,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
     def setNumTopFeatures(self, value):
         """
         Sets the value of :py:attr:`numTopFeatures`.
-        Only applicable when selectorType = "kbest".
+        Only applicable when selectorType = "numTopFeatures".
         """
         return self._set(numTopFeatures=value)
 
@@ -2692,19 +2693,19 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
         return self.getOrDefault(self.percentile)
 
     @since("2.1.0")
-    def setAlpha(self, value):
+    def setFpr(self, value):
         """
-        Sets the value of :py:attr:`alpha`.
+        Sets the value of :py:attr:`fpr`.
         Only applicable when selectorType = "fpr".
         """
-        return self._set(alpha=value)
+        return self._set(fpr=value)
 
     @since("2.1.0")
-    def getAlpha(self):
+    def getFpr(self):
         """
-        Gets the value of alpha or its default value.
+        Gets the value of fpr or its default value.
         """
-        return self.getOrDefault(self.alpha)
+        return self.getOrDefault(self.fpr)
 
     def _create_model(self, java_model):
         return ChiSqSelectorModel(java_model)

http://git-wip-us.apache.org/repos/asf/spark/blob/91c33a0c/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 50ef7c7..7eaa228 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -274,52 +274,48 @@ class ChiSqSelectorModel(JavaVectorTransformer):
 class ChiSqSelector(object):
     """
     Creates a ChiSquared feature selector.
-    The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
-    `kbest` chooses the `k` top features according to a chi-squared test.
+    The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`.
+    `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
     `percentile` is similar but chooses a fraction of all features instead of a fixed number.
-    `fpr` chooses all features whose false positive rate meets some threshold.
-    By default, the selection method is `kbest`, the default number of top features is 50.
+    `fpr` chooses all features whose p-value is below a threshold, thus controlling the false
+    positive rate of selection.
+    By default, the selection method is `numTopFeatures`, with the default number of top features
+    set to 50.
 
-    >>> data = [
+    >>> data = sc.parallelize([
     ...     LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
     ...     LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
     ...     LabeledPoint(1.0, [0.0, 9.0, 8.0]),
-    ...     LabeledPoint(2.0, [8.0, 9.0, 5.0])
-    ... ]
-    >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data))
+    ...     LabeledPoint(2.0, [7.0, 9.0, 5.0]),
+    ...     LabeledPoint(2.0, [8.0, 7.0, 3.0])
+    ... ])
+    >>> model = ChiSqSelector(numTopFeatures=1).fit(data)
     >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
     SparseVector(1, {})
-    >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
-    DenseVector([8.0])
-    >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit(
-    ...     sc.parallelize(data))
+    >>> model.transform(DenseVector([7.0, 9.0, 5.0]))
+    DenseVector([7.0])
+    >>> model = ChiSqSelector(selectorType="fpr", fpr=0.2).fit(data)
     >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
     SparseVector(1, {})
-    >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
-    DenseVector([8.0])
-    >>> data = [
-    ...     LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})),
-    ...     LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})),
-    ...     LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]),
-    ...     LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0])
-    ... ]
-    >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data))
-    >>> model.transform(DenseVector([1.0,2.0,3.0,4.0]))
-    DenseVector([4.0])
+    >>> model.transform(DenseVector([7.0, 9.0, 5.0]))
+    DenseVector([7.0])
+    >>> model = ChiSqSelector(selectorType="percentile", percentile=0.34).fit(data)
+    >>> model.transform(DenseVector([7.0, 9.0, 5.0]))
+    DenseVector([7.0])
 
     .. versionadded:: 1.4.0
     """
-    def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05):
+    def __init__(self, numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, fpr=0.05):
         self.numTopFeatures = numTopFeatures
         self.selectorType = selectorType
         self.percentile = percentile
-        self.alpha = alpha
+        self.fpr = fpr
 
     @since('2.1.0')
     def setNumTopFeatures(self, numTopFeatures):
         """
         set numTopFeature for feature selection by number of top features.
-        Only applicable when selectorType = "kbest".
+        Only applicable when selectorType = "numTopFeatures".
         """
         self.numTopFeatures = int(numTopFeatures)
         return self
@@ -334,19 +330,19 @@ class ChiSqSelector(object):
         return self
 
     @since('2.1.0')
-    def setAlpha(self, alpha):
+    def setFpr(self, fpr):
         """
-        set alpha [0.0, 1.0] for feature selection by FPR.
+        set FPR [0.0, 1.0] for feature selection by FPR.
         Only applicable when selectorType = "fpr".
         """
-        self.alpha = float(alpha)
+        self.fpr = float(fpr)
         return self
 
     @since('2.1.0')
     def setSelectorType(self, selectorType):
         """
         set the selector type of the ChisqSelector.
-        Supported options: "kbest" (default), "percentile" and "fpr".
+        Supported options: "numTopFeatures" (default), "percentile", "fpr".
         """
         self.selectorType = str(selectorType)
         return self
@@ -362,7 +358,7 @@ class ChiSqSelector(object):
                      Apply feature discretizer before using this function.
         """
         jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures,
-                               self.percentile, self.alpha, data)
+                               self.percentile, self.fpr, data)
         return ChiSqSelectorModel(jmodel)
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org