You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/10/01 20:10:54 UTC

spark git commit: [SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.

Repository: spark
Updated Branches:
  refs/heads/master af6ece33d -> b88cb63da


[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.

## What changes were proposed in this pull request?

Partial revert of #15277 to instead sort and store input to model rather than require sorted input

## How was this patch tested?

Existing tests.

Author: Sean Owen <so...@cloudera.com>

Closes #15299 from srowen/SPARK-17704.2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b88cb63d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b88cb63d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b88cb63d

Branch: refs/heads/master
Commit: b88cb63da39786c07cb4bfa70afed32ec5eb3286
Parents: af6ece3
Author: Sean Owen <so...@cloudera.com>
Authored: Sat Oct 1 16:10:39 2016 -0400
Committer: Sean Owen <so...@cloudera.com>
Committed: Sat Oct 1 16:10:39 2016 -0400

----------------------------------------------------------------------
 .../apache/spark/ml/feature/ChiSqSelector.scala |  2 +-
 .../spark/mllib/feature/ChiSqSelector.scala     | 22 ++++++++++----------
 python/pyspark/ml/feature.py                    |  2 +-
 3 files changed, 13 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b88cb63d/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 9c131a4..d0385e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] (
 
   import ChiSqSelectorModel._
 
-  /** list of indices to select (filter). Must be ordered asc */
+  /** list of indices to select (filter). */
   @Since("1.6.0")
   val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b88cb63d/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 706ce78..c305b36 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession}
 /**
  * Chi Squared selector model.
  *
- * @param selectedFeatures list of indices to select (filter). Must be ordered asc
+ * @param selectedFeatures list of indices to select (filter).
  */
 @Since("1.3.0")
 class ChiSqSelectorModel @Since("1.3.0") (
   @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
 
-  require(isSorted(selectedFeatures), "Array has to be sorted asc")
+  private val filterIndices = selectedFeatures.sorted
 
+  @deprecated("not intended for subclasses to use", "2.1.0")
   protected def isSorted(array: Array[Int]): Boolean = {
     var i = 1
     val len = array.length
@@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
    */
   @Since("1.3.0")
   override def transform(vector: Vector): Vector = {
-    compress(vector, selectedFeatures)
+    compress(vector)
   }
 
   /**
@@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") (
    * Preserves the order of filtered features the same as their indices are stored.
    * Might be moved to Vector as .slice
    * @param features vector
-   * @param filterIndices indices of features to filter, must be ordered asc
    */
-  private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
+  private def compress(features: Vector): Vector = {
     features match {
       case SparseVector(size, indices, values) =>
         val newSize = filterIndices.length
@@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
    */
   @Since("1.3.0")
   def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
-    val chiSqTestResult = Statistics.chiSqTest(data)
+    val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
     val features = selectorType match {
       case ChiSqSelector.KBest =>
-        chiSqTestResult.zipWithIndex
+        chiSqTestResult
           .sortBy { case (res, _) => -res.statistic }
           .take(numTopFeatures)
       case ChiSqSelector.Percentile =>
-        chiSqTestResult.zipWithIndex
+        chiSqTestResult
           .sortBy { case (res, _) => -res.statistic }
           .take((chiSqTestResult.length * percentile).toInt)
       case ChiSqSelector.FPR =>
-        chiSqTestResult.zipWithIndex
-          .filter{ case (res, _) => res.pValue < alpha }
+        chiSqTestResult
+          .filter { case (res, _) => res.pValue < alpha }
       case errorType =>
         throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
     }
-    val indices = features.map { case (_, indices) => indices }.sorted
+    val indices = features.map { case (_, index) => index }
     new ChiSqSelectorModel(indices)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b88cb63d/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 12a1384..64b21ca 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
     @since("2.0.0")
     def selectedFeatures(self):
         """
-        List of indices to select (filter). Must be ordered asc.
+        List of indices to select (filter).
         """
         return self._call_java("selectedFeatures")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org