This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 9243ff6b827d [SPARK-51091][ML][PYTHON][CONNECT] Fix the default params 
of `StopWordsRemover`
9243ff6b827d is described below

commit 9243ff6b827de7c246ce7c0e3abd377680bc99df
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 23:28:54 2025 +0800

    [SPARK-51091][ML][PYTHON][CONNECT] Fix the default params of 
`StopWordsRemover`
    
    ### What changes were proposed in this pull request?
    Fix the default params of `StopWordsRemover`
    
    ### Why are the changes needed?
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49809 from zhengruifeng/ml_connect_swr.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 4fd750c3f41cd195f35c03af982ff5cb74f2b3e1)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../apache/spark/ml/feature/StopWordsRemover.scala | 32 ++++++++++------------
 .../org/apache/spark/ml/util/ConnectHelper.scala   | 10 ++++++-
 python/pyspark/ml/feature.py                       | 30 ++++++++++----------
 .../ml/tests/connect/test_parity_feature.py        |  4 +--
 python/pyspark/ml/tests/test_feature.py            | 24 +++++++++++++---
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  4 ++-
 6 files changed, 62 insertions(+), 42 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 99a20f3aa52c..cb9d8b32f006 100755
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
 import java.util.Locale
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.internal.{LogKeys, MDC}
+import org.apache.spark.internal.{Logging, LogKeys, MDC}
 import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, 
HasOutputCol, HasOutputCols}
@@ -122,21 +122,6 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") 
override val uid: String
   @Since("2.4.0")
   def getLocale: String = $(locale)
 
-  /**
-   * Returns system default locale, or `Locale.US` if the default locale is 
not in available locales
-   * in JVM.
-   */
-  private val getDefaultOrUS: Locale = {
-    if (Locale.getAvailableLocales.contains(Locale.getDefault)) {
-      Locale.getDefault
-    } else {
-      logWarning(log"Default locale set was [${MDC(LogKeys.LOCALE, 
Locale.getDefault)}]; " +
-        log"however, it was not found in available locales in JVM, falling 
back to en_US locale. " +
-        log"Set param `locale` in order to respect another locale.")
-      Locale.US
-    }
-  }
-
   /** Returns the input and output column names corresponding in pair. */
   private[feature] def getInOutCols(): (Array[String], Array[String]) = {
     if (isSet(inputCol)) {
@@ -147,7 +132,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") 
override val uid: String
   }
 
   setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
-    caseSensitive -> false, locale -> getDefaultOrUS.toString)
+    caseSensitive -> false, locale -> StopWordsRemover.getDefaultOrUS.toString)
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
@@ -218,7 +203,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") 
override val uid: String
 }
 
 @Since("1.6.0")
-object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
+object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] with 
Logging {
 
   private[feature]
   val supportedLanguages = Set("danish", "dutch", "english", "finnish", 
"french", "german",
@@ -241,4 +226,15 @@ object StopWordsRemover extends 
DefaultParamsReadable[StopWordsRemover] {
     val is = 
getClass.getResourceAsStream(s"/org/apache/spark/ml/feature/stopwords/$language.txt")
     scala.io.Source.fromInputStream(is)(scala.io.Codec.UTF8).getLines().toArray
   }
+
+  private[spark] def getDefaultOrUS: Locale = {
+    if (Locale.getAvailableLocales.contains(Locale.getDefault)) {
+      Locale.getDefault
+    } else {
+      logWarning(log"Default locale set was [${MDC(LogKeys.LOCALE, 
Locale.getDefault)}]; " +
+        log"however, it was not found in available locales in JVM, falling 
back to en_US locale. " +
+        log"Set param `locale` in order to respect another locale.")
+      Locale.US
+    }
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
index d4a0a1301e15..fb2e1a0c0b4e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -17,7 +17,7 @@
 package org.apache.spark.ml.util
 
 import org.apache.spark.ml.Model
-import org.apache.spark.ml.feature.{CountVectorizerModel, StringIndexerModel}
+import org.apache.spark.ml.feature._
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.types.StructType
@@ -46,6 +46,14 @@ private[spark] class ConnectHelper(override val uid: String) 
extends Model[Conne
     new CountVectorizerModel(uid, vocabulary)
   }
 
