You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2016/08/18 12:33:58 UTC

spark git commit: [SPARK-16447][ML][SPARKR] LDA wrapper in SparkR

Repository: spark
Updated Branches:
  refs/heads/master 68f5087d2 -> b72bb62d4


[SPARK-16447][ML][SPARKR] LDA wrapper in SparkR

## What changes were proposed in this pull request?

Add LDA Wrapper in SparkR with the following interfaces:

- spark.lda(data, ...)

- spark.posterior(object, newData, ...)

- spark.perplexity(object, ...)

- summary(object)

- write.ml(object)

- read.ml(path)

## How was this patch tested?

Test with SparkR unit test.

Author: Xusen Yin <yi...@gmail.com>

Closes #14229 from yinxusen/SPARK-16447.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b72bb62d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b72bb62d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b72bb62d

Branch: refs/heads/master
Commit: b72bb62d421840f82d663c6b8e3922bd14383fbb
Parents: 68f5087
Author: Xusen Yin <yi...@gmail.com>
Authored: Thu Aug 18 05:33:52 2016 -0700
Committer: Felix Cheung <fe...@apache.org>
Committed: Thu Aug 18 05:33:52 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |   3 +
 R/pkg/R/generics.R                              |  14 ++
 R/pkg/R/mllib.R                                 | 166 +++++++++++++-
 R/pkg/inst/tests/testthat/test_mllib.R          |  87 ++++++++
 .../org/apache/spark/ml/clustering/LDA.scala    |   4 +
 .../org/apache/spark/ml/r/LDAWrapper.scala      | 216 +++++++++++++++++++
 .../scala/org/apache/spark/ml/r/RWrappers.scala |   2 +
 7 files changed, 490 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b72bb62d/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index c71eec5..4404cff 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -25,6 +25,9 @@ exportMethods("glm",
               "fitted",
               "spark.naiveBayes",
               "spark.survreg",
+              "spark.lda",
+              "spark.posterior",
+              "spark.perplexity",
               "spark.isoreg",
               "spark.gaussianMixture")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b72bb62d/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 06bb25d..fe04bcf 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1304,6 +1304,19 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
 #' @export
 setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
 
+#' @rdname spark.lda
+#' @param ... Additional parameters to tune LDA.
+#' @export
+setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
+
+#' @rdname spark.lda
+#' @export
+setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") })
+
+#' @rdname spark.lda
+#' @export
+setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") })
+
 #' @rdname spark.isoreg
 #' @export
 setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
@@ -1315,6 +1328,7 @@ setGeneric("spark.gaussianMixture",
              standardGeneric("spark.gaussianMixture")
            })
 
+#' write.ml
 #' @rdname write.ml
 #' @export
 setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })

http://git-wip-us.apache.org/repos/asf/spark/blob/b72bb62d/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index db74046..b952741 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -39,6 +39,13 @@ setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))
 #' @note NaiveBayesModel since 2.0.0
 setClass("NaiveBayesModel", representation(jobj = "jobj"))
 
+#' S4 class that represents an LDAModel
+#'
+#' @param jobj a Java object reference to the backing Scala LDAWrapper
+#' @export
+#' @note LDAModel since 2.1.0
+setClass("LDAModel", representation(jobj = "jobj"))
+
 #' S4 class that represents a AFTSurvivalRegressionModel
 #'
 #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
@@ -75,7 +82,7 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
 #' @name write.ml
 #' @export
 #' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda}
 #' @seealso \link{spark.isoreg}
 #' @seealso \link{read.ml}
 NULL
@@ -315,6 +322,94 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
             return(list(apriori = apriori, tables = tables))
           })
 
+# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda()
+
+#' @param newData A SparkDataFrame for testing
+#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities
+#'         vectors named "topicDistribution"
+#' @rdname spark.lda
+#' @aliases spark.posterior,LDAModel,SparkDataFrame-method
+#' @export
+#' @note spark.posterior(LDAModel) since 2.1.0
+setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"),
+          function(object, newData) {
+            return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+          })
+
+# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda}
+
+#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}.
+#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10.
+#' @return \code{summary} returns a list containing
+#'         \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for
+#'               the prior placed on documents distributions over topics \code{theta}}
+#'         \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or
+#'               \code{eta} for the prior placed on topic distributions over terms}
+#'         \item{\code{logLikelihood}}{log likelihood of the entire corpus}
+#'         \item{\code{logPerplexity}}{log perplexity}
+#'         \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model}
+#'         \item{\code{vocabSize}}{number of terms in the corpus}
+#'         \item{\code{topics}}{top 10 terms and their weights of all topics}
+#'         \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file
+#'               used as training set}
+#' @rdname spark.lda
+#' @aliases summary,LDAModel-method
+#' @export
+#' @note summary(LDAModel) since 2.1.0
+setMethod("summary", signature(object = "LDAModel"),
+          function(object, maxTermsPerTopic) {
+            maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic))
+            jobj <- object@jobj
+            docConcentration <- callJMethod(jobj, "docConcentration")
+            topicConcentration <- callJMethod(jobj, "topicConcentration")
+            logLikelihood <- callJMethod(jobj, "logLikelihood")
+            logPerplexity <- callJMethod(jobj, "logPerplexity")
+            isDistributed <- callJMethod(jobj, "isDistributed")
+            vocabSize <- callJMethod(jobj, "vocabSize")
+            topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
+            vocabulary <- callJMethod(jobj, "vocabulary")
+            return(list(docConcentration = unlist(docConcentration),
+                        topicConcentration = topicConcentration,
+                        logLikelihood = logLikelihood, logPerplexity = logPerplexity,
+                        isDistributed = isDistributed, vocabSize = vocabSize,
+                        topics = topics,
+                        vocabulary = unlist(vocabulary)))
+          })
+
+# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}
+
+#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log
+#'         perplexity of the training data if missing argument "data".
+#' @rdname spark.lda
+#' @aliases spark.perplexity,LDAModel-method
+#' @export
+#' @note spark.perplexity(LDAModel) since 2.1.0
+setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"),
+          function(object, data) {
+            return(ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
+                   callJMethod(object@jobj, "computeLogPerplexity", data@sdf)))
+         })
+
+# Saves the Latent Dirichlet Allocation model to the input path.
+
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @rdname spark.lda
+#' @aliases write.ml,LDAModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(LDAModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "LDAModel", path = "character"),
+          function(object, path, overwrite = FALSE) {
+            writer <- callJMethod(object@jobj, "write")
+            if (overwrite) {
+              writer <- callJMethod(writer, "overwrite")
+            }
+            invisible(callJMethod(writer, "save", path))
+          })
+
 #' Isotonic Regression Model
 #'
 #' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg().
@@ -700,6 +795,8 @@ read.ml <- function(path) {
       return(new("GeneralizedLinearRegressionModel", jobj = jobj))
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
       return(new("KMeansModel", jobj = jobj))
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
+      return(new("LDAModel", jobj = jobj))
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
       return(new("IsotonicRegressionModel", jobj = jobj))
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
@@ -751,6 +848,71 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
             return(new("AFTSurvivalRegressionModel", jobj = jobj))
           })
 
