This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 8a01aa64e1c4 [SPARK-50958][ML][PYTHON][CONNECT] Support
`Word2VecModel.findSynonymsArray` on connect
8a01aa64e1c4 is described below
commit 8a01aa64e1c4118a206daf7be08adbc03779c102
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 18:19:33 2025 +0800
[SPARK-50958][ML][PYTHON][CONNECT] Support
`Word2VecModel.findSynonymsArray` on connect
### What changes were proposed in this pull request?
Fix `Word2VecModel.findSynonymsArray` on connect
### Why are the changes needed?
for parity
### Does this PR introduce _any_ user-facing change?
yes, new API supported on connect
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49804 from zhengruifeng/ml_connect_fix_w2v_find.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 6ed7733265c2dc9ced0a36bfd8bee5fc7950b152)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/ml/feature.py | 9 ++++-----
python/pyspark/ml/tests/test_feature.py | 13 ++++++-------
.../scala/org/apache/spark/sql/connect/ml/MLUtils.scala | 2 +-
3 files changed, 11 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index fb9c96bd6114..d6c743510744 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -6480,11 +6480,10 @@ class Word2VecModel(JavaModel, _Word2VecParams,
JavaMLReadable["Word2VecModel"],
Returns an array with two fields word and similarity (which
gives the cosine similarity).
"""
- if not isinstance(word, str):
- word = _convert_to_vector(word)
- assert self._java_obj is not None
- tuples = self._java_obj.findSynonymsArray(word, num)
- return list(map(lambda st: (st._1(), st._2()), list(tuples)))
+ res = []
+ for row in self.findSynonyms(word, num).collect():
+ res.append((str(row.word), float(row.similarity)))
+ return res
class _PCAParams(HasInputCol, HasOutputCol):
diff --git a/python/pyspark/ml/tests/test_feature.py
b/python/pyspark/ml/tests/test_feature.py
index 4298905e452b..d86499ee9e42 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -678,13 +678,12 @@ class FeatureTestsMixin:
self.assertEqual(synonyms.columns, ["word", "similarity"])
self.assertEqual(synonyms.count(), 2)
- # TODO(SPARK-50958): Support Word2VecModel.findSynonymsArray
- # synonyms = model.findSynonymsArray("a", 2)
- # self.assertEqual(len(synonyms), 2)
- # self.assertEqual(synonyms[0][0], "b")
- # self.assertTrue(np.allclose(synonyms[0][1], -0.024012837558984756,
atol=1e-4))
- # self.assertEqual(synonyms[1][0], "c")
- # self.assertTrue(np.allclose(synonyms[1][1], -0.19355154037475586,
atol=1e-4))
+ synonyms = model.findSynonymsArray("a", 2)
+ self.assertEqual(len(synonyms), 2)
+ self.assertEqual(synonyms[0][0], "b")
+ self.assertTrue(np.allclose(synonyms[0][1], -0.024012837558984756,
atol=1e-4))
+ self.assertEqual(synonyms[1][0], "c")
+ self.assertTrue(np.allclose(synonyms[1][1], -0.19355154037475586,
atol=1e-4))
output = model.transform(df)
self.assertEqual(output.columns, ["sentence", "model"])
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 4418dee68ed1..ba881f0de48f 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -645,7 +645,7 @@ private[ml] object MLUtils {
(classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")),
(classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")),
(classOf[PCAModel], Set("pc", "explainedVariance")),
- (classOf[Word2VecModel], Set("getVectors", "findSynonyms",
"findSynonymsArray")),
+ (classOf[Word2VecModel], Set("getVectors", "findSynonyms")),
(classOf[CountVectorizerModel], Set("vocabulary")),
(classOf[OneHotEncoderModel], Set("categorySizes")),
(classOf[StringIndexerModel], Set("labels", "labelsArray")),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]