Repository: spark
Updated Branches:
  refs/heads/branch-2.0 c9bd67e94 -> eb2675de9


[SPARK-17548][MLLIB] Word2VecModel.findSynonyms no longer spuriously rejects 
the best match when invoked with a vector

## What changes were proposed in this pull request?

This pull request changes the behavior of `Word2VecModel.findSynonyms` so that 
it will not spuriously reject the best match when invoked with a vector that 
does not correspond to a word in the model's vocabulary.  Instead of blindly 
discarding the best match, the changed implementation discards a match that 
corresponds to the query word (in cases where `findSynonyms` is invoked with a 
word) or that has an identical angle to the query vector.

## How was this patch tested?

I added a test to `Word2VecSuite` to ensure that the word with the most similar 
vector from a supplied vector would not be spuriously rejected.

Author: William Benton <wi...@redhat.com>

Closes #15105 from willb/fix/findSynonyms.

(cherry picked from commit 25cbbe6ca334140204e7035ab8b9d304da9b8a8a)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eb2675de
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eb2675de
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eb2675de

Branch: refs/heads/branch-2.0
Commit: eb2675de92b865852d7aa3ef25a20e6cff940299
Parents: c9bd67e
Author: William Benton <wi...@redhat.com>
Authored: Sat Sep 17 12:49:58 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sat Sep 17 12:50:09 2016 +0100

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Word2Vec.scala  | 20 ++++++-----
 .../mllib/api/python/Word2VecModelWrapper.scala | 22 ++++++++++--
 .../apache/spark/mllib/feature/Word2Vec.scala   | 37 +++++++++++++++-----
 .../spark/mllib/feature/Word2VecSuite.scala     | 16 +++++++++
 python/pyspark/mllib/feature.py                 | 12 +++++--
 5 files changed, 83 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eb2675de/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index c2b434c..14c0512 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -221,24 +221,26 @@ class Word2VecModel private[ml] (
   }
 
   /**
-   * Find "num" number of words closest in similarity to the given word.
-   * Returns a dataframe with the words and the cosine similarities between the
-   * synonyms and the given word.
+   * Find "num" number of words closest in similarity to the given word, not
+   * including the word itself. Returns a dataframe with the words and the
+   * cosine similarities between the synonyms and the given word.
    */
   @Since("1.5.0")
   def findSynonyms(word: String, num: Int): DataFrame = {
-    findSynonyms(wordVectors.transform(word), num)
+    val spark = SparkSession.builder().getOrCreate()
+    spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", 
"similarity")
   }
 
   /**
-   * Find "num" number of words closest to similarity to the given vector 
representation
-   * of the word. Returns a dataframe with the words and the cosine 
similarities between the
-   * synonyms and the given word vector.
+   * Find "num" number of words whose vector representation most similar to 
the supplied vector.
+   * If the supplied vector is the vector representation of a word in the 
model's vocabulary,
+   * that word will be in the results.  Returns a dataframe with the words and 
the cosine
+   * similarities between the synonyms and the given word vector.
    */
   @Since("2.0.0")
-  def findSynonyms(word: Vector, num: Int): DataFrame = {
+  def findSynonyms(vec: Vector, num: Int): DataFrame = {
     val spark = SparkSession.builder().getOrCreate()
-    spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", 
"similarity")
+    spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", 
"similarity")
   }
 
   /** @group setParam */

http://git-wip-us.apache.org/repos/asf/spark/blob/eb2675de/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
index 4b4ed22..5cbfbff 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
@@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: 
Word2VecModel) {
     rdd.rdd.map(model.transform)
   }
 
+  /**
+   * Finds synonyms of a word; do not include the word itself in results.
+   * @param word a word
+   * @param num number of synonyms to find
+   * @return a list consisting of a list of words and a vector of cosine 
similarities
+   */
   def findSynonyms(word: String, num: Int): JList[Object] = {
-    val vec = transform(word)
-    findSynonyms(vec, num)
+    prepareResult(model.findSynonyms(word, num))
   }
 
+  /**
+   * Finds words similar to the the vector representation of a word without
+   * filtering results.
+   * @param vector a vector
+   * @param num number of synonyms to find
+   * @return a list consisting of a list of words and a vector of cosine 
similarities
+   */
   def findSynonyms(vector: Vector, num: Int): JList[Object] = {
-    val result = model.findSynonyms(vector, num)
+    prepareResult(model.findSynonyms(vector, num))
+  }
+
+  private def prepareResult(result: Array[(String, Double)]) = {
     val similarity = Vectors.dense(result.map(_._2))
     val words = result.map(_._1)
     List(words, similarity).map(_.asInstanceOf[Object]).asJava
   }
 
