This is an automated email from the ASF dual-hosted git repository. viirya pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 6dff114 [SPARK-24666][ML] Fix infinity vectors produced by Word2Vec when numIterations are large 6dff114 is described below commit 6dff114ddb3de7877625b79bea818ba724ccd22d Author: Liang-Chi Hsieh <liang...@uber.com> AuthorDate: Thu Dec 5 16:32:33 2019 -0800 [SPARK-24666][ML] Fix infinity vectors produced by Word2Vec when numIterations are large ### What changes were proposed in this pull request? This patch adds normalization to word vectors when fitting dataset in Word2Vec. ### Why are the changes needed? Running Word2Vec on some datasets, when numIterations is large, can produce infinity word vectors. ### Does this PR introduce any user-facing change? Yes. After this patch, Word2Vec won't produce infinity word vectors. ### How was this patch tested? Manually. This issue is not always reproducible on any dataset. The dataset known to reproduce it is too large (925M) to upload. ```scala case class Sentences(name: String, words: Array[String]) val dataset = spark.read .option("header", "true").option("sep", "\t") .option("quote", "").option("nullValue", "\\N") .csv("/tmp/title.akas.tsv") .filter("region = 'US' or language = 'en'") .select("title") .as[String] .map(s => Sentences(s, s.split(' '))) .persist() println("Training model...") val word2Vec = new Word2Vec() .setInputCol("words") .setOutputCol("vector") .setVectorSize(64) .setWindowSize(4) .setNumPartitions(50) .setMinCount(5) .setMaxIter(30) val model = word2Vec.fit(dataset) model.getVectors.show() ``` Before: ``` Training model... +-------------+--------------------+ | word| vector| +-------------+--------------------+ | Unspoken|[-Infinity,-Infin...| | Talent|[-Infinity,Infini...| | Hourglass|[2.02805806500023...| |Nickelodeon's|[-4.2918617120906...| | Priests|[-1.3570403355926...| | Religion:|[-6.7049072282803...| | Bu|[5.05591774315586...| | Totoro:|[-1.0539840178632...| | Trouble,|[-3.5363592836003...| | Hatter|[4.90413981352826...| | '79|[7.50436471285412...| | Vile|[-2.9147142985312...| | 9/11|[-Infinity,Infini...| | Santino|[1.30005911270850...| | Motives|[-1.2538958306253...| | '13|[-4.5040152427657...| | Fierce|[Infinity,Infinit...| | Stover|[-2.6326895394029...| | 'It|[1.66574533864436...| | Butts|[Infinity,Infinit...| +-------------+--------------------+ only showing top 20 rows ``` After: ``` Training model... +-------------+--------------------+ | word| vector| +-------------+--------------------+ | Unspoken|[-0.0454501919448...| | Talent|[-0.2657704949378...| | Hourglass|[-0.1399687677621...| |Nickelodeon's|[-0.1767119318246...| | Priests|[-0.0047509293071...| | Religion:|[-0.0411605164408...| | Bu|[0.11837736517190...| | Totoro:|[0.05258282646536...| | Trouble,|[0.09482011198997...| | Hatter|[0.06040831282734...| | '79|[0.04783720895648...| | Vile|[-0.0017210749210...| | 9/11|[-0.0713915303349...| | Santino|[-0.0412711687386...| | Motives|[-0.0492418706417...| | '13|[-0.0073119504377...| | Fierce|[-0.0565455369651...| | Stover|[0.06938160210847...| | 'It|[0.01117012929171...| | Butts|[0.05374567210674...| +-------------+--------------------+ only showing top 20 rows ``` Closes #26722 from viirya/SPARK-24666-2. Lead-authored-by: Liang-Chi Hsieh <liang...@uber.com> Co-authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: Liang-Chi Hsieh <liang...@uber.com> (cherry picked from commit 755d8894485396b0a21304568c8ec5a55030f2fd) Signed-off-by: Liang-Chi Hsieh <liang...@uber.com> --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 17 ++++++++++++++--- .../org/apache/spark/ml/feature/Word2VecSuite.scala | 8 -------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d5b91df..bb5d02e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -438,9 +438,20 @@ class Word2Vec extends Serializable with Logging { } }.flatten } - val synAgg = partial.reduceByKey { case (v1, v2) => - blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) - v1 + // SPARK-24666: do normalization for aggregating weights from partitions. + // Original Word2Vec either single-thread or multi-thread which do Hogwild-style aggregation. + // Our approach needs to do extra normalization, otherwise adding weights continuously may + // cause overflow on float and lead to infinity/-infinity weights. + val synAgg = partial.mapPartitions { iter => + iter.map { case (id, vec) => + (id, (vec, 1)) + } + }.reduceByKey { case ((v1, count1), (v2, count2)) => + blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + (v1, count1 + count2) + }.map { case (id, (vec, count)) => + blas.sscal(vectorSize, 1.0f / count, vec, 1) + (id, vec) }.collect() var i = 0 while (i < synAgg.length) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index b59c4e7..c816a6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -75,14 +75,6 @@ class Word2VecSuite extends MLTest with DefaultReadWriteTest { test("getVectors") { val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) - - val codes = Map( - "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), - "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), - "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) - ) - val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) } - val docDF = doc.zip(doc).toDF("text", "alsotext") val model = new Word2Vec() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org