Repository: spark
Updated Branches:
  refs/heads/master 9038d94e1 -> f0060b75f


[MLlib] Correctly set vectorSize and alpha

mengxr
Correctly set vectorSize and alpha in Word2Vec training.

Author: Liquan Pei <[email protected]>

Closes #1900 from Ishiihara/Word2Vec-bugfix and squashes the following commits:

85f64f2 [Liquan Pei] correctly set vectorSize and alpha


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f0060b75
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f0060b75
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f0060b75

Branch: refs/heads/master
Commit: f0060b75ff67ab60babf54149a6860edc53cb6e9
Parents: 9038d94
Author: Liquan Pei <[email protected]>
Authored: Tue Aug 12 00:28:00 2014 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Tue Aug 12 00:28:00 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/feature/Word2Vec.scala   | 25 ++++++++++----------
 1 file changed, 12 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f0060b75/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 395037e..ecd49ea 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -119,7 +119,6 @@ class Word2Vec extends Serializable with Logging {
   private val MAX_EXP = 6
   private val MAX_CODE_LENGTH = 40
   private val MAX_SENTENCE_LENGTH = 1000
-  private val layer1Size = vectorSize
 
   /** context words from [-window, window] */
   private val window = 5
@@ -131,7 +130,6 @@ class Word2Vec extends Serializable with Logging {
   private var vocabSize = 0
   private var vocab: Array[VocabWord] = null
   private var vocabHash = mutable.HashMap.empty[String, Int]
-  private var alpha = startingAlpha
 
   private def learnVocab(words: RDD[String]): Unit = {
     vocab = words.map(w => (w, 1))
@@ -287,9 +285,10 @@ class Word2Vec extends Serializable with Logging {
     val newSentences = sentences.repartition(numPartitions).cache()
     val initRandom = new XORShiftRandom(seed)
     var syn0Global =
-      Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 
0.5f) / layer1Size)
-    var syn1Global = new Array[Float](vocabSize * layer1Size)
+      Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 
0.5f) / vectorSize)
+    var syn1Global = new Array[Float](vocabSize * vectorSize)
 
+    var alpha = startingAlpha
     for (k <- 1 to numIterations) {
       val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
         val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) 
<< 8))
@@ -317,24 +316,24 @@ class Word2Vec extends Serializable with Logging {
                   val c = pos - window + a
                   if (c >= 0 && c < sentence.size) {
                     val lastWord = sentence(c)
-                    val l1 = lastWord * layer1Size
-                    val neu1e = new Array[Float](layer1Size)
+                    val l1 = lastWord * vectorSize
+                    val neu1e = new Array[Float](vectorSize)
                     // Hierarchical softmax
                     var d = 0
                     while (d < bcVocab.value(word).codeLen) {
-                      val l2 = bcVocab.value(word).point(d) * layer1Size
+                      val l2 = bcVocab.value(word).point(d) * vectorSize
                       // Propagate hidden -> output
-                      var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1)
+                      var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
                       if (f > -MAX_EXP && f < MAX_EXP) {
                         val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 
2.0)).toInt
                         f = expTable.value(ind)
                         val g = ((1 - bcVocab.value(word).code(d) - f) * 
alpha).toFloat
-                        blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
-                        blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
+                        blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
+                        blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
                       }
                       d += 1
                     }
-                    blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1)
+                    blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
                   }
                 }
                 a += 1
@@ -365,8 +364,8 @@ class Word2Vec extends Serializable with Logging {
     var i = 0
     while (i < vocabSize) {
       val word = bcVocab.value(i).word
-      val vector = new Array[Float](layer1Size)
-      Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
+      val vector = new Array[Float](vectorSize)
+      Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
       word2VecMap += word -> vector
       i += 1
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to