+
   def getVectors: JMap[String, JList[Float]] = {
     model.getVectors.map { case (k, v) =>
       (k, v.toList.asJava)

http://git-wip-us.apache.org/repos/asf/spark/blob/eb2675de/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index bc75646..761996f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -518,7 +518,7 @@ class Word2VecModel private[spark] (
   }
 
   /**
-   * Find synonyms of a word
+   * Find synonyms of a word; do not include the word itself in results.
    * @param word a word
    * @param num number of synonyms to find
    * @return array of (word, cosineSimilarity)
@@ -526,17 +526,34 @@ class Word2VecModel private[spark] (
   @Since("1.1.0")
   def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
     val vector = transform(word)
-    findSynonyms(vector, num)
+    findSynonyms(vector, num, Some(word))
   }
 
   /**
-   * Find synonyms of the vector representation of a word
+   * Find synonyms of the vector representation of a word, possibly
+   * including any words in the model vocabulary whose vector respresentation
+   * is the supplied vector.
    * @param vector vector representation of a word
    * @param num number of synonyms to find
    * @return array of (word, cosineSimilarity)
    */
   @Since("1.1.0")
   def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
+    findSynonyms(vector, num, None)
+  }
+
+  /**
+   * Find synonyms of the vector representation of a word, rejecting
+   * words identical to the value of wordOpt, if one is supplied.
+   * @param vector vector representation of a word
+   * @param num number of synonyms to find
+   * @param wordOpt optionally, a word to reject from the results list
+   * @return array of (word, cosineSimilarity)
+   */
+  private def findSynonyms(
+      vector: Vector,
+      num: Int,
+      wordOpt: Option[String]): Array[(String, Double)] = {
     require(num > 0, "Number of similar words should > 0")
     // TODO: optimize top-k
     val fVector = vector.toArray.map(_.toFloat)
@@ -563,12 +580,14 @@ class Word2VecModel private[spark] (
       ind += 1
     }
 
-    wordList.zip(cosVec)
-      .toSeq
-      .sortBy(-_._2)
-      .take(num + 1)
-      .tail
-      .toArray
+    val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2)
+
+    val filtered = wordOpt match {
+      case Some(w) => scored.take(num + 1).filter(tup => w != tup._1)
+      case None => scored
+    }
+
+    filtered.take(num).toArray
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/eb2675de/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 22de4c4..f4fa216 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.mllib.feature
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.util.Utils
 
@@ -68,6 +69,21 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(syms(1)._1 == "japan")
   }
 
+  test("findSynonyms doesn't reject similar word vectors when called with a 
vector") {
+    val num = 2
+    val word2VecMap = Map(
+      ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
+      ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
+      ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
+      ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
+    )
+    val model = new Word2VecModel(word2VecMap)
+    val syms = model.findSynonyms(Vectors.dense(Array(0.52, 0.5, 0.5, 0.5)), 
num)
+    assert(syms.length == num)
+    assert(syms(0)._1 == "china")
+    assert(syms(1)._1 == "taiwan")
+  }
+
   test("model load / save") {
 
     val word2VecMap = Map(

http://git-wip-us.apache.org/repos/asf/spark/blob/eb2675de/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index c8a6e33..9295318 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -545,8 +545,7 @@ class Word2VecModel(JavaVectorTransformer, JavaSaveable, 
JavaLoader):
 
 @ignore_unicode_prefix
 class Word2Vec(object):
-    """
-    Word2Vec creates vector representation of words in a text corpus.
+    """Word2Vec creates vector representation of words in a text corpus.
     The algorithm first constructs a vocabulary from the corpus
     and then learns vector representation of words in the vocabulary.
     The vector representation can be used as features in
@@ -568,13 +567,19 @@ class Word2Vec(object):
     >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
     >>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc)
 
+    Querying for synonyms of a word will not return that word:
+
     >>> syms = model.findSynonyms("a", 2)
     >>> [s[0] for s in syms]
     [u'b', u'c']
+
+    But querying for synonyms of a vector may return the word whose
+    representation is that vector:
+
     >>> vec = model.transform("a")
     >>> syms = model.findSynonyms(vec, 2)
     >>> [s[0] for s in syms]
-    [u'b', u'c']
+    [u'a', u'b']
 
     >>> import os, tempfile
     >>> path = tempfile.mkdtemp()
@@ -592,6 +597,7 @@ class Word2Vec(object):
     ...     pass
 
     .. versionadded:: 1.2.0
+
     """
     def __init__(self):
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to