You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2020/03/22 04:45:10 UTC

[spark] branch master updated: [SPARK-31185][ML] Implement VarianceThresholdSelector

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 307cfe1  [SPARK-31185][ML] Implement VarianceThresholdSelector
307cfe1 is described below

commit 307cfe1f8eacef38fc41a46d94273a5bb6b95fbe
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Sun Mar 22 12:44:18 2020 +0800

    [SPARK-31185][ML] Implement VarianceThresholdSelector
    
    ### What changes were proposed in this pull request?
    Implement a Feature selector that removes all low-variance features. Features with a
    variance lower than the threshold will be removed. The default is to keep all features with non-zero variance, i.e. remove the features that have the same value in all samples.
    
    ### Why are the changes needed?
    VarianceThreshold is a simple baseline approach to feature selection. It removes all features whose variance doesn’t meet some threshold. The idea is when a feature doesn’t vary much within itself, it generally has very little predictive power.
    scikit has implemented this selector.
    https://scikit-learn.org/stable/modules/feature_selection.html#variance-threshold
    
    ### Does this PR introduce any user-facing change?
    Yes.
    
    ### How was this patch tested?
    Add new test suite.
    
    Closes #27954 from huaxingao/variance-threshold.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: zhengruifeng <ru...@foxmail.com>
---
 .../ml/feature/VarianceThresholdSelector.scala     | 274 +++++++++++++++++++++
 .../feature/VarianceThresholdSelectorSuite.scala   | 137 +++++++++++
 2 files changed, 411 insertions(+)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
