You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/09/21 09:17:45 UTC

spark git commit: [SPARK-17017][MLLIB][ML] add a chiSquare Selector based on False Positive Rate (FPR) test

Repository: spark
Updated Branches:
  refs/heads/master 28fafa3ee -> b366f1849


[SPARK-17017][MLLIB][ML] add a chiSquare Selector based on False Positive Rate (FPR) test

## What changes were proposed in this pull request?

Univariate feature selection works by selecting the best features based on univariate statistical tests. False Positive Rate (FPR) is a popular univariate statistical test for feature selection. We add a chiSquare Selector based on False Positive Rate (FPR) test in this PR, like it is implemented in scikit-learn.
http://scikit-learn.org/stable/modules/feature_selection.html#univariate-feature-selection

## How was this patch tested?

Add Scala ut

Author: Peng, Meng <pe...@intel.com>

Closes #14597 from mpjlu/fprChiSquare.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b366f184
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b366f184
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b366f184

Branch: refs/heads/master
Commit: b366f18496e1ce8bd20fe58a0245ef7d91819a03
Parents: 28fafa3
Author: Peng, Meng <pe...@intel.com>
Authored: Wed Sep 21 10:17:38 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Sep 21 10:17:38 2016 +0100

----------------------------------------------------------------------
 .../apache/spark/ml/feature/ChiSqSelector.scala |  69 ++++++++++++-
 .../spark/mllib/api/python/PythonMLLibAPI.scala |  28 ++++-
 .../spark/mllib/feature/ChiSqSelector.scala     | 103 ++++++++++++++-----
 .../spark/ml/feature/ChiSqSelectorSuite.scala   |  11 +-
 .../mllib/feature/ChiSqSelectorSuite.scala      |  18 ++++
 project/MimaExcludes.scala                      |   3 +
 python/pyspark/mllib/feature.py                 |  71 ++++++++++++-
 7 files changed, 262 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b366f184/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 1482eb3..0c6a37b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -27,6 +27,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.feature.ChiSqSelectorType
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.rdd.RDD
@@ -54,11 +55,47 @@ private[feature] trait ChiSqSelectorParams extends Params
 
   /** @group getParam */
   def getNumTopFeatures: Int = $(numTopFeatures)
+
+  final val percentile = new DoubleParam(this, "percentile",
+    "Percentile of features that selector will select, ordered by statistics value descending.",
+    ParamValidators.inRange(0, 1))
+  setDefault(percentile -> 0.1)
+
+  /** @group getParam */
+  def getPercentile: Double = $(percentile)
+
+  final val alpha = new DoubleParam(this, "alpha",
+    "The highest p-value for features to be kept.",
+    ParamValidators.inRange(0, 1))
+  setDefault(alpha -> 0.05)
+
+  /** @group getParam */
+  def getAlpha: Double = $(alpha)
+
+  /**
+   * The ChiSqSelector supports KBest, Percentile, FPR selection,
+   * which is the same as ChiSqSelectorType defined in MLLIB.
+   * when call setNumTopFeatures, the selectorType is set to KBest
+   * when call setPercentile, the selectorType is set to Percentile
+   * when call setAlpha, the selectorType is set to FPR
+   */
+  final val selectorType = new Param[String](this, "selectorType",
+    "ChiSqSelector Type: KBest, Percentile, FPR")
+  setDefault(selectorType -> ChiSqSelectorType.KBest.toString)
+
+  /** @group getParam */
+  def getChiSqSelectorType: String = $(selectorType)
 }
 
 /**
  * Chi-Squared feature selection, which selects categorical features to use for predicting a
  * categorical label.
+ * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
+ * `KBest` chooses the `k` top features according to a chi-squared test.
+ * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `FPR` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `KBest`, the default number of top features is 50.
+ * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
  */
 @Since("1.6.0")
 final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@@ -69,7 +106,22 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
 
   /** @group setParam */
   @Since("1.6.0")
-  def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
+  def setNumTopFeatures(value: Int): this.type = {
+    set(selectorType, ChiSqSelectorType.KBest.toString)
+    set(numTopFeatures, value)
+  }
+
+  @Since("2.1.0")
+  def setPercentile(value: Double): this.type = {
+    set(selectorType, ChiSqSelectorType.Percentile.toString)
+    set(percentile, value)
+  }
+
+  @Since("2.1.0")
+  def setAlpha(value: Double): this.type = {
+    set(selectorType, ChiSqSelectorType.FPR.toString)
+    set(alpha, value)
+  }
 
   /** @group setParam */
   @Since("1.6.0")
