You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2020/03/22 04:45:10 UTC
[spark] branch master updated: [SPARK-31185][ML] Implement
VarianceThresholdSelector
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 307cfe1 [SPARK-31185][ML] Implement VarianceThresholdSelector
307cfe1 is described below
commit 307cfe1f8eacef38fc41a46d94273a5bb6b95fbe
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Sun Mar 22 12:44:18 2020 +0800
[SPARK-31185][ML] Implement VarianceThresholdSelector
### What changes were proposed in this pull request?
Implement a Feature selector that removes all low-variance features. Features with a
variance lower than the threshold will be removed. The default is to keep all features with non-zero variance, i.e. remove the features that have the same value in all samples.
### Why are the changes needed?
VarianceThreshold is a simple baseline approach to feature selection. It removes all features whose variance doesn’t meet some threshold. The idea is when a feature doesn’t vary much within itself, it generally has very little predictive power.
scikit has implemented this selector.
https://scikit-learn.org/stable/modules/feature_selection.html#variance-threshold
### Does this PR introduce any user-facing change?
Yes.
### How was this patch tested?
Add new test suite.
Closes #27954 from huaxingao/variance-threshold.
Authored-by: Huaxin Gao <hu...@us.ibm.com>
Signed-off-by: zhengruifeng <ru...@foxmail.com>
---
.../ml/feature/VarianceThresholdSelector.scala | 274 +++++++++++++++++++++
.../feature/VarianceThresholdSelectorSuite.scala | 137 +++++++++++
2 files changed, 411 insertions(+)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
new file mode 100644
index 0000000..3eac29c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml._
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.stat.Summarizer
+import org.apache.spark.ml.util._
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StructField, StructType}
+
+
+/**
+ * Params for [[VarianceThresholdSelector]] and [[VarianceThresholdSelectorModel]].
+ */
+private[feature] trait VarianceThresholdSelectorParams extends Params
+ with HasFeaturesCol with HasOutputCol {
+
+ /**
+ * Param for variance threshold. Features with a variance not greater than this threshold
+ * will be removed. The default value is 0.0.
+ *
+ * @group param
+ */
+ @Since("3.1.0")
+ final val varianceThreshold = new DoubleParam(this, "varianceThreshold",
+ "Param for variance threshold. Features with a variance not greater than this threshold" +
+ " will be removed. The default value is 0.0.", ParamValidators.gtEq(0))
+ setDefault(varianceThreshold -> 0.0)
+
+ /** @group getParam */
+ @Since("3.1.0")
+ def getVarianceThreshold: Double = $(varianceThreshold)
+}
+
+/**
+ * Feature selector that removes all low-variance features. Features with a
+ * variance not greater than the threshold will be removed. The default is to keep
+ * all features with non-zero variance, i.e. remove the features that have the
+ * same value in all samples.
+ */
+@Since("3.1.0")
+final class VarianceThresholdSelector @Since("3.1.0")(@Since("3.1.0") override val uid: String)
+ extends Estimator[VarianceThresholdSelectorModel] with VarianceThresholdSelectorParams
+with DefaultParamsWritable {
+
+ @Since("3.1.0")
+ def this() = this(Identifiable.randomUID("VarianceThresholdSelector"))
+
+ /** @group setParam */
+ @Since("3.1.0")
+ def setVarianceThreshold(value: Double): this.type = set(varianceThreshold, value)
+
+ /** @group setParam */
+ @Since("3.1.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("3.1.0")
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ @Since("3.1.0")
+ override def fit(dataset: Dataset[_]): VarianceThresholdSelectorModel = {
+ transformSchema(dataset.schema, logging = true)
+
+ val Row(maxs: Vector, mins: Vector, variances: Vector) = dataset
+ .select(Summarizer.metrics("max", "min", "variance").summary(col($(featuresCol)))
+ .as("summary"))
+ .select("summary.max", "summary.min", "summary.variance")
+ .first()
+
+ val numFeatures = maxs.size
+ val indices = Array.tabulate(numFeatures) { i =>
+ // Use peak-to-peak to avoid numeric precision issues for constant features
+ (i, if (maxs(i) == mins(i)) 0.0 else variances(i))
+ }.filter(_._2 > getVarianceThreshold).map(_._1)
+ copyValues(new VarianceThresholdSelectorModel(uid, indices.sorted)
+ .setParent(this))
+ }
+
+ @Since("3.1.0")
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+ }
+
+ @Since("3.1.0")
+ override def copy(extra: ParamMap): VarianceThresholdSelector = defaultCopy(extra)
+}
+
+@Since("3.1.0")
+object VarianceThresholdSelector extends DefaultParamsReadable[VarianceThresholdSelector] {
+
+ @Since("3.1.0")
+ override def load(path: String): VarianceThresholdSelector = super.load(path)
+}
+
+/**
+ * Model fitted by [[VarianceThresholdSelector]].
+ */
+@Since("3.1.0")
+class VarianceThresholdSelectorModel private[ml](
+ @Since("3.1.0") override val uid: String,
+ @Since("3.1.0") val selectedFeatures: Array[Int])
+ extends Model[VarianceThresholdSelectorModel] with VarianceThresholdSelectorParams
+ with MLWritable {
+
+ if (selectedFeatures.length >= 2) {
+ require(selectedFeatures.sliding(2).forall(l => l(0) < l(1)),
+ "Index should be strictly increasing.")
+ }
+
+ /** @group setParam */
+ @Since("3.1.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("3.1.0")
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ @Since("3.1.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
+ val newSize = selectedFeatures.length
+ val func = { vector: Vector =>
+ vector match {
+ case SparseVector(_, indices, values) =>
+ val (newIndices, newValues) = compressSparse(indices, values)
+ Vectors.sparse(newSize, newIndices, newValues)
+ case DenseVector(values) =>
+ Vectors.dense(selectedFeatures.map(values))
+ case other =>
+ throw new UnsupportedOperationException(
+ s"Only sparse and dense vectors are supported but got ${other.getClass}.")
+ }
+ }
+
+ val transformer = udf(func)
+ dataset.withColumn($(outputCol), transformer(col($(featuresCol))),
+ outputSchema($(outputCol)).metadata)
+ }
+
+ @Since("3.1.0")
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ val newField = prepOutputField(schema)
+ SchemaUtils.appendColumn(schema, newField)
+ }
+
+ /**
+ * Prepare the output column field, including per-feature metadata.
+ */
+ private def prepOutputField(schema: StructType): StructField = {
+ val selector = selectedFeatures.toSet
+ val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol)))
+ val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
+ origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
+ } else {
+ Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr)
+ }
+ val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
+ newAttributeGroup.toStructField()
+ }
+
+ @Since("3.1.0")
+ override def copy(extra: ParamMap): VarianceThresholdSelectorModel = {
+ val copied = new VarianceThresholdSelectorModel(uid, selectedFeatures)
+ .setParent(parent)
+ copyValues(copied, extra)
+ }
+
+ @Since("3.1.0")
+ override def write: MLWriter =
+ new VarianceThresholdSelectorModel.VarianceThresholdSelectorWriter(this)
+
+ @Since("3.1.0")
+ override def toString: String = {
+ s"VarianceThresholdSelectorModel: uid=$uid, numSelectedFeatures=${selectedFeatures.length}"
+ }
+
+ private[spark] def compressSparse(
+ indices: Array[Int],
+ values: Array[Double]): (Array[Int], Array[Double]) = {
+ val newValues = new ArrayBuilder.ofDouble
+ val newIndices = new ArrayBuilder.ofInt
+ var i = 0
+ var j = 0
+ while (i < indices.length && j < selectedFeatures.length) {
+ val indicesIdx = indices(i)
+ val filterIndicesIdx = selectedFeatures(j)
+ if (indicesIdx == filterIndicesIdx) {
+ newIndices += j
+ newValues += values(i)
+ j += 1
+ i += 1
+ } else {
+ if (indicesIdx > filterIndicesIdx) {
+ j += 1
+ } else {
+ i += 1
+ }
+ }
+ }
+ // TODO: Sparse representation might be ineffective if (newSize ~= newValues.size)
+ (newIndices.result(), newValues.result())
+ }
+}
+
+@Since("3.1.0")
+object VarianceThresholdSelectorModel extends MLReadable[VarianceThresholdSelectorModel] {
+
+ @Since("3.1.0")
+ override def read: MLReader[VarianceThresholdSelectorModel] =
+ new VarianceThresholdSelectorModelReader
+
+ @Since("3.1.0")
+ override def load(path: String): VarianceThresholdSelectorModel = super.load(path)
+
+ private[VarianceThresholdSelectorModel] class VarianceThresholdSelectorWriter(
+ instance: VarianceThresholdSelectorModel) extends MLWriter {
+
+ private case class Data(selectedFeatures: Seq[Int])
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.selectedFeatures.toSeq)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class VarianceThresholdSelectorModelReader extends
+ MLReader[VarianceThresholdSelectorModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[VarianceThresholdSelectorModel].getName
+
+ override def load(path: String): VarianceThresholdSelectorModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sparkSession.read.parquet(dataPath)
+ .select("selectedFeatures").head()
+ val selectedFeatures = data.getAs[Seq[Int]](0).toArray
+ val model = new VarianceThresholdSelectorModel(metadata.uid, selectedFeatures)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala
new file mode 100644
index 0000000..6cc9803
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.sql.{Dataset, Row}
+
+class VarianceThresholdSelectorSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ @transient var dataset: Dataset[_] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val data = Seq(
+ (1, Vectors.dense(Array(6.0, 7.0, 0.0, 5.0, 6.0, 0.0)),
+ Vectors.dense(Array(6.0, 7.0, 0.0, 6.0, 0.0))),
+ (2, Vectors.dense(Array(0.0, 9.0, 6.0, 5.0, 5.0, 9.0)),
+ Vectors.dense(Array(0.0, 9.0, 6.0, 5.0, 9.0))),
+ (3, Vectors.dense(Array(0.0, 9.0, 3.0, 5.0, 5.0, 5.0)),
+ Vectors.dense(Array(0.0, 9.0, 3.0, 5.0, 5.0))),
+ (4, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+ Vectors.dense(Array(0.0, 9.0, 8.0, 6.0, 4.0))),
+ (5, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+ Vectors.dense(Array(8.0, 9.0, 6.0, 4.0, 4.0))),
+ (6, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 0.0, 0.0)),
+ Vectors.dense(Array(8.0, 9.0, 6.0, 0.0, 0.0))))
+
+ dataset = spark.createDataFrame(data).toDF("id", "features", "expected")
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new VarianceThresholdSelector)
+ }
+
+ test("Test VarianceThresholdSelector: varainceThreshold not set") {
+ val selector = new VarianceThresholdSelector().setOutputCol("filtered")
+ testSelector(selector, dataset)
+ }
+
+ test("Test VarianceThresholdSelector: set varianceThreshold") {
+ val df = spark.createDataFrame(Seq(
+ (1, Vectors.dense(Array(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
+ Vectors.dense(Array(6.0, 7.0, 0.0))),
+ (2, Vectors.dense(Array(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)),
+ Vectors.dense(Array(0.0, 0.0, 9.0))),
+ (3, Vectors.dense(Array(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
+ Vectors.dense(Array(0.0, 0.0, 5.0))),
+ (4, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+ Vectors.dense(Array(0.0, 5.0, 4.0))),
+ (5, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+ Vectors.dense(Array(8.0, 5.0, 4.0))),
+ (6, Vectors.dense(Array(8.0, 9.0, 6.0, 0.0, 0.0, 0.0)),
+ Vectors.dense(Array(8.0, 0.0, 0.0)))
+ )).toDF("id", "features", "expected")
+ val selector = new VarianceThresholdSelector()
+ .setVarianceThreshold(8.2)
+ .setOutputCol("filtered")
+ testSelector(selector, df)
+ }
+
+ test("Test VarianceThresholdSelector: sparse vector") {
+ val df = spark.createDataFrame(Seq(
+ (1, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0))),
+ Vectors.dense(Array(6.0, 0.0, 7.0, 0.0))),
+ (2, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0))),
+ Vectors.dense(Array(0.0, 6.0, 0.0, 9.0))),
+ (3, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0))),
+ Vectors.dense(Array(0.0, 3.0, 0.0, 5.0))),
+ (4, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+ Vectors.dense(Array(0.0, 8.0, 5.0, 4.0))),
+ (5, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+ Vectors.dense(Array(8.0, 6.0, 5.0, 4.0))),
+ (6, Vectors.dense(Array(8.0, 9.0, 6.0, 4.0, 0.0, 0.0)),
+ Vectors.dense(Array(8.0, 6.0, 4.0, 0.0)))
+ )).toDF("id", "features", "expected")
+ val selector = new VarianceThresholdSelector()
+ .setVarianceThreshold(8.1)
+ .setOutputCol("filtered")
+ testSelector(selector, df)
+ }
+
+ test("read/write") {
+ def checkModelData(model: VarianceThresholdSelectorModel, model2:
+ VarianceThresholdSelectorModel): Unit = {
+ assert(model.selectedFeatures === model2.selectedFeatures)
+ }
+ val varSelector = new VarianceThresholdSelector
+ testEstimatorAndModelReadWrite(varSelector, dataset,
+ VarianceThresholdSelectorSuite.allParamSettings,
+ VarianceThresholdSelectorSuite.allParamSettings, checkModelData)
+ }
+
+ private def testSelector(selector: VarianceThresholdSelector, data: Dataset[_]):
+ VarianceThresholdSelectorModel = {
+ val selectorModel = selector.fit(data)
+ testTransformer[(Int, Vector, Vector)](data.toDF(), selectorModel,
+ "filtered", "expected") {
+ case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 ~== vec2 absTol 1e-6)
+ }
+ selectorModel
+ }
+}
+
+object VarianceThresholdSelectorSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "varianceThreshold" -> 0.12,
+ "outputCol" -> "myOutput"
+ )
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org