Repository: spark
Updated Branches:
  refs/heads/master 1d7db65e9 -> 5d6a53d98


[SPARK-15064][ML] Locale support in StopWordsRemover

## What changes were proposed in this pull request?

Add locale support for `StopWordsRemover`.

## How was this patch tested?

[Scala|Python] unit tests.

Author: Lee Dongjin <[email protected]>

Closes #21501 from dongjinleekr/feature/SPARK-15064.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5d6a53d9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5d6a53d9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5d6a53d9

Branch: refs/heads/master
Commit: 5d6a53d9831cc1e2115560db5cebe0eea2565dcd
Parents: 1d7db65
Author: Lee Dongjin <[email protected]>
Authored: Tue Jun 12 08:16:37 2018 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Tue Jun 12 08:16:37 2018 -0700

----------------------------------------------------------------------
 .../spark/ml/feature/StopWordsRemover.scala     | 30 ++++++++++--
 .../ml/feature/StopWordsRemoverSuite.scala      | 51 ++++++++++++++++++++
 python/pyspark/ml/feature.py                    | 30 ++++++++++--
 python/pyspark/ml/tests.py                      |  7 +++
 4 files changed, 109 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5d6a53d9/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 3fcd84c..0f946dd 100755
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -17,9 +17,11 @@
 
 package org.apache.spark.ml.feature
 
+import java.util.Locale
+
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
+import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -84,7 +86,27 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") 
override val uid: String
   @Since("1.5.0")
   def getCaseSensitive: Boolean = $(caseSensitive)
 
