You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/01/16 14:06:07 UTC

spark git commit: [SPARK-19066][SPARKR] SparkR LDA doesn't set optimizer correctly

Repository: spark
Updated Branches:
  refs/heads/master e635cbb6e -> 12c8c2160


[SPARK-19066][SPARKR] SparkR LDA doesn't set optimizer correctly

## What changes were proposed in this pull request?

spark.lda passes the optimizer "em" or "online" as a string to the backend. However, LDAWrapper doesn't set optimizer based on the value from R. Therefore, for optimizer "em", the `isDistributed` field is FALSE, which should be TRUE based on scala code.

In addition, the `summary` method should bring back the results related to `DistributedLDAModel`.

## How was this patch tested?
Manual tests by comparing with scala example.
Modified the current unit test: fix the incorrect unit test and add necessary tests for `summary` method.

Author: wm624@hotmail.com <wm...@hotmail.com>

Closes #16464 from wangmiao1981/new.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/12c8c216
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/12c8c216
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/12c8c216

Branch: refs/heads/master
Commit: 12c8c2160829ad8ccdab1741530361cdabdcd39d
Parents: e635cbb
Author: wm624@hotmail.com <wm...@hotmail.com>
Authored: Mon Jan 16 06:05:59 2017 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Mon Jan 16 06:05:59 2017 -0800

----------------------------------------------------------------------
 R/pkg/R/mllib_clustering.R                      | 20 +++++++++++++++++++-
 .../inst/tests/testthat/test_mllib_clustering.R | 16 ++++++++++++++--
 R/pkg/inst/tests/testthat/test_mllib_tree.R     |  1 -
 .../org/apache/spark/ml/r/LDAWrapper.scala      | 10 +++++++++-
 4 files changed, 42 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/12c8c216/R/pkg/R/mllib_clustering.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R
index ca5182d..fb8d9e7 100644
--- a/R/pkg/R/mllib_clustering.R
+++ b/R/pkg/R/mllib_clustering.R
@@ -397,6 +397,13 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"),
 #'         \item{\code{topics}}{top 10 terms and their weights of all topics}
 #'         \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file
 #'               used as training set}
+#'         \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the training set,
+#'               given the current parameter estimates:
+#'               log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters)
+#'               It is only for distributed LDA model (i.e., optimizer = "em")}
+#'         \item{\code{logPrior}}{Log probability of the current parameter estimate:
+#'               log P(topics, topic distributions for docs | Dirichlet hyperparameters)
+#'               It is only for distributed LDA model (i.e., optimizer = "em")}
 #' @rdname spark.lda
 #' @aliases summary,LDAModel-method
 #' @export
@@ -413,11 +420,22 @@ setMethod("summary", signature(object = "LDAModel"),
             vocabSize <- callJMethod(jobj, "vocabSize")
             topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
             vocabulary <- callJMethod(jobj, "vocabulary")
+            trainingLogLikelihood <- if (isDistributed) {
+              callJMethod(jobj, "trainingLogLikelihood")
+            } else {
+              NA
+            }
+            logPrior <- if (isDistributed) {
+              callJMethod(jobj, "logPrior")
+            } else {
+              NA
+            }
             list(docConcentration = unlist(docConcentration),
                  topicConcentration = topicConcentration,
                  logLikelihood = logLikelihood, logPerplexity = logPerplexity,
                  isDistributed = isDistributed, vocabSize = vocabSize,
-                 topics = topics, vocabulary = unlist(vocabulary))
+                 topics = topics, vocabulary = unlist(vocabulary),
+                 trainingLogLikelihood = trainingLogLikelihood, logPrior = logPrior)
           })
 
 #  Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}

