You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/04/29 19:43:03 UTC

spark git commit: [SPARK-11940][PYSPARK][ML] Python API for ml.clustering.LDA PR2

Repository: spark
Updated Branches:
  refs/heads/master f08dcdb8d -> 775772de3


[SPARK-11940][PYSPARK][ML] Python API for ml.clustering.LDA PR2

## What changes were proposed in this pull request?

pyspark.ml API for LDA
* LDA, LDAModel, LocalLDAModel, DistributedLDAModel
* includes persistence

This replaces [https://github.com/apache/spark/pull/10242]

## How was this patch tested?

* doc test for LDA, including Param setters
* unit test for persistence

Author: Joseph K. Bradley <jo...@databricks.com>
Author: Jeff Zhang <zj...@apache.org>

Closes #12723 from jkbradley/zjffdu-SPARK-11940.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/775772de
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/775772de
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/775772de

Branch: refs/heads/master
Commit: 775772de36d5b7e80595aad850aa1dcea8791688
Parents: f08dcdb
Author: Jeff Zhang <zj...@gmail.com>
Authored: Fri Apr 29 10:42:52 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Apr 29 10:42:52 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/LDA.scala    |   7 +-
 python/pyspark/ml/clustering.py                 | 488 ++++++++++++++++++-
 python/pyspark/ml/tests.py                      |  57 ++-
 3 files changed, 546 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/775772de/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 1554d56..38ecc5a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -355,7 +355,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
  * :: Experimental ::
  * Model fitted by [[LDA]].
  *
- * @param vocabSize  Vocabulary size (number of terms or terms in the vocabulary)
+ * @param vocabSize  Vocabulary size (number of terms or words in the vocabulary)
  * @param sparkSession  Used to construct local DataFrames for returning query results
  */
 @Since("1.6.0")
@@ -745,9 +745,8 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
  *  - "topic": multinomial distribution over terms representing some concept
  *  - "document": one piece of text, corresponding to one row in the input data
  *
- * References:
- *  - Original LDA paper (journal version):
- *    Blei, Ng, and Jordan.  "Latent Dirichlet Allocation."  JMLR, 2003.
+ * Original LDA paper (journal version):
+ *  Blei, Ng, and Jordan.  "Latent Dirichlet Allocation."  JMLR, 2003.
  *
  * Input data (featuresCol):
  *  LDA is given a collection of documents as input data, via the featuresCol parameter.

http://git-wip-us.apache.org/repos/asf/spark/blob/775772de/python/pyspark/ml/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 16ce02e..50ebf4f 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -23,7 +23,8 @@ from pyspark.mllib.common import inherit_doc
 
 __all__ = ['BisectingKMeans', 'BisectingKMeansModel',
            'KMeans', 'KMeansModel',
-           'GaussianMixture', 'GaussianMixtureModel']
+           'GaussianMixture', 'GaussianMixtureModel',
+           'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel']
 
 
 class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
@@ -450,6 +451,491 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
         return BisectingKMeansModel(java_model)
 
 
