You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/05/04 09:06:31 UTC

spark git commit: [SPARK-5563] [MLLIB] LDA with online variational inference

Repository: spark
Updated Branches:
  refs/heads/master 9646018bb -> 3539cb7d2


[SPARK-5563] [MLLIB] LDA with online variational inference

JIRA: https://issues.apache.org/jira/browse/SPARK-5563
The PR contains the implementation for [Online LDA] (https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) based on the research of  Matt Hoffman and David M. Blei, which provides an efficient option for LDA users. Major advantages for the algorithm are the stream compatibility and economic time/memory consumption due to the corpus split. For more details, please refer to the jira.

Online LDA can act as a fast option for LDA, and will be especially helpful for the users who needs a quick result or with large corpus.

 Correctness test.
I have tested current PR with https://github.com/Blei-Lab/onlineldavb and the results are identical. I've uploaded the result and code to https://github.com/hhbyyh/LDACrossValidation.

Author: Yuhao Yang <hh...@gmail.com>
Author: Joseph K. Bradley <jo...@databricks.com>

Closes #4419 from hhbyyh/ldaonline and squashes the following commits:

1045eec [Yuhao Yang] Merge pull request #2 from jkbradley/hhbyyh-ldaonline2
cf376ff [Joseph K. Bradley] For private vars needed for testing, I made them private and added accessors.  Java doesn’t understand package-private tags, so this minimizes the issues Java users might encounter.
6149ca6 [Yuhao Yang] fix for setOptimizer
cf0007d [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
54cf8da [Yuhao Yang] some style change
68c2318 [Yuhao Yang] add a java ut
4041723 [Yuhao Yang] add ut
138bfed [Yuhao Yang] Merge pull request #1 from jkbradley/hhbyyh-ldaonline-update
9e910d9 [Joseph K. Bradley] small fix
61d60df [Joseph K. Bradley] Minor cleanups: * Update *Concentration parameter documentation * EM Optimizer: createVertices() does not need to be a function * OnlineLDAOptimizer: typos in doc * Clean up the core code for online LDA (Scala style)
a996a82 [Yuhao Yang] respond to comments
b1178cf [Yuhao Yang] fit into the optimizer framework
dbe3cff [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
15be071 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
b29193b [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
d19ef55 [Yuhao Yang] change OnlineLDA to class
97b9e1a [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
e7bf3b0 [Yuhao Yang] move to seperate file
f367cc9 [Yuhao Yang] change to optimization
8cb16a6 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
62405cc [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
02d0373 [Yuhao Yang] fix style in comment
f6d47ca [Yuhao Yang] Merge branch 'ldaonline' of https://github.com/hhbyyh/spark into ldaonline
d86cdec [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
a570c9a [Yuhao Yang] use sample to pick up batch
4a3f27e [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
e271eb1 [Yuhao Yang] remove non ascii
581c623 [Yuhao Yang] seperate API and adjust batch split
37af91a [Yuhao Yang] iMerge remote-tracking branch 'upstream/master' into ldaonline
20328d1 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline i
aa365d1 [Yuhao Yang] merge upstream master
3a06526 [Yuhao Yang] merge with new example
0dd3947 [Yuhao Yang] kMerge remote-tracking branch 'upstream/master' into ldaonline
0d0f3ee [Yuhao Yang] replace random split with sliding
fa408a8 [Yuhao Yang] ssMerge remote-tracking branch 'upstream/master' into ldaonline
45884ab [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline s
f41c5ca [Yuhao Yang] style fix
26dca1b [Yuhao Yang] style fix and make class private
043e786 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline s Conflicts: 	mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
d640d9c [Yuhao Yang] online lda initial checkin


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3539cb7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3539cb7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3539cb7d

Branch: refs/heads/master
Commit: 3539cb7d20f5f878132407ec3b854011b183b2ad
Parents: 9646018
Author: Yuhao Yang <hh...@gmail.com>
Authored: Mon May 4 00:06:25 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon May 4 00:06:25 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/mllib/clustering/LDA.scala |  65 ++--
 .../spark/mllib/clustering/LDAOptimizer.scala   | 320 +++++++++++++++++--
 .../spark/mllib/clustering/JavaLDASuite.java    |  38 ++-
 .../spark/mllib/clustering/LDASuite.scala       |  89 +++++-
 4 files changed, 438 insertions(+), 74 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3539cb7d/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index 37bf88b..c8daa23 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -78,35 +78,29 @@ class LDA private (
    *
    * This is the parameter to a symmetric Dirichlet distribution.
    */
-  def getDocConcentration: Double = {
-    if (this.docConcentration == -1) {
-      (50.0 / k) + 1.0
-    } else {
-      this.docConcentration
-    }
-  }
+  def getDocConcentration: Double = this.docConcentration
 
   /**
    * Concentration parameter (commonly named "alpha") for the prior placed on documents'
    * distributions over topics ("theta").
    *
-   * This is the parameter to a symmetric Dirichlet distribution.
+   * This is the parameter to a symmetric Dirichlet distribution, where larger values
+   * mean more smoothing (more regularization).
    *
-   * This value should be > 1.0, where larger values mean more smoothing (more regularization).
    * If set to -1, then docConcentration is set automatically.
    *  (default = -1 = automatic)
    *
-   * Automatic setting of parameter:
-   *  - For EM: default = (50 / k) + 1.
-   *     - The 50/k is common in LDA libraries.
-   *     - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
-   *
-   * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
-   *       but values in (0,1) are not yet supported.
+   * Optimizer-specific parameter settings:
+   *  - EM
+   *     - Value should be > 1.0
+   *     - default = (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
+   *       Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+   *  - Online
+   *     - Value should be >= 0
+   *     - default = (1.0 / k), following the implementation from
+   *       [[https://github.com/Blei-Lab/onlineldavb]].
    */
   def setDocConcentration(docConcentration: Double): this.type = {
-    require(docConcentration > 1.0 || docConcentration == -1.0,
-      s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration")
     this.docConcentration = docConcentration
     this
   }
@@ -126,13 +120,7 @@ class LDA private (
    * Note: The topics' distributions over terms are called "beta" in the original LDA paper
    * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
    */
-  def getTopicConcentration: Double = {
-    if (this.topicConcentration == -1) {
-      1.1
-    } else {
-      this.topicConcentration
-    }
-  }
+  def getTopicConcentration: Double = this.topicConcentration
 
   /**
    * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
@@ -143,21 +131,20 @@ class LDA private (
    * Note: The topics' distributions over terms are called "beta" in the original LDA paper
    * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
    *
-   * This value should be > 0.0.
    * If set to -1, then topicConcentration is set automatically.
    *  (default = -1 = automatic)
    *
-   * Automatic setting of parameter:
-   *  - For EM: default = 0.1 + 1.
-   *     - The 0.1 gives a small amount of smoothing.
-   *     - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
-   *
-   * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
-   *       but values in (0,1) are not yet supported.
+   * Optimizer-specific parameter settings:
+   *  - EM
+   *     - Value should be > 1.0
+   *     - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows
+   *       Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+   *  - Online
+   *     - Value should be >= 0
+   *     - default = (1.0 / k), following the implementation from
+   *       [[https://github.com/Blei-Lab/onlineldavb]].
    */
   def setTopicConcentration(topicConcentration: Double): this.type = {
-    require(topicConcentration > 1.0 || topicConcentration == -1.0,
-      s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration")
     this.topicConcentration = topicConcentration
     this
   }
@@ -223,14 +210,15 @@ class LDA private (
 
   /**
    * Set the LDAOptimizer used to perform the actual calculation by algorithm name.
-   * Currently "em" is supported.
+   * Currently "em", "online" is supported.
    */
   def setOptimizer(optimizerName: String): this.type = {
     this.ldaOptimizer =
       optimizerName.toLowerCase match {
         case "em" => new EMLDAOptimizer
+        case "online" => new OnlineLDAOptimizer
         case other =>
-          throw new IllegalArgumentException(s"Only em is supported but got $other.")
+          throw new IllegalArgumentException(s"Only em, online are supported but got $other.")
       }
     this
   }
@@ -245,8 +233,7 @@ class LDA private (
    * @return  Inferred LDA model
    */
   def run(documents: RDD[(Long, Vector)]): LDAModel = {
-    val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration,
-      seed, checkpointInterval)
+    val state = ldaOptimizer.initialize(documents, this)
     var iter = 0
     val iterationTimes = Array.fill[Double](maxIterations)(0)
     while (iter < maxIterations) {

http://git-wip-us.apache.org/repos/asf/spark/blob/3539cb7d/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index ffd72a2..093aa0f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -19,13 +19,15 @@ package org.apache.spark.mllib.clustering
 
 import java.util.Random
 
-import breeze.linalg.{DenseVector => BDV, normalize}
+import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
+import breeze.numerics.{digamma, exp, abs}
+import breeze.stats.distributions.{Gamma, RandBasis}
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.graphx._
 import org.apache.spark.graphx.impl.GraphImpl
 import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
 import org.apache.spark.rdd.RDD
 
 /**
@@ -35,7 +37,7 @@ import org.apache.spark.rdd.RDD
  * hold optimizer-specific parameters for users to set.
  */
 @Experimental
-trait LDAOptimizer{
+trait LDAOptimizer {
 
   /*
     DEVELOPERS NOTE:
@@ -49,13 +51,7 @@ trait LDAOptimizer{
    * Initializer for the optimizer. LDA passes the common parameters to the optimizer and
    * the internal structure can be initialized properly.
    */
-  private[clustering] def initialState(
-      docs: RDD[(Long, Vector)],
-      k: Int,
-      docConcentration: Double,
-      topicConcentration: Double,
-      randomSeed: Long,
-      checkpointInterval: Int): LDAOptimizer
+  private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer
 
   private[clustering] def next(): LDAOptimizer
 
@@ -80,12 +76,12 @@ trait LDAOptimizer{
  *
  */
 @Experimental
-class EMLDAOptimizer extends LDAOptimizer{
+class EMLDAOptimizer extends LDAOptimizer {
 
   import LDA._
 
   /**
-   * Following fields will only be initialized through initialState method
+   * The following fields will only be initialized through the initialize() method
    */
   private[clustering] var graph: Graph[TopicCounts, TokenCount] = null
   private[clustering] var k: Int = 0
@@ -98,13 +94,23 @@ class EMLDAOptimizer extends LDAOptimizer{
   /**
    * Compute bipartite term/doc graph.
    */
-  private[clustering] override def initialState(
-      docs: RDD[(Long, Vector)],
-      k: Int,
-      docConcentration: Double,
-      topicConcentration: Double,
-      randomSeed: Long,
-      checkpointInterval: Int): LDAOptimizer = {
+  override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
+
+    val docConcentration = lda.getDocConcentration
+    val topicConcentration = lda.getTopicConcentration
+    val k = lda.getK
+
+    // Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
+    // but values in (0,1) are not yet supported.
+    require(docConcentration > 1.0 || docConcentration == -1.0, s"LDA docConcentration must be" +
+      s" > 1.0 (or -1 for auto) for EM Optimizer, but was set to $docConcentration")
+    require(topicConcentration > 1.0 || topicConcentration == -1.0, s"LDA topicConcentration " +
+      s"must be > 1.0 (or -1 for auto) for EM Optimizer, but was set to $topicConcentration")
+
+    this.docConcentration = if (docConcentration == -1) (50.0 / k) + 1.0 else docConcentration
+    this.topicConcentration = if (topicConcentration == -1) 1.1 else topicConcentration
+    val randomSeed = lda.getSeed
+
     // For each document, create an edge (Document -> Term) for each unique term in the document.
     val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
       // Add edges for terms with non-zero counts.
@@ -113,11 +119,9 @@ class EMLDAOptimizer extends LDAOptimizer{
       }
     }
 
-    val vocabSize = docs.take(1).head._2.size
-
     // Create vertices.
     // Initially, we use random soft assignments of tokens to topics (random gamma).
-    def createVertices(): RDD[(VertexId, TopicCounts)] = {
+    val docTermVertices: RDD[(VertexId, TopicCounts)] = {
       val verticesTMP: RDD[(VertexId, TopicCounts)] =
         edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
           val random = new Random(partIndex + randomSeed)
@@ -130,22 +134,18 @@ class EMLDAOptimizer extends LDAOptimizer{
       verticesTMP.reduceByKey(_ + _)
     }
 
-    val docTermVertices = createVertices()
-
     // Partition such that edges are grouped by document
     this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D)
     this.k = k
-    this.vocabSize = vocabSize
-    this.docConcentration = docConcentration
-    this.topicConcentration = topicConcentration
-    this.checkpointInterval = checkpointInterval
+    this.vocabSize = docs.take(1).head._2.size
+    this.checkpointInterval = lda.getCheckpointInterval
     this.graphCheckpointer = new
       PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
     this.globalTopicTotals = computeGlobalTopicTotals()
     this
   }
 
-  private[clustering] override def next(): EMLDAOptimizer = {
+  override private[clustering] def next(): EMLDAOptimizer = {
     require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
 
     val eta = topicConcentration
@@ -202,9 +202,269 @@ class EMLDAOptimizer extends LDAOptimizer{
     graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
   }
 
-  private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
+  override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
     require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
     this.graphCheckpointer.deleteAllCheckpoints()
     new DistributedLDAModel(this, iterationTimes)
   }
 }
+
+
+/**
+ * :: Experimental ::
+ *
+ * An online optimizer for LDA. The Optimizer implements the Online variational Bayes LDA
+ * algorithm, which processes a subset of the corpus on each iteration, and updates the term-topic
+ * distribution adaptively.
+ *
+ * Original Online LDA paper:
+ *   Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
+ */
+@Experimental
+class OnlineLDAOptimizer extends LDAOptimizer {
+
+  // LDA common parameters
+  private var k: Int = 0
+  private var corpusSize: Long = 0
+  private var vocabSize: Int = 0
+
+  /** alias for docConcentration */
+  private var alpha: Double = 0
+
+  /** (private[clustering] for debugging)  Get docConcentration */
+  private[clustering] def getAlpha: Double = alpha
+
+  /** alias for topicConcentration */
+  private var eta: Double = 0
+
+  /** (private[clustering] for debugging)  Get topicConcentration */
+  private[clustering] def getEta: Double = eta
+
+  private var randomGenerator: java.util.Random = null
+
+  // Online LDA specific parameters
+  // Learning rate is: (tau_0 + t)^{-kappa}
+  private var tau_0: Double = 1024
+  private var kappa: Double = 0.51
+  private var miniBatchFraction: Double = 0.05
+
+  // internal data structure
+  private var docs: RDD[(Long, Vector)] = null
+
+  /** Dirichlet parameter for the posterior over topics */
+  private var lambda: BDM[Double] = null
+
+  /** (private[clustering] for debugging) Get parameter for topics */
+  private[clustering] def getLambda: BDM[Double] = lambda
+
+  /** Current iteration (count of invocations of [[next()]]) */
+  private var iteration: Int = 0
+  private var gammaShape: Double = 100
+
+  /**
+   * A (positive) learning parameter that downweights early iterations. Larger values make early
+   * iterations count less.
+   */
+  def getTau_0: Double = this.tau_0
+
+  /**
+   * A (positive) learning parameter that downweights early iterations. Larger values make early
+   * iterations count less.
+   * Default: 1024, following the original Online LDA paper.
+   */
+  def setTau_0(tau_0: Double): this.type = {
+    require(tau_0 > 0,  s"LDA tau_0 must be positive, but was set to $tau_0")
+    this.tau_0 = tau_0
+    this
+  }
+
+  /**
+   * Learning rate: exponential decay rate
+   */
+  def getKappa: Double = this.kappa
+
+  /**
+   * Learning rate: exponential decay rate---should be between
+   * (0.5, 1.0] to guarantee asymptotic convergence.
+   * Default: 0.51, based on the original Online LDA paper.
+   */
+  def setKappa(kappa: Double): this.type = {
+    require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa")
+    this.kappa = kappa
+    this
+  }
+
+  /**
+   * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration
+   */
+  def getMiniBatchFraction: Double = this.miniBatchFraction
+
+  /**
+   * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in
+   * each iteration.
+   *
+   * Note that this should be adjusted in synch with [[LDA.setMaxIterations()]]
+   * so the entire corpus is used.  Specifically, set both so that
+   * maxIterations * miniBatchFraction >= 1.
+   *
+   * Default: 0.05, i.e., 5% of total documents.
+   */
+  def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
+    require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0,
+      s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction")
+    this.miniBatchFraction = miniBatchFraction
+    this
+  }
+
+  /**
+   * (private[clustering])
+   * Set the Dirichlet parameter for the posterior over topics.
+   * This is only used for testing now. In the future, it can help support training stop/resume.
+   */
+  private[clustering] def setLambda(lambda: BDM[Double]): this.type = {
+    this.lambda = lambda
+    this
+  }
+
+  /**
+   * (private[clustering])
+   * Used for random initialization of the variational parameters.
+   * Larger value produces values closer to 1.0.
+   * This is only used for testing currently.
+   */
+  private[clustering] def setGammaShape(shape: Double): this.type = {
+    this.gammaShape = shape
+    this
+  }
+
+  override private[clustering] def initialize(
+      docs: RDD[(Long, Vector)],
+      lda: LDA):  OnlineLDAOptimizer = {
+    this.k = lda.getK
+    this.corpusSize = docs.count()
+    this.vocabSize = docs.first()._2.size
+    this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration
+    this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
+    this.randomGenerator = new Random(lda.getSeed)
+
+    this.docs = docs
+
+    // Initialize the variational distribution q(beta|lambda)
+    this.lambda = getGammaMatrix(k, vocabSize)
+    this.iteration = 0
+    this
+  }
+
+  override private[clustering] def next(): OnlineLDAOptimizer = {
+    val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong())
+    if (batch.isEmpty()) return this
+    submitMiniBatch(batch)
+  }
+
+  /**
+   * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA
+   * model, and it will update the topic distribution adaptively for the terms appearing in the
+   * subset.
+   */
+  private[clustering] def submitMiniBatch(batch: RDD[(Long, Vector)]): OnlineLDAOptimizer = {
+    iteration += 1
+    val k = this.k
+    val vocabSize = this.vocabSize
+    val Elogbeta = dirichletExpectation(lambda)
+    val expElogbeta = exp(Elogbeta)
+    val alpha = this.alpha
+    val gammaShape = this.gammaShape
+
+    val stats: RDD[BDM[Double]] = batch.mapPartitions { docs =>
+      val stat = BDM.zeros[Double](k, vocabSize)
+      docs.foreach { doc =>
+        val termCounts = doc._2
+        val (ids: List[Int], cts: Array[Double]) = termCounts match {
+          case v: DenseVector => ((0 until v.size).toList, v.values)
+          case v: SparseVector => (v.indices.toList, v.values)
+          case v => throw new IllegalArgumentException("Online LDA does not support vector type "
+            + v.getClass)
+        }
+
+        // Initialize the variational distribution q(theta|gamma) for the mini-batch
+        var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
+        var Elogthetad = digamma(gammad) - digamma(sum(gammad))     // 1 * K
+        var expElogthetad = exp(Elogthetad)                         // 1 * K
+        val expElogbetad = expElogbeta(::, ids).toDenseMatrix       // K * ids
+
+        var phinorm = expElogthetad * expElogbetad + 1e-100         // 1 * ids
+        var meanchange = 1D
+        val ctsVector = new BDV[Double](cts).t                      // 1 * ids
+
+        // Iterate between gamma and phi until convergence
+        while (meanchange > 1e-3) {
+          val lastgamma = gammad
+          //        1*K                  1 * ids               ids * k
+          gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
+          Elogthetad = digamma(gammad) - digamma(sum(gammad))
+          expElogthetad = exp(Elogthetad)
+          phinorm = expElogthetad * expElogbetad + 1e-100
+          meanchange = sum(abs(gammad - lastgamma)) / k
+        }
+
+        val m1 = expElogthetad.t
+        val m2 = (ctsVector / phinorm).t.toDenseVector
+        var i = 0
+        while (i < ids.size) {
+          stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
+          i += 1
+        }
+      }
+      Iterator(stat)
+    }
+
+    val statsSum: BDM[Double] = stats.reduce(_ += _)
+    val batchResult = statsSum :* expElogbeta
+
+    // Note that this is an optimization to avoid batch.count
+    update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
+    this
+  }
+
+  override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
+    new LocalLDAModel(Matrices.fromBreeze(lambda).transpose)
+  }
+
+  /**
+   * Update lambda based on the batch submitted. batchSize can be different for each iteration.
+   */
+  private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = {
+    val tau_0 = this.getTau_0
+    val kappa = this.getKappa
+
+    // weight of the mini-batch.
+    val weight = math.pow(tau_0 + iter, -kappa)
+
+    // Update lambda based on documents.
+    lambda = lambda * (1 - weight) +
+      (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight
+  }
+
+  /**
+   * Get a random matrix to initialize lambda
+   */
+  private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
+    val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
+      randomGenerator.nextLong()))
+    val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis)
+    val temp = gammaRandomGenerator.sample(row * col).toArray
+    new BDM[Double](col, row, temp).t
+  }
+
+  /**
+   * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
+   * uses digamma which is accurate but expensive.
+   */
+  private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
+    val rowSum =  sum(alpha(breeze.linalg.*, ::))
+    val digAlpha = digamma(alpha)
+    val digRowSum = digamma(rowSum)
+    val result = digAlpha(::, breeze.linalg.*) - digRowSum
+    result
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3539cb7d/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index fbe171b..f394d90 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.clustering;
 import java.io.Serializable;
 import java.util.ArrayList;
 
-import org.apache.spark.api.java.JavaRDD;
 import scala.Tuple2;
 
 import org.junit.After;
@@ -30,6 +29,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Matrix;
 import org.apache.spark.mllib.linalg.Vector;
@@ -109,11 +109,45 @@ public class JavaLDASuite implements Serializable {
     assert(model.logPrior() < 0.0);
   }
 
+  @Test
+  public void OnlineOptimizerCompatibility() {
+    int k = 3;
+    double topicSmoothing = 1.2;
+    double termSmoothing = 1.2;
+
+    // Train a model
+    OnlineLDAOptimizer op = new OnlineLDAOptimizer()
+      .setTau_0(1024)
+      .setKappa(0.51)
+      .setGammaShape(1e40)
+      .setMiniBatchFraction(0.5);
+
+    LDA lda = new LDA();
+    lda.setK(k)
+      .setDocConcentration(topicSmoothing)
+      .setTopicConcentration(termSmoothing)
+      .setMaxIterations(5)
+      .setSeed(12345)
+      .setOptimizer(op);
+
+    LDAModel model = lda.run(corpus);
+
+    // Check: basic parameters
+    assertEquals(model.k(), k);
+    assertEquals(model.vocabSize(), tinyVocabSize);
+
+    // Check: topic summaries
+    Tuple2<int[], double[]>[] roundedTopicSummary = model.describeTopics();
+    assertEquals(roundedTopicSummary.length, k);
+    Tuple2<int[], double[]>[] roundedLocalTopicSummary = model.describeTopics();
+    assertEquals(roundedLocalTopicSummary.length, k);
+  }
+
   private static int tinyK = LDASuite$.MODULE$.tinyK();
   private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
   private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
   private static Tuple2<int[], double[]>[] tinyTopicDescription =
       LDASuite$.MODULE$.tinyTopicDescription();
-  JavaPairRDD<Long, Vector> corpus;
+  private JavaPairRDD<Long, Vector> corpus;
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3539cb7d/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 41ec794..2dcc881 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.mllib.clustering
 
+import breeze.linalg.{DenseMatrix => BDM}
+
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
@@ -37,7 +39,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
 
     // Check: describeTopics() with all terms
     val fullTopicSummary = model.describeTopics()
-    assert(fullTopicSummary.size === tinyK)
+    assert(fullTopicSummary.length === tinyK)
     fullTopicSummary.zip(tinyTopicDescription).foreach {
       case ((algTerms, algTermWeights), (terms, termWeights)) =>
         assert(algTerms === terms)
@@ -54,7 +56,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
     }
   }
 
-  test("running and DistributedLDAModel") {
+  test("running and DistributedLDAModel with default Optimizer (EM)") {
     val k = 3
     val topicSmoothing = 1.2
     val termSmoothing = 1.2
@@ -99,7 +101,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
     // Check: per-doc topic distributions
     val topicDistributions = model.topicDistributions.collect()
     //  Ensure all documents are covered.
-    assert(topicDistributions.size === tinyCorpus.size)
+    assert(topicDistributions.length === tinyCorpus.length)
     assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
     //  Ensure we have proper distributions
     topicDistributions.foreach { case (docId, topicDistribution) =>
@@ -131,6 +133,87 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
     assert(lda.getBeta === 3.0)
     assert(lda.getTopicConcentration === 3.0)
   }
+
+  test("OnlineLDAOptimizer initialization") {
+    val lda = new LDA().setK(2)
+    val corpus = sc.parallelize(tinyCorpus, 2)
+    val op = new OnlineLDAOptimizer().initialize(corpus, lda)
+    op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau_0(567)
+    assert(op.getAlpha == 0.5) // default 1.0 / k
+    assert(op.getEta == 0.5)   // default 1.0 / k
+    assert(op.getKappa == 0.9876)
+    assert(op.getMiniBatchFraction == 0.123)
+    assert(op.getTau_0 == 567)
+  }
+
+  test("OnlineLDAOptimizer one iteration") {
+    // run OnlineLDAOptimizer for 1 iteration to verify it's consistency with Blei-lab,
+    // [[https://github.com/Blei-Lab/onlineldavb]]
+    val k = 2
+    val vocabSize = 6
+
+    def docs: Array[(Long, Vector)] = Array(
+      Vectors.sparse(vocabSize, Array(0, 1, 2), Array(1, 1, 1)), // apple, orange, banana
+      Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1)) // tiger, cat, dog
+    ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+    val corpus = sc.parallelize(docs, 2)
+
+    // Set GammaShape large to avoid the stochastic impact.
+    val op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51).setGammaShape(1e40)
+      .setMiniBatchFraction(1)
+    val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op).setSeed(12345)
+
+    val state = op.initialize(corpus, lda)
+    // override lambda to simulate an intermediate state
+    //    [[ 1.1  1.2  1.3  0.9  0.8  0.7]
+    //     [ 0.9  0.8  0.7  1.1  1.2  1.3]]
+    op.setLambda(new BDM[Double](k, vocabSize,
+      Array(1.1, 0.9, 1.2, 0.8, 1.3, 0.7, 0.9, 1.1, 0.8, 1.2, 0.7, 1.3)))
+
+    // run for one iteration
+    state.submitMiniBatch(corpus)
+
+    // verify the result, Note this generate the identical result as
+    // [[https://github.com/Blei-Lab/onlineldavb]]
+    val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
+    val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
+    assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
+    assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
+  }
+
+  test("OnlineLDAOptimizer with toy data") {
+    def toydata: Array[(Long, Vector)] = Array(
+      Vectors.sparse(6, Array(0, 1), Array(1, 1)),
+      Vectors.sparse(6, Array(1, 2), Array(1, 1)),
+      Vectors.sparse(6, Array(0, 2), Array(1, 1)),
+      Vectors.sparse(6, Array(3, 4), Array(1, 1)),
+      Vectors.sparse(6, Array(3, 5), Array(1, 1)),
+      Vectors.sparse(6, Array(4, 5), Array(1, 1))
+    ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+
+    val docs = sc.parallelize(toydata)
+    val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau_0(1024).setKappa(0.51)
+      .setGammaShape(1e10)
+    val lda = new LDA().setK(2)
+      .setDocConcentration(0.01)
+      .setTopicConcentration(0.01)
+      .setMaxIterations(100)
+      .setOptimizer(op)
+      .setSeed(12345)
+
+    val ldaModel = lda.run(docs)
+    val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
+    val topics = topicIndices.map { case (terms, termWeights) =>
+      terms.zip(termWeights)
+    }
+
+    // check distribution for each topic, typical distribution is (0.3, 0.3, 0.3, 0.02, 0.02, 0.02)
+    topics.foreach { topic =>
+      val smalls = topic.filter(t => t._2 < 0.1).map(_._2)
+      assert(smalls.length == 3 && smalls.sum < 0.2)
+    }
+  }
+
 }
 
 private[clustering] object LDASuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org