Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/4247#discussion_r23708566
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala ---
@@ -290,111 +290,126 @@ class Word2Vec extends Serializable with Logging {
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
- val syn0Global =
- Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() -
0.5f) / vectorSize)
- val syn1Global = new Array[Float](vocabSize * vectorSize)
- var alpha = learningRate
- for (k <- 1 to numIterations) {
- val partial = newSentences.mapPartitionsWithIndex { case (idx, iter)
=>
- val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k -
1) << 8))
- val syn0Modify = new Array[Int](vocabSize)
- val syn1Modify = new Array[Int](vocabSize)
- val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
- case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
- var lwc = lastWordCount
- var wc = wordCount
- if (wordCount - lastWordCount > 10000) {
- lwc = wordCount
- // TODO: discount by iteration?
- alpha =
- learningRate * (1 - numPartitions * wordCount.toDouble /
(trainWordsCount + 1))
- if (alpha < learningRate * 0.0001) alpha = learningRate *
0.0001
- logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
- }
- wc += sentence.size
- var pos = 0
- while (pos < sentence.size) {
- val word = sentence(pos)
- val b = random.nextInt(window)
- // Train Skip-gram
- var a = b
- while (a < window * 2 + 1 - b) {
- if (a != window) {
- val c = pos - window + a
- if (c >= 0 && c < sentence.size) {
- val lastWord = sentence(c)
- val l1 = lastWord * vectorSize
- val neu1e = new Array[Float](vectorSize)
- // Hierarchical softmax
- var d = 0
- while (d < bcVocab.value(word).codeLen) {
- val inner = bcVocab.value(word).point(d)
- val l2 = inner * vectorSize
- // Propagate hidden -> output
- var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2,
1)
- if (f > -MAX_EXP && f < MAX_EXP) {
- val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE /
MAX_EXP / 2.0)).toInt
- f = expTable.value(ind)
- val g = ((1 - bcVocab.value(word).code(d) - f) *
alpha).toFloat
- blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
- blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
- syn1Modify(inner) += 1
+ val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
+
+ val hint="Please increase minCount or decrease vectorSize in Word2Vec
to avoid an OOM. " +
+ "You are highly recommended to make vocabSize*vectorSize less
than `Int.MaxValue/8`."
+
+ try {
+ if(vocabSize * vectorSize * 8 > Int.MaxValue){
--- End diff --
space before `(` and after `)`
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]