Github user nzw0301 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19372#discussion_r142332793
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala ---
@@ -368,11 +371,12 @@ class Word2Vec extends Serializable with Logging {
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
- // TODO: discount by iteration?
- alpha =
- learningRate * (1 - numPartitions * wordCount.toDouble /
(trainWordsCount + 1))
+ alpha = learningRate *
+ (1 - (numPartitions * wordCount.toDouble +
numWordsProcessedInPreviousIterations) /
+ totalWordsCounts)
if (alpha < learningRate * 0.0001) alpha = learningRate *
0.0001
- logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
+ logInfo("wordCount = " + (wordCount +
numWordsProcessedInPreviousIterations) +
--- End diff --
@srowen Done.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]