You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/04/22 01:42:51 UTC

spark git commit: [SPARK-6065] [MLlib] Optimize word2vec.findSynonyms using blas calls

Repository: spark
Updated Branches:
  refs/heads/master a70e849c7 -> 7fe6142cd


[SPARK-6065] [MLlib] Optimize word2vec.findSynonyms using blas calls

1. Use blas calls to find the dot product between two vectors.
2. Prevent re-computing the L2 norm of the given vector for each word in model.

Author: MechCoder <ma...@gmail.com>

Closes #5467 from MechCoder/spark-6065 and squashes the following commits:

dd0b0b2 [MechCoder] Preallocate wordVectors
ffc9240 [MechCoder] Minor
6b74c81 [MechCoder] Switch back to native blas calls
da1642d [MechCoder] Explicit types and indexing
64575b0 [MechCoder] Save indexedmap and a wordvecmat instead of matrix
fbe0108 [MechCoder] Made the following changes 1. Calculate norms during initialization. 2. Use Blas calls from linalg.blas
1350cf3 [MechCoder] [SPARK-6065] Optimize word2vec.findSynonynms using blas calls


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7fe6142c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7fe6142c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7fe6142c

Branch: refs/heads/master
Commit: 7fe6142cd3c39ec79899878c3deca9d5130d05b1
Parents: a70e849
Author: MechCoder <ma...@gmail.com>
Authored: Tue Apr 21 16:42:45 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Apr 21 16:42:45 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/feature/Word2Vec.scala   | 57 +++++++++++++++++---
 1 file changed, 51 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7fe6142c/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index b2d9053..98e8311 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -34,7 +34,7 @@ import org.apache.spark.SparkContext
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector}
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd._
 import org.apache.spark.util.Utils
@@ -429,7 +429,36 @@ class Word2Vec extends Serializable with Logging {
  */
 @Experimental
 class Word2VecModel private[mllib] (
-    private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
+    model: Map[String, Array[Float]]) extends Serializable with Saveable {
+
+  // wordList: Ordered list of words obtained from model.
+  private val wordList: Array[String] = model.keys.toArray
+
+  // wordIndex: Maps each word to an index, which can retrieve the corresponding
+  //            vector from wordVectors (see below).
+  private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
+
+  // vectorSize: Dimension of each word's vector.
+  private val vectorSize = model.head._2.size
+  private val numWords = wordIndex.size
+
+  // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
+  //              mapped with index i can be retrieved by the slice
+  //              (ind * vectorSize, ind * vectorSize + vectorSize)
+  // wordVecNorms: Array of length numWords, each value being the Euclidean norm
+  //               of the wordVector.
+  private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
+    val wordVectors = new Array[Float](vectorSize * numWords)
+    val wordVecNorms = new Array[Double](numWords)
+    var i = 0
+    while (i < numWords) {
+      val vec = model.get(wordList(i)).get
+      Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize)
+      wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
+      i += 1
+    }
+    (wordVectors, wordVecNorms)
+  }
 
   private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
     require(v1.length == v2.length, "Vectors should have the same length")
@@ -443,7 +472,7 @@ class Word2VecModel private[mllib] (
   override protected def formatVersion = "1.0"
 
   def save(sc: SparkContext, path: String): Unit = {
-    Word2VecModel.SaveLoadV1_0.save(sc, path, model)
+    Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)
   }
 
   /**
@@ -479,9 +508,23 @@ class Word2VecModel private[mllib] (
    */
   def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
     require(num > 0, "Number of similar words should > 0")
-    // TODO: optimize top-k
+
     val fVector = vector.toArray.map(_.toFloat)
-    model.mapValues(vec => cosineSimilarity(fVector, vec))
+    val cosineVec = Array.fill[Float](numWords)(0)
+    val alpha: Float = 1
+    val beta: Float = 0
+
+    blas.sgemv(
+      "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
+
+    // Need not divide with the norm of the given vector since it is constant.
+    val updatedCosines = new Array[Double](numWords)
+    var ind = 0
+    while (ind < numWords) {
+      updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
+      ind += 1
+    }
+    wordList.zip(updatedCosines)
       .toSeq
       .sortBy(- _._2)
       .take(num + 1)
@@ -493,7 +536,9 @@ class Word2VecModel private[mllib] (
    * Returns a map of words to their vector representations.
    */
   def getVectors: Map[String, Array[Float]] = {
-    model
+    wordIndex.map { case (word, ind) =>
+      (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
+    }
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org