This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 9243ff6b827d [SPARK-51091][ML][PYTHON][CONNECT] Fix the default params
of `StopWordsRemover`
9243ff6b827d is described below
commit 9243ff6b827de7c246ce7c0e3abd377680bc99df
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 23:28:54 2025 +0800
[SPARK-51091][ML][PYTHON][CONNECT] Fix the default params of
`StopWordsRemover`
### What changes were proposed in this pull request?
Fix the default params of `StopWordsRemover`
### Why are the changes needed?
for feature parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49809 from zhengruifeng/ml_connect_swr.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 4fd750c3f41cd195f35c03af982ff5cb74f2b3e1)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../apache/spark/ml/feature/StopWordsRemover.scala | 32 ++++++++++------------
.../org/apache/spark/ml/util/ConnectHelper.scala | 10 ++++++-
python/pyspark/ml/feature.py | 30 ++++++++++----------
.../ml/tests/connect/test_parity_feature.py | 4 +--
python/pyspark/ml/tests/test_feature.py | 24 +++++++++++++---
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 4 ++-
6 files changed, 62 insertions(+), 42 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 99a20f3aa52c..cb9d8b32f006 100755
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import java.util.Locale
import org.apache.spark.annotation.Since
-import org.apache.spark.internal.{LogKeys, MDC}
+import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols,
HasOutputCol, HasOutputCols}
@@ -122,21 +122,6 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0")
override val uid: String
@Since("2.4.0")
def getLocale: String = $(locale)
- /**
- * Returns system default locale, or `Locale.US` if the default locale is
not in available locales
- * in JVM.
- */
- private val getDefaultOrUS: Locale = {
- if (Locale.getAvailableLocales.contains(Locale.getDefault)) {
- Locale.getDefault
- } else {
- logWarning(log"Default locale set was [${MDC(LogKeys.LOCALE,
Locale.getDefault)}]; " +
- log"however, it was not found in available locales in JVM, falling
back to en_US locale. " +
- log"Set param `locale` in order to respect another locale.")
- Locale.US
- }
- }
-
/** Returns the input and output column names corresponding in pair. */
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
if (isSet(inputCol)) {
@@ -147,7 +132,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0")
override val uid: String
}
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
- caseSensitive -> false, locale -> getDefaultOrUS.toString)
+ caseSensitive -> false, locale -> StopWordsRemover.getDefaultOrUS.toString)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
@@ -218,7 +203,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0")
override val uid: String
}
@Since("1.6.0")
-object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {
+object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] with
Logging {
private[feature]
val supportedLanguages = Set("danish", "dutch", "english", "finnish",
"french", "german",
@@ -241,4 +226,15 @@ object StopWordsRemover extends
DefaultParamsReadable[StopWordsRemover] {
val is =
getClass.getResourceAsStream(s"/org/apache/spark/ml/feature/stopwords/$language.txt")
scala.io.Source.fromInputStream(is)(scala.io.Codec.UTF8).getLines().toArray
}
+
+ private[spark] def getDefaultOrUS: Locale = {
+ if (Locale.getAvailableLocales.contains(Locale.getDefault)) {
+ Locale.getDefault
+ } else {
+ logWarning(log"Default locale set was [${MDC(LogKeys.LOCALE,
Locale.getDefault)}]; " +
+ log"however, it was not found in available locales in JVM, falling
back to en_US locale. " +
+ log"Set param `locale` in order to respect another locale.")
+ Locale.US
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
index d4a0a1301e15..fb2e1a0c0b4e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.util
import org.apache.spark.ml.Model
-import org.apache.spark.ml.feature.{CountVectorizerModel, StringIndexerModel}
+import org.apache.spark.ml.feature._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.types.StructType
@@ -46,6 +46,14 @@ private[spark] class ConnectHelper(override val uid: String)
extends Model[Conne
new CountVectorizerModel(uid, vocabulary)
}
+ def stopWordsRemoverLoadDefaultStopWords(language: String): Array[String] = {
+ StopWordsRemover.loadDefaultStopWords(language)
+ }
+
+ def stopWordsRemoverGetDefaultOrUS: String = {
+ StopWordsRemover.getDefaultOrUS.toString
+ }
+
override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)
override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF()
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 6d1ddf5e51c4..3797b0f2b04c 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -5043,7 +5043,6 @@ class StopWordsRemover(
Notes
-----
- null values from input array are preserved unless adding null to
stopWords explicitly.
- - In Spark Connect Mode, the default value of parameter `locale` and
`stopWords` are not set.
Examples
--------
@@ -5142,19 +5141,14 @@ class StopWordsRemover(
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.StopWordsRemover", self.uid
)
- if isinstance(self._java_obj, str):
- # Skip setting the default value of 'locale' and 'stopWords', which
- # needs to invoke a JVM method.
- # So if users don't explicitly set 'locale' and/or 'stopWords',
then the getters fails.
- self._setDefault(
- caseSensitive=False,
- )
+ if is_remote():
+ helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+ locale = helper._call_java("stopWordsRemoverGetDefaultOrUS")
else:
- self._setDefault(
- stopWords=StopWordsRemover.loadDefaultStopWords("english"),
- caseSensitive=False,
- locale=self._java_obj.getLocale(),
- )
+ locale = self._java_obj.getLocale()
+
+ stopWords = StopWordsRemover.loadDefaultStopWords("english")
+ self._setDefault(stopWords=stopWords, caseSensitive=False,
locale=locale)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -5279,8 +5273,14 @@ class StopWordsRemover(
Supported languages: danish, dutch, english, finnish, french, german,
hungarian,
italian, norwegian, portuguese, russian, spanish, swedish, turkish
"""
- stopWordsObj = getattr(_jvm(),
"org.apache.spark.ml.feature.StopWordsRemover")
- return list(stopWordsObj.loadDefaultStopWords(language))
+ if is_remote():
+ helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+ stopWords =
helper._call_java("stopWordsRemoverLoadDefaultStopWords", language)
+ return list(stopWords)
+
+ else:
+ stopWordsObj = getattr(_jvm(),
"org.apache.spark.ml.feature.StopWordsRemover")
+ return list(stopWordsObj.loadDefaultStopWords(language))
class _TargetEncoderParams(
diff --git a/python/pyspark/ml/tests/connect/test_parity_feature.py
b/python/pyspark/ml/tests/connect/test_parity_feature.py
index 323721130e12..dd4580e4da5a 100644
--- a/python/pyspark/ml/tests/connect/test_parity_feature.py
+++ b/python/pyspark/ml/tests/connect/test_parity_feature.py
@@ -22,9 +22,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
class FeatureParityTests(FeatureTestsMixin, ReusedConnectTestCase):
- @unittest.skip("Need to support.")
- def test_stop_words_lengague_selection(self):
- super().test_stop_words_lengague_selection()
+ pass
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/test_feature.py
b/python/pyspark/ml/tests/test_feature.py
index a3e580ec7220..cd2011c5ec87 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -883,8 +883,9 @@ class FeatureTestsMixin:
remover2 = StopWordsRemover.load(d)
self.assertEqual(str(remover), str(remover2))
- def test_stop_words_remover_II(self):
- dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
+ def test_stop_words_remover_with_given_words(self):
+ spark = self.spark
+ dataset = spark.createDataFrame([Row(input=["a", "panda"])])
stopWordRemover = StopWordsRemover(inputCol="input",
outputCol="output")
# Default
self.assertEqual(stopWordRemover.getInputCol(), "input")
@@ -905,15 +906,30 @@ class FeatureTestsMixin:
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
- def test_stop_words_language_selection(self):
+ def test_stop_words_remover_with_turkish(self):
+ spark = self.spark
+ dataset = spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
stopWordRemover = StopWordsRemover(inputCol="input",
outputCol="output")
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
- dataset = self.spark.createDataFrame([Row(input=["acaba", "ama",
"biri"])])
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
+ def test_stop_words_remover_default(self):
+ stopWordRemover = StopWordsRemover(inputCol="input",
outputCol="output")
+
+ # check the default value of local
+ locale = stopWordRemover.getLocale()
+ self.assertIsInstance(locale, str)
+ self.assertTrue(len(locale) > 0)
+
+ # check the default value of stop words
+ stopwords = stopWordRemover.getStopWords()
+ self.assertIsInstance(stopwords, list)
+ self.assertTrue(len(stopwords) > 0)
+ self.assertTrue(all(isinstance(word, str) for word in stopwords))
+
def test_binarizer(self):
b0 = Binarizer()
self.assertListEqual(
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 2cbae0fa7d9d..b613b2202137 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -659,7 +659,9 @@ private[ml] object MLUtils {
"handleOverwrite",
"stringIndexerModelFromLabels",
"stringIndexerModelFromLabelsArray",
- "countVectorizerModelFromVocabulary")))
+ "countVectorizerModelFromVocabulary",
+ "stopWordsRemoverLoadDefaultStopWords",
+ "stopWordsRemoverGetDefaultOrUS")))
private def validate(obj: Any, method: String): Unit = {
assert(obj != null)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]