You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/04/01 01:01:12 UTC
spark git commit: [SPARK-5692] [MLlib] Word2Vec save/load
Repository: spark
Updated Branches:
refs/heads/master 2036bc599 -> 0e00f12d3
[SPARK-5692] [MLlib] Word2Vec save/load
Word2Vec model now supports saving and loading.
a] The Metadata stored in JSON format consists of "version", "classname", "vectorSize" and "numWords"
b] The data stored in Parquet file format consists of an Array of rows with each row consisting of 2 columns, first being the word: String and the second, an Array of Floats.
Author: MechCoder <ma...@gmail.com>
Closes #5291 from MechCoder/spark-5692 and squashes the following commits:
1142f3a [MechCoder] Add numWords to metaData
bfe4c39 [MechCoder] [SPARK-5692] Word2Vec save/load
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e00f12d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e00f12d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e00f12d
Branch: refs/heads/master
Commit: 0e00f12d33d28d064c166262b14e012a1aeaa7b0
Parents: 2036bc5
Author: MechCoder <ma...@gmail.com>
Authored: Tue Mar 31 16:01:08 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Mar 31 16:01:08 2015 -0700
----------------------------------------------------------------------
.../apache/spark/mllib/feature/Word2Vec.scala | 87 +++++++++++++++++++-
.../spark/mllib/feature/Word2VecSuite.scala | 26 ++++++
2 files changed, 110 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/0e00f12d/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 59a79e5..9ee7e4a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -25,14 +25,21 @@ import scala.collection.mutable.ArrayBuilder
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.Logging
+import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.sql.{SQLContext, Row}
/**
* Entry in vocabulary
@@ -422,7 +429,7 @@ class Word2Vec extends Serializable with Logging {
*/
@Experimental
class Word2VecModel private[mllib] (
- private val model: Map[String, Array[Float]]) extends Serializable {
+ private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
@@ -432,7 +439,13 @@ class Word2VecModel private[mllib] (
if (norm1 == 0 || norm2 == 0) return 0.0
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
}
-
+
+ override protected def formatVersion = "1.0"
+
+ def save(sc: SparkContext, path: String): Unit = {
+ Word2VecModel.SaveLoadV1_0.save(sc, path, model)
+ }
+
/**
* Transforms a word to its vector representation
* @param word a word
@@ -475,7 +488,7 @@ class Word2VecModel private[mllib] (
.tail
.toArray
}
-
+
/**
* Returns a map of words to their vector representations.
*/
@@ -483,3 +496,71 @@ class Word2VecModel private[mllib] (
model
}
}
+
+@Experimental
+object Word2VecModel extends Loader[Word2VecModel] {
+
+ private object SaveLoadV1_0 {
+
+ val formatVersionV1_0 = "1.0"
+
+ val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel"
+
+ case class Data(word: String, vector: Array[Float])
+
+ def load(sc: SparkContext, path: String): Word2VecModel = {
+ val dataPath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val dataFrame = sqlContext.parquetFile(dataPath)
+
+ val dataArray = dataFrame.select("word", "vector").collect()
+
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[Data](dataFrame.schema)
+
+ val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
+ new Word2VecModel(word2VecMap)
+ }
+
+ def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]) = {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val vectorSize = model.values.head.size
+ val numWords = model.size
+ val metadata = compact(render
+ (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
+ ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
+ sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): Word2VecModel = {
+
+ val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val expectedVectorSize = (metadata \ "vectorSize").extract[Int]
+ val expectedNumWords = (metadata \ "numWords").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+ (loadedClassName, loadedVersion) match {
+ case (classNameV1_0, "1.0") =>
+ val model = SaveLoadV1_0.load(sc, path)
+ val vectorSize = model.getVectors.values.head.size
+ val numWords = model.getVectors.size
+ require(expectedVectorSize == vectorSize,
+ s"Word2VecModel requires each word to be mapped to a vector of size " +
+ s"$expectedVectorSize, got vector of size $vectorSize")
+ require(expectedNumWords == numWords,
+ s"Word2VecModel requires $expectedNumWords words, but got $numWords")
+ model
+ case _ => throw new Exception(
+ s"Word2VecModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $loadedVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/0e00f12d/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 5227869..98a98a7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -21,6 +21,9 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
// TODO: add more tests
@@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
assert(syms(0)._1 == "taiwan")
assert(syms(1)._1 == "japan")
}
+
+ test("model load / save") {
+
+ val word2VecMap = Map(
+ ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
+ ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
+ ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
+ ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
+ )
+ val model = new Word2VecModel(word2VecMap)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ try {
+ model.save(sc, path)
+ val sameModel = Word2VecModel.load(sc, path)
+ assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq))
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org