You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/01/18 04:20:46 UTC
[spark] branch branch-3.1 updated:
[SPARK-34080][ML][PYTHON][FOLLOWUP] Add UnivariateFeatureSelector - make
methods private
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 56f93e5 [SPARK-34080][ML][PYTHON][FOLLOWUP] Add UnivariateFeatureSelector - make methods private
56f93e5 is described below
commit 56f93e56ab731be27a05a299fcbe0ef529f280ba
Author: Ruifeng Zheng <ru...@foxmail.com>
AuthorDate: Mon Jan 18 13:19:59 2021 +0900
[SPARK-34080][ML][PYTHON][FOLLOWUP] Add UnivariateFeatureSelector - make methods private
### What changes were proposed in this pull request?
1, make `getTopIndices`/`selectIndicesFromPValues` private;
2, avoid setting `selectionThreshold` in `fit`
3, move param checking to `transformSchema`
### Why are the changes needed?
`getTopIndices`/`selectIndicesFromPValues` should not be exposed to end users;
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing testsuites
Closes #31222 from zhengruifeng/selector_clean_up.
Authored-by: Ruifeng Zheng <ru...@foxmail.com>
Signed-off-by: HyukjinKwon <gu...@apache.org>
(cherry picked from commit ac322a1ac3be79b5e514f0119275f53b3a40c923)
Signed-off-by: HyukjinKwon <gu...@apache.org>
---
.../ml/feature/UnivariateFeatureSelector.scala | 74 ++++++++--------------
1 file changed, 27 insertions(+), 47 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index 6d5f09e..bfe1d5f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -76,8 +76,7 @@ private[feature] trait UnivariateFeatureSelectorParams extends Params
@Since("3.1.1")
final val selectionMode = new Param[String](this, "selectionMode",
"The selection mode. Supported options: numTopFeatures, percentile, fpr, fdr, fwe",
- ParamValidators.inArray(Array("numTopFeatures", "percentile", "fpr", "fdr",
- "fwe")))
+ ParamValidators.inArray(Array("numTopFeatures", "percentile", "fpr", "fdr", "fwe")))
/** @group getParam */
@Since("3.1.1")
@@ -161,48 +160,17 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
transformSchema(dataset.schema, logging = true)
val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
- $(selectionMode) match {
- case ("numTopFeatures") =>
- if (!isSet(selectionThreshold)) {
- set(selectionThreshold, 50.0)
- } else {
- require($(selectionThreshold) > 0 && $(selectionThreshold).toInt == $(selectionThreshold),
- "selectionThreshold needs to be a positive Integer for selection mode numTopFeatures")
- }
- case ("percentile") =>
- if (!isSet(selectionThreshold)) {
- set(selectionThreshold, 0.1)
- } else {
- require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1,
- "selectionThreshold needs to be in the range of 0 to 1 for selection mode percentile")
- }
- case ("fpr") =>
- if (!isSet(selectionThreshold)) {
- set(selectionThreshold, 0.05)
- } else {
- require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1,
- "selectionThreshold needs to be in the range of 0 to 1 for selection mode fpr")
- }
- case ("fdr") =>
- if (!isSet(selectionThreshold)) {
- set(selectionThreshold, 0.05)
- } else {
- require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1,
- "selectionThreshold needs to be in the range of 0 to 1 for selection mode fdr")
- }
- case ("fwe") =>
- if (!isSet(selectionThreshold)) {
- set(selectionThreshold, 0.05)
- } else {
- require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1,
- "selectionThreshold needs to be in the range of 0 to 1 for selection mode fwe")
- }
- case _ =>
- throw new IllegalArgumentException(s"Unsupported selection mode:" +
- s" selectionMode=${$(selectionMode)}")
+ var threshold = Double.NaN
+ if (isSet(selectionThreshold)) {
+ threshold = $(selectionThreshold)
+ } else {
+ $(selectionMode) match {
+ case "numTopFeatures" => threshold = 50
+ case "percentile" => threshold = 0.1
+ case "fpr" | "fdr" | "fwe" => threshold = 0.05
+ }
}
- require(isSet(featureType) && isSet(labelType), "featureType and labelType need to be set")
val resultDF = ($(featureType), $(labelType)) match {
case ("categorical", "categorical") =>
ChiSquareTest.test(dataset.toDF, getFeaturesCol, getLabelCol, true)
@@ -215,14 +183,12 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
s" featureType=${$(featureType)}, labelType=${$(labelType)}")
}
- val indices =
- selectIndicesFromPValues(numFeatures, resultDF, $(selectionMode), $(selectionThreshold))
-
+ val indices = selectIndicesFromPValues(numFeatures, resultDF, $(selectionMode), threshold)
copyValues(new UnivariateFeatureSelectorModel(uid, indices)
.setParent(this))
}
- def getTopIndices(df: DataFrame, k: Int): Array[Int] = {
+ private def getTopIndices(df: DataFrame, k: Int): Array[Int] = {
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
df.sort("pValue", "featureIndex")
@@ -232,7 +198,7 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
.collect()
}
- def selectIndicesFromPValues(
+ private[feature] def selectIndicesFromPValues(
numFeatures: Int,
resultDF: DataFrame,
selectionMode: String,
@@ -276,6 +242,20 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
@Since("3.1.1")
override def transformSchema(schema: StructType): StructType = {
+ if (isSet(selectionThreshold)) {
+ val threshold = $(selectionThreshold)
+ $(selectionMode) match {
+ case "numTopFeatures" =>
+ require(threshold >= 1 && threshold.toInt == threshold,
+ s"selectionThreshold needs to be a positive Integer for selection mode " +
+ s"numTopFeatures, but got $threshold")
+ case "percentile" | "fpr" | "fdr" | "fwe" =>
+ require(0 <= threshold && threshold <= 1,
+ s"selectionThreshold needs to be in the range [0, 1] for selection mode " +
+ s"${$(selectionMode)}, but got $threshold")
+ }
+ }
+ require(isSet(featureType) && isSet(labelType), "featureType and labelType need to be set")
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkNumericType(schema, $(labelCol))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org