You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/10 10:13:28 UTC

spark git commit: [SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as dict in ConutVectorizer

Repository: spark
Updated Branches:
  refs/heads/master 22014e6fb -> f4344582b


[SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as dict in ConutVectorizer

## What changes were proposed in this pull request?

Replace sortBy() with top() to calculate the top N frequent words as dictionary.

## How was this patch tested?
existing unit tests.  The terms with same TF would be sorted in descending order. The test would fail if hardcode the terms with same TF the dictionary like "c", "d"...

Author: fwang1 <de...@gmail.com>

Closes #12265 from lionelfeng/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f4344582
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f4344582
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f4344582

Branch: refs/heads/master
Commit: f4344582ba28983bf3892d08e11236f090f5bf92
Parents: 22014e6
Author: fwang1 <de...@gmail.com>
Authored: Sun Apr 10 01:13:25 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun Apr 10 01:13:25 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/CountVectorizer.scala | 14 ++++----------
 .../spark/ml/feature/CountVectorizerSuite.scala       |  7 ++++---
 2 files changed, 8 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f4344582/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index f1be971..00abbbe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String)
       (word, count)
     }.cache()
     val fullVocabSize = wordCounts.count()
-    val vocab: Array[String] = {
-      val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
-        // Use all terms
-        wordCounts.collect().sortBy(-_._2)
-      } else {
-        // Sort terms to select vocab
-        wordCounts.sortBy(_._2, ascending = false).take(vocSize)
-      }
-      tmpSortedWC.map(_._1)
-    }
+
+    val vocab = wordCounts
+      .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
+      .map(_._1)
 
     require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
     copyValues(new CountVectorizerModel(uid, vocab).setParent(this))

http://git-wip-us.apache.org/repos/asf/spark/blob/f4344582/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index ff0de06..7641e3b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
       (0, split("a b c d e"),
         Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
       (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
-      (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
-      (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
+      (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))),
+      (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))),
+      (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
     ).toDF("id", "words", "expected")
     val cv = new CountVectorizer()
       .setInputCol("words")
       .setOutputCol("features")
       .fit(df)
-    assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
+    assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
 
     cv.transform(df).select("features", "expected").collect().foreach {
       case Row(features: Vector, expected: Vector) =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org