This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 364ff27fc0d7 [SPARK-55867][PS] Fix StringMethods with pandas 3
364ff27fc0d7 is described below
commit 364ff27fc0d7c76ef474d17500070e4bc74360ae
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Mar 9 13:02:38 2026 +0900
[SPARK-55867][PS] Fix StringMethods with pandas 3
### What changes were proposed in this pull request?
Fixes `StringMethods` with pandas 3.
### Why are the changes needed?
There are some methods failing with pandas 3.
- `findall`
- `match`
- `rsplit`
- `split`
### Does this PR introduce _any_ user-facing change?
Yes, it will behave more like pandas 3.
### How was this patch tested?
Updated the related tests and the other existing tests should pass.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54664 from ueshin/issues/SPARK-55867/string_methods.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/strings.py | 55 +++++++++++++++++++---
.../pandas/tests/series/test_string_ops_adv.py | 37 +++++++++++----
python/pyspark/pandas/typedef/typehints.py | 17 ++++---
python/pyspark/testing/pandasutils.py | 12 +++++
4 files changed, 99 insertions(+), 22 deletions(-)
diff --git a/python/pyspark/pandas/strings.py b/python/pyspark/pandas/strings.py
index 1fb6a0c505a5..b5b7d25b1204 100644
--- a/python/pyspark/pandas/strings.py
+++ b/python/pyspark/pandas/strings.py
@@ -33,8 +33,12 @@ from typing import (
import numpy as np
import pandas as pd
+from pandas.api.extensions import no_default
+from pyspark._globals import _NoValue, _NoValueType
+from pyspark.loose_version import LooseVersion
from pyspark.pandas.utils import ansi_mode_context, is_ansi_mode_enabled
+from pyspark.pandas.typedef.typehints import is_str_dtype, SeriesType
from pyspark.sql.types import StringType, BinaryType, ArrayType, LongType,
MapType
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf
@@ -1170,13 +1174,18 @@ class StringMethods:
2 [b, b]
dtype: object
"""
+ str_dtype = is_str_dtype(self._data.dtype)
# type hint does not support to specify array type yet.
@pandas_udf( # type: ignore[call-overload]
returnType=ArrayType(StringType(), containsNull=True)
)
def pudf(s: pd.Series) -> pd.Series:
- return s.str.findall(pat, flags)
+ ret = s.str.findall(pat, flags)
+ if str_dtype:
+ # ArrayType does not support NaN, so replace with None
+ ret = ret.replace(np.nan, None)
+ return ret
return self._data._with_new_scol(scol=pudf(self._data.spark.column))
@@ -1266,11 +1275,12 @@ class StringMethods:
1 None
dtype: object
"""
+ ret_type: SeriesType = SeriesType(self._data.dtype, StringType())
- def pandas_join(s) -> ps.Series[str]: # type: ignore[no-untyped-def]
+ def pandas_join(s): # type: ignore[no-untyped-def]
return s.str.join(sep)
- return self._data.pandas_on_spark.transform_batch(pandas_join)
+ return self._data.pandas_on_spark._transform_batch(pandas_join,
ret_type)
def len(self) -> "ps.Series":
"""
@@ -1342,7 +1352,13 @@ class StringMethods:
return self._data.pandas_on_spark.transform_batch(pandas_ljust)
- def match(self, pat: str, case: bool = True, flags: int = 0, na: Any =
np.nan) -> "ps.Series":
+ def match(
+ self,
+ pat: str,
+ case: Union[bool, _NoValueType] = _NoValue,
+ flags: Union[int, _NoValueType] = _NoValue,
+ na: Any = _NoValue,
+ ) -> "ps.Series":
"""
Determine if each string matches a regular expression.
@@ -1403,6 +1419,21 @@ class StringMethods:
dtype: object
"""
+ if LooseVersion(pd.__version__) < "3.0.0":
+ if case is _NoValue:
+ case = True
+ if flags is _NoValue:
+ flags = 0
+ if na is _NoValue:
+ na = np.nan
+ else:
+ if case is _NoValue:
+ case = no_default # type: ignore[assignment]
+ if flags is _NoValue:
+ flags = no_default # type: ignore[assignment]
+ if na is _NoValue:
+ na = no_default
+
def pandas_match(s) -> ps.Series[bool]: # type: ignore[no-untyped-def]
return s.str.match(pat, case, flags, na)
@@ -2035,9 +2066,15 @@ class StringMethods:
# type hint does not support to specify array type yet.
return_type = ArrayType(StringType(), containsNull=True)
+ str_dtype = is_str_dtype(self._data.dtype)
+
@pandas_udf(returnType=return_type) # type: ignore[call-overload]
def pudf(s: pd.Series) -> pd.Series:
- return s.str.split(pat, n=n)
+ ret = s.str.split(pat, n=n)
+ if str_dtype:
+ # ArrayType does not support NaN, so replace with None
+ ret = ret.replace(np.nan, None)
+ return ret
psser = self._data._with_new_scol(
pudf(self._data.spark.column).alias(self._data._internal.data_spark_column_names[0]),
@@ -2189,9 +2226,15 @@ class StringMethods:
# type hint does not support to specify array type yet.
return_type = ArrayType(StringType(), containsNull=True)
+ str_dtype = is_str_dtype(self._data.dtype)
+
@pandas_udf(returnType=return_type) # type: ignore[call-overload]
def pudf(s: pd.Series) -> pd.Series:
- return s.str.rsplit(pat, n=n)
+ ret = s.str.rsplit(pat, n=n)
+ if str_dtype:
+ # ArrayType does not support NaN, so replace with None
+ ret = ret.replace(np.nan, None)
+ return ret
psser = self._data._with_new_scol(
pudf(self._data.spark.column).alias(self._data._internal.data_spark_column_names[0]),
diff --git a/python/pyspark/pandas/tests/series/test_string_ops_adv.py
b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
index ad4775d02940..922914ee3dfe 100644
--- a/python/pyspark/pandas/tests/series/test_string_ops_adv.py
+++ b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
@@ -19,6 +19,7 @@ import numpy as np
import re
from pyspark import pandas as ps
+from pyspark.loose_version import LooseVersion
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
@@ -44,8 +45,10 @@ class SeriesStringOpsAdvMixin:
def check_func(self, func, almost=False):
self.check_func_on_series(func, self.pser, almost=almost)
- def check_func_on_series(self, func, pser, almost=False):
- self.assert_eq(func(ps.from_pandas(pser)), func(pser), almost=almost)
+ def check_func_on_series(self, func, pser, almost=False,
ignore_null=False):
+ self.assert_eq(
+ func(ps.from_pandas(pser)), func(pser), almost=almost,
ignore_null=ignore_null
+ )
def test_string_decode(self):
psser = ps.from_pandas(self.pser)
@@ -73,10 +76,16 @@ class SeriesStringOpsAdvMixin:
self.check_func(lambda x: x.str.find("a", start=0, end=1))
def test_string_findall(self):
- self.check_func_on_series(lambda x: x.str.findall("es|as").apply(str),
self.pser[:-1])
- self.check_func_on_series(
- lambda x: x.str.findall("wh.*", flags=re.IGNORECASE).apply(str),
self.pser[:-1]
- )
+ if LooseVersion(pd.__version__) < "3.0.0":
+ self.check_func_on_series(lambda x:
x.str.findall("es|as").apply(str), self.pser[:-1])
+ self.check_func_on_series(
+ lambda x: x.str.findall("wh.*",
flags=re.IGNORECASE).apply(str), self.pser[:-1]
+ )
+ else:
+ self.check_func_on_series(lambda x: x.str.findall("es|as"),
self.pser, ignore_null=True)
+ self.check_func_on_series(
+ lambda x: x.str.findall("wh.*", flags=re.IGNORECASE),
self.pser, ignore_null=True
+ )
def test_string_index(self):
pser = pd.Series(["tea", "eat"])
@@ -173,8 +182,12 @@ class SeriesStringOpsAdvMixin:
self.check_func(lambda x: x.str.slice_replace(start=1, stop=3,
repl="X"))
def test_string_split(self):
- self.check_func_on_series(lambda x: repr(x.str.split()),
self.pser[:-1])
- self.check_func_on_series(lambda x: repr(x.str.split(r"p*")),
self.pser[:-1])
+ if LooseVersion(pd.__version__) < "3.0.0":
+ self.check_func_on_series(lambda x: repr(x.str.split()),
self.pser[:-1])
+ self.check_func_on_series(lambda x: repr(x.str.split(r"p*")),
self.pser[:-1])
+ else:
+ self.check_func_on_series(lambda x: x.str.split(), self.pser,
ignore_null=True)
+ self.check_func_on_series(lambda x: x.str.split(r"p*"), self.pser,
ignore_null=True)
pser = pd.Series(["This is a sentence.", "This-is-a-long-word."])
self.check_func_on_series(lambda x: repr(x.str.split(n=2)), pser)
self.check_func_on_series(lambda x: repr(x.str.split(pat="-", n=2)),
pser)
@@ -185,8 +198,12 @@ class SeriesStringOpsAdvMixin:
self.check_func_on_series(lambda x: repr(x.str.split("-", n=1,
expand=True)), pser)
def test_string_rsplit(self):
- self.check_func_on_series(lambda x: repr(x.str.rsplit()),
self.pser[:-1])
- self.check_func_on_series(lambda x: repr(x.str.rsplit(r"p*")),
self.pser[:-1])
+ if LooseVersion(pd.__version__) < "3.0.0":
+ self.check_func_on_series(lambda x: repr(x.str.rsplit()),
self.pser[:-1])
+ self.check_func_on_series(lambda x: repr(x.str.rsplit(r"p*")),
self.pser[:-1])
+ else:
+ self.check_func_on_series(lambda x: x.str.rsplit(), self.pser,
ignore_null=True)
+ self.check_func_on_series(lambda x: x.str.rsplit(r"p*"),
self.pser, ignore_null=True)
pser = pd.Series(["This is a sentence.", "This-is-a-long-word."])
self.check_func_on_series(lambda x: repr(x.str.rsplit(n=2)), pser)
self.check_func_on_series(lambda x: repr(x.str.rsplit(pat="-", n=2)),
pser)
diff --git a/python/pyspark/pandas/typedef/typehints.py
b/python/pyspark/pandas/typedef/typehints.py
index ce5f0971693f..8af93f9dcc13 100644
--- a/python/pyspark/pandas/typedef/typehints.py
+++ b/python/pyspark/pandas/typedef/typehints.py
@@ -325,14 +325,19 @@ def spark_type_to_pandas_dtype(
)
-def handle_dtype_as_extension_dtype(tpe: Dtype) -> bool:
+def is_str_dtype(tpe: Dtype) -> bool:
if LooseVersion(pd.__version__) < "3.0.0":
- return isinstance(tpe, extension_dtypes)
-
+ return False
if extension_object_dtypes_available:
- if isinstance(tpe, StringDtype):
- return tpe.na_value is pd.NA
- return isinstance(tpe, extension_dtypes)
+ return isinstance(tpe, StringDtype) and tpe.na_value is np.nan
+ return False
+
+
+def handle_dtype_as_extension_dtype(tpe: Dtype) -> bool:
+ if is_str_dtype(tpe):
+ return False
+ else:
+ return isinstance(tpe, extension_dtypes)
def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype,
types.DataType]:
diff --git a/python/pyspark/testing/pandasutils.py
b/python/pyspark/testing/pandasutils.py
index 3c529e524a2d..36af5c7c4d18 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -40,6 +40,14 @@ try:
except ImportError:
pass
+try:
+ from pyspark.sql.pandas.utils import require_minimum_numpy_version
+
+ require_minimum_numpy_version()
+ import numpy as np
+except ImportError:
+ pass
+
from pyspark.loose_version import LooseVersion
import pyspark.pandas as ps
from pyspark.pandas.frame import DataFrame
@@ -134,6 +142,10 @@ def _assert_pandas_almost_equal(
"""
def compare_vals_approx(val1, val2):
+ if isinstance(val1, np.ndarray):
+ return compare_vals_approx(list(val1), val2)
+ if isinstance(val2, np.ndarray):
+ return compare_vals_approx(val1, list(val2))
# compare vals for approximate equality
if isinstance(val1, (float, decimal.Decimal)) or isinstance(val2,
(float, decimal.Decimal)):
if abs(float(val1) - float(val2)) > (atol + rtol *
abs(float(val2))):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]