new file mode 100644
index 0000000..3eac29c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml._
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.stat.Summarizer
+import org.apache.spark.ml.util._
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StructField, StructType}
+
+
+/**
+ * Params for [[VarianceThresholdSelector]] and [[VarianceThresholdSelectorModel]].
+ */
+private[feature] trait VarianceThresholdSelectorParams extends Params
+  with HasFeaturesCol with HasOutputCol {
+
+  /**
+   * Param for variance threshold. Features with a variance not greater than this threshold
+   * will be removed. The default value is 0.0.
+   *
+   * @group param
+   */
+  @Since("3.1.0")
+  final val varianceThreshold = new DoubleParam(this, "varianceThreshold",
+    "Param for variance threshold. Features with a variance not greater than this threshold" +
+      " will be removed. The default value is 0.0.", ParamValidators.gtEq(0))
+  setDefault(varianceThreshold -> 0.0)
+
+  /** @group getParam */
+  @Since("3.1.0")
+  def getVarianceThreshold: Double = $(varianceThreshold)
+}
+
+/**
+ * Feature selector that removes all low-variance features. Features with a
+ * variance not greater than the threshold will be removed. The default is to keep
+ * all features with non-zero variance, i.e. remove the features that have the
+ * same value in all samples.
+ */
+@Since("3.1.0")
+final class VarianceThresholdSelector @Since("3.1.0")(@Since("3.1.0") override val uid: String)
+  extends Estimator[VarianceThresholdSelectorModel] with VarianceThresholdSelectorParams
+with DefaultParamsWritable {
+
+  @Since("3.1.0")
+  def this() = this(Identifiable.randomUID("VarianceThresholdSelector"))
+
+  /** @group setParam */
+  @Since("3.1.0")
+  def setVarianceThreshold(value: Double): this.type = set(varianceThreshold, value)
+
+  /** @group setParam */
+  @Since("3.1.0")
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  @Since("3.1.0")
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  @Since("3.1.0")
+  override def fit(dataset: Dataset[_]): VarianceThresholdSelectorModel = {
+    transformSchema(dataset.schema, logging = true)
+
+    val Row(maxs: Vector, mins: Vector, variances: Vector) = dataset
+      .select(Summarizer.metrics("max", "min", "variance").summary(col($(featuresCol)))
+        .as("summary"))
+      .select("summary.max", "summary.min", "summary.variance")
+      .first()
+
+    val numFeatures = maxs.size
+    val indices = Array.tabulate(numFeatures) { i =>
+      // Use peak-to-peak to avoid numeric precision issues for constant features
+      (i, if (maxs(i) == mins(i)) 0.0 else variances(i))
+    }.filter(_._2 > getVarianceThreshold).map(_._1)
+    copyValues(new VarianceThresholdSelectorModel(uid, indices.sorted)
+      .setParent(this))
+  }
+
+  @Since("3.1.0")
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+  }
+
+  @Since("3.1.0")
+  override def copy(extra: ParamMap): VarianceThresholdSelector = defaultCopy(extra)
+}
+
+@Since("3.1.0")
+object VarianceThresholdSelector extends DefaultParamsReadable[VarianceThresholdSelector] {
+
+  @Since("3.1.0")
+  override def load(path: String): VarianceThresholdSelector = super.load(path)
+}
+
+/**
+ * Model fitted by [[VarianceThresholdSelector]].
+ */
+@Since("3.1.0")
+class VarianceThresholdSelectorModel private[ml](
+    @Since("3.1.0") override val uid: String,
+    @Since("3.1.0") val selectedFeatures: Array[Int])
+  extends Model[VarianceThresholdSelectorModel] with VarianceThresholdSelectorParams
+    with MLWritable {
+
+  if (selectedFeatures.length >= 2) {
+    require(selectedFeatures.sliding(2).forall(l => l(0) < l(1)),
+      "Index should be strictly increasing.")
+  }
+
+  /** @group setParam */
+  @Since("3.1.0")
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  @Since("3.1.0")
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  @Since("3.1.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
+    val outputSchema = transformSchema(dataset.schema, logging = true)
+
+    val newSize = selectedFeatures.length
+    val func = { vector: Vector =>
+      vector match {
+        case SparseVector(_, indices, values) =>
+          val (newIndices, newValues) = compressSparse(indices, values)
+          Vectors.sparse(newSize, newIndices, newValues)
+        case DenseVector(values) =>
+          Vectors.dense(selectedFeatures.map(values))
+        case other =>
+          throw new UnsupportedOperationException(
+            s"Only sparse and dense vectors are supported but got ${other.getClass}.")
+      }
+    }
+
+    val transformer = udf(func)
+    dataset.withColumn($(outputCol), transformer(col($(featuresCol))),
+      outputSchema($(outputCol)).metadata)
+  }
+
+  @Since("3.1.0")
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    val newField = prepOutputField(schema)
+    SchemaUtils.appendColumn(schema, newField)
+  }
+
+  /**
+   * Prepare the output column field, including per-feature metadata.
+   */
+  private def prepOutputField(schema: StructType): StructField = {
+    val selector = selectedFeatures.toSet
+    val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol)))
+    val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
+      origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
+    } else {
+      Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr)
+    }
+    val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
+    newAttributeGroup.toStructField()
+  }
+
+  @Since("3.1.0")
+  override def copy(extra: ParamMap): VarianceThresholdSelectorModel = {
+    val copied = new VarianceThresholdSelectorModel(uid, selectedFeatures)
+      .setParent(parent)
+    copyValues(copied, extra)
+  }
+
+  @Since("3.1.0")
+  override def write: MLWriter =
+    new VarianceThresholdSelectorModel.VarianceThresholdSelectorWriter(this)
+
+  @Since("3.1.0")
+  override def toString: String = {
+    s"VarianceThresholdSelectorModel: uid=$uid, numSelectedFeatures=${selectedFeatures.length}"
+  }
+
+  private[spark] def compressSparse(
+      indices: Array[Int],
+      values: Array[Double]): (Array[Int], Array[Double]) = {
+    val newValues = new ArrayBuilder.ofDouble
+    val newIndices = new ArrayBuilder.ofInt
+    var i = 0
+    var j = 0
+    while (i < indices.length && j < selectedFeatures.length) {
+      val indicesIdx = indices(i)
+      val filterIndicesIdx = selectedFeatures(j)
+      if (indicesIdx == filterIndicesIdx) {
+        newIndices += j
+        newValues += values(i)
+        j += 1
+        i += 1
+      } else {
+        if (indicesIdx > filterIndicesIdx) {
+          j += 1
+        } else {
+          i += 1
+        }
+      }
+    }
+    // TODO: Sparse representation might be ineffective if (newSize ~= newValues.size)
+    (newIndices.result(), newValues.result())
+  }
+}
+
+@Since("3.1.0")
+object VarianceThresholdSelectorModel extends MLReadable[VarianceThresholdSelectorModel] {
+
+  @Since("3.1.0")
+  override def read: MLReader[VarianceThresholdSelectorModel] =
+    new VarianceThresholdSelectorModelReader
+
+  @Since("3.1.0")
+  override def load(path: String): VarianceThresholdSelectorModel = super.load(path)
+
+  private[VarianceThresholdSelectorModel] class VarianceThresholdSelectorWriter(
+      instance: VarianceThresholdSelectorModel) extends MLWriter {
+
+    private case class Data(selectedFeatures: Seq[Int])
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.selectedFeatures.toSeq)
+      val dataPath = new Path(path, "data").toString
+      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class VarianceThresholdSelectorModelReader extends
+    MLReader[VarianceThresholdSelectorModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[VarianceThresholdSelectorModel].getName
+
+    override def load(path: String): VarianceThresholdSelectorModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sparkSession.read.parquet(dataPath)
+        .select("selectedFeatures").head()
+      val selectedFeatures = data.getAs[Seq[Int]](0).toArray
+      val model = new VarianceThresholdSelectorModel(metadata.uid, selectedFeatures)
+      metadata.getAndSetParams(model)
+      model
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala
new file mode 100644
index 0000000..6cc9803
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.sql.{Dataset, Row}
+
+class VarianceThresholdSelectorSuite extends MLTest with DefaultReadWriteTest {
+
+  import testImplicits._
+
+  @transient var dataset: Dataset[_] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+
+    val data = Seq(
+      (1, Vectors.dense(Array(6.0, 7.0, 0.0, 5.0, 6.0, 0.0)),
+        Vectors.dense(Array(6.0, 7.0, 0.0, 6.0, 0.0))),
+      (2, Vectors.dense(Array(0.0, 9.0, 6.0, 5.0, 5.0, 9.0)),
+        Vectors.dense(Array(0.0, 9.0, 6.0, 5.0, 9.0))),
+      (3, Vectors.dense(Array(0.0, 9.0, 3.0, 5.0, 5.0, 5.0)),
+        Vectors.dense(Array(0.0, 9.0, 3.0, 5.0, 5.0))),
+      (4, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+        Vectors.dense(Array(0.0, 9.0, 8.0, 6.0, 4.0))),
+      (5, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+        Vectors.dense(Array(8.0, 9.0, 6.0, 4.0, 4.0))),
+      (6, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 0.0, 0.0)),
+        Vectors.dense(Array(8.0, 9.0, 6.0, 0.0, 0.0))))
+
+    dataset = spark.createDataFrame(data).toDF("id", "features", "expected")
+  }
+
+  test("params") {
+    ParamsSuite.checkParams(new VarianceThresholdSelector)
+  }
+
+  test("Test VarianceThresholdSelector: varainceThreshold not set") {
+    val selector = new VarianceThresholdSelector().setOutputCol("filtered")
+    testSelector(selector, dataset)
+  }
+
+  test("Test VarianceThresholdSelector: set varianceThreshold") {
+    val df = spark.createDataFrame(Seq(
+      (1, Vectors.dense(Array(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
+        Vectors.dense(Array(6.0, 7.0, 0.0))),
+      (2, Vectors.dense(Array(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)),
+        Vectors.dense(Array(0.0, 0.0, 9.0))),
+      (3, Vectors.dense(Array(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
+        Vectors.dense(Array(0.0, 0.0, 5.0))),
+      (4, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+        Vectors.dense(Array(0.0, 5.0, 4.0))),
+      (5, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+        Vectors.dense(Array(8.0, 5.0, 4.0))),
+      (6, Vectors.dense(Array(8.0, 9.0, 6.0, 0.0, 0.0, 0.0)),
+        Vectors.dense(Array(8.0, 0.0, 0.0)))
+    )).toDF("id", "features", "expected")
+    val selector = new VarianceThresholdSelector()
+      .setVarianceThreshold(8.2)
+      .setOutputCol("filtered")
+    testSelector(selector, df)
+  }
+
+  test("Test VarianceThresholdSelector: sparse vector") {
+    val df = spark.createDataFrame(Seq(
+      (1, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0))),
+        Vectors.dense(Array(6.0, 0.0, 7.0, 0.0))),
+      (2, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0))),
+        Vectors.dense(Array(0.0, 6.0, 0.0, 9.0))),
+      (3, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0))),
+        Vectors.dense(Array(0.0, 3.0, 0.0, 5.0))),
+      (4, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+        Vectors.dense(Array(0.0, 8.0, 5.0, 4.0))),
+      (5, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+        Vectors.dense(Array(8.0, 6.0, 5.0, 4.0))),
+      (6, Vectors.dense(Array(8.0, 9.0, 6.0, 4.0, 0.0, 0.0)),
+        Vectors.dense(Array(8.0, 6.0, 4.0, 0.0)))
+    )).toDF("id", "features", "expected")
+    val selector = new VarianceThresholdSelector()
+      .setVarianceThreshold(8.1)
+      .setOutputCol("filtered")
+    testSelector(selector, df)
+  }
+
+  test("read/write") {
+    def checkModelData(model: VarianceThresholdSelectorModel, model2:
+    VarianceThresholdSelectorModel): Unit = {
+      assert(model.selectedFeatures === model2.selectedFeatures)
+    }
+    val varSelector = new VarianceThresholdSelector
+    testEstimatorAndModelReadWrite(varSelector, dataset,
+      VarianceThresholdSelectorSuite.allParamSettings,
+      VarianceThresholdSelectorSuite.allParamSettings, checkModelData)
+  }
+
+  private def testSelector(selector: VarianceThresholdSelector, data: Dataset[_]):
+    VarianceThresholdSelectorModel = {
+    val selectorModel = selector.fit(data)
+    testTransformer[(Int, Vector, Vector)](data.toDF(), selectorModel,
+      "filtered", "expected") {
+      case Row(vec1: Vector, vec2: Vector) =>
+        assert(vec1 ~== vec2 absTol 1e-6)
+    }
+    selectorModel
+  }
+}
+
+object VarianceThresholdSelectorSuite {
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "varianceThreshold" -> 0.12,
+    "outputCol" -> "myOutput"
+  )
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org