You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/05/22 03:04:52 UTC

spark git commit: [SPARK-7219] [MLLIB] Output feature attributes in HashingTF

Repository: spark
Updated Branches:
  refs/heads/master f5db4b416 -> 85b96372c


[SPARK-7219] [MLLIB] Output feature attributes in HashingTF

This PR updates `HashingTF` to output ML attributes that tell the number of features in the output column. We need to expand `UnaryTransformer` to support output metadata. A `df outputMetadata: Metadata` is not sufficient because the metadata may also depends on the input data. Though this is not true for `HashingTF`, I think it is reasonable to update `UnaryTransformer` in a separate PR. `checkParams` is added to verify common requirements for params. I will send a separate PR to use it in other test suites. jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #6308 from mengxr/SPARK-7219 and squashes the following commits:

9bd2922 [Xiangrui Meng] address comments
e82a68a [Xiangrui Meng] remove sqlContext from test suite
995535b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7219
2194703 [Xiangrui Meng] add test for attributes
178ae23 [Xiangrui Meng] update HashingTF with tests
91a6106 [Xiangrui Meng] WIP


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/85b96372
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/85b96372
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/85b96372

Branch: refs/heads/master
Commit: 85b96372cf0fd055f89fc639f45c1f2cb02a378f
Parents: f5db4b4
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu May 21 18:04:45 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu May 21 18:04:45 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/HashingTF.scala | 34 +++++++++---
 .../spark/ml/feature/HashingTFSuite.scala       | 55 ++++++++++++++++++++
 .../org/apache/spark/ml/param/ParamsSuite.scala | 20 +++++++
 3 files changed, 101 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/85b96372/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 30033ce..8942d45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -18,22 +18,31 @@
 package org.apache.spark.ml.feature
 
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.param.{IntParam, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
 import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.{udf, col}
+import org.apache.spark.sql.types.{ArrayType, StructType}
 
 /**
  * :: AlphaComponent ::
  * Maps a sequence of terms to their term frequencies using the hashing trick.
  */
 @AlphaComponent
-class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
+class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
 
   def this() = this(Identifiable.randomUID("hashingTF"))
 
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
   /**
    * Number of features.  Should be > 0.
    * (default = 2^18^)
@@ -50,10 +59,19 @@ class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_],
   /** @group setParam */
   def setNumFeatures(value: Int): this.type = set(numFeatures, value)
 
-  override protected def createTransformFunc: Iterable[_] => Vector = {
+  override def transform(dataset: DataFrame): DataFrame = {
+    val outputSchema = transformSchema(dataset.schema)
     val hashingTF = new feature.HashingTF($(numFeatures))
-    hashingTF.transform
+    val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
+    val metadata = outputSchema($(outputCol)).metadata
+    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
   }
 
-  override protected def outputDataType: DataType = new VectorUDT()
+  override def transformSchema(schema: StructType): StructType = {
+    val inputType = schema($(inputCol)).dataType
+    require(inputType.isInstanceOf[ArrayType],
+      s"The input column must be ArrayType, but got $inputType.")
+    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
+    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/85b96372/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
new file mode 100644
index 0000000..2e4beb0
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
+class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
+
+  test("params") {
+    val hashingTF = new HashingTF
+    ParamsSuite.checkParams(hashingTF, 3)
+  }
+
+  test("hashingTF") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, "a a b b c d".split(" ").toSeq)
+    )).toDF("id", "words")
+    val n = 100
+    val hashingTF = new HashingTF()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setNumFeatures(n)
+    val output = hashingTF.transform(df)
+    val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
+    require(attrGroup.numAttributes === Some(n))
+    val features = output.select("features").first().getAs[Vector](0)
+    // Assume perfect hash on "a", "b", "c", and "d".
+    def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
+    val expected = Vectors.sparse(n,
+      Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
+    assert(features ~== expected absTol 1e-14)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/85b96372/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index b96874f..d270ad7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -201,3 +201,23 @@ class ParamsSuite extends FunSuite {
     assert(inArray(1) && inArray(2) && !inArray(0))
   }
 }
+
+object ParamsSuite extends FunSuite {
+
+  /**
+   * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
+   * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as
+   * the param method name.
+   */
+  def checkParams(obj: Params, expectedNumParams: Int): Unit = {
+    val params = obj.params
+    require(params.length === expectedNumParams,
+      s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.")
+    val paramNames = params.map(_.name)
+    require(paramNames === paramNames.sorted)
+    params.foreach { p =>
+      assert(p.parent === obj.uid)
+      assert(obj.getParam(p.name) === p)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org