Github user AnthonyTruchet commented on a diff in the pull request:

    https://github.com/apache/spark/pull/14299#discussion_r71723751
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala ---
    @@ -313,133 +313,139 @@ class Word2Vec extends Serializable with Logging {
         val expTable = sc.broadcast(createExpTable())
         val bcVocab = sc.broadcast(vocab)
         val bcVocabHash = sc.broadcast(vocabHash)
    -    // each partition is a collection of sentences,
    -    // will be translated into arrays of Index integer
    -    val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter 
=>
    -      // Each sentence will map to 0 or more Array[Int]
    -      sentenceIter.flatMap { sentence =>
    -        // Sentence of words, some of which map to a word index
    -        val wordIndexes = sentence.flatMap(bcVocabHash.value.get)
    -        // break wordIndexes into trunks of maxSentenceLength when has more
    -        wordIndexes.grouped(maxSentenceLength).map(_.toArray)
    +
    +    try {
    +      // each partition is a collection of sentences,
    +      // will be translated into arrays of Index integer
    +      val sentences: RDD[Array[Int]] = dataset.mapPartitions { 
sentenceIter =>
    +        // Each sentence will map to 0 or more Array[Int]
    +        sentenceIter.flatMap { sentence =>
    +          // Sentence of words, some of which map to a word index
    +          val wordIndexes = sentence.flatMap(bcVocabHash.value.get)
    +          // break wordIndexes into trunks of maxSentenceLength when has 
more
    +          wordIndexes.grouped(maxSentenceLength).map(_.toArray)
    +        }
           }
    -    }
     
    -    val newSentences = sentences.repartition(numPartitions).cache()
    -    val initRandom = new XORShiftRandom(seed)
    +      val newSentences = sentences.repartition(numPartitions).cache()
    +      val initRandom = new XORShiftRandom(seed)
     
    -    if (vocabSize.toLong * vectorSize >= Int.MaxValue) {
    -      throw new RuntimeException("Please increase minCount or decrease 
vectorSize in Word2Vec" +
    -        " to avoid an OOM. You are highly recommended to make your 
vocabSize*vectorSize, " +
    -        "which is " + vocabSize + "*" + vectorSize + " for now, less than 
`Int.MaxValue`.")
    -    }
    +      if (vocabSize.toLong * vectorSize >= Int.MaxValue) {
    +        throw new RuntimeException("Please increase minCount or decrease 
vectorSize in Word2Vec" +
    +          " to avoid an OOM. You are highly recommended to make your 
vocabSize*vectorSize, " +
    +          "which is " + vocabSize + "*" + vectorSize + " for now, less 
than `Int.MaxValue`.")
    +      }
     
    -    val syn0Global =
    -      Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 
0.5f) / vectorSize)
    -    val syn1Global = new Array[Float](vocabSize * vectorSize)
    -    var alpha = learningRate
    -
    -    for (k <- 1 to numIterations) {
    -      val bcSyn0Global = sc.broadcast(syn0Global)
    -      val bcSyn1Global = sc.broadcast(syn1Global)
    -      val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) 
=>
    -        val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 
1) << 8))
    -        val syn0Modify = new Array[Int](vocabSize)
    -        val syn1Modify = new Array[Int](vocabSize)
    -        val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 
