This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 208f523b2cf [SPARK-40393][PS][TESTS] Refactor expanding and rolling
test for function with input
208f523b2cf is described below
commit 208f523b2cfeef0390604a8439a255b776765ae0
Author: Yikun Jiang <[email protected]>
AuthorDate: Tue Sep 13 14:39:06 2022 +0900
[SPARK-40393][PS][TESTS] Refactor expanding and rolling test for function
with input
### What changes were proposed in this pull request?
Refactor expanding and rolling test for function with input
### Why are the changes needed?
Refactor expanding and rolling test for function with input:
```python
# Before
self._test_groupby_rolling_func("count")
# After
# str can be accept
self._test_groupby_rolling_func("count")
# Can also accept lambda to support more func style
self._test_groupby_expanding_func(
lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")
)
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- CI passed
- cherry-pick:
https://github.com/apache/spark/commit/e22ee1c80279c5d23125c44327cd1dd58f5a592a
and test manually.
Closes #37835 from Yikun/SPARK-40327.
Authored-by: Yikun Jiang <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/tests/test_expanding.py | 100 +++++++++++++-------------
python/pyspark/pandas/tests/test_rolling.py | 94 +++++++++++++-----------
python/pyspark/testing/pandasutils.py | 6 ++
3 files changed, 110 insertions(+), 90 deletions(-)
diff --git a/python/pyspark/pandas/tests/test_expanding.py
b/python/pyspark/pandas/tests/test_expanding.py
index 9ea8e08bb01..aeb0e9f297b 100644
--- a/python/pyspark/pandas/tests/test_expanding.py
+++ b/python/pyspark/pandas/tests/test_expanding.py
@@ -26,37 +26,37 @@ from pyspark.testing.pandasutils import
PandasOnSparkTestCase, TestUtils
class ExpandingTest(PandasOnSparkTestCase, TestUtils):
- def _test_expanding_func(self, f):
+ def _test_expanding_func(self, ps_func, pd_func=None):
+ if not pd_func:
+ pd_func = ps_func
+ if isinstance(pd_func, str):
+ pd_func = self.convert_str_to_lambda(pd_func)
+ if isinstance(ps_func, str):
+ ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
psser = ps.from_pandas(pser)
- self.assert_eq(
- getattr(psser.expanding(2), f)(), getattr(pser.expanding(2), f)(),
almost=True
- )
- self.assert_eq(
- getattr(psser.expanding(2), f)().sum(),
- getattr(pser.expanding(2), f)().sum(),
- almost=True,
- )
+ self.assert_eq(ps_func(psser.expanding(2)),
pd_func(pser.expanding(2)), almost=True)
+ self.assert_eq(ps_func(psser.expanding(2)),
pd_func(pser.expanding(2)), almost=True)
# Multiindex
pser = pd.Series(
[1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a",
"y"), ("b", "z")])
)
psser = ps.from_pandas(pser)
- self.assert_eq(getattr(psser.expanding(2), f)(),
getattr(pser.expanding(2), f)())
+ self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)))
pdf = pd.DataFrame(
{"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]},
index=np.random.rand(4)
)
psdf = ps.from_pandas(pdf)
- self.assert_eq(getattr(psdf.expanding(2), f)(),
getattr(pdf.expanding(2), f)())
- self.assert_eq(getattr(psdf.expanding(2), f)().sum(),
getattr(pdf.expanding(2), f)().sum())
+ self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))
+ self.assert_eq(ps_func(psdf.expanding(2)).sum(),
pd_func(pdf.expanding(2)).sum())
# Multiindex column
columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
pdf.columns = columns
psdf.columns = columns
- self.assert_eq(getattr(psdf.expanding(2), f)(),
getattr(pdf.expanding(2), f)())
+ self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))
def test_expanding_error(self):
with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
@@ -97,16 +97,22 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
def test_expanding_kurt(self):
self._test_expanding_func("kurt")
- def _test_groupby_expanding_func(self, f):
+ def _test_groupby_expanding_func(self, ps_func, pd_func=None):
+ if not pd_func:
+ pd_func = ps_func
+ if isinstance(pd_func, str):
+ pd_func = self.convert_str_to_lambda(pd_func)
+ if isinstance(ps_func, str):
+ ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
psser = ps.from_pandas(pser)
self.assert_eq(
- getattr(psser.groupby(psser).expanding(2), f)().sort_index(),
- getattr(pser.groupby(pser).expanding(2), f)().sort_index(),
+ ps_func(psser.groupby(psser).expanding(2)).sort_index(),
+ pd_func(pser.groupby(pser).expanding(2)).sort_index(),
)
self.assert_eq(
- getattr(psser.groupby(psser).expanding(2), f)().sum(),
- getattr(pser.groupby(pser).expanding(2), f)().sum(),
+ ps_func(psser.groupby(psser).expanding(2)).sum(),
+ pd_func(pser.groupby(pser).expanding(2)).sum(),
)
# Multiindex
@@ -117,8 +123,8 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
)
psser = ps.from_pandas(pser)
self.assert_eq(
- getattr(psser.groupby(psser).expanding(2), f)().sort_index(),
- getattr(pser.groupby(pser).expanding(2), f)().sort_index(),
+ ps_func(psser.groupby(psser).expanding(2)).sort_index(),
+ pd_func(pser.groupby(pser).expanding(2)).sort_index(),
)
pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0,
1.0]})
@@ -127,42 +133,42 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
# The behavior of GroupBy.expanding is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
- getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a).expanding(2), f)().sort_index(),
+ ps_func(psdf.groupby(psdf.a).expanding(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a).expanding(2)).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(),
- getattr(pdf.groupby(pdf.a).expanding(2), f)().sum(),
+ ps_func(psdf.groupby(psdf.a).expanding(2)).sum(),
+ pd_func(pdf.groupby(pdf.a).expanding(2)).sum(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a + 1).expanding(2),
f)().sort_index(),
- getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().sort_index(),
+ ps_func(psdf.groupby(psdf.a + 1).expanding(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a + 1).expanding(2)).sort_index(),
)
else:
self.assert_eq(
- getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a).expanding(2), f)().drop("a",
axis=1).sort_index(),
+ ps_func(psdf.groupby(psdf.a).expanding(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a).expanding(2)).drop("a",
axis=1).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(),
- getattr(pdf.groupby(pdf.a).expanding(2), f)().sum().drop("a"),
+ ps_func(psdf.groupby(psdf.a).expanding(2)).sum(),
+ pd_func(pdf.groupby(pdf.a).expanding(2)).sum().drop("a"),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a + 1).expanding(2),
f)().sort_index(),
- getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().drop("a",
axis=1).sort_index(),
+ ps_func(psdf.groupby(psdf.a + 1).expanding(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a + 1).expanding(2)).drop("a",
axis=1).sort_index(),
)
self.assert_eq(
- getattr(psdf.b.groupby(psdf.a).expanding(2), f)().sort_index(),
- getattr(pdf.b.groupby(pdf.a).expanding(2), f)().sort_index(),
+ ps_func(psdf.b.groupby(psdf.a).expanding(2)).sort_index(),
+ pd_func(pdf.b.groupby(pdf.a).expanding(2)).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a)["b"].expanding(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a)["b"].expanding(2), f)().sort_index(),
+ ps_func(psdf.groupby(psdf.a)["b"].expanding(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a)["b"].expanding(2)).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a)[["b"]].expanding(2),
f)().sort_index(),
- getattr(pdf.groupby(pdf.a)[["b"]].expanding(2), f)().sort_index(),
+ ps_func(psdf.groupby(psdf.a)[["b"]].expanding(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a)[["b"]].expanding(2)).sort_index(),
)
# Multiindex column
@@ -173,25 +179,23 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
# The behavior of GroupBy.expanding is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
- getattr(psdf.groupby(("a", "x")).expanding(2),
f)().sort_index(),
- getattr(pdf.groupby(("a", "x")).expanding(2),
f)().sort_index(),
+ ps_func(psdf.groupby(("a", "x")).expanding(2)).sort_index(),
+ pd_func(pdf.groupby(("a", "x")).expanding(2)).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2),
f)().sort_index(),
- getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2),
f)().sort_index(),
+ ps_func(psdf.groupby([("a", "x"), ("a",
"y")]).expanding(2)).sort_index(),
+ pd_func(pdf.groupby([("a", "x"), ("a",
"y")]).expanding(2)).sort_index(),
)
else:
self.assert_eq(
- getattr(psdf.groupby(("a", "x")).expanding(2),
f)().sort_index(),
- getattr(pdf.groupby(("a", "x")).expanding(2), f)()
- .drop(("a", "x"), axis=1)
- .sort_index(),
+ ps_func(psdf.groupby(("a", "x")).expanding(2)).sort_index(),
+ pd_func(pdf.groupby(("a", "x")).expanding(2)).drop(("a", "x"),
axis=1).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2),
f)().sort_index(),
- getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2),
f)()
+ ps_func(psdf.groupby([("a", "x"), ("a",
"y")]).expanding(2)).sort_index(),
+ pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2))
.drop([("a", "x"), ("a", "y")], axis=1)
.sort_index(),
)
diff --git a/python/pyspark/pandas/tests/test_rolling.py
b/python/pyspark/pandas/tests/test_rolling.py
index bf793765655..3f92eba79ce 100644
--- a/python/pyspark/pandas/tests/test_rolling.py
+++ b/python/pyspark/pandas/tests/test_rolling.py
@@ -36,11 +36,17 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
):
Rolling(1, 2)
- def _test_rolling_func(self, f):
+ def _test_rolling_func(self, ps_func, pd_func=None):
+ if not pd_func:
+ pd_func = ps_func
+ if isinstance(pd_func, str):
+ pd_func = self.convert_str_to_lambda(pd_func)
+ if isinstance(ps_func, str):
+ ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
psser = ps.from_pandas(pser)
- self.assert_eq(getattr(psser.rolling(2), f)(),
getattr(pser.rolling(2), f)())
- self.assert_eq(getattr(psser.rolling(2), f)().sum(),
getattr(pser.rolling(2), f)().sum())
+ self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2)))
+ self.assert_eq(ps_func(psser.rolling(2)).sum(),
pd_func(pser.rolling(2)).sum())
# Multiindex
pser = pd.Series(
@@ -49,20 +55,20 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
name="a",
)
psser = ps.from_pandas(pser)
- self.assert_eq(getattr(psser.rolling(2), f)(),
getattr(pser.rolling(2), f)())
+ self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2)))
pdf = pd.DataFrame(
{"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]},
index=np.random.rand(4)
)
psdf = ps.from_pandas(pdf)
- self.assert_eq(getattr(psdf.rolling(2), f)(), getattr(pdf.rolling(2),
f)())
- self.assert_eq(getattr(psdf.rolling(2), f)().sum(),
getattr(pdf.rolling(2), f)().sum())
+ self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2)))
+ self.assert_eq(ps_func(psdf.rolling(2)).sum(),
pd_func(pdf.rolling(2)).sum())
# Multiindex column
columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
pdf.columns = columns
psdf.columns = columns
- self.assert_eq(getattr(psdf.rolling(2), f)(), getattr(pdf.rolling(2),
f)())
+ self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2)))
def test_rolling_min(self):
self._test_rolling_func("min")
@@ -91,16 +97,22 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
def test_rolling_kurt(self):
self._test_rolling_func("kurt")
- def _test_groupby_rolling_func(self, f):
+ def _test_groupby_rolling_func(self, ps_func, pd_func=None):
+ if not pd_func:
+ pd_func = ps_func
+ if isinstance(pd_func, str):
+ pd_func = self.convert_str_to_lambda(pd_func)
+ if isinstance(ps_func, str):
+ ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
psser = ps.from_pandas(pser)
self.assert_eq(
- getattr(psser.groupby(psser).rolling(2), f)().sort_index(),
- getattr(pser.groupby(pser).rolling(2), f)().sort_index(),
+ ps_func(psser.groupby(psser).rolling(2)).sort_index(),
+ pd_func(pser.groupby(pser).rolling(2)).sort_index(),
)
self.assert_eq(
- getattr(psser.groupby(psser).rolling(2), f)().sum(),
- getattr(pser.groupby(pser).rolling(2), f)().sum(),
+ ps_func(psser.groupby(psser).rolling(2)).sum(),
+ pd_func(pser.groupby(pser).rolling(2)).sum(),
)
# Multiindex
@@ -111,8 +123,8 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
)
psser = ps.from_pandas(pser)
self.assert_eq(
- getattr(psser.groupby(psser).rolling(2), f)().sort_index(),
- getattr(pser.groupby(pser).rolling(2), f)().sort_index(),
+ ps_func(psser.groupby(psser).rolling(2)).sort_index(),
+ pd_func(pser.groupby(pser).rolling(2)).sort_index(),
)
pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0,
1.0]})
@@ -121,42 +133,42 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
# The behavior of GroupBy.rolling is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
- getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a).rolling(2), f)().sort_index(),
+ ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a).rolling(2)).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(),
- getattr(pdf.groupby(pdf.a).rolling(2), f)().sum(),
+ ps_func(psdf.groupby(psdf.a).rolling(2)).sum(),
+ pd_func(pdf.groupby(pdf.a).rolling(2)).sum(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().sort_index(),
+ ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a + 1).rolling(2)).sort_index(),
)
else:
self.assert_eq(
- getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a).rolling(2), f)().drop("a",
axis=1).sort_index(),
+ ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a).rolling(2)).drop("a",
axis=1).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(),
- getattr(pdf.groupby(pdf.a).rolling(2), f)().sum().drop("a"),
+ ps_func(psdf.groupby(psdf.a).rolling(2)).sum(),
+ pd_func(pdf.groupby(pdf.a).rolling(2)).sum().drop("a"),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().drop("a",
axis=1).sort_index(),
+ ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a + 1).rolling(2)).drop("a",
axis=1).sort_index(),
)
self.assert_eq(
- getattr(psdf.b.groupby(psdf.a).rolling(2), f)().sort_index(),
- getattr(pdf.b.groupby(pdf.a).rolling(2), f)().sort_index(),
+ ps_func(psdf.b.groupby(psdf.a).rolling(2)).sort_index(),
+ pd_func(pdf.b.groupby(pdf.a).rolling(2)).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a)["b"].rolling(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a)["b"].rolling(2), f)().sort_index(),
+ ps_func(psdf.groupby(psdf.a)["b"].rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a)["b"].rolling(2)).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby(psdf.a)[["b"]].rolling(2), f)().sort_index(),
- getattr(pdf.groupby(pdf.a)[["b"]].rolling(2), f)().sort_index(),
+ ps_func(psdf.groupby(psdf.a)[["b"]].rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a)[["b"]].rolling(2)).sort_index(),
)
# Multiindex column
@@ -167,25 +179,23 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
# The behavior of GroupBy.rolling is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
- getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
- getattr(pdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
+ ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(("a", "x")).rolling(2)).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2),
f)().sort_index(),
- getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2),
f)().sort_index(),
+ ps_func(psdf.groupby([("a", "x"), ("a",
"y")]).rolling(2)).sort_index(),
+ pd_func(pdf.groupby([("a", "x"), ("a",
"y")]).rolling(2)).sort_index(),
)
else:
self.assert_eq(
- getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
- getattr(pdf.groupby(("a", "x")).rolling(2), f)()
- .drop(("a", "x"), axis=1)
- .sort_index(),
+ ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(("a", "x")).rolling(2)).drop(("a", "x"),
axis=1).sort_index(),
)
self.assert_eq(
- getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2),
f)().sort_index(),
- getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)()
+ ps_func(psdf.groupby([("a", "x"), ("a",
"y")]).rolling(2)).sort_index(),
+ pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2))
.drop([("a", "x"), ("a", "y")], axis=1)
.sort_index(),
)
diff --git a/python/pyspark/testing/pandasutils.py
b/python/pyspark/testing/pandasutils.py
index baa43e5b9d5..ad2f74e8af4 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -65,6 +65,12 @@ class PandasOnSparkTestCase(ReusedSQLTestCase):
super(PandasOnSparkTestCase, cls).setUpClass()
cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
+ def convert_str_to_lambda(self, func):
+ """
+ This function coverts `func` str to lambda call
+ """
+ return lambda x: getattr(x, func)()
+
def assertPandasEqual(self, left, right, check_exact=True):
if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
try:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]