You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/09 19:26:42 UTC

spark git commit: [SPARK-8703] [ML] Add CountVectorizer as a ml transformer to convert document to words count vector

Repository: spark
Updated Branches:
  refs/heads/master c59e268d1 -> 0cd84c86c


[SPARK-8703] [ML] Add CountVectorizer as a ml transformer to convert document to words count vector

jira: https://issues.apache.org/jira/browse/SPARK-8703

Converts a text document to a sparse vector of token counts.

I can further add an estimator to extract vocabulary from corpus if that's appropriate.

Author: Yuhao Yang <hh...@gmail.com>

Closes #7084 from hhbyyh/countVectorization and squashes the following commits:

5f3f655 [Yuhao Yang] text change
24728e4 [Yuhao Yang] style improvement
576728a [Yuhao Yang] rename to model and some fix
1deca28 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into countVectorization
99b0c14 [Yuhao Yang] undo extension from HashingTF
12c2dc8 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into countVectorization
7ee1c31 [Yuhao Yang] extends HashingTF
809fb59 [Yuhao Yang] minor fix for ut
7c61fb3 [Yuhao Yang] add countVectorizer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0cd84c86
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0cd84c86
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0cd84c86

Branch: refs/heads/master
Commit: 0cd84c86cac68600a74d84e50ad40c0c8b84822a
Parents: c59e268
Author: Yuhao Yang <hh...@gmail.com>
Authored: Thu Jul 9 10:26:38 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Jul 9 10:26:38 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/feature/CountVectorizerModel.scala | 82 ++++++++++++++++++++
 .../spark/ml/feature/CountVectorizorSuite.scala | 73 +++++++++++++++++
 2 files changed, 155 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0cd84c86/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
new file mode 100644
index 0000000..6b77de8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
+import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
+
+/**
+ * :: Experimental ::
+ * Converts a text document to a sparse vector of token counts.
+ * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
+ */
+@Experimental
+class CountVectorizerModel (override val uid: String, val vocabulary: Array[String])
+  extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] {
+
+  def this(vocabulary: Array[String]) =
+    this(Identifiable.randomUID("cntVec"), vocabulary)
+
+  /**
+   * Corpus-specific filter to ignore scarce words in a document. For each document, terms with
+   * frequency (count) less than the given threshold are ignored.
+   * Default: 1
+   * @group param
+   */
+  val minTermFreq: IntParam = new IntParam(this, "minTermFreq",
+    "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " +
+      "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1))
+
+  /** @group setParam */
+  def setMinTermFreq(value: Int): this.type = set(minTermFreq, value)
+
+  /** @group getParam */
+  def getMinTermFreq: Int = $(minTermFreq)
+
+  setDefault(minTermFreq -> 1)
+
+  override protected def createTransformFunc: Seq[String] => Vector = {
+    val dict = vocabulary.zipWithIndex.toMap
+    document =>
+      val termCounts = mutable.HashMap.empty[Int, Double]
+      document.foreach { term =>
+        dict.get(term) match {
+          case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
+          case None => // ignore terms not in the vocabulary
+        }
+      }
+      Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq)
+  }
+
+  override protected def validateInputType(inputType: DataType): Unit = {
+    require(inputType.sameType(ArrayType(StringType)),
+      s"Input type must be ArrayType(StringType) but got $inputType.")
+  }
+
+  override protected def outputDataType: DataType = new VectorUDT()
+
+  override def copy(extra: ParamMap): CountVectorizerModel = {
+    val copied = new CountVectorizerModel(uid, vocabulary)
+    copyValues(copied, extra)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0cd84c86/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
new file mode 100644
index 0000000..e90d9d4
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("params") {
+    ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
+  }
+
+  test("CountVectorizerModel common cases") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, "a b c d".split(" ").toSeq,
+        Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
+      (1, "a b b c d  a".split(" ").toSeq,
+        Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
+      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))),
+      (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string
+      (4, "a notInDict d".split(" ").toSeq,
+        Vectors.sparse(4, Seq((0, 1.0), (3, 1.0))))  // with words not in vocabulary
+    )).toDF("id", "words", "expected")
+    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+      .setInputCol("words")
+      .setOutputCol("features")
+    val output = cv.transform(df).collect()
+    output.foreach { p =>
+      val features = p.getAs[Vector]("features")
+      val expected = p.getAs[Vector]("expected")
+      assert(features ~== expected absTol 1e-14)
+    }
+  }
+
+  test("CountVectorizerModel with minTermFreq") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
+      (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))),
+      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())),
+      (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq())))
+    ).toDF("id", "words", "expected")
+    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMinTermFreq(3)
+    val output = cv.transform(df).collect()
+    output.foreach { p =>
+      val features = p.getAs[Vector]("features")
+      val expected = p.getAs[Vector]("expected")
+      assert(features ~== expected absTol 1e-14)
+    }
+  }
+}
+
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org