+#' Latent Dirichlet Allocation
+#'
+#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call
+#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute
+#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new
+#' data and \code{write.ml}/\code{read.ml} to save/load fitted models.
+#'
+#' @param data A SparkDataFrame for training
+#' @param features Features column name, default "features". Either libSVM-format column or
+#'        character-format column is valid.
+#' @param k Number of topics, default 10
+#' @param maxIter Maximum iterations, default 20
+#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online"
+#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in
+#'        each iteration of mini-batch gradient descent, in range (0, 1], default 0.05
+#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for
+#'        the prior placed on topic distributions over terms, default -1 to set automatically on the
+#'        Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size
+#'        numeric is accepted.
+#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the
+#'        prior placed on documents distributions over topics (\code{theta}), default -1 to set
+#'        automatically on the Spark side. Use \code{summary} to retrieve the effective
+#'        docConcentration. Only 1-size or \code{k}-size numeric is accepted.
+#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the
+#'        parameter if libSVM-format column is used as the features column.
+#' @param maxVocabSize maximum vocabulary size, default 1 << 18
+#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model
+#' @rdname spark.lda
+#' @aliases spark.lda,SparkDataFrame-method
+#' @seealso topicmodels: \url{https://cran.r-project.org/web/packages/topicmodels/}
+#' @export
+#' @examples
+#' \dontrun{
+#' text <- read.df("path/to/data", source = "libsvm")
+#' model <- spark.lda(data = text, optimizer = "em")
+#'
+#' # get a summary of the model
+#' summary(model)
+#'
+#' # compute posterior probabilities
+#' posterior <- spark.posterior(model, df)
+#' showDF(posterior)
+#'
+#' # compute perplexity
+#' perplexity <- spark.perplexity(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.lda since 2.1.0
+setMethod("spark.lda", signature(data = "SparkDataFrame"),
+          function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"),
+                   subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1,
+                   customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) {
+            optimizer <- match.arg(optimizer)
+            jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features,
+                                as.integer(k), as.integer(maxIter), optimizer,
+                                as.numeric(subsamplingRate), topicConcentration,
+                                as.array(docConcentration), as.array(customizedStopWords),
+                                maxVocabSize)
+            return(new("LDAModel", jobj = jobj))
+          })
 
 # Returns a summary of the AFT survival regression model produced by spark.survreg,
 # similarly to R's summary().
@@ -891,4 +1053,4 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
 setMethod("predict", signature(object = "GaussianMixtureModel"),
           function(object, newData) {
             return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
-          })
+          })
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/b72bb62d/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 9617986..8c380fb 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", {
   unlink(modelPath)
 })
 
