You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/12/08 19:29:55 UTC

spark git commit: [SPARK-10393] use ML pipeline in LDA example

Repository: spark
Updated Branches:
  refs/heads/master 5d96a710a -> 872a2ee28


[SPARK-10393] use ML pipeline in LDA example

jira: https://issues.apache.org/jira/browse/SPARK-10393

Since the logic of the text processing part has been moved to ML estimators/transformers, replace the related code in LDA Example with the ML pipeline.

Author: Yuhao Yang <hh...@gmail.com>
Author: yuhaoyang <yu...@zhanglipings-iMac.local>

Closes #8551 from hhbyyh/ldaExUpdate.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/872a2ee2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/872a2ee2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/872a2ee2

Branch: refs/heads/master
Commit: 872a2ee281d84f40a786f765bf772cdb06e8c956
Parents: 5d96a71
Author: Yuhao Yang <hh...@gmail.com>
Authored: Tue Dec 8 10:29:51 2015 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Dec 8 10:29:51 2015 -0800

----------------------------------------------------------------------
 .../spark/examples/mllib/LDAExample.scala       | 153 +++++--------------
 1 file changed, 40 insertions(+), 113 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/872a2ee2/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index 75b0f69..70010b0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -18,19 +18,16 @@
 // scalastyle:off println
 package org.apache.spark.examples.mllib
 
-import java.text.BreakIterator
-
-import scala.collection.mutable
-
 import scopt.OptionParser
 
 import org.apache.log4j.{Level, Logger}
-
-import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
+import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer}
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.rdd.RDD
-
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.{SparkConf, SparkContext}
 
 /**
  * An example Latent Dirichlet Allocation (LDA) app. Run with
@@ -192,115 +189,45 @@ object LDAExample {
       vocabSize: Int,
       stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
 
+    val sqlContext = SQLContext.getOrCreate(sc)
+    import sqlContext.implicits._
+
     // Get dataset of document texts
     // One document per line in each text file. If the input consists of many small files,
     // this can result in a large number of small partitions, which can degrade performance.
     // In this case, consider using coalesce() to create fewer, larger partitions.
-    val textRDD: RDD[String] = sc.textFile(paths.mkString(","))
-
-    // Split text into words
-    val tokenizer = new SimpleTokenizer(sc, stopwordFile)
-    val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
-      id -> tokenizer.getWords(text)
-    }
-    tokenized.cache()
-
-    // Counts words: RDD[(word, wordCount)]
-    val wordCounts: RDD[(String, Long)] = tokenized
-      .flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
-      .reduceByKey(_ + _)
-    wordCounts.cache()
-    val fullVocabSize = wordCounts.count()
-    // Select vocab
-    //  (vocab: Map[word -> id], total tokens after selecting vocab)
-    val (vocab: Map[String, Int], selectedTokenCount: Long) = {
-      val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
-        // Use all terms
-        wordCounts.collect().sortBy(-_._2)
-      } else {
-        // Sort terms to select vocab
-        wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
-      }
-      (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
-    }
-
-    val documents = tokenized.map { case (id, tokens) =>
-      // Filter tokens by vocabulary, and create word count vector representation of document.
-      val wc = new mutable.HashMap[Int, Int]()
-      tokens.foreach { term =>
-        if (vocab.contains(term)) {
-          val termIndex = vocab(term)
-          wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
-        }
-      }
-      val indices = wc.keys.toArray.sorted
-      val values = indices.map(i => wc(i).toDouble)
-
-      val sb = Vectors.sparse(vocab.size, indices, values)
-      (id, sb)
-    }
-
-    val vocabArray = new Array[String](vocab.size)
-    vocab.foreach { case (term, i) => vocabArray(i) = term }
-
-    (documents, vocabArray, selectedTokenCount)
-  }
-}
-
-/**
- * Simple Tokenizer.
- *
- * TODO: Formalize the interface, and make this a public class in mllib.feature
- */
-private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
-
-  private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
-    Set.empty[String]
-  } else {
-    val stopwordText = sc.textFile(stopwordFile).collect()
-    stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
-  }
-
-  // Matches sequences of Unicode letters
-  private val allWordRegex = "^(\\p{L}*)$".r
-
-  // Ignore words shorter than this length.
-  private val minWordLength = 3
-
-  def getWords(text: String): IndexedSeq[String] = {
-
-    val words = new mutable.ArrayBuffer[String]()
-
-    // Use Java BreakIterator to tokenize text into words.
-    val wb = BreakIterator.getWordInstance
-    wb.setText(text)
-
-    // current,end index start,end of each word
-    var current = wb.first()
-    var end = wb.next()
-    while (end != BreakIterator.DONE) {
-      // Convert to lowercase
-      val word: String = text.substring(current, end).toLowerCase
-      // Remove short words and strings that aren't only letters
-      word match {
-        case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
-          words += w
-        case _ =>
-      }
-
-      current = end
-      try {
-        end = wb.next()
-      } catch {
-        case e: Exception =>
-          // Ignore remaining text in line.
-          // This is a known bug in BreakIterator (for some Java versions),
-          // which fails when it sees certain characters.
-          end = BreakIterator.DONE
-      }
+    val df = sc.textFile(paths.mkString(",")).toDF("docs")
+    val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) {
+      Array.empty[String]
+    } else {
+      val stopWordText = sc.textFile(stopwordFile).collect()
+      stopWordText.flatMap(_.stripMargin.split("\\s+"))
     }
-    words
+    val tokenizer = new RegexTokenizer()
+      .setInputCol("docs")
+      .setOutputCol("rawTokens")
+    val stopWordsRemover = new StopWordsRemover()
+      .setInputCol("rawTokens")
+      .setOutputCol("tokens")
+    stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
+    val countVectorizer = new CountVectorizer()
+      .setVocabSize(vocabSize)
+      .setInputCol("tokens")
+      .setOutputCol("features")
+
+    val pipeline = new Pipeline()
+      .setStages(Array(tokenizer, stopWordsRemover, countVectorizer))
+
+    val model = pipeline.fit(df)
+    val documents = model.transform(df)
+      .select("features")
+      .map { case Row(features: Vector) => features }
+      .zipWithIndex()
+      .map(_.swap)
+
+    (documents,
+      model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary,  // vocabulary
+      documents.map(_._2.numActives).sum().toLong) // total token count
   }
-
 }
 // scalastyle:on println


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org