http://git-wip-us.apache.org/repos/asf/spark/blob/12c8c216/R/pkg/inst/tests/testthat/test_mllib_clustering.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
index f013991..cfbdea5 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -166,12 +166,16 @@ test_that("spark.lda with libsvm", {
   topics <- stats$topicTopTerms
   weights <- stats$topicTopTermsWeights
   vocabulary <- stats$vocabulary
+  trainingLogLikelihood <- stats$trainingLogLikelihood
+  logPrior <- stats$logPrior
 
-  expect_false(isDistributed)
+  expect_true(isDistributed)
   expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
   expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
   expect_equal(vocabSize, 11)
   expect_true(is.null(vocabulary))
+  expect_true(trainingLogLikelihood <= 0 & !is.na(trainingLogLikelihood))
+  expect_true(logPrior <= 0 & !is.na(logPrior))
 
   # Test model save/load
   modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
@@ -181,11 +185,13 @@ test_that("spark.lda with libsvm", {
   model2 <- read.ml(modelPath)
   stats2 <- summary(model2)
 
-  expect_false(stats2$isDistributed)
+  expect_true(stats2$isDistributed)
   expect_equal(logLikelihood, stats2$logLikelihood)
   expect_equal(logPerplexity, stats2$logPerplexity)
   expect_equal(vocabSize, stats2$vocabSize)
   expect_equal(vocabulary, stats2$vocabulary)
+  expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood)
+  expect_equal(logPrior, stats2$logPrior)
 
   unlink(modelPath)
 })
@@ -202,12 +208,16 @@ test_that("spark.lda with text input", {
   topics <- stats$topicTopTerms
   weights <- stats$topicTopTermsWeights
   vocabulary <- stats$vocabulary
+  trainingLogLikelihood <- stats$trainingLogLikelihood
+  logPrior <- stats$logPrior
 
   expect_false(isDistributed)
   expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
   expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
   expect_equal(vocabSize, 10)
   expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
+  expect_true(is.na(trainingLogLikelihood))
+  expect_true(is.na(logPrior))
 
   # Test model save/load
   modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
@@ -222,6 +232,8 @@ test_that("spark.lda with text input", {
   expect_equal(logPerplexity, stats2$logPerplexity)
   expect_equal(vocabSize, stats2$vocabSize)
   expect_true(all.equal(vocabulary, stats2$vocabulary))
+  expect_true(is.na(stats2$trainingLogLikelihood))
+  expect_true(is.na(stats2$logPrior))
 
   unlink(modelPath)
 })

http://git-wip-us.apache.org/repos/asf/spark/blob/12c8c216/R/pkg/inst/tests/testthat/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R
index 5d13539..e6fda25 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_tree.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R
@@ -126,7 +126,6 @@ test_that("spark.randomForest", {
                                          63.53160, 64.05470, 65.12710, 64.30450,
                                          66.70910, 67.86125, 68.08700, 67.21865,
                                          68.89275, 69.53180, 69.39640, 69.68250),
-
                tolerance = 1e-4)
   stats <- summary(model)
   expect_equal(stats$numTrees, 20)

http://git-wip-us.apache.org/repos/asf/spark/blob/12c8c216/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
index cbe6a70..e096bf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
@@ -26,7 +26,7 @@ import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkException
 import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
-import org.apache.spark.ml.clustering.{LDA, LDAModel}
+import org.apache.spark.ml.clustering.{DistributedLDAModel, LDA, LDAModel}
 import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
 import org.apache.spark.ml.linalg.{Vector, VectorUDT}
 import org.apache.spark.ml.param.ParamPair
@@ -45,6 +45,13 @@ private[r] class LDAWrapper private (
   import LDAWrapper._
 
   private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
+
+  // The following variables were called by R side code only when the LDA model is distributed
+  lazy private val distributedModel =
+    pipeline.stages.last.asInstanceOf[DistributedLDAModel]
+  lazy val trainingLogLikelihood: Double = distributedModel.trainingLogLikelihood
+  lazy val logPrior: Double = distributedModel.logPrior
+
   private val preprocessor: PipelineModel =
     new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
 
@@ -122,6 +129,7 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
       .setK(k)
       .setMaxIter(maxIter)
       .setSubsamplingRate(subsamplingRate)
+      .setOptimizer(optimizer)
 
     val featureSchema = data.schema(features)
     val stages = featureSchema.dataType match {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org