0L, 0L)) {
    -          case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
    -            var lwc = lastWordCount
    -            var wc = wordCount
    -            if (wordCount - lastWordCount > 10000) {
    -              lwc = wordCount
    -              // TODO: discount by iteration?
    -              alpha =
    -                learningRate * (1 - numPartitions * wordCount.toDouble / 
(trainWordsCount + 1))
    -              if (alpha < learningRate * 0.0001) alpha = learningRate * 
0.0001
    -              logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
    -            }
    -            wc += sentence.length
    -            var pos = 0
    -            while (pos < sentence.length) {
    -              val word = sentence(pos)
    -              val b = random.nextInt(window)
    -              // Train Skip-gram
    -              var a = b
    -              while (a < window * 2 + 1 - b) {
    -                if (a != window) {
    -                  val c = pos - window + a
    -                  if (c >= 0 && c < sentence.length) {
    -                    val lastWord = sentence(c)
    -                    val l1 = lastWord * vectorSize
    -                    val neu1e = new Array[Float](vectorSize)
    -                    // Hierarchical softmax
    -                    var d = 0
    -                    while (d < bcVocab.value(word).codeLen) {
    -                      val inner = bcVocab.value(word).point(d)
    -                      val l2 = inner * vectorSize
    -                      // Propagate hidden -> output
    -                      var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 
1)
    -                      if (f > -MAX_EXP && f < MAX_EXP) {
    -                        val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / 
MAX_EXP / 2.0)).toInt
    -                        f = expTable.value(ind)
    -                        val g = ((1 - bcVocab.value(word).code(d) - f) * 
alpha).toFloat
    -                        blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
    -                        blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
    -                        syn1Modify(inner) += 1
    +      val syn0Global =
    +        Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() 
- 0.5f) / vectorSize)
    +      val syn1Global = new Array[Float](vocabSize * vectorSize)
    +      var alpha = learningRate
    +
    +      for (k <- 1 to numIterations) {
    +        val bcSyn0Global = sc.broadcast(syn0Global)
    +        val bcSyn1Global = sc.broadcast(syn1Global)
    +        val partial = newSentences.mapPartitionsWithIndex { case (idx, 
iter) =>
    +          val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k 
- 1) << 8))
    +          val syn0Modify = new Array[Int](vocabSize)
    +          val syn1Modify = new Array[Int](vocabSize)
    +          val model = iter.foldLeft((bcSyn0Global.value, 
bcSyn1Global.value, 0L, 0L)) {
    +            case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
    +              var lwc = lastWordCount
    +              var wc = wordCount
    +              if (wordCount - lastWordCount > 10000) {
    +                lwc = wordCount
    +                // TODO: discount by iteration?
    +                alpha =
    +                  learningRate * (1 - numPartitions * wordCount.toDouble / 
(trainWordsCount + 1))
    +                if (alpha < learningRate * 0.0001) alpha = learningRate * 
0.0001
    +                logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
    +              }
    +              wc += sentence.length
    +              var pos = 0
    +              while (pos < sentence.length) {
    +                val word = sentence(pos)
    +                val b = random.nextInt(window)
    +                // Train Skip-gram
    +                var a = b
    +                while (a < window * 2 + 1 - b) {
    +                  if (a != window) {
    +                    val c = pos - window + a
    +                    if (c >= 0 && c < sentence.length) {
    +                      val lastWord = sentence(c)
    +                      val l1 = lastWord * vectorSize
    +                      val neu1e = new Array[Float](vectorSize)
    +                      // Hierarchical softmax
    +                      var d = 0
    +                      while (d < bcVocab.value(word).codeLen) {
    +                        val inner = bcVocab.value(word).point(d)
    +                        val l2 = inner * vectorSize
    +                        // Propagate hidden -> output
    +                        var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, 
l2, 1)
    +                        if (f > -MAX_EXP && f < MAX_EXP) {
    +                          val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / 
MAX_EXP / 2.0)).toInt
    +                          f = expTable.value(ind)
    +                          val g = ((1 - bcVocab.value(word).code(d) - f) * 
alpha).toFloat
    +                          blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 
1)
    +                          blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 
1)
    +                          syn1Modify(inner) += 1
    +                        }
    +                        d += 1
                           }
    -                      d += 1
    +                      blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 
1)
    +                      syn0Modify(lastWord) += 1
                         }
    -                    blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
    -                    syn0Modify(lastWord) += 1
                       }
    +                  a += 1
                     }
    -                a += 1
    +                pos += 1
                   }
    -              pos += 1
    +              (syn0, syn1, lwc, wc)
    +          }
    +          val syn0Local = model._1
    +          val syn1Local = model._2
    +          // Only output modified vectors.
    +          Iterator.tabulate(vocabSize) { index =>
    +            if (syn0Modify(index) > 0) {
    +              Some((index, syn0Local.slice(index * vectorSize, (index + 1) 
* vectorSize)))
    +            } else {
    +              None
    +            }
    +          }.flatten ++ Iterator.tabulate(vocabSize) { index =>
    +            if (syn1Modify(index) > 0) {
    +              Some((index + vocabSize, syn1Local.slice(index * vectorSize, 
(index + 1) * vectorSize)))
    +            } else {
    +              None
                 }
    -            (syn0, syn1, lwc, wc)
    +          }.flatten
             }
    -        val syn0Local = model._1
    -        val syn1Local = model._2
    -        // Only output modified vectors.
    -        Iterator.tabulate(vocabSize) { index =>
    -          if (syn0Modify(index) > 0) {
    -            Some((index, syn0Local.slice(index * vectorSize, (index + 1) * 
vectorSize)))
    -          } else {
    -            None
    -          }
    -        }.flatten ++ Iterator.tabulate(vocabSize) { index =>
    -          if (syn1Modify(index) > 0) {
    -            Some((index + vocabSize, syn1Local.slice(index * vectorSize, 
(index + 1) * vectorSize)))
    -          } else {
    -            None
    -          }
    -        }.flatten
    -      }
    -      val synAgg = partial.reduceByKey { case (v1, v2) =>
    +        val synAgg = partial.reduceByKey { case (v1, v2) =>
               blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
               v1
    -      }.collect()
    -      var i = 0
    -      while (i < synAgg.length) {
    -        val index = synAgg(i)._1
    -        if (index < vocabSize) {
    -          Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, 
vectorSize)
    -        } else {
    -          Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * 
vectorSize, vectorSize)
    +        }.collect()
    +        var i = 0
    +        while (i < synAgg.length) {
    +          val index = synAgg(i)._1
    +          if (index < vocabSize) {
    +            Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, 
vectorSize)
    +          } else {
    +            Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * 
vectorSize, vectorSize)
    +          }
    +          i += 1
             }
    -        i += 1
    +        bcSyn0Global.unpersist(false)
    +        bcSyn1Global.unpersist(false)
           }
    -      bcSyn0Global.unpersist(false)
    -      bcSyn1Global.unpersist(false)
    -    }
    -    newSentences.unpersist()
    -    expTable.destroy()
    -    bcVocab.destroy()
    -    bcVocabHash.destroy()
    +      newSentences.unpersist()
     
    -    val wordArray = vocab.map(_.word)
    -    new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
    +      val wordArray = vocab.map(_.word)
    +      new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
    +    }
    +    finally
    --- End diff --
    
    btw. are you aare of http://jsuereth.com/scala-arm/ which might prove damn 
useful to ease resource management without clustering (much) the syntax. It is 
akin to using clauses in C# or context managers in Python.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to