You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2017/10/07 07:30:51 UTC

spark git commit: [SPARK-22156][MLLIB] Fix update equation of learning rate in Word2Vec.scala

Repository: spark
Updated Branches:
  refs/heads/master 2030f1951 -> 5eacc3bfa


[SPARK-22156][MLLIB] Fix update equation of learning rate in Word2Vec.scala

## What changes were proposed in this pull request?

Current equation of learning rate is incorrect when `numIterations` > `1`.
This PR is based on [original C code](https://github.com/tmikolov/word2vec/blob/master/word2vec.c#L393).

cc: mengxr

## How was this patch tested?

manual tests

I modified [this example code](https://spark.apache.org/docs/2.1.1/mllib-feature-extraction.html#example).

### `numIteration=1`

#### Code

```scala
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}

val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq)

val word2vec = new Word2Vec()

val model = word2vec.fit(input)

val synonyms = model.findSynonyms("1", 5)

for((synonym, cosineSimilarity) <- synonyms) {
  println(s"$synonym $cosineSimilarity")
}
```

#### Result

```
2 0.175856813788414
0 0.10971353203058243
4 0.09818313270807266
3 0.012947646901011467
9 -0.09881238639354706
```

### `numIteration=5`

#### Code

```scala
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}

val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq)

val word2vec = new Word2Vec()
word2vec.setNumIterations(5)

val model = word2vec.fit(input)

val synonyms = model.findSynonyms("1", 5)

for((synonym, cosineSimilarity) <- synonyms) {
  println(s"$synonym $cosineSimilarity")
}
```

#### Result

```
0 0.9898583889007568
2 0.9808019399642944
4 0.9794934391975403
3 0.9506527781486511
9 -0.9065656661987305
```

Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>

Closes #19372 from nzw0301/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5eacc3bf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5eacc3bf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5eacc3bf

Branch: refs/heads/master
Commit: 5eacc3bfa9b9c1435ce04222ac7f943b5f930cf4
Parents: 2030f19
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Authored: Sat Oct 7 08:30:48 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sat Oct 7 08:30:48 2017 +0100

----------------------------------------------------------------------
 .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5eacc3bf/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 6f96813..b8c306d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -353,11 +353,14 @@ class Word2Vec extends Serializable with Logging {
     val syn0Global =
       Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
     val syn1Global = new Array[Float](vocabSize * vectorSize)
+    val totalWordsCounts = numIterations * trainWordsCount + 1
     var alpha = learningRate
 
     for (k <- 1 to numIterations) {
       val bcSyn0Global = sc.broadcast(syn0Global)
       val bcSyn1Global = sc.broadcast(syn1Global)
+      val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount
+
       val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
         val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
         val syn0Modify = new Array[Int](vocabSize)
@@ -368,11 +371,12 @@ class Word2Vec extends Serializable with Logging {
             var wc = wordCount
             if (wordCount - lastWordCount > 10000) {
               lwc = wordCount
-              // TODO: discount by iteration?
-              alpha =
-                learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
+              alpha = learningRate *
+                (1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) /
+                  totalWordsCounts)
               if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
-              logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
+              logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " +
+                s"alpha = $alpha")
             }
             wc += sentence.length
             var pos = 0


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org