You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/10/02 19:26:04 UTC

spark git commit: [SPARK-6530] [ML] Add chi-square selector for ml package

Repository: spark
Updated Branches:
  refs/heads/master 23a9448c0 -> 633aaae0a


[SPARK-6530] [ML] Add chi-square selector for ml package

See JIRA [here](https://issues.apache.org/jira/browse/SPARK-6530).

Author: Xusen Yin <yi...@gmail.com>

Closes #5742 from yinxusen/SPARK-6530.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/633aaae0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/633aaae0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/633aaae0

Branch: refs/heads/master
Commit: 633aaae0a1e31e9ba634423840e350b22342c6b5
Parents: 23a9448
Author: Xusen Yin <yi...@gmail.com>
Authored: Fri Oct 2 10:25:58 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Oct 2 10:25:58 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/ChiSqSelector.scala | 150 +++++++++++++++++++
 .../spark/mllib/feature/ChiSqSelector.scala     |   2 +
 .../spark/ml/feature/ChiSqSelectorSuite.scala   |  61 ++++++++
 3 files changed, 213 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/633aaae0/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
new file mode 100644
index 0000000..5e4061f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml._
+import org.apache.spark.ml.attribute.{AttributeGroup, _}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+
+/**
+ * Params for [[ChiSqSelector]] and [[ChiSqSelectorModel]].
+ */
+private[feature] trait ChiSqSelectorParams extends Params
+  with HasFeaturesCol with HasOutputCol with HasLabelCol {
+
+  /**
+   * Number of features that selector will select (ordered by statistic value descending). If the
+   * number of features is < numTopFeatures, then this will select all features. The default value
+   * of numTopFeatures is 50.
+   * @group param
+   */
+  final val numTopFeatures = new IntParam(this, "numTopFeatures",
+    "Number of features that selector will select, ordered by statistics value descending. If the" +
+      " number of features is < numTopFeatures, then this will select all features.",
+    ParamValidators.gtEq(1))
+  setDefault(numTopFeatures -> 50)
+
+  /** @group getParam */
+  def getNumTopFeatures: Int = $(numTopFeatures)
+}
+
+/**
+ * :: Experimental ::
+ * Chi-Squared feature selection, which selects categorical features to use for predicting a
+ * categorical label.
+ */
+@Experimental
+final class ChiSqSelector(override val uid: String)
+  extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams {
+
+  def this() = this(Identifiable.randomUID("chiSqSelector"))
+
+  /** @group setParam */
+  def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
+
+  /** @group setParam */
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /** @group setParam */
+  def setLabelCol(value: String): this.type = set(labelCol, value)
+
+  override def fit(dataset: DataFrame): ChiSqSelectorModel = {
+    transformSchema(dataset.schema, logging = true)
+    val input = dataset.select($(labelCol), $(featuresCol)).map {
+      case Row(label: Double, features: Vector) =>
+        LabeledPoint(label, features)
+    }
+    val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input)
+    copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+    SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+  }
+
+  override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by [[ChiSqSelector]].
+ */
+@Experimental
+final class ChiSqSelectorModel private[ml] (
+    override val uid: String,
+    private val chiSqSelector: feature.ChiSqSelectorModel)
+  extends Model[ChiSqSelectorModel] with ChiSqSelectorParams {
+
+  /** @group setParam */
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /** @group setParam */
+  def setLabelCol(value: String): this.type = set(labelCol, value)
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    val transformedSchema = transformSchema(dataset.schema, logging = true)
+    val newField = transformedSchema.last
+    val selector = udf { chiSqSelector.transform _ }
+    dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata)
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    val newField = prepOutputField(schema)
+    val outputFields = schema.fields :+ newField
+    StructType(outputFields)
+  }
+
+  /**
+   * Prepare the output column field, including per-feature metadata.
+   */
+  private def prepOutputField(schema: StructType): StructField = {
+    val selector = chiSqSelector.selectedFeatures.toSet
+    val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol)))
+    val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
+      origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
+    } else {
+      Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr)
+    }
+    val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
+    newAttributeGroup.toStructField()
+  }
+
+  override def copy(extra: ParamMap): ChiSqSelectorModel = {
+    val copied = new ChiSqSelectorModel(uid, chiSqSelector)
+    copyValues(copied, extra).setParent(parent)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/633aaae0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 4743cfd..b1524cf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -109,6 +109,8 @@ class ChiSqSelectorModel @Since("1.3.0") (
  * Creates a ChiSquared feature selector.
  * @param numTopFeatures number of features that selector will select
  *                       (ordered by statistic value descending)
+ *                       Note that if the number of features is < numTopFeatures, then this will
+ *                       select all features.
  */
 @Since("1.3.0")
 @Experimental

http://git-wip-us.apache.org/repos/asf/spark/blob/633aaae0/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
new file mode 100644
index 0000000..e5a4296
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{Row, SQLContext}
+
+class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
+  test("Test Chi-Square selector") {
+    val sqlContext = SQLContext.getOrCreate(sc)
+    import sqlContext.implicits._
+
+    val data = Seq(
+      LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
+      LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
+      LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
+      LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
+    )
+
+    val preFilteredData = Seq(
+      Vectors.dense(0.0),
+      Vectors.dense(6.0),
+      Vectors.dense(8.0),
+      Vectors.dense(5.0)
+    )
+
+    val df = sc.parallelize(data.zip(preFilteredData))
+      .map(x => (x._1.label, x._1.features, x._2))
+      .toDF("label", "data", "preFilteredData")
+
+    val model = new ChiSqSelector()
+      .setNumTopFeatures(1)
+      .setFeaturesCol("data")
+      .setLabelCol("label")
+      .setOutputCol("filtered")
+
+    model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
+      case Row(vec1: Vector, vec2: Vector) =>
+        assert(vec1 ~== vec2 absTol 1e-1)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org