Github user shubhamchopra commented on a diff in the pull request: https://github.com/apache/spark/pull/17673#discussion_r143529286 --- Diff: mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala --- @@ -189,6 +305,136 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } + test("window size - CBOW") { + + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setSolver("cbow-ns") + .setMaxUnigramTableSize(10000) + .setNegativeSamples(5) + .setVectorSize(3) + .setWindowSize(2) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val (synonyms, similarity) = model.findSynonyms("a", 6).rdd.map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + // Increase the window size + val biggerModel = new Word2Vec() + .setSolver("cbow-ns") + .setMaxUnigramTableSize(10000) + .setNegativeSamples(5) + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .setWindowSize(10) + .fit(docDF) + + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).rdd.map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + // The similarity score should be very different with the larger window + assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) + } + + test("negative sampling - CBOW") { + + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setSolver("cbow-ns") + .setMaxUnigramTableSize(10000) + .setNegativeSamples(2) + .setVectorSize(3) + .setWindowSize(2) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val (synonyms, similarity) = model.findSynonyms("a", 6).rdd.map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + // Increase the window size --- End diff -- Corrected. Thanks.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org