@@ -91,8 +143,19 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
         case Row(label: Double, features: Vector) =>
           OldLabeledPoint(label, OldVectors.fromML(features))
       }
-    val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input)
-    copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this))
+    var selector = new feature.ChiSqSelector()
+    ChiSqSelectorType.withName($(selectorType)) match {
+      case ChiSqSelectorType.KBest =>
+        selector.setNumTopFeatures($(numTopFeatures))
+      case ChiSqSelectorType.Percentile =>
+        selector.setPercentile($(percentile))
+      case ChiSqSelectorType.FPR =>
+        selector.setAlpha($(alpha))
+      case errorType =>
+        throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
+    }
+    val model = selector.fit(input)
+    copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
   }
 
   @Since("1.6.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/b366f184/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 2ed6c6b..5cffbf0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -629,13 +629,35 @@ private[python] class PythonMLLibAPI extends Serializable {
   }
 
   /**
-   * Java stub for ChiSqSelector.fit(). This stub returns a
+   * Java stub for ChiSqSelector.fit() when the seletion type is KBest. This stub returns a
    * handle to the Java object instead of the content of the Java object.
    * Extra care needs to be taken in the Python code to ensure it gets freed on
    * exit; see the Py4J documentation.
    */
-  def fitChiSqSelector(numTopFeatures: Int, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
-    new ChiSqSelector(numTopFeatures).fit(data.rdd)
+  def fitChiSqSelectorKBest(numTopFeatures: Int,
+    data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
+    new ChiSqSelector().setNumTopFeatures(numTopFeatures).fit(data.rdd)
+  }
+
+  /**
+   * Java stub for ChiSqSelector.fit() when the selection type is Percentile. This stub returns a
+   * handle to the Java object instead of the content of the Java object.
+   * Extra care needs to be taken in the Python code to ensure it gets freed on
+   * exit; see the Py4J documentation.
+   */
+  def fitChiSqSelectorPercentile(percentile: Double,
+    data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
+    new ChiSqSelector().setPercentile(percentile).fit(data.rdd)
+  }
+
+  /**
+   * Java stub for ChiSqSelector.fit() when the selection type is FPR. This stub returns a
+   * handle to the Java object instead of the content of the Java object.
+   * Extra care needs to be taken in the Python code to ensure it gets freed on
+   * exit; see the Py4J documentation.
+   */
+  def fitChiSqSelectorFPR(alpha: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
+    new ChiSqSelector().setAlpha(alpha).fit(data.rdd)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b366f184/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 33a1f18..f68a017 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -32,27 +32,21 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext
 import org.apache.spark.sql.{Row, SparkSession}
 
+@Since("2.1.0")
+private[spark] object ChiSqSelectorType extends Enumeration {
+  type SelectorType = Value
+  val KBest, Percentile, FPR = Value
+}
+
 /**
  * Chi Squared selector model.
  *
- * @param selectedFeatures list of indices to select (filter). Must be ordered asc
+ * @param selectedFeatures list of indices to select (filter).
  */
 @Since("1.3.0")
 class ChiSqSelectorModel @Since("1.3.0") (
   @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
 
-  require(isSorted(selectedFeatures), "Array has to be sorted asc")
-
-  protected def isSorted(array: Array[Int]): Boolean = {
-    var i = 1
-    val len = array.length
-    while (i < len) {
-      if (array(i) < array(i-1)) return false
-      i += 1
-    }
-    true
-  }
-
   /**
    * Applies transformation on a vector.
    *
@@ -69,21 +63,22 @@ class ChiSqSelectorModel @Since("1.3.0") (
    * Preserves the order of filtered features the same as their indices are stored.
    * Might be moved to Vector as .slice
    * @param features vector
-   * @param filterIndices indices of features to filter, must be ordered asc
+   * @param filterIndices indices of features to filter
    */
   private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
+    val orderedIndices = filterIndices.sorted
     features match {
       case SparseVector(size, indices, values) =>
-        val newSize = filterIndices.length
+        val newSize = orderedIndices.length
         val newValues = new ArrayBuilder.ofDouble
         val newIndices = new ArrayBuilder.ofInt
         var i = 0
         var j = 0
         var indicesIdx = 0
         var filterIndicesIdx = 0
-        while (i < indices.length && j < filterIndices.length) {
+        while (i < indices.length && j < orderedIndices.length) {
           indicesIdx = indices(i)
-          filterIndicesIdx = filterIndices(j)
+          filterIndicesIdx = orderedIndices(j)
           if (indicesIdx == filterIndicesIdx) {
             newIndices += j
             newValues += values(i)
@@ -101,7 +96,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
         Vectors.sparse(newSize, newIndices.result(), newValues.result())
       case DenseVector(values) =>
         val values = features.toArray
-        Vectors.dense(filterIndices.map(i => values(i)))
+        Vectors.dense(orderedIndices.map(i => values(i)))
       case other =>
         throw new UnsupportedOperationException(
           s"Only sparse and dense vectors are supported but got ${other.getClass}.")
@@ -171,14 +166,57 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
 
 /**
  * Creates a ChiSquared feature selector.
- * @param numTopFeatures number of features that selector will select
- *                       (ordered by statistic value descending)
- *                       Note that if the number of features is less than numTopFeatures,
- *                       then this will select all features.
+ * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
+ * `KBest` chooses the `k` top features according to a chi-squared test.
+ * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `FPR` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `KBest`, the default number of top features is 50.
+ * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
  */
 @Since("1.3.0")
-class ChiSqSelector @Since("1.3.0") (
-  @Since("1.3.0") val numTopFeatures: Int) extends Serializable {
+class ChiSqSelector @Since("2.1.0") () extends Serializable {
+  var numTopFeatures: Int = 50
+  var percentile: Double = 0.1
+  var alpha: Double = 0.05
+  var selectorType = ChiSqSelectorType.KBest
+
+  /**
+   * The is the same to call this() and setNumTopFeatures(numTopFeatures)
+   */
+  @Since("1.3.0")
+  def this(numTopFeatures: Int) {
+    this()
+    this.numTopFeatures = numTopFeatures
+  }
+
+  @Since("1.6.0")
+  def setNumTopFeatures(value: Int): this.type = {
+    numTopFeatures = value
+    selectorType = ChiSqSelectorType.KBest
+    this
+  }
+
+  @Since("2.1.0")
+  def setPercentile(value: Double): this.type = {
+    require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]")
+    percentile = value
+    selectorType = ChiSqSelectorType.Percentile
+    this
+  }
+
+  @Since("2.1.0")
+  def setAlpha(value: Double): this.type = {
+    require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]")
+    alpha = value
+    selectorType = ChiSqSelectorType.FPR
+    this
+  }
+
+  @Since("2.1.0")
+  def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = {
+    selectorType = value
+    this
+  }
 
   /**
    * Returns a ChiSquared feature selector.
@@ -189,11 +227,20 @@ class ChiSqSelector @Since("1.3.0") (
    */
   @Since("1.3.0")
   def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
-    val indices = Statistics.chiSqTest(data)
+    val chiSqTestResult = Statistics.chiSqTest(data)
       .zipWithIndex.sortBy { case (res, _) => -res.statistic }
-      .take(numTopFeatures)
-      .map { case (_, indices) => indices }
-      .sorted
+    val features = selectorType match {
+      case ChiSqSelectorType.KBest => chiSqTestResult
+        .take(numTopFeatures)
+      case ChiSqSelectorType.Percentile => chiSqTestResult
+        .take((chiSqTestResult.length * percentile).toInt)
+      case ChiSqSelectorType.FPR => chiSqTestResult
+        .filter{ case (res, _) => res.pValue < alpha }
+      case errorType =>
+        throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
+    }
+    val indices = features.map { case (_, indices) => indices }
     new ChiSqSelectorModel(indices)
   }
 }
+

http://git-wip-us.apache.org/repos/asf/spark/blob/b366f184/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 3558290..e0293db 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -49,16 +49,23 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
       .map(x => (x._1.label, x._1.features, x._2))
       .toDF("label", "data", "preFilteredData")
 
-    val model = new ChiSqSelector()
+    val selector = new ChiSqSelector()
       .setNumTopFeatures(1)
       .setFeaturesCol("data")
       .setLabelCol("label")
       .setOutputCol("filtered")
 
-    model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
+    selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
       case Row(vec1: Vector, vec2: Vector) =>
         assert(vec1 ~== vec2 absTol 1e-1)
     }
+
+    selector.setPercentile(0.34).fit(df).transform(df)
+    .select("filtered", "preFilteredData").collect().foreach {
+      case Row(vec1: Vector, vec2: Vector) =>
+        assert(vec1 ~== vec2 absTol 1e-1)
+    }
+
   }
 
   test("ChiSqSelector read/write") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b366f184/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index 734800a..e181a54 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -65,6 +65,24 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(filteredData == preFilteredData)
   }
 
+  test("ChiSqSelector by FPR transform test (sparse & dense vector)") {
+    val labeledDiscreteData = sc.parallelize(
+      Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))),
+        LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))),
+        LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))),
+        LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2)
+    val preFilteredData =
+      Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
+        LabeledPoint(1.0, Vectors.dense(Array(4.0))),
+        LabeledPoint(1.0, Vectors.dense(Array(4.0))),
+        LabeledPoint(2.0, Vectors.dense(Array(9.0))))
+    val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData)
+    val filteredData = labeledDiscreteData.map { lp =>
+      LabeledPoint(lp.label, model.transform(lp.features))
+    }.collect().toSet
+    assert(filteredData == preFilteredData)
+  }
+
   test("model load / save") {
     val model = ChiSqSelectorSuite.createModel()
     val tempDir = Utils.createTempDir()

http://git-wip-us.apache.org/repos/asf/spark/blob/b366f184/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index d4cbf51..f13f3ff 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -815,6 +815,9 @@ object MimaExcludes {
     ) ++ Seq(
       // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
       ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
+    ) ++ Seq(
+      // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted")
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b366f184/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 5d99644..077c113 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -271,11 +271,22 @@ class ChiSqSelectorModel(JavaVectorTransformer):
         return JavaVectorTransformer.transform(self, vector)
 
 
+class ChiSqSelectorType:
+    """
+    This class defines the selector types of Chi Square Selector.
+    """
+    KBest, Percentile, FPR = range(3)
+
+
 class ChiSqSelector(object):
     """
     Creates a ChiSquared feature selector.
-
-    :param numTopFeatures: number of features that selector will select.
+    The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
+    `KBest` chooses the `k` top features according to a chi-squared test.
+    `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
+    `FPR` chooses all features whose false positive rate meets some threshold.
+    By default, the selection method is `KBest`, the default number of top features is 50.
+    User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
 
     >>> data = [
     ...     LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
@@ -283,16 +294,58 @@ class ChiSqSelector(object):
     ...     LabeledPoint(1.0, [0.0, 9.0, 8.0]),
     ...     LabeledPoint(2.0, [8.0, 9.0, 5.0])
     ... ]
-    >>> model = ChiSqSelector(1).fit(sc.parallelize(data))
+    >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data))
+    >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
+    SparseVector(1, {0: 6.0})
+    >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
+    DenseVector([5.0])
+    >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data))
     >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
     SparseVector(1, {0: 6.0})
     >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
     DenseVector([5.0])
+    >>> data = [
+    ...     LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})),
+    ...     LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})),
+    ...     LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]),
+    ...     LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0])
+    ... ]
+    >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data))
+    >>> model.transform(DenseVector([1.0,2.0,3.0,4.0]))
+    DenseVector([4.0])
 
     .. versionadded:: 1.4.0
     """
-    def __init__(self, numTopFeatures):
+    def __init__(self, numTopFeatures=50):
+        self.numTopFeatures = numTopFeatures
+        self.selectorType = ChiSqSelectorType.KBest
+
+    @since('2.1.0')
+    def setNumTopFeatures(self, numTopFeatures):
+        """
+        set numTopFeature for feature selection by number of top features
+        """
         self.numTopFeatures = int(numTopFeatures)
+        self.selectorType = ChiSqSelectorType.KBest
+        return self
+
+    @since('2.1.0')
+    def setPercentile(self, percentile):
+        """
+        set percentile [0.0, 1.0] for feature selection by percentile
+        """
+        self.percentile = float(percentile)
+        self.selectorType = ChiSqSelectorType.Percentile
+        return self
+
+    @since('2.1.0')
+    def setAlpha(self, alpha):
+        """
+        set alpha [0.0, 1.0] for feature selection by FPR
+        """
+        self.alpha = float(alpha)
+        self.selectorType = ChiSqSelectorType.FPR
+        return self
 
     @since('1.4.0')
     def fit(self, data):
@@ -304,7 +357,15 @@ class ChiSqSelector(object):
                      treated as categorical for each distinct value.
                      Apply feature discretizer before using this function.
         """
-        jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
+        if self.selectorType == ChiSqSelectorType.KBest:
+            jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data)
+        elif self.selectorType == ChiSqSelectorType.Percentile:
+            jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data)
+        elif self.selectorType == ChiSqSelectorType.FPR:
+            jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data)
+        else:
+            raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and"
+                             " FPR(2), the current value is: %s" % self.selectorType)
         return ChiSqSelectorModel(jmodel)
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org