You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/01/18 04:20:46 UTC

[spark] branch branch-3.1 updated: [SPARK-34080][ML][PYTHON][FOLLOWUP] Add UnivariateFeatureSelector - make methods private

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 56f93e5  [SPARK-34080][ML][PYTHON][FOLLOWUP] Add UnivariateFeatureSelector - make methods private
56f93e5 is described below

commit 56f93e56ab731be27a05a299fcbe0ef529f280ba
Author: Ruifeng Zheng <ru...@foxmail.com>
AuthorDate: Mon Jan 18 13:19:59 2021 +0900

    [SPARK-34080][ML][PYTHON][FOLLOWUP] Add UnivariateFeatureSelector - make methods private
    
    ### What changes were proposed in this pull request?
    1, make `getTopIndices`/`selectIndicesFromPValues` private;
    2, avoid setting `selectionThreshold` in `fit`
    3, move param checking to `transformSchema`
    
    ### Why are the changes needed?
    `getTopIndices`/`selectIndicesFromPValues` should not be exposed to end users;
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing testsuites
    
    Closes #31222 from zhengruifeng/selector_clean_up.
    
    Authored-by: Ruifeng Zheng <ru...@foxmail.com>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
    (cherry picked from commit ac322a1ac3be79b5e514f0119275f53b3a40c923)
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 .../ml/feature/UnivariateFeatureSelector.scala     | 74 ++++++++--------------
 1 file changed, 27 insertions(+), 47 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index 6d5f09e..bfe1d5f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -76,8 +76,7 @@ private[feature] trait UnivariateFeatureSelectorParams extends Params
   @Since("3.1.1")
   final val selectionMode = new Param[String](this, "selectionMode",
     "The selection mode. Supported options: numTopFeatures, percentile, fpr, fdr, fwe",
-    ParamValidators.inArray(Array("numTopFeatures", "percentile", "fpr", "fdr",
-      "fwe")))
+    ParamValidators.inArray(Array("numTopFeatures", "percentile", "fpr", "fdr", "fwe")))
 
   /** @group getParam */
   @Since("3.1.1")
@@ -161,48 +160,17 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
     transformSchema(dataset.schema, logging = true)
     val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
 
-    $(selectionMode) match {
-      case ("numTopFeatures") =>
-        if (!isSet(selectionThreshold)) {
-          set(selectionThreshold, 50.0)
-        } else {
-          require($(selectionThreshold) > 0 && $(selectionThreshold).toInt == $(selectionThreshold),
-            "selectionThreshold needs to be a positive Integer for selection mode numTopFeatures")
-        }
-      case ("percentile") =>
-        if (!isSet(selectionThreshold)) {
-          set(selectionThreshold, 0.1)
-        } else {
-          require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1,
-            "selectionThreshold needs to be in the range of 0 to 1 for selection mode percentile")
-        }
-      case ("fpr") =>
-        if (!isSet(selectionThreshold)) {
-          set(selectionThreshold, 0.05)
-        } else {
-          require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1,
-            "selectionThreshold needs to be in the range of 0 to 1 for selection mode fpr")
-        }
-      case ("fdr") =>
-        if (!isSet(selectionThreshold)) {
-          set(selectionThreshold, 0.05)
-        } else {
-          require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1,
-            "selectionThreshold needs to be in the range of 0 to 1 for selection mode fdr")
-        }
-      case ("fwe") =>
-        if (!isSet(selectionThreshold)) {
-          set(selectionThreshold, 0.05)
-        } else {
-          require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1,
-            "selectionThreshold needs to be in the range of 0 to 1 for selection mode fwe")
-        }
-      case _ =>
-        throw new IllegalArgumentException(s"Unsupported selection mode:" +
-          s" selectionMode=${$(selectionMode)}")
+    var threshold = Double.NaN
+    if (isSet(selectionThreshold)) {
+      threshold = $(selectionThreshold)
+    } else {
+      $(selectionMode) match {
+        case "numTopFeatures" => threshold = 50
+        case "percentile" => threshold = 0.1
+        case "fpr" | "fdr" | "fwe" => threshold = 0.05
+      }
     }
 
-    require(isSet(featureType) && isSet(labelType), "featureType and labelType need to be set")
     val resultDF = ($(featureType), $(labelType)) match {
       case ("categorical", "categorical") =>
         ChiSquareTest.test(dataset.toDF, getFeaturesCol, getLabelCol, true)
@@ -215,14 +183,12 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
           s" featureType=${$(featureType)}, labelType=${$(labelType)}")
     }
 
-    val indices =
-      selectIndicesFromPValues(numFeatures, resultDF, $(selectionMode), $(selectionThreshold))
-
+    val indices = selectIndicesFromPValues(numFeatures, resultDF, $(selectionMode), threshold)
     copyValues(new UnivariateFeatureSelectorModel(uid, indices)
       .setParent(this))
   }
 
-  def getTopIndices(df: DataFrame, k: Int): Array[Int] = {
+  private def getTopIndices(df: DataFrame, k: Int): Array[Int] = {
     val spark = SparkSession.builder().getOrCreate()
     import spark.implicits._
     df.sort("pValue", "featureIndex")
@@ -232,7 +198,7 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
       .collect()
   }
 
-  def selectIndicesFromPValues(
+  private[feature] def selectIndicesFromPValues(
       numFeatures: Int,
       resultDF: DataFrame,
       selectionMode: String,
@@ -276,6 +242,20 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
 
   @Since("3.1.1")
   override def transformSchema(schema: StructType): StructType = {
+    if (isSet(selectionThreshold)) {
+      val threshold = $(selectionThreshold)
+      $(selectionMode) match {
+        case "numTopFeatures" =>
+          require(threshold >= 1 && threshold.toInt == threshold,
+            s"selectionThreshold needs to be a positive Integer for selection mode " +
+              s"numTopFeatures, but got $threshold")
+        case "percentile" | "fpr" | "fdr" | "fwe" =>
+          require(0 <= threshold && threshold <= 1,
+            s"selectionThreshold needs to be in the range [0, 1] for selection mode " +
+              s"${$(selectionMode)}, but got $threshold")
+      }
+    }
+    require(isSet(featureType) && isSet(labelType), "featureType and labelType need to be set")
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.checkNumericType(schema, $(labelCol))
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org