You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/09/26 08:45:37 UTC

spark git commit: [SPARK-17017][FOLLOW-UP][ML] Refactor of ChiSqSelector and add ML Python API.

Repository: spark
Updated Branches:
  refs/heads/master 59d87d240 -> ac65139be


[SPARK-17017][FOLLOW-UP][ML] Refactor of ChiSqSelector and add ML Python API.

## What changes were proposed in this pull request?
#14597 modified ```ChiSqSelector``` to support ```fpr``` type selector, however, it left some issue need to be addressed:
* We should allow users to set selector type explicitly rather than switching them by using different setting function, since the setting order will involves some unexpected issue. For example, if users both set ```numTopFeatures``` and ```percentile```, it will train ```kbest``` or ```percentile``` model based on the order of setting (the latter setting one will be trained). This make users confused, and we should allow users to set selector type explicitly. We handle similar issues at other place of ML code base such as ```GeneralizedLinearRegression``` and ```LogisticRegression```.
* Meanwhile, if there are more than one parameter except ```alpha``` can be set for ```fpr``` model, we can not handle it elegantly in the existing framework. And similar issues for ```kbest``` and ```percentile``` model. Setting selector type explicitly can solve this issue also.
* If setting selector type explicitly by users is allowed, we should handle param interaction such as if users set ```selectorType = percentile``` and ```alpha = 0.1```, we should notify users the parameter ```alpha``` will take no effect. We should handle complex parameter interaction checks at ```transformSchema```. (FYI #11620)
* We should use lower case of the selector type names to follow MLlib convention.
* Add ML Python API.

## How was this patch tested?
Unit test.

Author: Yanbo Liang <yb...@gmail.com>

Closes #15214 from yanboliang/spark-17017.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ac65139b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ac65139b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ac65139b

Branch: refs/heads/master
Commit: ac65139be96dbf87402b9a85729a93afd3c6ff17
Parents: 59d87d2
Author: Yanbo Liang <yb...@gmail.com>
Authored: Mon Sep 26 09:45:33 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Mon Sep 26 09:45:33 2016 +0100

----------------------------------------------------------------------
 .../apache/spark/ml/feature/ChiSqSelector.scala | 86 +++++++++++---------
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 38 +++------
 .../spark/mllib/feature/ChiSqSelector.scala     | 51 +++++++-----
 .../spark/ml/feature/ChiSqSelectorSuite.scala   | 27 ++++--
 .../mllib/feature/ChiSqSelectorSuite.scala      |  2 +-
 python/pyspark/ml/feature.py                    | 71 ++++++++++++++--
 python/pyspark/mllib/feature.py                 | 59 +++++++-------
 7 files changed, 206 insertions(+), 128 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ac65139b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 0c6a37b..9c131a4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.feature.ChiSqSelectorType
+import org.apache.spark.mllib.feature.{ChiSqSelector => OldChiSqSelector}
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.rdd.RDD
@@ -44,7 +44,9 @@ private[feature] trait ChiSqSelectorParams extends Params
   /**
    * Number of features that selector will select (ordered by statistic value descending). If the
    * number of features is less than numTopFeatures, then this will select all features.
+   * Only applicable when selectorType = "kbest".
    * The default value of numTopFeatures is 50.
+   *
    * @group param
    */
   final val numTopFeatures = new IntParam(this, "numTopFeatures",
@@ -56,6 +58,11 @@ private[feature] trait ChiSqSelectorParams extends Params
   /** @group getParam */
   def getNumTopFeatures: Int = $(numTopFeatures)
 
+  /**
+   * Percentile of features that selector will select, ordered by statistics value descending.
+   * Only applicable when selectorType = "percentile".
+   * Default value is 0.1.
+   */
   final val percentile = new DoubleParam(this, "percentile",
     "Percentile of features that selector will select, ordered by statistics value descending.",
     ParamValidators.inRange(0, 1))
@@ -64,8 +71,12 @@ private[feature] trait ChiSqSelectorParams extends Params
   /** @group getParam */
   def getPercentile: Double = $(percentile)
 
-  final val alpha = new DoubleParam(this, "alpha",
-    "The highest p-value for features to be kept.",
+  /**
+   * The highest p-value for features to be kept.
+   * Only applicable when selectorType = "fpr".
+   * Default value is 0.05.
+   */
+  final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.",
     ParamValidators.inRange(0, 1))
   setDefault(alpha -> 0.05)
 
@@ -73,29 +84,27 @@ private[feature] trait ChiSqSelectorParams extends Params
   def getAlpha: Double = $(alpha)
 
   /**
-   * The ChiSqSelector supports KBest, Percentile, FPR selection,
-   * which is the same as ChiSqSelectorType defined in MLLIB.
-   * when call setNumTopFeatures, the selectorType is set to KBest
-   * when call setPercentile, the selectorType is set to Percentile
-   * when call setAlpha, the selectorType is set to FPR
+   * The selector type of the ChisqSelector.
+   * Supported options: "kbest" (default), "percentile" and "fpr".
    */
   final val selectorType = new Param[String](this, "selectorType",
-    "ChiSqSelector Type: KBest, Percentile, FPR")
-  setDefault(selectorType -> ChiSqSelectorType.KBest.toString)
+    "The selector type of the ChisqSelector. " +
+      "Supported options: kbest (default), percentile and fpr.",
+    ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray))
+  setDefault(selectorType -> OldChiSqSelector.KBest)
 
   /** @group getParam */
