This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 364ff27fc0d7 [SPARK-55867][PS] Fix StringMethods with pandas 3
364ff27fc0d7 is described below

commit 364ff27fc0d7c76ef474d17500070e4bc74360ae
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Mar 9 13:02:38 2026 +0900

    [SPARK-55867][PS] Fix StringMethods with pandas 3
    
    ### What changes were proposed in this pull request?
    
    Fixes `StringMethods` with pandas 3.
    
    ### Why are the changes needed?
    
    There are some methods failing with pandas 3.
    
    - `findall`
    - `match`
    - `rsplit`
    - `split`
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it will behave more like pandas 3.
    
    ### How was this patch tested?
    
    Updated the related tests and the other existing tests should pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54664 from ueshin/issues/SPARK-55867/string_methods.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/pandas/strings.py                   | 55 +++++++++++++++++++---
 .../pandas/tests/series/test_string_ops_adv.py     | 37 +++++++++++----
 python/pyspark/pandas/typedef/typehints.py         | 17 ++++---
 python/pyspark/testing/pandasutils.py              | 12 +++++
 4 files changed, 99 insertions(+), 22 deletions(-)

diff --git a/python/pyspark/pandas/strings.py b/python/pyspark/pandas/strings.py
index 1fb6a0c505a5..b5b7d25b1204 100644
--- a/python/pyspark/pandas/strings.py
+++ b/python/pyspark/pandas/strings.py
@@ -33,8 +33,12 @@ from typing import (
 
 import numpy as np
 import pandas as pd
+from pandas.api.extensions import no_default
 
+from pyspark._globals import _NoValue, _NoValueType
+from pyspark.loose_version import LooseVersion
 from pyspark.pandas.utils import ansi_mode_context, is_ansi_mode_enabled
+from pyspark.pandas.typedef.typehints import is_str_dtype, SeriesType
 from pyspark.sql.types import StringType, BinaryType, ArrayType, LongType, 
MapType
 from pyspark.sql import functions as F
 from pyspark.sql.functions import pandas_udf
@@ -1170,13 +1174,18 @@ class StringMethods:
         2    [b, b]
         dtype: object
         """
+        str_dtype = is_str_dtype(self._data.dtype)
 
         # type hint does not support to specify array type yet.
         @pandas_udf(  # type: ignore[call-overload]
             returnType=ArrayType(StringType(), containsNull=True)
         )
         def pudf(s: pd.Series) -> pd.Series:
-            return s.str.findall(pat, flags)
+            ret = s.str.findall(pat, flags)
+            if str_dtype:
+                # ArrayType does not support NaN, so replace with None
+                ret = ret.replace(np.nan, None)
+            return ret
 
         return self._data._with_new_scol(scol=pudf(self._data.spark.column))
 
@@ -1266,11 +1275,12 @@ class StringMethods:
         1                   None
         dtype: object
         """
+        ret_type: SeriesType = SeriesType(self._data.dtype, StringType())
 
-        def pandas_join(s) -> ps.Series[str]:  # type: ignore[no-untyped-def]
+        def pandas_join(s):  # type: ignore[no-untyped-def]
             return s.str.join(sep)
 
-        return self._data.pandas_on_spark.transform_batch(pandas_join)
+        return self._data.pandas_on_spark._transform_batch(pandas_join, 
ret_type)
 
     def len(self) -> "ps.Series":
         """
@@ -1342,7 +1352,13 @@ class StringMethods:
 
         return self._data.pandas_on_spark.transform_batch(pandas_ljust)
 
-    def match(self, pat: str, case: bool = True, flags: int = 0, na: Any = 
np.nan) -> "ps.Series":
+    def match(
+        self,
+        pat: str,
+        case: Union[bool, _NoValueType] = _NoValue,
+        flags: Union[int, _NoValueType] = _NoValue,
+        na: Any = _NoValue,
+    ) -> "ps.Series":
         """
         Determine if each string matches a regular expression.
 
@@ -1403,6 +1419,21 @@ class StringMethods:
         dtype: object
         """
 
+        if LooseVersion(pd.__version__) < "3.0.0":
+            if case is _NoValue:
+                case = True
+            if flags is _NoValue:
+                flags = 0
+            if na is _NoValue:
+                na = np.nan
+        else:
+            if case is _NoValue:
+                case = no_default  # type: ignore[assignment]
+            if flags is _NoValue:
+                flags = no_default  # type: ignore[assignment]
+            if na is _NoValue:
+                na = no_default
+
         def pandas_match(s) -> ps.Series[bool]:  # type: ignore[no-untyped-def]
             return s.str.match(pat, case, flags, na)
 
@@ -2035,9 +2066,15 @@ class StringMethods:
         # type hint does not support to specify array type yet.
         return_type = ArrayType(StringType(), containsNull=True)
 
+        str_dtype = is_str_dtype(self._data.dtype)
+
         @pandas_udf(returnType=return_type)  # type: ignore[call-overload]
         def pudf(s: pd.Series) -> pd.Series:
-            return s.str.split(pat, n=n)
+            ret = s.str.split(pat, n=n)
+            if str_dtype:
+                # ArrayType does not support NaN, so replace with None
+                ret = ret.replace(np.nan, None)
+            return ret
 
         psser = self._data._with_new_scol(
             
pudf(self._data.spark.column).alias(self._data._internal.data_spark_column_names[0]),
@@ -2189,9 +2226,15 @@ class StringMethods:
         # type hint does not support to specify array type yet.
         return_type = ArrayType(StringType(), containsNull=True)
 
+        str_dtype = is_str_dtype(self._data.dtype)
+
         @pandas_udf(returnType=return_type)  # type: ignore[call-overload]
         def pudf(s: pd.Series) -> pd.Series:
-            return s.str.rsplit(pat, n=n)
+            ret = s.str.rsplit(pat, n=n)
+            if str_dtype:
+                # ArrayType does not support NaN, so replace with None
+                ret = ret.replace(np.nan, None)
+            return ret
 
         psser = self._data._with_new_scol(
             
pudf(self._data.spark.column).alias(self._data._internal.data_spark_column_names[0]),
diff --git a/python/pyspark/pandas/tests/series/test_string_ops_adv.py 
b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
index ad4775d02940..922914ee3dfe 100644
--- a/python/pyspark/pandas/tests/series/test_string_ops_adv.py
+++ b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
@@ -19,6 +19,7 @@ import numpy as np
 import re
 
 from pyspark import pandas as ps
+from pyspark.loose_version import LooseVersion
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
 
@@ -44,8 +45,10 @@ class SeriesStringOpsAdvMixin:
     def check_func(self, func, almost=False):
         self.check_func_on_series(func, self.pser, almost=almost)
 
-    def check_func_on_series(self, func, pser, almost=False):
-        self.assert_eq(func(ps.from_pandas(pser)), func(pser), almost=almost)
+    def check_func_on_series(self, func, pser, almost=False, 
ignore_null=False):
+        self.assert_eq(
+            func(ps.from_pandas(pser)), func(pser), almost=almost, 
ignore_null=ignore_null
+        )
 
     def test_string_decode(self):
         psser = ps.from_pandas(self.pser)
@@ -73,10 +76,16 @@ class SeriesStringOpsAdvMixin:
         self.check_func(lambda x: x.str.find("a", start=0, end=1))
 
     def test_string_findall(self):
-        self.check_func_on_series(lambda x: x.str.findall("es|as").apply(str), 
self.pser[:-1])
-        self.check_func_on_series(
-            lambda x: x.str.findall("wh.*", flags=re.IGNORECASE).apply(str), 
self.pser[:-1]
-        )
+        if LooseVersion(pd.__version__) < "3.0.0":
+            self.check_func_on_series(lambda x: 
x.str.findall("es|as").apply(str), self.pser[:-1])
+            self.check_func_on_series(
+                lambda x: x.str.findall("wh.*", 
flags=re.IGNORECASE).apply(str), self.pser[:-1]
+            )
+        else:
+            self.check_func_on_series(lambda x: x.str.findall("es|as"), 
self.pser, ignore_null=True)
+            self.check_func_on_series(
+                lambda x: x.str.findall("wh.*", flags=re.IGNORECASE), 
self.pser, ignore_null=True
+            )
 
     def test_string_index(self):
         pser = pd.Series(["tea", "eat"])
@@ -173,8 +182,12 @@ class SeriesStringOpsAdvMixin:
         self.check_func(lambda x: x.str.slice_replace(start=1, stop=3, 
repl="X"))
 
     def test_string_split(self):
-        self.check_func_on_series(lambda x: repr(x.str.split()), 
self.pser[:-1])
-        self.check_func_on_series(lambda x: repr(x.str.split(r"p*")), 
self.pser[:-1])
+        if LooseVersion(pd.__version__) < "3.0.0":
+            self.check_func_on_series(lambda x: repr(x.str.split()), 
self.pser[:-1])
+            self.check_func_on_series(lambda x: repr(x.str.split(r"p*")), 
self.pser[:-1])
+        else:
+            self.check_func_on_series(lambda x: x.str.split(), self.pser, 
ignore_null=True)
+            self.check_func_on_series(lambda x: x.str.split(r"p*"), self.pser, 
ignore_null=True)
         pser = pd.Series(["This is a sentence.", "This-is-a-long-word."])
         self.check_func_on_series(lambda x: repr(x.str.split(n=2)), pser)
         self.check_func_on_series(lambda x: repr(x.str.split(pat="-", n=2)), 
pser)
@@ -185,8 +198,12 @@ class SeriesStringOpsAdvMixin:
         self.check_func_on_series(lambda x: repr(x.str.split("-", n=1, 
expand=True)), pser)
 
     def test_string_rsplit(self):
-        self.check_func_on_series(lambda x: repr(x.str.rsplit()), 
self.pser[:-1])
-        self.check_func_on_series(lambda x: repr(x.str.rsplit(r"p*")), 
self.pser[:-1])
+        if LooseVersion(pd.__version__) < "3.0.0":
+            self.check_func_on_series(lambda x: repr(x.str.rsplit()), 
self.pser[:-1])
+            self.check_func_on_series(lambda x: repr(x.str.rsplit(r"p*")), 
self.pser[:-1])
+        else:
+            self.check_func_on_series(lambda x: x.str.rsplit(), self.pser, 
ignore_null=True)
+            self.check_func_on_series(lambda x: x.str.rsplit(r"p*"), 
self.pser, ignore_null=True)
         pser = pd.Series(["This is a sentence.", "This-is-a-long-word."])
         self.check_func_on_series(lambda x: repr(x.str.rsplit(n=2)), pser)
         self.check_func_on_series(lambda x: repr(x.str.rsplit(pat="-", n=2)), 
pser)
diff --git a/python/pyspark/pandas/typedef/typehints.py 
b/python/pyspark/pandas/typedef/typehints.py
index ce5f0971693f..8af93f9dcc13 100644
--- a/python/pyspark/pandas/typedef/typehints.py
+++ b/python/pyspark/pandas/typedef/typehints.py
@@ -325,14 +325,19 @@ def spark_type_to_pandas_dtype(
         )
 
 
-def handle_dtype_as_extension_dtype(tpe: Dtype) -> bool:
+def is_str_dtype(tpe: Dtype) -> bool:
     if LooseVersion(pd.__version__) < "3.0.0":
-        return isinstance(tpe, extension_dtypes)
-
+        return False
     if extension_object_dtypes_available:
-        if isinstance(tpe, StringDtype):
-            return tpe.na_value is pd.NA
-    return isinstance(tpe, extension_dtypes)
+        return isinstance(tpe, StringDtype) and tpe.na_value is np.nan
+    return False
+
+
+def handle_dtype_as_extension_dtype(tpe: Dtype) -> bool:
+    if is_str_dtype(tpe):
+        return False
+    else:
+        return isinstance(tpe, extension_dtypes)
 
 
 def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, 
types.DataType]:
diff --git a/python/pyspark/testing/pandasutils.py 
b/python/pyspark/testing/pandasutils.py
index 3c529e524a2d..36af5c7c4d18 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -40,6 +40,14 @@ try:
 except ImportError:
     pass
 
+try:
+    from pyspark.sql.pandas.utils import require_minimum_numpy_version
+
+    require_minimum_numpy_version()
+    import numpy as np
+except ImportError:
+    pass
+
 from pyspark.loose_version import LooseVersion
 import pyspark.pandas as ps
 from pyspark.pandas.frame import DataFrame
@@ -134,6 +142,10 @@ def _assert_pandas_almost_equal(
     """
 
     def compare_vals_approx(val1, val2):
+        if isinstance(val1, np.ndarray):
+            return compare_vals_approx(list(val1), val2)
+        if isinstance(val2, np.ndarray):
+            return compare_vals_approx(val1, list(val2))
         # compare vals for approximate equality
         if isinstance(val1, (float, decimal.Decimal)) or isinstance(val2, 
(float, decimal.Decimal)):
             if abs(float(val1) - float(val2)) > (atol + rtol * 
abs(float(val2))):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to