+@inherit_doc
+class LDAModel(JavaModel):
+    """
+    .. note:: Experimental
+
+    Latent Dirichlet Allocation (LDA) model.
+    This abstraction permits for different underlying representations,
+    including local and distributed data structures.
+
+    .. versionadded:: 2.0.0
+    """
+
+    @since("2.0.0")
+    def isDistributed(self):
+        """
+        Indicates whether this instance is of type DistributedLDAModel
+        """
+        return self._call_java("isDistributed")
+
+    @since("2.0.0")
+    def vocabSize(self):
+        """Vocabulary size (number of terms or words in the vocabulary)"""
+        return self._call_java("vocabSize")
+
+    @since("2.0.0")
+    def topicsMatrix(self):
+        """
+        Inferred topics, where each topic is represented by a distribution over terms.
+        This is a matrix of size vocabSize x k, where each column is a topic.
+        No guarantees are given about the ordering of the topics.
+
+        WARNING: If this model is actually a :py:class:`DistributedLDAModel` instance produced by
+        the Expectation-Maximization ("em") `optimizer`, then this method could involve
+        collecting a large amount of data to the driver (on the order of vocabSize x k).
+        """
+        return self._call_java("topicsMatrix")
+
+    @since("2.0.0")
+    def logLikelihood(self, dataset):
+        """
+        Calculates a lower bound on the log likelihood of the entire corpus.
+        See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
+
+        WARNING: If this model is an instance of :py:class:`DistributedLDAModel` (produced when
+        :py:attr:`optimizer` is set to "em"), this involves collecting a large
+        :py:func:`topicsMatrix` to the driver. This implementation may be changed in the future.
+        """
+        return self._call_java("logLikelihood", dataset)
+
+    @since("2.0.0")
+    def logPerplexity(self, dataset):
+        """
+        Calculate an upper bound bound on perplexity.  (Lower is better.)
+        See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
+
+        WARNING: If this model is an instance of :py:class:`DistributedLDAModel` (produced when
+        :py:attr:`optimizer` is set to "em"), this involves collecting a large
+        :py:func:`topicsMatrix` to the driver. This implementation may be changed in the future.
+        """
+        return self._call_java("logPerplexity", dataset)
+
+    @since("2.0.0")
+    def describeTopics(self, maxTermsPerTopic=10):
+        """
+        Return the topics described by their top-weighted terms.
+        """
+        return self._call_java("describeTopics", maxTermsPerTopic)
+
+    @since("2.0.0")
+    def estimatedDocConcentration(self):
+        """
+        Value for :py:attr:`LDA.docConcentration` estimated from data.
+        If Online LDA was used and :py:attr::`LDA.optimizeDocConcentration` was set to false,
+        then this returns the fixed (given) value for the :py:attr:`LDA.docConcentration` parameter.
+        """
+        return self._call_java("estimatedDocConcentration")
+
+
+@inherit_doc
+class DistributedLDAModel(LDAModel, JavaMLReadable, JavaMLWritable):
+    """
+    .. note:: Experimental
+
+    Distributed model fitted by :py:class:`LDA`.
+    This type of model is currently only produced by Expectation-Maximization (EM).
+
+    This model stores the inferred topics, the full training dataset, and the topic distribution
+    for each training document.
+
+    .. versionadded:: 2.0.0
+    """
+
+    @since("2.0.0")
+    def toLocal(self):
+        """
+        Convert this distributed model to a local representation.  This discards info about the
+        training dataset.
+
+        WARNING: This involves collecting a large :py:func:`topicsMatrix` to the driver.
+        """
+        return LocalLDAModel(self._call_java("toLocal"))
+
+    @since("2.0.0")
+    def trainingLogLikelihood(self):
+        """
+        Log likelihood of the observed tokens in the training set,
+        given the current parameter estimates:
+        log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters)
+
+        Notes:
+          - This excludes the prior; for that, use :py:func:`logPrior`.
+          - Even with :py:func:`logPrior`, this is NOT the same as the data log likelihood given
+            the hyperparameters.
+          - This is computed from the topic distributions computed during training. If you call
+            :py:func:`logLikelihood` on the same training dataset, the topic distributions
+            will be computed again, possibly giving different results.
+        """
+        return self._call_java("trainingLogLikelihood")
+
+    @since("2.0.0")
+    def logPrior(self):
+        """
+        Log probability of the current parameter estimate:
+        log P(topics, topic distributions for docs | alpha, eta)
+        """
+        return self._call_java("logPrior")
+
+    @since("2.0.0")
+    def getCheckpointFiles(self):
+        """
+        If using checkpointing and :py:attr:`LDA.keepLastCheckpoint` is set to true, then there may
+        be saved checkpoint files.  This method is provided so that users can manage those files.
+
+        Note that removing the checkpoints can cause failures if a partition is lost and is needed
+        by certain :py:class:`DistributedLDAModel` methods.  Reference counting will clean up the
+        checkpoints when this model and derivative data go out of scope.
+
+        :return  List of checkpoint files from training
+        """
+        return self._call_java("getCheckpointFiles")
+
+
+@inherit_doc
+class LocalLDAModel(LDAModel, JavaMLReadable, JavaMLWritable):
+    """
+    .. note:: Experimental
+
+    Local (non-distributed) model fitted by :py:class:`LDA`.
+    This model stores the inferred topics only; it does not store info about the training dataset.
+
+    .. versionadded:: 2.0.0
+    """
+    pass
+
+
+@inherit_doc
+class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInterval,
+          JavaMLReadable, JavaMLWritable):
+    """
+    .. note:: Experimental
+
+    Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
+
+    Terminology:
+
+     - "term" = "word": an el
+     - "token": instance of a term appearing in a document
+     - "topic": multinomial distribution over terms representing some concept
+     - "document": one piece of text, corresponding to one row in the input data
+
+    Original LDA paper (journal version):
+      Blei, Ng, and Jordan.  "Latent Dirichlet Allocation."  JMLR, 2003.
+
+    Input data (featuresCol):
+    LDA is given a collection of documents as input data, via the featuresCol parameter.
+    Each document is specified as a :py:class:`Vector` of length vocabSize, where each entry is the
+    count for the corresponding term (word) in the document.  Feature transformers such as
+    :py:class:`pyspark.ml.feature.Tokenizer` and :py:class:`pyspark.ml.feature.CountVectorizer`
+    can be useful for converting text to word count vectors.
+
+    >>> from pyspark.mllib.linalg import Vectors, SparseVector
+    >>> from pyspark.ml.clustering import LDA
+    >>> df = sqlContext.createDataFrame([[1, Vectors.dense([0.0, 1.0])],
+    ...      [2, SparseVector(2, {0: 1.0})],], ["id", "features"])
+    >>> lda = LDA(k=2, seed=1, optimizer="em")
+    >>> model = lda.fit(df)
+    >>> model.isDistributed()
+    True
+    >>> localModel = model.toLocal()
+    >>> localModel.isDistributed()
+    False
+    >>> model.vocabSize()
+    2
+    >>> model.describeTopics().show()
+    +-----+-----------+--------------------+
+    |topic|termIndices|         termWeights|
+    +-----+-----------+--------------------+
+    |    0|     [1, 0]|[0.50401530077160...|
+    |    1|     [0, 1]|[0.50401530077160...|
+    +-----+-----------+--------------------+
+    ...
+    >>> model.topicsMatrix()
+    DenseMatrix(2, 2, [0.496, 0.504, 0.504, 0.496], 0)
+    >>> lda_path = temp_path + "/lda"
+    >>> lda.save(lda_path)
+    >>> sameLDA = LDA.load(lda_path)
+    >>> distributed_model_path = temp_path + "/lda_distributed_model"
+    >>> model.save(distributed_model_path)
+    >>> sameModel = DistributedLDAModel.load(distributed_model_path)
+    >>> local_model_path = temp_path + "/lda_local_model"
+    >>> localModel.save(local_model_path)
+    >>> sameLocalModel = LocalLDAModel.load(local_model_path)
+
+    .. versionadded:: 2.0.0
+    """
+
+    k = Param(Params._dummy(), "k", "number of topics (clusters) to infer",
+              typeConverter=TypeConverters.toInt)
+    optimizer = Param(Params._dummy(), "optimizer",
+                      "Optimizer or inference algorithm used to estimate the LDA model.  "
+                      "Supported: online, em", typeConverter=TypeConverters.toString)
+    learningOffset = Param(Params._dummy(), "learningOffset",
+                           "A (positive) learning parameter that downweights early iterations."
+                           " Larger values make early iterations count less",
+                           typeConverter=TypeConverters.toFloat)
+    learningDecay = Param(Params._dummy(), "learningDecay", "Learning rate, set as an"
+                          "exponential decay rate. This should be between (0.5, 1.0] to "
+                          "guarantee asymptotic convergence.", typeConverter=TypeConverters.toFloat)
+    subsamplingRate = Param(Params._dummy(), "subsamplingRate",
+                            "Fraction of the corpus to be sampled and used in each iteration "
+                            "of mini-batch gradient descent, in range (0, 1].",
+                            typeConverter=TypeConverters.toFloat)
+    optimizeDocConcentration = Param(Params._dummy(), "optimizeDocConcentration",
+                                     "Indicates whether the docConcentration (Dirichlet parameter "
+                                     "for document-topic distribution) will be optimized during "
+                                     "training.", typeConverter=TypeConverters.toBoolean)
+    docConcentration = Param(Params._dummy(), "docConcentration",
+                             "Concentration parameter (commonly named \"alpha\") for the "
+                             "prior placed on documents' distributions over topics (\"theta\").",
+                             typeConverter=TypeConverters.toListFloat)
+    topicConcentration = Param(Params._dummy(), "topicConcentration",
+                               "Concentration parameter (commonly named \"beta\" or \"eta\") for "
+                               "the prior placed on topic' distributions over terms.",
+                               typeConverter=TypeConverters.toFloat)
+    topicDistributionCol = Param(Params._dummy(), "topicDistributionCol",
+                                 "Output column with estimates of the topic mixture distribution "
+                                 "for each document (often called \"theta\" in the literature). "
+                                 "Returns a vector of zeros for an empty document.",
+                                 typeConverter=TypeConverters.toString)
+    keepLastCheckpoint = Param(Params._dummy(), "keepLastCheckpoint",
+                               "(For EM optimizer) If using checkpointing, this indicates whether"
+                               " to keep the last checkpoint. If false, then the checkpoint will be"
+                               " deleted. Deleting the checkpoint can cause failures if a data"
+                               " partition is lost, so set this bit with care.",
+                               TypeConverters.toBoolean)
+
+    @keyword_only
+    def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,
+                 k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
+                 subsamplingRate=0.05, optimizeDocConcentration=True,
+                 docConcentration=None, topicConcentration=None,
+                 topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+        """
+        __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\
+                  k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
+                  subsamplingRate=0.05, optimizeDocConcentration=True,\
+                  docConcentration=None, topicConcentration=None,\
+                  topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+        """
+        super(LDA, self).__init__()
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid)
+        self._setDefault(maxIter=20, checkpointInterval=10,
+                         k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
+                         subsamplingRate=0.05, optimizeDocConcentration=True,
+                         topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
+        kwargs = self.__init__._input_kwargs
+        self.setParams(**kwargs)
+
+    def _create_model(self, java_model):
+        if self.getOptimizer() == "em":
+            return DistributedLDAModel(java_model)
+        else:
+            return LocalLDAModel(java_model)
+
+    @keyword_only
+    @since("2.0.0")
+    def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,
+                  k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
+                  subsamplingRate=0.05, optimizeDocConcentration=True,
+                  docConcentration=None, topicConcentration=None,
+                  topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+        """
+        setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\
+                  k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
+                  subsamplingRate=0.05, optimizeDocConcentration=True,\
+                  docConcentration=None, topicConcentration=None,\
+                  topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+
+        Sets params for LDA.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    @since("2.0.0")
+    def setK(self, value):
+        """
+        Sets the value of :py:attr:`k`.
+
+        >>> algo = LDA().setK(10)
+        >>> algo.getK()
+        10
+        """
+        return self._set(k=value)
+
+    @since("2.0.0")
+    def getK(self):
+        """
+        Gets the value of :py:attr:`k` or its default value.
+        """
+        return self.getOrDefault(self.k)
+
+    @since("2.0.0")
+    def setOptimizer(self, value):
+        """
+        Sets the value of :py:attr:`optimizer`.
+        Currenlty only support 'em' and 'online'.
+
+        >>> algo = LDA().setOptimizer("em")
+        >>> algo.getOptimizer()
+        'em'
+        """
+        return self._set(optimizer=value)
+
+    @since("2.0.0")
+    def getOptimizer(self):
+        """
+        Gets the value of :py:attr:`optimizer` or its default value.
+        """
+        return self.getOrDefault(self.optimizer)
+
+    @since("2.0.0")
+    def setLearningOffset(self, value):
+        """
+        Sets the value of :py:attr:`learningOffset`.
+
+        >>> algo = LDA().setLearningOffset(100)
+        >>> algo.getLearningOffset()
+        100.0
+        """
+        return self._set(learningOffset=value)
+
+    @since("2.0.0")
+    def getLearningOffset(self):
+        """
+        Gets the value of :py:attr:`learningOffset` or its default value.
+        """
+        return self.getOrDefault(self.learningOffset)
+
+    @since("2.0.0")
+    def setLearningDecay(self, value):
+        """
+        Sets the value of :py:attr:`learningDecay`.
+
+        >>> algo = LDA().setLearningDecay(0.1)
+        >>> algo.getLearningDecay()
+        0.1...
+        """
+        return self._set(learningDecay=value)
+
+    @since("2.0.0")
+    def getLearningDecay(self):
+        """
+        Gets the value of :py:attr:`learningDecay` or its default value.
+        """
+        return self.getOrDefault(self.learningDecay)
+
+    @since("2.0.0")
+    def setSubsamplingRate(self, value):
+        """
+        Sets the value of :py:attr:`subsamplingRate`.
+
+        >>> algo = LDA().setSubsamplingRate(0.1)
+        >>> algo.getSubsamplingRate()
+        0.1...
+        """
+        return self._set(subsamplingRate=value)
+
+    @since("2.0.0")
+    def getSubsamplingRate(self):
+        """
+        Gets the value of :py:attr:`subsamplingRate` or its default value.
+        """
+        return self.getOrDefault(self.subsamplingRate)
+
+    @since("2.0.0")
+    def setOptimizeDocConcentration(self, value):
+        """
+        Sets the value of :py:attr:`optimizeDocConcentration`.
+
+        >>> algo = LDA().setOptimizeDocConcentration(True)
+        >>> algo.getOptimizeDocConcentration()
+        True
+        """
+        return self._set(optimizeDocConcentration=value)
+
+    @since("2.0.0")
+    def getOptimizeDocConcentration(self):
+        """
+        Gets the value of :py:attr:`optimizeDocConcentration` or its default value.
+        """
+        return self.getOrDefault(self.optimizeDocConcentration)
+
+    @since("2.0.0")
+    def setDocConcentration(self, value):
+        """
+        Sets the value of :py:attr:`docConcentration`.
+
+        >>> algo = LDA().setDocConcentration([0.1, 0.2])
+        >>> algo.getDocConcentration()
+        [0.1..., 0.2...]
+        """
+        return self._set(docConcentration=value)
+
+    @since("2.0.0")
+    def getDocConcentration(self):
+        """
+        Gets the value of :py:attr:`docConcentration` or its default value.
+        """
+        return self.getOrDefault(self.docConcentration)
+
+    @since("2.0.0")
+    def setTopicConcentration(self, value):
+        """
+        Sets the value of :py:attr:`topicConcentration`.
+
+        >>> algo = LDA().setTopicConcentration(0.5)
+        >>> algo.getTopicConcentration()
+        0.5...
+        """
+        return self._set(topicConcentration=value)
+
+    @since("2.0.0")
+    def getTopicConcentration(self):
+        """
+        Gets the value of :py:attr:`topicConcentration` or its default value.
+        """
+        return self.getOrDefault(self.topicConcentration)
+
+    @since("2.0.0")
+    def setTopicDistributionCol(self, value):
+        """
+        Sets the value of :py:attr:`topicDistributionCol`.
+
+        >>> algo = LDA().setTopicDistributionCol("topicDistributionCol")
+        >>> algo.getTopicDistributionCol()
+        'topicDistributionCol'
+        """
+        return self._set(topicDistributionCol=value)
+
+    @since("2.0.0")
+    def getTopicDistributionCol(self):
+        """
+        Gets the value of :py:attr:`topicDistributionCol` or its default value.
+        """
+        return self.getOrDefault(self.topicDistributionCol)
+
+    @since("2.0.0")
+    def setKeepLastCheckpoint(self, value):
+        """
+        Sets the value of :py:attr:`keepLastCheckpoint`.
+
+        >>> algo = LDA().setKeepLastCheckpoint(False)
+        >>> algo.getKeepLastCheckpoint()
+        False
+        """
+        return self._set(keepLastCheckpoint=value)
+
+    @since("2.0.0")
+    def getKeepLastCheckpoint(self):
+        """
+        Gets the value of :py:attr:`keepLastCheckpoint` or its default value.
+        """
+        return self.getOrDefault(self.keepLastCheckpoint)
+
+
 if __name__ == "__main__":
     import doctest
     import pyspark.ml.clustering

http://git-wip-us.apache.org/repos/asf/spark/blob/775772de/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 36cecd4..e7d4c0a 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -46,7 +46,7 @@ from pyspark import keyword_only
 from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
 from pyspark.ml.classification import (
     LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel)
-from pyspark.ml.clustering import KMeans
+from pyspark.ml.clustering import *
 from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
 from pyspark.ml.feature import *
 from pyspark.ml.param import Param, Params, TypeConverters
@@ -809,6 +809,61 @@ class PersistenceTest(PySparkTestCase):
             pass
 
 
+class LDATest(PySparkTestCase):
+
+    def _compare(self, m1, m2):
+        """
+        Temp method for comparing instances.
+        TODO: Replace with generic implementation once SPARK-14706 is merged.
+        """
+        self.assertEqual(m1.uid, m2.uid)
+        self.assertEqual(type(m1), type(m2))
+        self.assertEqual(len(m1.params), len(m2.params))
+        for p in m1.params:
+            if m1.isDefined(p):
+                self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
+                self.assertEqual(p.parent, m2.getParam(p.name).parent)
+        if isinstance(m1, LDAModel):
+            self.assertEqual(m1.vocabSize(), m2.vocabSize())
+            self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix())
+
+    def test_persistence(self):
+        # Test save/load for LDA, LocalLDAModel, DistributedLDAModel.
+        sqlContext = SQLContext(self.sc)
+        df = sqlContext.createDataFrame([
+            [1, Vectors.dense([0.0, 1.0])],
+            [2, Vectors.sparse(2, {0: 1.0})],
+        ], ["id", "features"])
+        # Fit model
+        lda = LDA(k=2, seed=1, optimizer="em")
+        distributedModel = lda.fit(df)
+        self.assertTrue(distributedModel.isDistributed())
+        localModel = distributedModel.toLocal()
+        self.assertFalse(localModel.isDistributed())
+        # Define paths
+        path = tempfile.mkdtemp()
+        lda_path = path + "/lda"
+        dist_model_path = path + "/distLDAModel"
+        local_model_path = path + "/localLDAModel"
+        # Test LDA
+        lda.save(lda_path)
+        lda2 = LDA.load(lda_path)
+        self._compare(lda, lda2)
+        # Test DistributedLDAModel
+        distributedModel.save(dist_model_path)
+        distributedModel2 = DistributedLDAModel.load(dist_model_path)
+        self._compare(distributedModel, distributedModel2)
+        # Test LocalLDAModel
+        localModel.save(local_model_path)
+        localModel2 = LocalLDAModel.load(local_model_path)
+        self._compare(localModel, localModel2)
+        # Clean up
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
+
 class TrainingSummaryTest(PySparkTestCase):
 
     def test_linear_regression_summary(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org