Github user shubhamchopra commented on a diff in the pull request:
https://github.com/apache/spark/pull/17673#discussion_r143570164
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/Word2VecCBOWSolver.scala ---
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.feature
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.random.XORShiftRandom
+
+object Word2VecCBOWSolver extends Logging {
+ // learning rate is updated for every batch of size batchSize
+ private val batchSize = 10000
+
+ // power to raise the unigram distribution with
+ private val power = 0.75
+
+ private val MAX_EXP = 6
+
+ case class Vocabulary(
+ totalWordCount: Long,
+ vocabMap: Map[String, Int],
+ unigramTable: Array[Int],
+ samplingTable: Array[Float])
+
+ /**
+ * This method implements Word2Vec Continuous Bag Of Words based
implementation using
+ * negative sampling optimization, using BLAS for vectorizing operations
where applicable.
+ * The algorithm is parallelized in the same way as the skip-gram based
estimation.
+ * We divide input data into N equally sized random partitions.
+ * We then generate initial weights and broadcast them to the N
partitions. This way
+ * all the partitions start with the same initial weights. We then run N
independent
+ * estimations that each estimate a model on a partition. The weights
learned
+ * from each of the N models are averaged and rebroadcast the weights.
+ * This process is repeated `maxIter` number of times.
+ *
+ * @param input A RDD of strings. Each string would be considered a
sentence.
+ * @return Estimated word2vec model
+ */
+ def fitCBOW[S <: Iterable[String]](
+ word2Vec: Word2Vec,
+ input: RDD[S]): feature.Word2VecModel = {
+
+ val negativeSamples = word2Vec.getNegativeSamples
+ val sample = word2Vec.getSample
+
+ val Vocabulary(totalWordCount, vocabMap, uniTable, sampleTable) =
+ generateVocab(input, word2Vec.getMinCount, sample,
word2Vec.getUnigramTableSize)
+ val vocabSize = vocabMap.size
+
+ assert(negativeSamples < vocabSize, s"Vocab size ($vocabSize) cannot
be smaller" +
+ s" than negative samples($negativeSamples)")
+
+ val seed = word2Vec.getSeed
+ val initRandom = new XORShiftRandom(seed)
+
+ val vectorSize = word2Vec.getVectorSize
+ val syn0Global = Array.fill(vocabSize *
vectorSize)(initRandom.nextFloat - 0.5f)
+ val syn1Global = Array.fill(vocabSize * vectorSize)(0.0f)
+
+ val sc = input.context
+
+ val vocabMapBroadcast = sc.broadcast(vocabMap)
+ val unigramTableBroadcast = sc.broadcast(uniTable)
+ val sampleTableBroadcast = sc.broadcast(sampleTable)
+
+ val windowSize = word2Vec.getWindowSize
+ val maxSentenceLength = word2Vec.getMaxSentenceLength
+ val numPartitions = word2Vec.getNumPartitions
+
+ val digitSentences = input.flatMap { sentence =>
+ val wordIndexes = sentence.flatMap(vocabMapBroadcast.value.get)
+ wordIndexes.grouped(maxSentenceLength).map(_.toArray)
+ }.repartition(numPartitions).cache()
+
+ val learningRate = word2Vec.getStepSize
+
+ val wordsPerPartition = totalWordCount / numPartitions
+
+ logInfo(s"VocabSize: ${vocabMap.size}, TotalWordCount:
$totalWordCount")
+
+ val maxIter = word2Vec.getMaxIter
+ for {iteration <- 1 to maxIter} {
+ logInfo(s"Starting iteration: $iteration")
+ val iterationStartTime = System.nanoTime()
+
+ val syn0bc = sc.broadcast(syn0Global)
+ val syn1bc = sc.broadcast(syn1Global)
+
+ val partialFits = digitSentences.mapPartitionsWithIndex { case (i_,
iter) =>
+ logInfo(s"Iteration: $iteration, Partition: $i_")
+ val random = new XORShiftRandom(seed ^ ((i_ + 1) << 16) ^
((-iteration - 1) << 8))
+ val contextWordPairs = iter.flatMap { s =>
+ val doSample = sample > Double.MinPositiveValue
+ generateContextWordPairs(s, windowSize, doSample,
sampleTableBroadcast.value, random)
+ }
+
+ val groupedBatches = contextWordPairs.grouped(batchSize)
+
+ val negLabels = 1.0f +: Array.fill(negativeSamples)(0.0f)
--- End diff --
I instantiate them once and reuse them for the entire iteration. This gave
significant speed-ups.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]