This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 8a01aa64e1c4 [SPARK-50958][ML][PYTHON][CONNECT] Support 
`Word2VecModel.findSynonymsArray` on connect
8a01aa64e1c4 is described below

commit 8a01aa64e1c4118a206daf7be08adbc03779c102
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 18:19:33 2025 +0800

    [SPARK-50958][ML][PYTHON][CONNECT] Support 
`Word2VecModel.findSynonymsArray` on connect
    
    ### What changes were proposed in this pull request?
    Fix `Word2VecModel.findSynonymsArray` on connect
    
    ### Why are the changes needed?
    for parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new API supported on connect
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49804 from zhengruifeng/ml_connect_fix_w2v_find.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 6ed7733265c2dc9ced0a36bfd8bee5fc7950b152)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/ml/feature.py                                |  9 ++++-----
 python/pyspark/ml/tests/test_feature.py                     | 13 ++++++-------
 .../scala/org/apache/spark/sql/connect/ml/MLUtils.scala     |  2 +-
 3 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index fb9c96bd6114..d6c743510744 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -6480,11 +6480,10 @@ class Word2VecModel(JavaModel, _Word2VecParams, 
JavaMLReadable["Word2VecModel"],
         Returns an array with two fields word and similarity (which
         gives the cosine similarity).
         """
-        if not isinstance(word, str):
-            word = _convert_to_vector(word)
-        assert self._java_obj is not None
-        tuples = self._java_obj.findSynonymsArray(word, num)
-        return list(map(lambda st: (st._1(), st._2()), list(tuples)))
+        res = []
+        for row in self.findSynonyms(word, num).collect():
+            res.append((str(row.word), float(row.similarity)))
+        return res
 
 
 class _PCAParams(HasInputCol, HasOutputCol):
diff --git a/python/pyspark/ml/tests/test_feature.py 
b/python/pyspark/ml/tests/test_feature.py
index 4298905e452b..d86499ee9e42 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -678,13 +678,12 @@ class FeatureTestsMixin:
         self.assertEqual(synonyms.columns, ["word", "similarity"])
         self.assertEqual(synonyms.count(), 2)
 
-        # TODO(SPARK-50958): Support Word2VecModel.findSynonymsArray
-        # synonyms = model.findSynonymsArray("a", 2)
-        # self.assertEqual(len(synonyms), 2)
-        # self.assertEqual(synonyms[0][0], "b")
-        # self.assertTrue(np.allclose(synonyms[0][1], -0.024012837558984756, 
atol=1e-4))
-        # self.assertEqual(synonyms[1][0], "c")
-        # self.assertTrue(np.allclose(synonyms[1][1], -0.19355154037475586, 
atol=1e-4))
+        synonyms = model.findSynonymsArray("a", 2)
+        self.assertEqual(len(synonyms), 2)
+        self.assertEqual(synonyms[0][0], "b")
+        self.assertTrue(np.allclose(synonyms[0][1], -0.024012837558984756, 
atol=1e-4))
+        self.assertEqual(synonyms[1][0], "c")
+        self.assertTrue(np.allclose(synonyms[1][1], -0.19355154037475586, 
atol=1e-4))
 
         output = model.transform(df)
         self.assertEqual(output.columns, ["sentence", "model"])
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 4418dee68ed1..ba881f0de48f 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -645,7 +645,7 @@ private[ml] object MLUtils {
     (classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")),
     (classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")),
     (classOf[PCAModel], Set("pc", "explainedVariance")),
-    (classOf[Word2VecModel], Set("getVectors", "findSynonyms", 
"findSynonymsArray")),
+    (classOf[Word2VecModel], Set("getVectors", "findSynonyms")),
     (classOf[CountVectorizerModel], Set("vocabulary")),
     (classOf[OneHotEncoderModel], Set("categorySizes")),
     (classOf[StringIndexerModel], Set("labels", "labelsArray")),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to