+  def stopWordsRemoverLoadDefaultStopWords(language: String): Array[String] = {
+    StopWordsRemover.loadDefaultStopWords(language)
+  }
+
+  def stopWordsRemoverGetDefaultOrUS: String = {
+    StopWordsRemover.getDefaultOrUS.toString
+  }
+
   override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)
 
   override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF()
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 6d1ddf5e51c4..3797b0f2b04c 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -5043,7 +5043,6 @@ class StopWordsRemover(
     Notes
     -----
     - null values from input array are preserved unless adding null to 
stopWords explicitly.
-    - In Spark Connect Mode, the default value of parameter `locale` and 
`stopWords` are not set.
 
     Examples
     --------
@@ -5142,19 +5141,14 @@ class StopWordsRemover(
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.feature.StopWordsRemover", self.uid
         )
-        if isinstance(self._java_obj, str):
-            # Skip setting the default value of 'locale' and 'stopWords', which
-            # needs to invoke a JVM method.
-            # So if users don't explicitly set 'locale' and/or 'stopWords', 
then the getters fails.
-            self._setDefault(
-                caseSensitive=False,
-            )
+        if is_remote():
+            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+            locale = helper._call_java("stopWordsRemoverGetDefaultOrUS")
         else:
-            self._setDefault(
-                stopWords=StopWordsRemover.loadDefaultStopWords("english"),
-                caseSensitive=False,
-                locale=self._java_obj.getLocale(),
-            )
+            locale = self._java_obj.getLocale()
+
+        stopWords = StopWordsRemover.loadDefaultStopWords("english")
+        self._setDefault(stopWords=stopWords, caseSensitive=False, 
locale=locale)
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
@@ -5279,8 +5273,14 @@ class StopWordsRemover(
         Supported languages: danish, dutch, english, finnish, french, german, 
hungarian,
         italian, norwegian, portuguese, russian, spanish, swedish, turkish
         """
-        stopWordsObj = getattr(_jvm(), 
"org.apache.spark.ml.feature.StopWordsRemover")
-        return list(stopWordsObj.loadDefaultStopWords(language))
+        if is_remote():
+            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+            stopWords = 
helper._call_java("stopWordsRemoverLoadDefaultStopWords", language)
+            return list(stopWords)
+
+        else:
+            stopWordsObj = getattr(_jvm(), 
"org.apache.spark.ml.feature.StopWordsRemover")
+            return list(stopWordsObj.loadDefaultStopWords(language))
 
 
 class _TargetEncoderParams(
diff --git a/python/pyspark/ml/tests/connect/test_parity_feature.py 
b/python/pyspark/ml/tests/connect/test_parity_feature.py
index 323721130e12..dd4580e4da5a 100644
--- a/python/pyspark/ml/tests/connect/test_parity_feature.py
+++ b/python/pyspark/ml/tests/connect/test_parity_feature.py
@@ -22,9 +22,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
 class FeatureParityTests(FeatureTestsMixin, ReusedConnectTestCase):
-    @unittest.skip("Need to support.")
-    def test_stop_words_lengague_selection(self):
-        super().test_stop_words_lengague_selection()
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/test_feature.py 
b/python/pyspark/ml/tests/test_feature.py
index a3e580ec7220..cd2011c5ec87 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -883,8 +883,9 @@ class FeatureTestsMixin:
             remover2 = StopWordsRemover.load(d)
             self.assertEqual(str(remover), str(remover2))
 
-    def test_stop_words_remover_II(self):
-        dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
+    def test_stop_words_remover_with_given_words(self):
+        spark = self.spark
+        dataset = spark.createDataFrame([Row(input=["a", "panda"])])
         stopWordRemover = StopWordsRemover(inputCol="input", 
outputCol="output")
         # Default
         self.assertEqual(stopWordRemover.getInputCol(), "input")
@@ -905,15 +906,30 @@ class FeatureTestsMixin:
         transformedDF = stopWordRemover.transform(dataset)
         self.assertEqual(transformedDF.head().output, [])
 
-    def test_stop_words_language_selection(self):
+    def test_stop_words_remover_with_turkish(self):
+        spark = self.spark
+        dataset = spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
         stopWordRemover = StopWordsRemover(inputCol="input", 
outputCol="output")
         stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
-        dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", 
"biri"])])
         stopWordRemover.setStopWords(stopwords)
         self.assertEqual(stopWordRemover.getStopWords(), stopwords)
         transformedDF = stopWordRemover.transform(dataset)
         self.assertEqual(transformedDF.head().output, [])
 
+    def test_stop_words_remover_default(self):
+        stopWordRemover = StopWordsRemover(inputCol="input", 
outputCol="output")
+
+        # check the default value of local
+        locale = stopWordRemover.getLocale()
+        self.assertIsInstance(locale, str)
+        self.assertTrue(len(locale) > 0)
+
+        # check the default value of stop words
+        stopwords = stopWordRemover.getStopWords()
+        self.assertIsInstance(stopwords, list)
+        self.assertTrue(len(stopwords) > 0)
+        self.assertTrue(all(isinstance(word, str) for word in stopwords))
+
     def test_binarizer(self):
         b0 = Binarizer()
         self.assertListEqual(
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 2cbae0fa7d9d..b613b2202137 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -659,7 +659,9 @@ private[ml] object MLUtils {
         "handleOverwrite",
         "stringIndexerModelFromLabels",
         "stringIndexerModelFromLabelsArray",
-        "countVectorizerModelFromVocabulary")))
+        "countVectorizerModelFromVocabulary",
+        "stopWordsRemoverLoadDefaultStopWords",
+        "stopWordsRemoverGetDefaultOrUS")))
 
   private def validate(obj: Any, method: String): Unit = {
     assert(obj != null)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to