Github user shubhamchopra commented on a diff in the pull request:

    https://github.com/apache/spark/pull/17673#discussion_r143572149
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/Word2VecCBOWSolver.scala ---
    @@ -0,0 +1,344 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import com.github.fommil.netlib.BLAS.{getInstance => blas}
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.mllib.feature
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.util.random.XORShiftRandom
    +
    +object Word2VecCBOWSolver extends Logging {
    +  // learning rate is updated for every batch of size batchSize
    +  private val batchSize = 10000
    +
    +  // power to raise the unigram distribution with
    +  private val power = 0.75
    +
    +  private val MAX_EXP = 6
    +
    +  case class Vocabulary(
    +    totalWordCount: Long,
    +    vocabMap: Map[String, Int],
    +    unigramTable: Array[Int],
    +    samplingTable: Array[Float])
    +
    +  /**
    +   * This method implements Word2Vec Continuous Bag Of Words based 
implementation using
    +   * negative sampling optimization, using BLAS for vectorizing operations 
where applicable.
    +   * The algorithm is parallelized in the same way as the skip-gram based 
estimation.
    +   * We divide input data into N equally sized random partitions.
    +   * We then generate initial weights and broadcast them to the N 
partitions. This way
    +   * all the partitions start with the same initial weights. We then run N 
independent
    +   * estimations that each estimate a model on a partition. The weights 
learned
    +   * from each of the N models are averaged and rebroadcast the weights.
    +   * This process is repeated `maxIter` number of times.
    +   *
    +   * @param input A RDD of strings. Each string would be considered a 
sentence.
    +   * @return Estimated word2vec model
    +   */
    +  def fitCBOW[S <: Iterable[String]](
    +      word2Vec: Word2Vec,
    +      input: RDD[S]): feature.Word2VecModel = {
    +
    +    val negativeSamples = word2Vec.getNegativeSamples
    +    val sample = word2Vec.getSample
    +
    +    val Vocabulary(totalWordCount, vocabMap, uniTable, sampleTable) =
    +      generateVocab(input, word2Vec.getMinCount, sample, 
word2Vec.getUnigramTableSize)
    +    val vocabSize = vocabMap.size
    +
    +    assert(negativeSamples < vocabSize, s"Vocab size ($vocabSize) cannot 
be smaller" +
    +      s" than negative samples($negativeSamples)")
    +
    +    val seed = word2Vec.getSeed
    +    val initRandom = new XORShiftRandom(seed)
    +
    +    val vectorSize = word2Vec.getVectorSize
    +    val syn0Global = Array.fill(vocabSize * 
vectorSize)(initRandom.nextFloat - 0.5f)
    +    val syn1Global = Array.fill(vocabSize * vectorSize)(0.0f)
    +
    +    val sc = input.context
    +
    +    val vocabMapBroadcast = sc.broadcast(vocabMap)
    +    val unigramTableBroadcast = sc.broadcast(uniTable)
    +    val sampleTableBroadcast = sc.broadcast(sampleTable)
    +
    +    val windowSize = word2Vec.getWindowSize
    +    val maxSentenceLength = word2Vec.getMaxSentenceLength
    +    val numPartitions = word2Vec.getNumPartitions
    +
    +    val digitSentences = input.flatMap { sentence =>
    +      val wordIndexes = sentence.flatMap(vocabMapBroadcast.value.get)
    +      wordIndexes.grouped(maxSentenceLength).map(_.toArray)
    +    }.repartition(numPartitions).cache()
    +
    +    val learningRate = word2Vec.getStepSize
    +
    +    val wordsPerPartition = totalWordCount / numPartitions
    +
    +    logInfo(s"VocabSize: ${vocabMap.size}, TotalWordCount: 
$totalWordCount")
    +
    +    val maxIter = word2Vec.getMaxIter
    +    for {iteration <- 1 to maxIter} {
    +      logInfo(s"Starting iteration: $iteration")
    +      val iterationStartTime = System.nanoTime()
    +
    +      val syn0bc = sc.broadcast(syn0Global)
    +      val syn1bc = sc.broadcast(syn1Global)
    +
    +      val partialFits = digitSentences.mapPartitionsWithIndex { case (i_, 
iter) =>
    +        logInfo(s"Iteration: $iteration, Partition: $i_")
    +        val random = new XORShiftRandom(seed ^ ((i_ + 1) << 16) ^ 
((-iteration - 1) << 8))
    +        val contextWordPairs = iter.flatMap { s =>
    +          val doSample = sample > Double.MinPositiveValue
    +          generateContextWordPairs(s, windowSize, doSample, 
sampleTableBroadcast.value, random)
    +        }
    +
    +        val groupedBatches = contextWordPairs.grouped(batchSize)
    +
    +        val negLabels = 1.0f +: Array.fill(negativeSamples)(0.0f)
    +        val syn0 = syn0bc.value
    +        val syn1 = syn1bc.value
    +        val unigramTable = unigramTableBroadcast.value
    +
    +        // initialize intermediate arrays
    +        val contextVec = new Array[Float](vectorSize)
    +        val l2Vectors = new Array[Float](vectorSize * (negativeSamples + 
1))
    --- End diff --
    
    Done


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to