-  def getChiSqSelectorType: String = $(selectorType)
+  def getSelectorType: String = $(selectorType)
 }
 
 /**
  * Chi-Squared feature selection, which selects categorical features to use for predicting a
  * categorical label.
- * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
- * `KBest` chooses the `k` top features according to a chi-squared test.
- * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `FPR` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `KBest`, the default number of top features is 50.
- * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+ * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
+ * `kbest` chooses the `k` top features according to a chi-squared test.
+ * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `fpr` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `kbest`, the default number of top features is 50.
  */
 @Since("1.6.0")
 final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@@ -105,23 +114,20 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
   def this() = this(Identifiable.randomUID("chiSqSelector"))
 
   /** @group setParam */
+  @Since("2.1.0")
+  def setSelectorType(value: String): this.type = set(selectorType, value)
+
+  /** @group setParam */
   @Since("1.6.0")
-  def setNumTopFeatures(value: Int): this.type = {
-    set(selectorType, ChiSqSelectorType.KBest.toString)
-    set(numTopFeatures, value)
-  }
+  def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
 
+  /** @group setParam */
   @Since("2.1.0")
-  def setPercentile(value: Double): this.type = {
-    set(selectorType, ChiSqSelectorType.Percentile.toString)
-    set(percentile, value)
-  }
+  def setPercentile(value: Double): this.type = set(percentile, value)
 
+  /** @group setParam */
   @Since("2.1.0")
-  def setAlpha(value: Double): this.type = {
-    set(selectorType, ChiSqSelectorType.FPR.toString)
-    set(alpha, value)
-  }
+  def setAlpha(value: Double): this.type = set(alpha, value)
 
   /** @group setParam */
   @Since("1.6.0")
@@ -143,23 +149,23 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
         case Row(label: Double, features: Vector) =>
           OldLabeledPoint(label, OldVectors.fromML(features))
       }
-    var selector = new feature.ChiSqSelector()
-    ChiSqSelectorType.withName($(selectorType)) match {
-      case ChiSqSelectorType.KBest =>
-        selector.setNumTopFeatures($(numTopFeatures))
-      case ChiSqSelectorType.Percentile =>
-        selector.setPercentile($(percentile))
-      case ChiSqSelectorType.FPR =>
-        selector.setAlpha($(alpha))
-      case errorType =>
-        throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
-    }
+    val selector = new feature.ChiSqSelector()
+      .setSelectorType($(selectorType))
+      .setNumTopFeatures($(numTopFeatures))
+      .setPercentile($(percentile))
+      .setAlpha($(alpha))
     val model = selector.fit(input)
     copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
   }
 
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
+    val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType))
+    otherPairs.foreach { case (_, paramName: String) =>
+      if (isSet(getParam(paramName))) {
+        logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
+      }
+    }
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.checkNumericType(schema, $(labelCol))
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)

http://git-wip-us.apache.org/repos/asf/spark/blob/ac65139b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 5cffbf0..904000f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -629,35 +629,23 @@ private[python] class PythonMLLibAPI extends Serializable {
   }
 
   /**
-   * Java stub for ChiSqSelector.fit() when the seletion type is KBest. This stub returns a
+   * Java stub for ChiSqSelector.fit(). This stub returns a
    * handle to the Java object instead of the content of the Java object.
    * Extra care needs to be taken in the Python code to ensure it gets freed on
    * exit; see the Py4J documentation.
    */
