You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2018/09/01 13:41:10 UTC

spark git commit: [SPARK-25289][ML] Avoid exception in ChiSqSelector with FDR when no feature is selected

Repository: spark
Updated Branches:
  refs/heads/master 7c36ee46d -> 6ad8d4c37


[SPARK-25289][ML] Avoid exception in ChiSqSelector with FDR when no feature is selected

## What changes were proposed in this pull request?

Currently, when FDR is used for `ChiSqSelector` and no feature is selected an exception is thrown because the max operation fails.

The PR fixes the problem by handling this case and returning an empty array in that case, as sklearn (which was the reference for the initial implementation of FDR) does.

## How was this patch tested?

added UT

Closes #22303 from mgaido91/SPARK-25289.

Authored-by: Marco Gaido <ma...@gmail.com>
Signed-off-by: Sean Owen <se...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ad8d4c3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ad8d4c3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ad8d4c3

Branch: refs/heads/master
Commit: 6ad8d4c375772c0c907c25837de762b5b9266a8e
Parents: 7c36ee4
Author: Marco Gaido <ma...@gmail.com>
Authored: Sat Sep 1 08:41:07 2018 -0500
Committer: Sean Owen <se...@databricks.com>
Committed: Sat Sep 1 08:41:07 2018 -0500

----------------------------------------------------------------------
 .../org/apache/spark/mllib/feature/ChiSqSelector.scala  | 12 ++++++++----
 .../apache/spark/ml/feature/ChiSqSelectorSuite.scala    | 11 +++++++++++
 2 files changed, 19 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6ad8d4c3/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index f923be8..aa78e91 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -28,6 +28,7 @@ import org.apache.spark.annotation.Since
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.stat.Statistics
+import org.apache.spark.mllib.stat.test.ChiSqTestResult
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Row, SparkSession}
@@ -272,13 +273,16 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
         // https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure
         val tempRes = chiSqTestResult
           .sortBy { case (res, _) => res.pValue }
-        val maxIndex = tempRes
+        val selected = tempRes
           .zipWithIndex
           .filter { case ((res, _), index) =>
             res.pValue <= fdr * (index + 1) / chiSqTestResult.length }
-          .map { case (_, index) => index }
-          .max
-        tempRes.take(maxIndex + 1)
+        if (selected.isEmpty) {
+          Array.empty[(ChiSqTestResult, Int)]
+        } else {
+          val maxIndex = selected.map(_._2).max
+          tempRes.take(maxIndex + 1)
+        }
       case ChiSqSelector.FWE =>
         chiSqTestResult
           .filter { case (res, _) => res.pValue < fwe / chiSqTestResult.length }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ad8d4c3/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index c843df9..80499e7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -163,6 +163,17 @@ class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest {
       }
   }
 
+  test("SPARK-25289: ChiSqSelector should not fail when selecting no features with FDR") {
+    val labeledPoints = (0 to 1).map { n =>
+        val v = Vectors.dense((1 to 3).map(_ => n * 1.0).toArray)
+        (n.toDouble, v)
+      }
+    val inputDF = spark.createDataFrame(labeledPoints).toDF("label", "features")
+    val selector = new ChiSqSelector().setSelectorType("fdr").setFdr(0.05)
+    val model = selector.fit(inputDF)
+    assert(model.selectedFeatures.isEmpty)
+  }
+
   private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = {
     val selectorModel = selector.fit(data)
     testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org