-  setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), 
caseSensitive -> false)
+  /**
+   * Locale of the input for case insensitive matching. Ignored when 
[[caseSensitive]]
+   * is true.
+   * Default: Locale.getDefault.toString
+   * @group param
+   */
+  @Since("2.4.0")
+  val locale: Param[String] = new Param[String](this, "locale",
+    "Locale of the input for case insensitive matching. Ignored when 
caseSensitive is true.",
+    
ParamValidators.inArray[String](Locale.getAvailableLocales.map(_.toString)))
+
+  /** @group setParam */
+  @Since("2.4.0")
+  def setLocale(value: String): this.type = set(locale, value)
+
+  /** @group getParam */
+  @Since("2.4.0")
+  def getLocale: String = $(locale)
+
+  setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
+    caseSensitive -> false, locale -> Locale.getDefault.toString)
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
@@ -95,8 +117,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") 
override val uid: String
         terms.filter(s => !stopWordsSet.contains(s))
       }
     } else {
-      // TODO: support user locale (SPARK-15064)
-      val toLower = (s: String) => if (s != null) s.toLowerCase else s
+      val lc = new Locale($(locale))
+      val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s
       val lowerStopWords = $(stopWords).map(toLower(_)).toSet
       udf { terms: Seq[String] =>
         terms.filter(s => !lowerStopWords.contains(toLower(s)))

http://git-wip-us.apache.org/repos/asf/spark/blob/5d6a53d9/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index 21259a5..20972d1 100755
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -65,6 +65,57 @@ class StopWordsRemoverSuite extends MLTest with 
DefaultReadWriteTest {
     testStopWordsRemover(remover, dataSet)
   }
 
+  test("StopWordsRemover with localed input (case insensitive)") {
+    val stopWords = Array("milk", "cookie")
+    val remover = new StopWordsRemover()
+      .setInputCol("raw")
+      .setOutputCol("filtered")
+      .setStopWords(stopWords)
+      .setCaseSensitive(false)
+      .setLocale("tr")  // Turkish alphabet: has no Q, W, X but has dotted and 
dotless 'I's.
+    val dataSet = Seq(
+      // scalastyle:off
+      (Seq("mİlk", "and", "nuts"), Seq("and", "nuts")),
+      // scalastyle:on
+      (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")),
+      (Seq(null), Seq(null)),
+      (Seq(), Seq())
+    ).toDF("raw", "expected")
+
+    testStopWordsRemover(remover, dataSet)
+  }
+
+  test("StopWordsRemover with localed input (case sensitive)") {
+    val stopWords = Array("milk", "cookie")
+    val remover = new StopWordsRemover()
+      .setInputCol("raw")
+      .setOutputCol("filtered")
+      .setStopWords(stopWords)
+      .setCaseSensitive(true)
+      .setLocale("tr")  // Turkish alphabet: has no Q, W, X but has dotted and 
dotless 'I's.
+    val dataSet = Seq(
+      // scalastyle:off
+      (Seq("mİlk", "and", "nuts"), Seq("mİlk", "and", "nuts")),
+      // scalastyle:on
+      (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")),
+      (Seq(null), Seq(null)),
+      (Seq(), Seq())
+    ).toDF("raw", "expected")
+
+    testStopWordsRemover(remover, dataSet)
+  }
+
+  test("StopWordsRemover with invalid locale") {
+    intercept[IllegalArgumentException] {
+      val stopWords = Array("test", "a", "an", "the")
+      new StopWordsRemover()
+        .setInputCol("raw")
+        .setOutputCol("filtered")
+        .setStopWords(stopWords)
+        .setLocale("rt")  // invalid locale
+    }
+  }
+
   test("StopWordsRemover case sensitive") {
     val remover = new StopWordsRemover()
       .setInputCol("raw")

http://git-wip-us.apache.org/repos/asf/spark/blob/5d6a53d9/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index cdda30c..14800d4 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2582,25 +2582,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadabl
                       typeConverter=TypeConverters.toListString)
     caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a 
case sensitive " +
                           "comparison over the stop words", 
typeConverter=TypeConverters.toBoolean)
+    locale = Param(Params._dummy(), "locale", "locale of the input. ignored 
when case sensitive " +
+                   "is true", typeConverter=TypeConverters.toString)
 
     @keyword_only
-    def __init__(self, inputCol=None, outputCol=None, stopWords=None, 
caseSensitive=False):
+    def __init__(self, inputCol=None, outputCol=None, stopWords=None, 
caseSensitive=False,
+                 locale=None):
         """
-        __init__(self, inputCol=None, outputCol=None, stopWords=None, 
caseSensitive=false)
+        __init__(self, inputCol=None, outputCol=None, stopWords=None, 
caseSensitive=false, \
+        locale=None)
         """
         super(StopWordsRemover, self).__init__()
         self._java_obj = 
self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
                                             self.uid)
         
self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
-                         caseSensitive=False)
+                         caseSensitive=False, 
locale=self._java_obj.getLocale())
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.6.0")
-    def setParams(self, inputCol=None, outputCol=None, stopWords=None, 
caseSensitive=False):
+    def setParams(self, inputCol=None, outputCol=None, stopWords=None, 
caseSensitive=False,
+                  locale=None):
         """
-        setParams(self, inputCol=None, outputCol=None, stopWords=None, 
caseSensitive=false)
+        setParams(self, inputCol=None, outputCol=None, stopWords=None, 
caseSensitive=false, \
+        locale=None)
         Sets params for this StopWordRemover.
         """
         kwargs = self._input_kwargs
@@ -2634,6 +2640,20 @@ class StopWordsRemover(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadabl
         """
         return self.getOrDefault(self.caseSensitive)
 
+    @since("2.4.0")
+    def setLocale(self, value):
+        """
+        Sets the value of :py:attr:`locale`.
+        """
+        return self._set(locale=value)
+
+    @since("2.4.0")
+    def getLocale(self):
+        """
+        Gets the value of :py:attr:`locale`.
+        """
+        return self.getOrDefault(self.locale)
+
     @staticmethod
     @since("2.0.0")
     def loadDefaultStopWords(language):

http://git-wip-us.apache.org/repos/asf/spark/blob/5d6a53d9/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 0dde0db..ebd36cb 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -681,6 +681,13 @@ class FeatureTests(SparkSessionTestCase):
         self.assertEqual(stopWordRemover.getStopWords(), stopwords)
         transformedDF = stopWordRemover.transform(dataset)
         self.assertEqual(transformedDF.head().output, [])
+        # with locale
+        stopwords = ["BELKİ"]
+        dataset = self.spark.createDataFrame([Row(input=["belki"])])
+        stopWordRemover.setStopWords(stopwords).setLocale("tr")
+        self.assertEqual(stopWordRemover.getStopWords(), stopwords)
+        transformedDF = stopWordRemover.transform(dataset)
+        self.assertEqual(transformedDF.head().output, [])
 
     def test_count_vectorizer_with_binary(self):
         dataset = self.spark.createDataFrame([


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to