-  def fitChiSqSelectorKBest(numTopFeatures: Int,
-    data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
-    new ChiSqSelector().setNumTopFeatures(numTopFeatures).fit(data.rdd)
-  }
-
-  /**
-   * Java stub for ChiSqSelector.fit() when the selection type is Percentile. This stub returns a
-   * handle to the Java object instead of the content of the Java object.
-   * Extra care needs to be taken in the Python code to ensure it gets freed on
-   * exit; see the Py4J documentation.
-   */
-  def fitChiSqSelectorPercentile(percentile: Double,
-    data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
-    new ChiSqSelector().setPercentile(percentile).fit(data.rdd)
-  }
-
-  /**
-   * Java stub for ChiSqSelector.fit() when the selection type is FPR. This stub returns a
-   * handle to the Java object instead of the content of the Java object.
-   * Extra care needs to be taken in the Python code to ensure it gets freed on
-   * exit; see the Py4J documentation.
-   */
-  def fitChiSqSelectorFPR(alpha: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
-    new ChiSqSelector().setAlpha(alpha).fit(data.rdd)
+  def fitChiSqSelector(
+      selectorType: String,
+      numTopFeatures: Int,
+      percentile: Double,
+      alpha: Double,
+      data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
+    new ChiSqSelector()
+      .setSelectorType(selectorType)
+      .setNumTopFeatures(numTopFeatures)
+      .setPercentile(percentile)
+      .setAlpha(alpha)
+      .fit(data.rdd)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ac65139b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index f68a017..0f7c6e8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -32,12 +32,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext
 import org.apache.spark.sql.{Row, SparkSession}
 
-@Since("2.1.0")
-private[spark] object ChiSqSelectorType extends Enumeration {
-  type SelectorType = Value
-  val KBest, Percentile, FPR = Value
-}
-
 /**
  * Chi Squared selector model.
  *
@@ -166,19 +160,18 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
 
 /**
  * Creates a ChiSquared feature selector.
- * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
- * `KBest` chooses the `k` top features according to a chi-squared test.
- * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `FPR` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `KBest`, the default number of top features is 50.
- * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+ * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
+ * `kbest` chooses the `k` top features according to a chi-squared test.
+ * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `fpr` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `kbest`, the default number of top features is 50.
  */
 @Since("1.3.0")
 class ChiSqSelector @Since("2.1.0") () extends Serializable {
   var numTopFeatures: Int = 50
   var percentile: Double = 0.1
   var alpha: Double = 0.05
-  var selectorType = ChiSqSelectorType.KBest
+  var selectorType = ChiSqSelector.KBest
 
   /**
    * The is the same to call this() and setNumTopFeatures(numTopFeatures)
@@ -192,7 +185,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   @Since("1.6.0")
   def setNumTopFeatures(value: Int): this.type = {
     numTopFeatures = value
-    selectorType = ChiSqSelectorType.KBest
     this
   }
 
@@ -200,7 +192,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   def setPercentile(value: Double): this.type = {
     require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]")
     percentile = value
-    selectorType = ChiSqSelectorType.Percentile
     this
   }
 
@@ -208,12 +199,13 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   def setAlpha(value: Double): this.type = {
     require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]")
     alpha = value
-    selectorType = ChiSqSelectorType.FPR
     this
   }
 
   @Since("2.1.0")
-  def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = {
+  def setSelectorType(value: String): this.type = {
+    require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value),
+      s"ChiSqSelector Type: $value was not supported.")
     selectorType = value
     this
   }
@@ -230,11 +222,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
     val chiSqTestResult = Statistics.chiSqTest(data)
       .zipWithIndex.sortBy { case (res, _) => -res.statistic }
     val features = selectorType match {
-      case ChiSqSelectorType.KBest => chiSqTestResult
+      case ChiSqSelector.KBest => chiSqTestResult
         .take(numTopFeatures)
-      case ChiSqSelectorType.Percentile => chiSqTestResult
+      case ChiSqSelector.Percentile => chiSqTestResult
         .take((chiSqTestResult.length * percentile).toInt)
-      case ChiSqSelectorType.FPR => chiSqTestResult
+      case ChiSqSelector.FPR => chiSqTestResult
         .filter{ case (res, _) => res.pValue < alpha }
       case errorType =>
         throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
@@ -244,3 +236,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   }
 }
 
+@Since("2.1.0")
+object ChiSqSelector {
+
+  /** String name for `kbest` selector type. */
+  private[spark] val KBest: String = "kbest"
+
+  /** String name for `percentile` selector type. */
+  private[spark] val Percentile: String = "percentile"
+
+  /** String name for `fpr` selector type. */
+  private[spark] val FPR: String = "fpr"
+
+  /** Set of selector type and param pairs that ChiSqSelector supports. */
+  private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures",
+    Percentile -> "percentile", FPR -> "alpha")
+
+  /** Set of selector types that ChiSqSelector supports. */
+  private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ac65139b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index e0293db..6b56e42 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -50,6 +50,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
       .toDF("label", "data", "preFilteredData")
 
     val selector = new ChiSqSelector()
+      .setSelectorType("kbest")
       .setNumTopFeatures(1)
       .setFeaturesCol("data")
       .setLabelCol("label")
@@ -60,12 +61,28 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
         assert(vec1 ~== vec2 absTol 1e-1)
     }
 
-    selector.setPercentile(0.34).fit(df).transform(df)
-    .select("filtered", "preFilteredData").collect().foreach {
-      case Row(vec1: Vector, vec2: Vector) =>
-        assert(vec1 ~== vec2 absTol 1e-1)
-    }
+    selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df)
+      .select("filtered", "preFilteredData").collect().foreach {
+        case Row(vec1: Vector, vec2: Vector) =>
+          assert(vec1 ~== vec2 absTol 1e-1)
+      }
+
+    val preFilteredData2 = Seq(
+      Vectors.dense(8.0, 7.0),
+      Vectors.dense(0.0, 9.0),
+      Vectors.dense(0.0, 9.0),
+      Vectors.dense(8.0, 9.0)
+    )
 
