This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dc186c5e6b6 [SPARK-43773][CONNECT][PYTHON][, THRESHOLD] Implement 
'levenshtein(str1, str2)' functions in python client
dc186c5e6b6 is described below

commit dc186c5e6b6bdb63345081ee9f70b8c102792cdd
Author: panbingkun <pbk1...@gmail.com>
AuthorDate: Sun May 28 08:38:32 2023 +0800

    [SPARK-43773][CONNECT][PYTHON][, THRESHOLD] Implement 'levenshtein(str1, 
str2)' functions in python client
    
    ### What changes were proposed in this pull request?
    The pr aims to implement 'levenshtein(str1, str2[, threshold])' functions 
in python client
    
    ### Why are the changes needed?
    After Add a max distance argument to the levenshtein() function We have 
already implemented it on the scala side, so we need to align it on `pyspark`.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    - Manual testing
    python/run-tests --testnames 'python.pyspark.sql.tests.test_functions 
FunctionsTests.test_levenshtein_function'
    - Pass GA
    
    Closes #41296 from panbingkun/SPARK-43773.
    
    Lead-authored-by: panbingkun <pbk1...@gmail.com>
    Co-authored-by: panbingkun <84731...@qq.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/functions.py               |  9 +++++++--
 python/pyspark/sql/functions.py                       | 19 +++++++++++++++++--
 .../sql/tests/connect/test_connect_function.py        |  5 +++++
 python/pyspark/sql/tests/test_functions.py            |  7 +++++++
 4 files changed, 36 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index b7d7bc937cf..d3a05d6a1c6 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -1878,8 +1878,13 @@ def substring_index(str: "ColumnOrName", delim: str, 
count: int) -> Column:
 substring_index.__doc__ = pysparkfuncs.substring_index.__doc__
 
 
-def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column:
-    return _invoke_function_over_columns("levenshtein", left, right)
+def levenshtein(
+    left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = 
None
+) -> Column:
+    if threshold is None:
+        return _invoke_function_over_columns("levenshtein", left, right)
+    else:
+        return _invoke_function("levenshtein", _to_col(left), _to_col(right), 
lit(threshold))
 
 
 levenshtein.__doc__ = pysparkfuncs.levenshtein.__doc__
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e9b71f7d617..fe35f12c402 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -6594,7 +6594,9 @@ def substring_index(str: "ColumnOrName", delim: str, 
count: int) -> Column:
 
 
 @try_remote_functions
-def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column:
+def levenshtein(
+    left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = 
None
+) -> Column:
     """Computes the Levenshtein distance of the two given strings.
 
     .. versionadded:: 1.5.0
@@ -6608,6 +6610,12 @@ def levenshtein(left: "ColumnOrName", right: 
"ColumnOrName") -> Column:
         first column value.
     right : :class:`~pyspark.sql.Column` or str
         second column value.
+    threshold : int, optional
+        if set when the levenshtein distance of the two given strings
+        less than or equal to a given threshold then return result distance, 
or -1
+
+        .. versionchanged: 3.5.0
+            Added ``threshold`` argument.
 
     Returns
     -------
@@ -6619,8 +6627,15 @@ def levenshtein(left: "ColumnOrName", right: 
"ColumnOrName") -> Column:
     >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
     >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
     [Row(d=3)]
+    >>> df0.select(levenshtein('l', 'r', 2).alias('d')).collect()
+    [Row(d=-1)]
     """
-    return _invoke_function_over_columns("levenshtein", left, right)
+    if threshold is None:
+        return _invoke_function_over_columns("levenshtein", left, right)
+    else:
+        return _invoke_function(
+            "levenshtein", _to_java_column(left), _to_java_column(right), 
threshold
+        )
 
 
 @try_remote_functions
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index e274635d3c6..3e3b4dd5b16 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -1924,6 +1924,11 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
             cdf.select(CF.levenshtein(cdf.b, cdf.c)).toPandas(),
             sdf.select(SF.levenshtein(sdf.b, sdf.c)).toPandas(),
         )
+        self.assert_eq(
+            cdf.select(CF.levenshtein(cdf.b, cdf.c, 1)).toPandas(),
+            sdf.select(SF.levenshtein(sdf.b, sdf.c, 1)).toPandas(),
+        )
+
         self.assert_eq(
             cdf.select(CF.locate("e", cdf.b)).toPandas(),
             sdf.select(SF.locate("e", sdf.b)).toPandas(),
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 9067de34633..72c6c365b80 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -377,6 +377,13 @@ class FunctionsTestsMixin:
         actual = df.select(F.array_contains(df.data, "1").alias("b")).collect()
         self.assertEqual([Row(b=True), Row(b=False)], actual)
 
+    def test_levenshtein_function(self):
+        df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"])
+        actual_without_threshold = df.select(F.levenshtein(df.l, 
df.r).alias("b")).collect()
+        self.assertEqual([Row(b=3)], actual_without_threshold)
+        actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 
2).alias("b")).collect()
+        self.assertEqual([Row(b=-1)], actual_with_threshold)
+
     def test_between_function(self):
         df = self.spark.createDataFrame(
             [Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to