You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2017/10/07 07:30:51 UTC
spark git commit: [SPARK-22156][MLLIB] Fix update equation of
learning rate in Word2Vec.scala
Repository: spark
Updated Branches:
refs/heads/master 2030f1951 -> 5eacc3bfa
[SPARK-22156][MLLIB] Fix update equation of learning rate in Word2Vec.scala
## What changes were proposed in this pull request?
Current equation of learning rate is incorrect when `numIterations` > `1`.
This PR is based on [original C code](https://github.com/tmikolov/word2vec/blob/master/word2vec.c#L393).
cc: mengxr
## How was this patch tested?
manual tests
I modified [this example code](https://spark.apache.org/docs/2.1.1/mllib-feature-extraction.html#example).
### `numIteration=1`
#### Code
```scala
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq)
val word2vec = new Word2Vec()
val model = word2vec.fit(input)
val synonyms = model.findSynonyms("1", 5)
for((synonym, cosineSimilarity) <- synonyms) {
println(s"$synonym $cosineSimilarity")
}
```
#### Result
```
2 0.175856813788414
0 0.10971353203058243
4 0.09818313270807266
3 0.012947646901011467
9 -0.09881238639354706
```
### `numIteration=5`
#### Code
```scala
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq)
val word2vec = new Word2Vec()
word2vec.setNumIterations(5)
val model = word2vec.fit(input)
val synonyms = model.findSynonyms("1", 5)
for((synonym, cosineSimilarity) <- synonyms) {
println(s"$synonym $cosineSimilarity")
}
```
#### Result
```
0 0.9898583889007568
2 0.9808019399642944
4 0.9794934391975403
3 0.9506527781486511
9 -0.9065656661987305
```
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Closes #19372 from nzw0301/master.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5eacc3bf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5eacc3bf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5eacc3bf
Branch: refs/heads/master
Commit: 5eacc3bfa9b9c1435ce04222ac7f943b5f930cf4
Parents: 2030f19
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Authored: Sat Oct 7 08:30:48 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sat Oct 7 08:30:48 2017 +0100
----------------------------------------------------------------------
.../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5eacc3bf/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 6f96813..b8c306d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -353,11 +353,14 @@ class Word2Vec extends Serializable with Logging {
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
+ val totalWordsCounts = numIterations * trainWordsCount + 1
var alpha = learningRate
for (k <- 1 to numIterations) {
val bcSyn0Global = sc.broadcast(syn0Global)
val bcSyn1Global = sc.broadcast(syn1Global)
+ val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount
+
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
@@ -368,11 +371,12 @@ class Word2Vec extends Serializable with Logging {
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
- // TODO: discount by iteration?
- alpha =
- learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
+ alpha = learningRate *
+ (1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) /
+ totalWordsCounts)
if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
- logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
+ logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " +
+ s"alpha = $alpha")
}
wc += sentence.length
var pos = 0
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org