This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 20df062d85e8 [SPARK-46327][PS][CONNECT][TESTS] Reorganize
`SeriesStringTests`
20df062d85e8 is described below
commit 20df062d85e80422a55afae80ddbf2060f26516c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Dec 8 21:08:50 2023 +0900
[SPARK-46327][PS][CONNECT][TESTS] Reorganize `SeriesStringTests`
### What changes were proposed in this pull request?
Reorganize `SeriesStringTests`
### Why are the changes needed?
test code clean up
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44257 from zhengruifeng/ps_test_ser_str.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
dev/sparktestsupport/modules.py | 6 +-
.../test_parity_string_ops_adv.py} | 8 +-
.../test_parity_string_ops_basic.py} | 8 +-
.../test_string_ops_adv.py} | 125 +-------------
.../pandas/tests/series/test_string_ops_basic.py | 184 +++++++++++++++++++++
5 files changed, 199 insertions(+), 132 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 834b3bd235aa..e67cfce0f5c0 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -752,7 +752,8 @@ pyspark_pandas = Module(
"pyspark.pandas.tests.test_scalars",
"pyspark.pandas.tests.test_series_conversion",
"pyspark.pandas.tests.test_series_datetime",
- "pyspark.pandas.tests.test_series_string",
+ "pyspark.pandas.tests.series.test_string_ops_adv",
+ "pyspark.pandas.tests.series.test_string_ops_basic",
"pyspark.pandas.tests.test_spark_functions",
"pyspark.pandas.tests.test_sql",
"pyspark.pandas.tests.test_typedef",
@@ -1005,7 +1006,8 @@ pyspark_pandas_connect_part0 = Module(
"pyspark.pandas.tests.connect.test_parity_scalars",
"pyspark.pandas.tests.connect.test_parity_series_conversion",
"pyspark.pandas.tests.connect.test_parity_series_datetime",
- "pyspark.pandas.tests.connect.test_parity_series_string",
+ "pyspark.pandas.tests.connect.series.test_parity_string_ops_adv",
+ "pyspark.pandas.tests.connect.series.test_parity_string_ops_basic",
"pyspark.pandas.tests.connect.test_parity_spark_functions",
"pyspark.pandas.tests.connect.test_parity_sql",
"pyspark.pandas.tests.connect.test_parity_typedef",
diff --git a/python/pyspark/pandas/tests/connect/test_parity_series_string.py
b/python/pyspark/pandas/tests/connect/series/test_parity_string_ops_adv.py
similarity index 80%
copy from python/pyspark/pandas/tests/connect/test_parity_series_string.py
copy to python/pyspark/pandas/tests/connect/series/test_parity_string_ops_adv.py
index d7c0335f15c7..1213ae073cf5 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_series_string.py
+++ b/python/pyspark/pandas/tests/connect/series/test_parity_string_ops_adv.py
@@ -16,19 +16,19 @@
#
import unittest
-from pyspark.pandas.tests.test_series_string import SeriesStringTestsMixin
+from pyspark.pandas.tests.series.test_string_ops_adv import
SeriesStringOpsAdvMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class SeriesStringParityTests(
- SeriesStringTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase
+class SeriesParityStringOpsAdvTests(
+ SeriesStringOpsAdvMixin, PandasOnSparkTestUtils, ReusedConnectTestCase
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_series_string import * #
noqa: F401
+ from pyspark.pandas.tests.connect.series.test_parity_string_ops_adv import
* # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_series_string.py
b/python/pyspark/pandas/tests/connect/series/test_parity_string_ops_basic.py
similarity index 81%
rename from python/pyspark/pandas/tests/connect/test_parity_series_string.py
rename to
python/pyspark/pandas/tests/connect/series/test_parity_string_ops_basic.py
index d7c0335f15c7..58f10fa505b0 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_series_string.py
+++ b/python/pyspark/pandas/tests/connect/series/test_parity_string_ops_basic.py
@@ -16,19 +16,19 @@
#
import unittest
-from pyspark.pandas.tests.test_series_string import SeriesStringTestsMixin
+from pyspark.pandas.tests.series.test_string_ops_basic import
SeriesStringOpsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class SeriesStringParityTests(
- SeriesStringTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase
+class SeriesStringOpsParityTests(
+ SeriesStringOpsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_series_string import * #
noqa: F401
+ from pyspark.pandas.tests.connect.series.test_parity_string_ops_basic
import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/test_series_string.py
b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
similarity index 67%
rename from python/pyspark/pandas/tests/test_series_string.py
rename to python/pyspark/pandas/tests/series/test_string_ops_adv.py
index b8d35764f1bc..d1954cdf6dad 100644
--- a/python/pyspark/pandas/tests/test_series_string.py
+++ b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
@@ -18,14 +18,13 @@
import pandas as pd
import numpy as np
import re
-import unittest
from pyspark import pandas as ps
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
-class SeriesStringTestsMixin:
+class SeriesStringOpsAdvMixin:
@property
def pser(self):
return pd.Series(
@@ -49,124 +48,6 @@ class SeriesStringTestsMixin:
def check_func_on_series(self, func, pser, almost=False):
self.assert_eq(func(ps.from_pandas(pser)), func(pser), almost=almost)
- def test_string_add_str_num(self):
- pdf = pd.DataFrame(dict(col1=["a"], col2=[1]))
- psdf = ps.from_pandas(pdf)
- with self.assertRaises(TypeError):
- psdf["col1"] + psdf["col2"]
-
- def test_string_add_assign(self):
- pdf = pd.DataFrame(dict(col1=["a", "b", "c"], col2=["1", "2", "3"]))
- psdf = ps.from_pandas(pdf)
- psdf["col1"] += psdf["col2"]
- pdf["col1"] += pdf["col2"]
- self.assert_eq(psdf["col1"], pdf["col1"])
-
- def test_string_add_str_str(self):
- pdf = pd.DataFrame(dict(col1=["a", "b", "c"], col2=["1", "2", "3"]))
- psdf = ps.from_pandas(pdf)
-
- # TODO: Fix the Series names
- self.assert_eq(psdf["col1"] + psdf["col2"], pdf["col1"] + pdf["col2"])
- self.assert_eq(psdf["col2"] + psdf["col1"], pdf["col2"] + pdf["col1"])
-
- def test_string_add_str_lit(self):
- pdf = pd.DataFrame(dict(col1=["a", "b", "c"]))
- psdf = ps.from_pandas(pdf)
- self.assert_eq(psdf["col1"] + "_lit", pdf["col1"] + "_lit")
- self.assert_eq("_lit" + psdf["col1"], "_lit" + pdf["col1"])
-
- def test_string_capitalize(self):
- self.check_func(lambda x: x.str.capitalize())
-
- def test_string_title(self):
- self.check_func(lambda x: x.str.title())
-
- def test_string_lower(self):
- self.check_func(lambda x: x.str.lower())
-
- def test_string_upper(self):
- self.check_func(lambda x: x.str.upper())
-
- def test_string_swapcase(self):
- self.check_func(lambda x: x.str.swapcase())
-
- def test_string_startswith(self):
- pattern = "car"
- self.check_func(lambda x: x.str.startswith(pattern))
- self.check_func(lambda x: x.str.startswith(pattern, na=False))
-
- def test_string_endswith(self):
- pattern = "s"
- self.check_func(lambda x: x.str.endswith(pattern))
- self.check_func(lambda x: x.str.endswith(pattern, na=False))
-
- def test_string_strip(self):
- self.check_func(lambda x: x.str.strip())
- self.check_func(lambda x: x.str.strip("es\t"))
- self.check_func(lambda x: x.str.strip("1"))
-
- def test_string_lstrip(self):
- self.check_func(lambda x: x.str.lstrip())
- self.check_func(lambda x: x.str.lstrip("\n1le"))
- self.check_func(lambda x: x.str.lstrip("s"))
-
- def test_string_rstrip(self):
- self.check_func(lambda x: x.str.rstrip())
- self.check_func(lambda x: x.str.rstrip("\t ec"))
- self.check_func(lambda x: x.str.rstrip("0"))
-
- def test_string_get(self):
- self.check_func(lambda x: x.str.get(6))
- self.check_func(lambda x: x.str.get(-1))
-
- def test_string_isalnum(self):
- self.check_func(lambda x: x.str.isalnum())
-
- def test_string_isalpha(self):
- self.check_func(lambda x: x.str.isalpha())
-
- def test_string_isdigit(self):
- self.check_func(lambda x: x.str.isdigit())
-
- def test_string_isspace(self):
- self.check_func(lambda x: x.str.isspace())
-
- def test_string_islower(self):
- self.check_func(lambda x: x.str.islower())
-
- def test_string_isupper(self):
- self.check_func(lambda x: x.str.isupper())
-
- def test_string_istitle(self):
- self.check_func(lambda x: x.str.istitle())
-
- def test_string_isnumeric(self):
- self.check_func(lambda x: x.str.isnumeric())
-
- def test_string_isdecimal(self):
- self.check_func(lambda x: x.str.isdecimal())
-
- def test_string_cat(self):
- psser = ps.from_pandas(self.pser)
- with self.assertRaises(NotImplementedError):
- psser.str.cat()
-
- def test_string_center(self):
- self.check_func(lambda x: x.str.center(0))
- self.check_func(lambda x: x.str.center(10))
- self.check_func(lambda x: x.str.center(10, "x"))
-
- def test_string_contains(self):
- self.check_func(lambda x: x.str.contains("le", regex=False))
- self.check_func(lambda x: x.str.contains("White", case=True,
regex=False))
- self.check_func(lambda x: x.str.contains("apples|carrots", regex=True))
- self.check_func(lambda x: x.str.contains("BANANAS",
flags=re.IGNORECASE, na=False))
-
- def test_string_count(self):
- self.check_func(lambda x: x.str.count("wh|Wh"))
- self.check_func(lambda x: x.str.count("WH", flags=re.IGNORECASE))
-
def test_string_decode(self):
psser = ps.from_pandas(self.pser)
with self.assertRaises(NotImplementedError):
@@ -332,13 +213,13 @@ class SeriesStringTestsMixin:
self.check_func(lambda x: x.str.get_dummies())
-class SeriesStringTests(SeriesStringTestsMixin, PandasOnSparkTestCase,
SQLTestUtils):
+class SeriesStringOpsAdvTests(SeriesStringOpsAdvMixin, PandasOnSparkTestCase,
SQLTestUtils):
pass
if __name__ == "__main__":
import unittest
- from pyspark.pandas.tests.test_series_string import * # noqa: F401
+ from pyspark.pandas.tests.series.test_string_ops_adv import * # noqa: F401
try:
import xmlrunner
diff --git a/python/pyspark/pandas/tests/series/test_string_ops_basic.py
b/python/pyspark/pandas/tests/series/test_string_ops_basic.py
new file mode 100644
index 000000000000..bdb3bf74b098
--- /dev/null
+++ b/python/pyspark/pandas/tests/series/test_string_ops_basic.py
@@ -0,0 +1,184 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pandas as pd
+import numpy as np
+import re
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+
+
+class SeriesStringOpsMixin:
+ @property
+ def pser(self):
+ return pd.Series(
+ [
+ "apples",
+ "Bananas",
+ "carrots",
+ "1",
+ "100",
+ "",
+ "\nleading-whitespace",
+ "trailing-Whitespace \t",
+ None,
+ np.NaN,
+ ]
+ )
+
+ def check_func(self, func, almost=False):
+ self.check_func_on_series(func, self.pser, almost=almost)
+
+ def check_func_on_series(self, func, pser, almost=False):
+ self.assert_eq(func(ps.from_pandas(pser)), func(pser), almost=almost)
+
+ def test_string_add_str_num(self):
+ pdf = pd.DataFrame(dict(col1=["a"], col2=[1]))
+ psdf = ps.from_pandas(pdf)
+ with self.assertRaises(TypeError):
+ psdf["col1"] + psdf["col2"]
+
+ def test_string_add_assign(self):
+ pdf = pd.DataFrame(dict(col1=["a", "b", "c"], col2=["1", "2", "3"]))
+ psdf = ps.from_pandas(pdf)
+ psdf["col1"] += psdf["col2"]
+ pdf["col1"] += pdf["col2"]
+ self.assert_eq(psdf["col1"], pdf["col1"])
+
+ def test_string_add_str_str(self):
+ pdf = pd.DataFrame(dict(col1=["a", "b", "c"], col2=["1", "2", "3"]))
+ psdf = ps.from_pandas(pdf)
+
+ # TODO: Fix the Series names
+ self.assert_eq(psdf["col1"] + psdf["col2"], pdf["col1"] + pdf["col2"])
+ self.assert_eq(psdf["col2"] + psdf["col1"], pdf["col2"] + pdf["col1"])
+
+ def test_string_add_str_lit(self):
+ pdf = pd.DataFrame(dict(col1=["a", "b", "c"]))
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(psdf["col1"] + "_lit", pdf["col1"] + "_lit")
+ self.assert_eq("_lit" + psdf["col1"], "_lit" + pdf["col1"])
+
+ def test_string_capitalize(self):
+ self.check_func(lambda x: x.str.capitalize())
+
+ def test_string_title(self):
+ self.check_func(lambda x: x.str.title())
+
+ def test_string_lower(self):
+ self.check_func(lambda x: x.str.lower())
+
+ def test_string_upper(self):
+ self.check_func(lambda x: x.str.upper())
+
+ def test_string_swapcase(self):
+ self.check_func(lambda x: x.str.swapcase())
+
+ def test_string_startswith(self):
+ pattern = "car"
+ self.check_func(lambda x: x.str.startswith(pattern))
+ self.check_func(lambda x: x.str.startswith(pattern, na=False))
+
+ def test_string_endswith(self):
+ pattern = "s"
+ self.check_func(lambda x: x.str.endswith(pattern))
+ self.check_func(lambda x: x.str.endswith(pattern, na=False))
+
+ def test_string_strip(self):
+ self.check_func(lambda x: x.str.strip())
+ self.check_func(lambda x: x.str.strip("es\t"))
+ self.check_func(lambda x: x.str.strip("1"))
+
+ def test_string_lstrip(self):
+ self.check_func(lambda x: x.str.lstrip())
+ self.check_func(lambda x: x.str.lstrip("\n1le"))
+ self.check_func(lambda x: x.str.lstrip("s"))
+
+ def test_string_rstrip(self):
+ self.check_func(lambda x: x.str.rstrip())
+ self.check_func(lambda x: x.str.rstrip("\t ec"))
+ self.check_func(lambda x: x.str.rstrip("0"))
+
+ def test_string_get(self):
+ self.check_func(lambda x: x.str.get(6))
+ self.check_func(lambda x: x.str.get(-1))
+
+ def test_string_isalnum(self):
+ self.check_func(lambda x: x.str.isalnum())
+
+ def test_string_isalpha(self):
+ self.check_func(lambda x: x.str.isalpha())
+
+ def test_string_isdigit(self):
+ self.check_func(lambda x: x.str.isdigit())
+
+ def test_string_isspace(self):
+ self.check_func(lambda x: x.str.isspace())
+
+ def test_string_islower(self):
+ self.check_func(lambda x: x.str.islower())
+
+ def test_string_isupper(self):
+ self.check_func(lambda x: x.str.isupper())
+
+ def test_string_istitle(self):
+ self.check_func(lambda x: x.str.istitle())
+
+ def test_string_isnumeric(self):
+ self.check_func(lambda x: x.str.isnumeric())
+
+ def test_string_isdecimal(self):
+ self.check_func(lambda x: x.str.isdecimal())
+
+ def test_string_cat(self):
+ psser = ps.from_pandas(self.pser)
+ with self.assertRaises(NotImplementedError):
+ psser.str.cat()
+
+ def test_string_center(self):
+ self.check_func(lambda x: x.str.center(0))
+ self.check_func(lambda x: x.str.center(10))
+ self.check_func(lambda x: x.str.center(10, "x"))
+
+ def test_string_contains(self):
+ self.check_func(lambda x: x.str.contains("le", regex=False))
+ self.check_func(lambda x: x.str.contains("White", case=True,
regex=False))
+ self.check_func(lambda x: x.str.contains("apples|carrots", regex=True))
+ self.check_func(lambda x: x.str.contains("BANANAS",
flags=re.IGNORECASE, na=False))
+
+ def test_string_count(self):
+ self.check_func(lambda x: x.str.count("wh|Wh"))
+ self.check_func(lambda x: x.str.count("WH", flags=re.IGNORECASE))
+
+
+class SeriesStringOpsTests(SeriesStringOpsMixin, PandasOnSparkTestCase,
SQLTestUtils):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.series.test_string_ops_basic import * # noqa:
F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]