+    val df2 = sc.parallelize(data.zip(preFilteredData2))
+      .map(x => (x._1.label, x._1.features, x._2))
+      .toDF("label", "data", "preFilteredData")
+
+    selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2)
+      .select("filtered", "preFilteredData").collect().foreach {
+        case Row(vec1: Vector, vec2: Vector) =>
+          assert(vec1 ~== vec2 absTol 1e-1)
+      }
   }
 
   test("ChiSqSelector read/write") {

http://git-wip-us.apache.org/repos/asf/spark/blob/ac65139b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index e181a54..ec23a4a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -76,7 +76,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
         LabeledPoint(1.0, Vectors.dense(Array(4.0))),
         LabeledPoint(1.0, Vectors.dense(Array(4.0))),
         LabeledPoint(2.0, Vectors.dense(Array(9.0))))
-    val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData)
+    val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData)
     val filteredData = labeledDiscreteData.map { lp =>
       LabeledPoint(lp.label, model.transform(lp.features))
     }.collect().toSet

http://git-wip-us.apache.org/repos/asf/spark/blob/ac65139b/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index c45434f..12a1384 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2586,39 +2586,68 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
     .. versionadded:: 2.0.0
     """
 
+    selectorType = Param(Params._dummy(), "selectorType",
+                         "The selector type of the ChisqSelector. " +
+                         "Supported options: kbest (default), percentile and fpr.",
+                         typeConverter=TypeConverters.toString)
+
     numTopFeatures = \
         Param(Params._dummy(), "numTopFeatures",
               "Number of features that selector will select, ordered by statistics value " +
               "descending. If the number of features is < numTopFeatures, then this will select " +
               "all features.", typeConverter=TypeConverters.toInt)
 
+    percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " +
+                       "will select, ordered by statistics value descending.",
+                       typeConverter=TypeConverters.toFloat)
+
+    alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.",
+                  typeConverter=TypeConverters.toFloat)
+
     @keyword_only
-    def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"):
+    def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None,
+                 labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05):
         """
-        __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label")
+        __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
+                 labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05)
         """
         super(ChiSqSelector, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
-        self._setDefault(numTopFeatures=50)
+        self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05)
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.0.0")
     def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
-                  labelCol="labels"):
+                  labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05):
         """
-        setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\
-                  labelCol="labels")
+        setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
+                  labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05)
         Sets params for this ChiSqSelector.
         """
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)
 
+    @since("2.1.0")
+    def setSelectorType(self, value):
+        """
+        Sets the value of :py:attr:`selectorType`.
+        """
+        return self._set(selectorType=value)
+
+    @since("2.1.0")
+    def getSelectorType(self):
+        """
+        Gets the value of selectorType or its default value.
+        """
+        return self.getOrDefault(self.selectorType)
+
     @since("2.0.0")
     def setNumTopFeatures(self, value):
         """
         Sets the value of :py:attr:`numTopFeatures`.
+        Only applicable when selectorType = "kbest".
         """
         return self._set(numTopFeatures=value)
 
@@ -2629,6 +2658,36 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
         """
         return self.getOrDefault(self.numTopFeatures)
 
