This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 6dff114  [SPARK-24666][ML] Fix infinity vectors produced by Word2Vec 
when numIterations are large
6dff114 is described below

commit 6dff114ddb3de7877625b79bea818ba724ccd22d
Author: Liang-Chi Hsieh <liang...@uber.com>
AuthorDate: Thu Dec 5 16:32:33 2019 -0800

    [SPARK-24666][ML] Fix infinity vectors produced by Word2Vec when 
numIterations are large
    
    ### What changes were proposed in this pull request?
    
    This patch adds normalization to word vectors when fitting dataset in 
Word2Vec.
    
    ### Why are the changes needed?
    
    Running Word2Vec on some datasets, when numIterations is large, can produce 
infinity word vectors.
    
    ### Does this PR introduce any user-facing change?
    
    Yes. After this patch, Word2Vec won't produce infinity word vectors.
    
    ### How was this patch tested?
    
    Manually. This issue is not always reproducible on any dataset. The dataset 
known to reproduce it is too large (925M) to upload.
    
    ```scala
    case class Sentences(name: String, words: Array[String])
    val dataset = spark.read
      .option("header", "true").option("sep", "\t")
      .option("quote", "").option("nullValue", "\\N")
      .csv("/tmp/title.akas.tsv")
      .filter("region = 'US' or language = 'en'")
      .select("title")
      .as[String]
      .map(s => Sentences(s, s.split(' ')))
      .persist()
    
    println("Training model...")
    val word2Vec = new Word2Vec()
      .setInputCol("words")
      .setOutputCol("vector")
      .setVectorSize(64)
      .setWindowSize(4)
      .setNumPartitions(50)
      .setMinCount(5)
      .setMaxIter(30)
    val model = word2Vec.fit(dataset)
    model.getVectors.show()
    ```
    
    Before:
    ```
    Training model...
    +-------------+--------------------+
    |         word|              vector|
    +-------------+--------------------+
    |     Unspoken|[-Infinity,-Infin...|
    |       Talent|[-Infinity,Infini...|
    |    Hourglass|[2.02805806500023...|
    |Nickelodeon's|[-4.2918617120906...|
    |      Priests|[-1.3570403355926...|
    |    Religion:|[-6.7049072282803...|
    |           Bu|[5.05591774315586...|
    |      Totoro:|[-1.0539840178632...|
    |     Trouble,|[-3.5363592836003...|
    |       Hatter|[4.90413981352826...|
    |          '79|[7.50436471285412...|
    |         Vile|[-2.9147142985312...|
    |         9/11|[-Infinity,Infini...|
    |      Santino|[1.30005911270850...|
    |      Motives|[-1.2538958306253...|
    |          '13|[-4.5040152427657...|
    |       Fierce|[Infinity,Infinit...|
    |       Stover|[-2.6326895394029...|
    |          'It|[1.66574533864436...|
    |        Butts|[Infinity,Infinit...|
    +-------------+--------------------+
    only showing top 20 rows
    ```
    
    After:
    ```
    Training model...
    +-------------+--------------------+
    |         word|              vector|
    +-------------+--------------------+
    |     Unspoken|[-0.0454501919448...|
    |       Talent|[-0.2657704949378...|
    |    Hourglass|[-0.1399687677621...|
    |Nickelodeon's|[-0.1767119318246...|
    |      Priests|[-0.0047509293071...|
    |    Religion:|[-0.0411605164408...|
    |           Bu|[0.11837736517190...|
    |      Totoro:|[0.05258282646536...|
    |     Trouble,|[0.09482011198997...|
    |       Hatter|[0.06040831282734...|
    |          '79|[0.04783720895648...|
    |         Vile|[-0.0017210749210...|
    |         9/11|[-0.0713915303349...|
    |      Santino|[-0.0412711687386...|
    |      Motives|[-0.0492418706417...|
    |          '13|[-0.0073119504377...|
    |       Fierce|[-0.0565455369651...|
    |       Stover|[0.06938160210847...|
    |          'It|[0.01117012929171...|
    |        Butts|[0.05374567210674...|
    +-------------+--------------------+
    only showing top 20 rows
    ```
    
    Closes #26722 from viirya/SPARK-24666-2.
    
    Lead-authored-by: Liang-Chi Hsieh <liang...@uber.com>
    Co-authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <liang...@uber.com>
    (cherry picked from commit 755d8894485396b0a21304568c8ec5a55030f2fd)
    Signed-off-by: Liang-Chi Hsieh <liang...@uber.com>
---
 .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 17 ++++++++++++++---
 .../org/apache/spark/ml/feature/Word2VecSuite.scala     |  8 --------
 2 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index d5b91df..bb5d02e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -438,9 +438,20 @@ class Word2Vec extends Serializable with Logging {
           }
         }.flatten
       }
-      val synAgg = partial.reduceByKey { case (v1, v2) =>
-          blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
-          v1
+      // SPARK-24666: do normalization for aggregating weights from partitions.
+      // Original Word2Vec either single-thread or multi-thread which do 
Hogwild-style aggregation.
+      // Our approach needs to do extra normalization, otherwise adding 
weights continuously may
+      // cause overflow on float and lead to infinity/-infinity weights.
+      val synAgg = partial.mapPartitions { iter =>
+        iter.map { case (id, vec) =>
+          (id, (vec, 1))
+        }
+      }.reduceByKey { case ((v1, count1), (v2, count2)) =>
+        blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
+        (v1, count1 + count2)
+      }.map { case (id, (vec, count)) =>
+        blas.sscal(vectorSize, 1.0f / count, vec, 1)
+        (id, vec)
       }.collect()
       var i = 0
       while (i < synAgg.length) {
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index b59c4e7..c816a6c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -75,14 +75,6 @@ class Word2VecSuite extends MLTest with DefaultReadWriteTest 
{
   test("getVectors") {
     val sentence = "a b " * 100 + "a c " * 10
     val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" 
"))
-
-    val codes = Map(
-      "a" -> Array(-0.2811822295188904, -0.6356269121170044, 
-0.3020961284637451),
-      "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
-      "c" -> Array(-0.08456747233867645, 0.5137411952018738, 
0.11731560528278351)
-    )
-    val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => 
Vectors.dense(v) }
-
     val docDF = doc.zip(doc).toDF("text", "alsotext")
 
     val model = new Word2Vec()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to