You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/30 04:02:19 UTC
spark git commit: [SPARK-9440] [MLLIB] Add hyperparameters to
LocalLDAModel save/load
Repository: spark
Updated Branches:
refs/heads/master 2a9fe4a4e -> a200e6456
[SPARK-9440] [MLLIB] Add hyperparameters to LocalLDAModel save/load
jkbradley MechCoder
Resolves blocking issue for SPARK-6793. Please review after #7705 is merged.
Author: Feynman Liang <fl...@databricks.com>
Closes #7757 from feynmanliang/SPARK-9940-localSaveLoad and squashes the following commits:
d0d8cf4 [Feynman Liang] Fix thisClassName
0f30109 [Feynman Liang] Fix tests after changing LDAModel public API
dc61981 [Feynman Liang] Add hyperparams to LocalLDAModel save/load
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a200e645
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a200e645
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a200e645
Branch: refs/heads/master
Commit: a200e64561c8803731578267df16906f6773cbea
Parents: 2a9fe4a
Author: Feynman Liang <fl...@databricks.com>
Authored: Wed Jul 29 19:02:15 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Jul 29 19:02:15 2015 -0700
----------------------------------------------------------------------
.../spark/mllib/clustering/LDAModel.scala | 40 ++++++++++++++------
.../spark/mllib/clustering/LDASuite.scala | 6 ++-
2 files changed, 33 insertions(+), 13 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a200e645/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 059b52e..ece2884 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -215,7 +215,8 @@ class LocalLDAModel private[clustering] (
override protected def formatVersion = "1.0"
override def save(sc: SparkContext, path: String): Unit = {
- LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
+ LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
+ gammaShape)
}
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
@@ -312,16 +313,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
// as a Row in data.
case class Data(topic: Vector, index: Int)
- // TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in
- // model.predict()
- def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
+ def save(
+ sc: SparkContext,
+ path: String,
+ topicsMatrix: Matrix,
+ docConcentration: Vector,
+ topicConcentration: Double,
+ gammaShape: Double): Unit = {
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val k = topicsMatrix.numCols
val metadata = compact(render
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
- ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
+ ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~
+ ("docConcentration" -> docConcentration.toArray.toSeq) ~
+ ("topicConcentration" -> topicConcentration) ~
+ ("gammaShape" -> gammaShape)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
@@ -331,7 +339,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
}
- def load(sc: SparkContext, path: String): LocalLDAModel = {
+ def load(
+ sc: SparkContext,
+ path: String,
+ docConcentration: Vector,
+ topicConcentration: Double,
+ gammaShape: Double): LocalLDAModel = {
val dataPath = Loader.dataPath(path)
val sqlContext = SQLContext.getOrCreate(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
@@ -348,8 +361,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
val topicsMat = Matrices.fromBreeze(brzTopics)
// TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940
- new LocalLDAModel(topicsMat,
- Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D)
+ new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape)
}
}
@@ -358,11 +370,15 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
implicit val formats = DefaultFormats
val expectedK = (metadata \ "k").extract[Int]
val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
+ val docConcentration =
+ Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
+ val topicConcentration = (metadata \ "topicConcentration").extract[Double]
+ val gammaShape = (metadata \ "gammaShape").extract[Double]
val classNameV1_0 = SaveLoadV1_0.thisClassName
val model = (loadedClassName, loadedVersion) match {
case (className, "1.0") if className == classNameV1_0 =>
- SaveLoadV1_0.load(sc, path)
+ SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape)
case _ => throw new Exception(
s"LocalLDAModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $loadedVersion). Supported:\n" +
@@ -565,7 +581,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val thisFormatVersion = "1.0"
- val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
+ val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel"
// Store globalTopicTotals as a Vector.
case class Data(globalTopicTotals: Vector)
@@ -591,7 +607,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
import sqlContext.implicits._
val metadata = compact(render
- (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
+ (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("k" -> k) ~ ("vocabSize" -> vocabSize) ~
("docConcentration" -> docConcentration.toArray.toSeq) ~
("topicConcentration" -> topicConcentration) ~
@@ -660,7 +676,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
val gammaShape = (metadata \ "gammaShape").extract[Double]
- val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
val model = (loadedClassName, loadedVersion) match {
case (className, "1.0") if className == classNameV1_0 => {
http://git-wip-us.apache.org/repos/asf/spark/blob/a200e645/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index aa36336..b91c7ce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
test("model save/load") {
// Test for LocalLDAModel.
val localModel = new LocalLDAModel(tinyTopics,
- Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
+ Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D)
val tempDir1 = Utils.createTempDir()
val path1 = tempDir1.toURI.toString
@@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
assert(samelocalModel.k === localModel.k)
assert(samelocalModel.vocabSize === localModel.vocabSize)
+ assert(samelocalModel.docConcentration === localModel.docConcentration)
+ assert(samelocalModel.topicConcentration === localModel.topicConcentration)
+ assert(samelocalModel.gammaShape === localModel.gammaShape)
val sameDistributedModel = DistributedLDAModel.load(sc, path2)
assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
@@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
+ assert(distributedModel.gammaShape === sameDistributedModel.gammaShape)
assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
val graph = distributedModel.graph
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org