+    @since("2.1.0")
+    def setPercentile(self, value):
+        """
+        Sets the value of :py:attr:`percentile`.
+        Only applicable when selectorType = "percentile".
+        """
+        return self._set(percentile=value)
+
+    @since("2.1.0")
+    def getPercentile(self):
+        """
+        Gets the value of percentile or its default value.
+        """
+        return self.getOrDefault(self.percentile)
+
+    @since("2.1.0")
+    def setAlpha(self, value):
+        """
+        Sets the value of :py:attr:`alpha`.
+        Only applicable when selectorType = "fpr".
+        """
+        return self._set(alpha=value)
+
+    @since("2.1.0")
+    def getAlpha(self):
+        """
+        Gets the value of alpha or its default value.
+        """
+        return self.getOrDefault(self.alpha)
+
     def _create_model(self, java_model):
         return ChiSqSelectorModel(java_model)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ac65139b/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 077c113..4aea818 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -271,22 +271,14 @@ class ChiSqSelectorModel(JavaVectorTransformer):
         return JavaVectorTransformer.transform(self, vector)
 
 
-class ChiSqSelectorType:
-    """
-    This class defines the selector types of Chi Square Selector.
-    """
-    KBest, Percentile, FPR = range(3)
-
-
 class ChiSqSelector(object):
     """
     Creates a ChiSquared feature selector.
     The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
-    `KBest` chooses the `k` top features according to a chi-squared test.
-    `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
-    `FPR` chooses all features whose false positive rate meets some threshold.
-    By default, the selection method is `KBest`, the default number of top features is 50.
-    User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+    `kbest` chooses the `k` top features according to a chi-squared test.
+    `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+    `fpr` chooses all features whose false positive rate meets some threshold.
+    By default, the selection method is `kbest`, the default number of top features is 50.
 
     >>> data = [
     ...     LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
@@ -299,7 +291,8 @@ class ChiSqSelector(object):
     SparseVector(1, {0: 6.0})
     >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
     DenseVector([5.0])
-    >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data))
+    >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit(
+    ...     sc.parallelize(data))
     >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
     SparseVector(1, {0: 6.0})
     >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
@@ -310,41 +303,52 @@ class ChiSqSelector(object):
     ...     LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]),
     ...     LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0])
     ... ]
-    >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data))
+    >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data))
     >>> model.transform(DenseVector([1.0,2.0,3.0,4.0]))
     DenseVector([4.0])
 
     .. versionadded:: 1.4.0
     """
-    def __init__(self, numTopFeatures=50):
+    def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05):
         self.numTopFeatures = numTopFeatures
-        self.selectorType = ChiSqSelectorType.KBest
+        self.selectorType = selectorType
+        self.percentile = percentile
+        self.alpha = alpha
 
     @since('2.1.0')
     def setNumTopFeatures(self, numTopFeatures):
         """
-        set numTopFeature for feature selection by number of top features
+        set numTopFeature for feature selection by number of top features.
+        Only applicable when selectorType = "kbest".
         """
         self.numTopFeatures = int(numTopFeatures)
-        self.selectorType = ChiSqSelectorType.KBest
         return self
 
     @since('2.1.0')
     def setPercentile(self, percentile):
         """
-        set percentile [0.0, 1.0] for feature selection by percentile
+        set percentile [0.0, 1.0] for feature selection by percentile.
+        Only applicable when selectorType = "percentile".
         """
         self.percentile = float(percentile)
-        self.selectorType = ChiSqSelectorType.Percentile
         return self
 
     @since('2.1.0')
     def setAlpha(self, alpha):
         """
-        set alpha [0.0, 1.0] for feature selection by FPR
+        set alpha [0.0, 1.0] for feature selection by FPR.
+        Only applicable when selectorType = "fpr".
         """
         self.alpha = float(alpha)
-        self.selectorType = ChiSqSelectorType.FPR
+        return self
+
+    @since('2.1.0')
+    def setSelectorType(self, selectorType):
+        """
+        set the selector type of the ChisqSelector.
+        Supported options: "kbest" (default), "percentile" and "fpr".
+        """
+        self.selectorType = str(selectorType)
         return self
 
     @since('1.4.0')
@@ -357,15 +361,8 @@ class ChiSqSelector(object):
                      treated as categorical for each distinct value.
                      Apply feature discretizer before using this function.
         """
-        if self.selectorType == ChiSqSelectorType.KBest:
-            jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data)
-        elif self.selectorType == ChiSqSelectorType.Percentile:
-            jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data)
-        elif self.selectorType == ChiSqSelectorType.FPR:
-            jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data)
-        else:
-            raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and"
-                             " FPR(2), the current value is: %s" % self.selectorType)
+        jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures,
+                               self.percentile, self.alpha, data)
         return ChiSqSelectorModel(jmodel)
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org