+test_that("spark.lda with libsvm", {
+  text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm")
+  model <- spark.lda(text, optimizer = "em")
+
+  stats <- summary(model, 10)
+  isDistributed <- stats$isDistributed
+  logLikelihood <- stats$logLikelihood
+  logPerplexity <- stats$logPerplexity
+  vocabSize <- stats$vocabSize
+  topics <- stats$topicTopTerms
+  weights <- stats$topicTopTermsWeights
+  vocabulary <- stats$vocabulary
+
+  expect_false(isDistributed)
+  expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+  expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+  expect_equal(vocabSize, 11)
+  expect_true(is.null(vocabulary))
+
+  # Test model save/load
+  modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
+  write.ml(model, modelPath)
+  expect_error(write.ml(model, modelPath))
+  write.ml(model, modelPath, overwrite = TRUE)
+  model2 <- read.ml(modelPath)
+  stats2 <- summary(model2)
+
+  expect_false(stats2$isDistributed)
+  expect_equal(logLikelihood, stats2$logLikelihood)
+  expect_equal(logPerplexity, stats2$logPerplexity)
+  expect_equal(vocabSize, stats2$vocabSize)
+  expect_equal(vocabulary, stats2$vocabulary)
+
+  unlink(modelPath)
+})
+
+test_that("spark.lda with text input", {
+  text <- read.text("data/mllib/sample_lda_data.txt")
+  model <- spark.lda(text, optimizer = "online", features = "value")
+
+  stats <- summary(model)
+  isDistributed <- stats$isDistributed
+  logLikelihood <- stats$logLikelihood
+  logPerplexity <- stats$logPerplexity
+  vocabSize <- stats$vocabSize
+  topics <- stats$topicTopTerms
+  weights <- stats$topicTopTermsWeights
+  vocabulary <- stats$vocabulary
+
+  expect_false(isDistributed)
+  expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+  expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+  expect_equal(vocabSize, 10)
+  expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
+
+  # Test model save/load
+  modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
+  write.ml(model, modelPath)
+  expect_error(write.ml(model, modelPath))
+  write.ml(model, modelPath, overwrite = TRUE)
+  model2 <- read.ml(modelPath)
+  stats2 <- summary(model2)
+
+  expect_false(stats2$isDistributed)
+  expect_equal(logLikelihood, stats2$logLikelihood)
+  expect_equal(logPerplexity, stats2$logPerplexity)
+  expect_equal(vocabSize, stats2$vocabSize)
+  expect_true(all.equal(vocabulary, stats2$vocabulary))
+
+  unlink(modelPath)
+})
+
+test_that("spark.posterior and spark.perplexity", {
+  text <- read.text("data/mllib/sample_lda_data.txt")
+  model <- spark.lda(text, features = "value", k = 3)
+
+  # Assert perplexities are equal
+  stats <- summary(model)
+  logPerplexity <- spark.perplexity(model, text)
+  expect_equal(logPerplexity, stats$logPerplexity)
+
+  # Assert the sum of every topic distribution is equal to 1
+  posterior <- spark.posterior(model, text)
+  local.posterior <- collect(posterior)$topicDistribution
+  expect_equal(length(local.posterior), sum(unlist(local.posterior)))
+})
+
 sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/b72bb62d/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 034f2c3..b5a764b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -386,6 +386,10 @@ sealed abstract class LDAModel private[ml] (
   @Since("1.6.0")
   protected def getModel: OldLDAModel
 
+  private[ml] def getEffectiveDocConcentration: Array[Double] = getModel.docConcentration.toArray
+
+  private[ml] def getEffectiveTopicConcentration: Double = getModel.topicConcentration
+
   /**
    * The features for LDA should be a [[Vector]] representing the word counts in a document.
    * The vector should be of length vocabSize, with counts for each term (word).

http://git-wip-us.apache.org/repos/asf/spark/blob/b72bb62d/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
new file mode 100644
index 0000000..cbe6a70
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
+import org.apache.spark.ml.clustering.{LDA, LDAModel}
+import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.param.ParamPair
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StringType
+
+
+private[r] class LDAWrapper private (
+    val pipeline: PipelineModel,
+    val logLikelihood: Double,
+    val logPerplexity: Double,
+    val vocabulary: Array[String]) extends MLWritable {
+
+  import LDAWrapper._
+
+  private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
+  private val preprocessor: PipelineModel =
+    new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
+
+  def transform(data: Dataset[_]): DataFrame = {
+    val vec2ary = udf { vec: Vector => vec.toArray }
+    val outputCol = lda.getTopicDistributionCol
+    val tempCol = s"${Identifiable.randomUID(outputCol)}"
+    val preprocessed = preprocessor.transform(data)
+    lda.transform(preprocessed, ParamPair(lda.topicDistributionCol, tempCol))
+      .withColumn(outputCol, vec2ary(col(tempCol)))
+      .drop(TOKENIZER_COL, STOPWORDS_REMOVER_COL, COUNT_VECTOR_COL, tempCol)
+  }
+
+  def computeLogPerplexity(data: Dataset[_]): Double = {
+    lda.logPerplexity(preprocessor.transform(data))
+  }
+
+  def topics(maxTermsPerTopic: Int): DataFrame = {
+    val topicIndices: DataFrame = lda.describeTopics(maxTermsPerTopic)
+    if (vocabulary.isEmpty || vocabulary.length < vocabSize) {
+      topicIndices
+    } else {
+      val index2term = udf { indices: mutable.WrappedArray[Int] => indices.map(i => vocabulary(i)) }
+      topicIndices
+        .select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights"))
+    }
+  }
+
+  lazy val isDistributed: Boolean = lda.isDistributed
+  lazy val vocabSize: Int = lda.vocabSize
+  lazy val docConcentration: Array[Double] = lda.getEffectiveDocConcentration
+  lazy val topicConcentration: Double = lda.getEffectiveTopicConcentration
+
+  override def write: MLWriter = new LDAWrapper.LDAWrapperWriter(this)
+}
+
+private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
+
+  val TOKENIZER_COL = s"${Identifiable.randomUID("rawTokens")}"
+  val STOPWORDS_REMOVER_COL = s"${Identifiable.randomUID("tokens")}"
+  val COUNT_VECTOR_COL = s"${Identifiable.randomUID("features")}"
+
+  private def getPreStages(
+      features: String,
+      customizedStopWords: Array[String],
+      maxVocabSize: Int): Array[PipelineStage] = {
+    val tokenizer = new RegexTokenizer()
+      .setInputCol(features)
+      .setOutputCol(TOKENIZER_COL)
+    val stopWordsRemover = new StopWordsRemover()
+      .setInputCol(TOKENIZER_COL)
+      .setOutputCol(STOPWORDS_REMOVER_COL)
+    stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
+    val countVectorizer = new CountVectorizer()
+      .setVocabSize(maxVocabSize)
+      .setInputCol(STOPWORDS_REMOVER_COL)
+      .setOutputCol(COUNT_VECTOR_COL)
+
+    Array(tokenizer, stopWordsRemover, countVectorizer)
+  }
+
+  def fit(
+      data: DataFrame,
+      features: String,
+      k: Int,
+      maxIter: Int,
+      optimizer: String,
+      subsamplingRate: Double,
+      topicConcentration: Double,
+      docConcentration: Array[Double],
+      customizedStopWords: Array[String],
+      maxVocabSize: Int): LDAWrapper = {
+
+    val lda = new LDA()
+      .setK(k)
+      .setMaxIter(maxIter)
+      .setSubsamplingRate(subsamplingRate)
+
+    val featureSchema = data.schema(features)
+    val stages = featureSchema.dataType match {
+      case d: StringType =>
+        getPreStages(features, customizedStopWords, maxVocabSize) ++
+          Array(lda.setFeaturesCol(COUNT_VECTOR_COL))
+      case d: VectorUDT =>
+        Array(lda.setFeaturesCol(features))
+      case _ =>
+        throw new SparkException(
+          s"Unsupported input features type of ${featureSchema.dataType.typeName}," +
+            s" only String type and Vector type are supported now.")
+    }
+
+    if (topicConcentration != -1) {
+      lda.setTopicConcentration(topicConcentration)
+    } else {
+      // Auto-set topicConcentration
+    }
+
+    if (docConcentration.length == 1) {
+      if (docConcentration.head != -1) {
+        lda.setDocConcentration(docConcentration.head)
+      } else {
+        // Auto-set docConcentration
+      }
+    } else {
+      lda.setDocConcentration(docConcentration)
+    }
+
+    val pipeline = new Pipeline().setStages(stages)
+    val model = pipeline.fit(data)
+
+    val vocabulary: Array[String] = featureSchema.dataType match {
+      case d: StringType =>
+        val countVectorModel = model.stages(2).asInstanceOf[CountVectorizerModel]
+        countVectorModel.vocabulary
+      case _ => Array.empty[String]
+    }
+
+    val ldaModel: LDAModel = model.stages.last.asInstanceOf[LDAModel]
+    val preprocessor: PipelineModel =
+      new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", model.stages.dropRight(1))
+
+    val preprocessedData = preprocessor.transform(data)
+
+    new LDAWrapper(
+      model,
+      ldaModel.logLikelihood(preprocessedData),
+      ldaModel.logPerplexity(preprocessedData),
+      vocabulary)
+  }
+
+  override def read: MLReader[LDAWrapper] = new LDAWrapperReader
+
+  override def load(path: String): LDAWrapper = super.load(path)
+
+  class LDAWrapperWriter(instance: LDAWrapper) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("logLikelihood" -> instance.logLikelihood) ~
+        ("logPerplexity" -> instance.logPerplexity) ~
+        ("vocabulary" -> instance.vocabulary.toList)
+      val rMetadataJson: String = compact(render(rMetadata))
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class LDAWrapperReader extends MLReader[LDAWrapper] {
+
+    override def load(path: String): LDAWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
+      val logPerplexity = (rMetadata \ "logPerplexity").extract[Double]
+      val vocabulary = (rMetadata \ "vocabulary").extract[List[String]].toArray
+
+      val pipeline = PipelineModel.load(pipelinePath)
+      new LDAWrapper(pipeline, logLikelihood, logPerplexity, vocabulary)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b72bb62d/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 88ac26b..e23af51 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] {
         GeneralizedLinearRegressionWrapper.load(path)
       case "org.apache.spark.ml.r.KMeansWrapper" =>
         KMeansWrapper.load(path)
+      case "org.apache.spark.ml.r.LDAWrapper" =>
+        LDAWrapper.load(path)
       case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
         IsotonicRegressionWrapper.load(path)
       case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org