Repository: spark Updated Branches: refs/heads/master d3b886976 -> 7654385f2
[SPARK-17595][MLLIB] Use a bounded priority queue to find synonyms in Word2VecModel ## What changes were proposed in this pull request? The code in `Word2VecModel.findSynonyms` to choose the vocabulary elements with the highest similarity to the query vector currently sorts the collection of similarities for every vocabulary element. This involves making multiple copies of the collection of similarities while doing a (relatively) expensive sort. It would be more efficient to find the best matches by maintaining a bounded priority queue and populating it with a single pass over the vocabulary, and that is exactly what this patch does. ## How was this patch tested? This patch adds no user-visible functionality and its correctness should be exercised by existing tests. To ensure that this approach is actually faster, I made a microbenchmark for `findSynonyms`: ``` object W2VTiming { import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.mllib.feature.Word2VecModel def run(modelPath: String, scOpt: Option[SparkContext] = None) { val sc = scOpt.getOrElse(new SparkContext(new SparkConf(true).setMaster("local[*]").setAppName("test"))) val model = Word2VecModel.load(sc, modelPath) val keys = model.getVectors.keys val start = System.currentTimeMillis for(key <- keys) { model.findSynonyms(key, 5) model.findSynonyms(key, 10) model.findSynonyms(key, 25) model.findSynonyms(key, 50) } val finish = System.currentTimeMillis println("run completed in " + (finish - start) + "ms") } } ``` I ran this test on a model generated from the complete works of Jane Austen and found that the new approach was over 3x faster than the old approach. (If the `num` argument to `findSynonyms` is very close to the vocabulary size, the new approach will have less of an advantage over the old one.) Author: William Benton <wi...@redhat.com> Closes #15150 from willb/SPARK-17595. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7654385f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7654385f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7654385f Branch: refs/heads/master Commit: 7654385f268a3f481c4574ce47a19ab21155efd5 Parents: d3b8869 Author: William Benton <wi...@redhat.com> Authored: Wed Sep 21 09:45:06 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Wed Sep 21 09:45:06 2016 +0100 ---------------------------------------------------------------------- .../org/apache/spark/mllib/feature/Word2Vec.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7654385f/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 42ca966..2364d43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -35,6 +35,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.sql.SparkSession +import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -555,7 +556,7 @@ class Word2VecModel private[spark] ( num: Int, wordOpt: Option[String]): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 @@ -580,10 +581,16 @@ class Word2VecModel private[spark] ( ind += 1 } - val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2) + val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2)) + + for(i <- cosVec.indices) { + pq += Tuple2(wordList(i), cosVec(i)) + } + + val scored = pq.toSeq.sortBy(-_._2) val filtered = wordOpt match { - case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) + case Some(w) => scored.filter(tup => w != tup._1) case None => scored } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org