Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/17673#discussion_r142992652
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/Word2VecCBOWSolver.scala ---
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.feature
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.random.XORShiftRandom
+
+object Word2VecCBOWSolver extends Logging {
+ // learning rate is updated for every batch of size batchSize
+ private val batchSize = 10000
+
+ // power to raise the unigram distribution with
+ private val power = 0.75
+
+ private val MAX_EXP = 6
+
+ case class Vocabulary(
+ totalWordCount: Long,
+ vocabMap: Map[String, Int],
+ unigramTable: Array[Int],
+ samplingTable: Array[Float])
+
+ /**
+ * This method implements Word2Vec Continuous Bag Of Words based
implementation using
+ * negative sampling optimization, using BLAS for vectorizing operations
where applicable.
+ * The algorithm is parallelized in the same way as the skip-gram based
estimation.
+ * We divide input data into N equally sized random partitions.
+ * We then generate initial weights and broadcast them to the N
partitions. This way
+ * all the partitions start with the same initial weights. We then run N
independent
+ * estimations that each estimate a model on a partition. The weights
learned
+ * from each of the N models are averaged and rebroadcast the weights.
+ * This process is repeated `maxIter` number of times.
+ *
+ * @param input A RDD of strings. Each string would be considered a
sentence.
+ * @return Estimated word2vec model
+ */
+ def fitCBOW[S <: Iterable[String]](
+ word2Vec: Word2Vec,
+ input: RDD[S]): feature.Word2VecModel = {
+
+ val negativeSamples = word2Vec.getNegativeSamples
+ val sample = word2Vec.getSample
+
+ val Vocabulary(totalWordCount, vocabMap, uniTable, sampleTable) =
+ generateVocab(input, word2Vec.getMinCount, sample,
word2Vec.getUnigramTableSize)
+ val vocabSize = vocabMap.size
+
+ assert(negativeSamples < vocabSize, s"Vocab size ($vocabSize) cannot
be smaller" +
+ s" than negative samples($negativeSamples)")
+
+ val seed = word2Vec.getSeed
+ val initRandom = new XORShiftRandom(seed)
+
+ val vectorSize = word2Vec.getVectorSize
+ val syn0Global = Array.fill(vocabSize *
vectorSize)(initRandom.nextFloat - 0.5f)
+ val syn1Global = Array.fill(vocabSize * vectorSize)(0.0f)
+
+ val sc = input.context
+
+ val vocabMapBroadcast = sc.broadcast(vocabMap)
+ val unigramTableBroadcast = sc.broadcast(uniTable)
+ val sampleTableBroadcast = sc.broadcast(sampleTable)
+
+ val windowSize = word2Vec.getWindowSize
+ val maxSentenceLength = word2Vec.getMaxSentenceLength
+ val numPartitions = word2Vec.getNumPartitions
+
+ val digitSentences = input.flatMap { sentence =>
+ val wordIndexes = sentence.flatMap(vocabMapBroadcast.value.get)
+ wordIndexes.grouped(maxSentenceLength).map(_.toArray)
+ }.repartition(numPartitions).cache()
+
+ val learningRate = word2Vec.getStepSize
+
+ val wordsPerPartition = totalWordCount / numPartitions
+
+ logInfo(s"VocabSize: ${vocabMap.size}, TotalWordCount:
$totalWordCount")
+
+ val maxIter = word2Vec.getMaxIter
+ for {iteration <- 1 to maxIter} {
+ logInfo(s"Starting iteration: $iteration")
+ val iterationStartTime = System.nanoTime()
+
+ val syn0bc = sc.broadcast(syn0Global)
+ val syn1bc = sc.broadcast(syn1Global)
+
+ val partialFits = digitSentences.mapPartitionsWithIndex { case (i_,
iter) =>
+ logInfo(s"Iteration: $iteration, Partition: $i_")
+ val random = new XORShiftRandom(seed ^ ((i_ + 1) << 16) ^
((-iteration - 1) << 8))
+ val contextWordPairs = iter.flatMap { s =>
+ val doSample = sample > Double.MinPositiveValue
+ generateContextWordPairs(s, windowSize, doSample,
sampleTableBroadcast.value, random)
+ }
+
+ val groupedBatches = contextWordPairs.grouped(batchSize)
+
+ val negLabels = 1.0f +: Array.fill(negativeSamples)(0.0f)
+ val syn0 = syn0bc.value
+ val syn1 = syn1bc.value
+ val unigramTable = unigramTableBroadcast.value
+
+ // initialize intermediate arrays
+ val contextVec = new Array[Float](vectorSize)
+ val l2Vectors = new Array[Float](vectorSize * (negativeSamples +
1))
+ val gb = new Array[Float](negativeSamples + 1)
+ val neu1e = new Array[Float](vectorSize)
+ val wordIndices = new Array[Int](negativeSamples + 1)
+
+ val time = System.nanoTime
+ var batchTime = System.nanoTime
+ var idx = -1L
+ for (batch <- groupedBatches) {
+ idx = idx + 1
+
+ val wordRatio =
+ idx.toFloat * batchSize /
+ (maxIter * (wordsPerPartition.toFloat + 1)) + ((iteration -
1).toFloat / maxIter)
+ val alpha = math.max(learningRate * 0.0001, learningRate * (1 -
wordRatio)).toFloat
+
+ if(idx % 10 == 0 && idx > 0) {
+ logInfo(s"Partition: $i_, wordRatio = $wordRatio, alpha =
$alpha")
+ val wordCount = batchSize * idx
+ val timeTaken = (System.nanoTime - time) / 1e6
+ val batchWordCount = 10 * batchSize
+ val currentBatchTime = (System.nanoTime - batchTime) / 1e6
+ batchTime = System.nanoTime
+ logDebug(s"Partition: $i_, Batch time: $currentBatchTime ms,
batch speed: " +
+ s"${batchWordCount / currentBatchTime * 1000} words/s")
+ logDebug(s"Partition: $i_, Cumulative time: $timeTaken ms,
cumulative speed: " +
+ s"${wordCount / timeTaken * 1000} words/s")
+ }
+
+ val errors = for ((contextIds, word) <- batch) yield {
+ // initialize vectors to 0
+ zeroVector(contextVec)
+ zeroVector(l2Vectors)
+ zeroVector(gb)
+ zeroVector(neu1e)
+
+ val scale = 1.0f / contextIds.length
+
+ // feed forward
+ contextIds.foreach { c =>
+ blas.saxpy(vectorSize, scale, syn0, c * vectorSize, 1,
contextVec, 0, 1)
+ }
+
+ generateNegativeSamples(random, word, unigramTable,
negativeSamples, wordIndices)
+
+ Iterator.range(0, wordIndices.length).foreach { i =>
+ Array.copy(syn1, vectorSize * wordIndices(i), l2Vectors,
vectorSize * i, vectorSize)
+ }
+
+ // propagating hidden to output in batch
+ val rows = negativeSamples + 1
+ val cols = vectorSize
+ blas.sgemv("T", cols, rows, 1.0f, l2Vectors, 0, cols,
contextVec, 0, 1, 0.0f, gb, 0, 1)
+
+ Iterator.range(0, negativeSamples + 1).foreach { i =>
+ if (gb(i) > -MAX_EXP && gb(i) < MAX_EXP) {
+ val v = 1.0f / (1 + math.exp(-gb(i)).toFloat)
+ // computing error gradient
+ val err = (negLabels(i) - v) * alpha
+ // update hidden -> output layer, syn1
+ blas.saxpy(vectorSize, err, contextVec, 0, 1, syn1,
wordIndices(i) * vectorSize, 1)
+ // update for word vectors
+ blas.saxpy(vectorSize, err, l2Vectors, i * vectorSize, 1,
neu1e, 0, 1)
+ gb.update(i, err)
+ } else {
+ gb.update(i, 0.0f)
+ }
+ }
+
+ // update input -> hidden layer, syn0
+ contextIds.foreach { i =>
+ blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, i *
vectorSize, 1)
+ }
+ gb.map(math.abs).sum / alpha
+ }
+ logInfo(s"Partition: $i_, Average Batch Error = ${errors.sum /
batchSize}")
+ }
+ Iterator.tabulate(vocabSize) { index =>
+ (index, syn0.slice(index * vectorSize, (index + 1) * vectorSize))
+ } ++ Iterator.tabulate(vocabSize) { index =>
+ (vocabSize + index, syn1.slice(index * vectorSize, (index + 1) *
vectorSize))
+ }
+ }
+
+ val aggedMatrices = partialFits.reduceByKey { case (v1, v2) =>
+ blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
+ v1
+ }.collect()
+
+ val norm = 1.0f / numPartitions
+ aggedMatrices.foreach {case (index, v) =>
+ blas.sscal(v.length, norm, v, 0, 1)
+ if (index < vocabSize) {
+ Array.copy(v, 0, syn0Global, index * vectorSize, vectorSize)
+ } else {
+ Array.copy(v, 0, syn1Global, (index - vocabSize) * vectorSize,
vectorSize)
+ }
+ }
+
+ syn0bc.destroy(false)
+ syn1bc.destroy(false)
+ val timePerIteration = (System.nanoTime() - iterationStartTime) / 1e6
+ logInfo(s"Total time taken per iteration: ${timePerIteration} ms")
+ }
+ digitSentences.unpersist()
+ vocabMapBroadcast.destroy()
+ unigramTableBroadcast.destroy()
+ sampleTableBroadcast.destroy()
+
+ new feature.Word2VecModel(vocabMap, syn0Global)
+ }
+
+ /**
+ * Similar to InitUnigramTable in the original code.
+ */
+ private def generateUnigramTable(normalizedWeights: Array[Double],
tableSize: Int): Array[Int] = {
+ val table = new Array[Int](tableSize)
+ var index = 0
+ var wordId = 0
+ while (index < table.length) {
+ table.update(index, wordId)
+ if (index.toFloat / table.length >= normalizedWeights(wordId)) {
+ wordId = math.min(normalizedWeights.length - 1, wordId + 1)
+ }
+ index += 1
+ }
+ table
+ }
+
+ private def generateVocab[S <: Iterable[String]](
+ input: RDD[S],
+ minCount: Int,
+ sample: Double,
+ unigramTableSize: Int): Vocabulary = {
+ val sc = input.context
--- End diff --
not used
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]