This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 208f523b2cf [SPARK-40393][PS][TESTS] Refactor expanding and rolling test for function with input 208f523b2cf is described below commit 208f523b2cfeef0390604a8439a255b776765ae0 Author: Yikun Jiang <yikunk...@gmail.com> AuthorDate: Tue Sep 13 14:39:06 2022 +0900 [SPARK-40393][PS][TESTS] Refactor expanding and rolling test for function with input ### What changes were proposed in this pull request? Refactor expanding and rolling test for function with input ### Why are the changes needed? Refactor expanding and rolling test for function with input: ```python # Before self._test_groupby_rolling_func("count") # After # str can be accept self._test_groupby_rolling_func("count") # Can also accept lambda to support more func style self._test_groupby_expanding_func( lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") ) ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - CI passed - cherry-pick: https://github.com/apache/spark/commit/e22ee1c80279c5d23125c44327cd1dd58f5a592a and test manually. Closes #37835 from Yikun/SPARK-40327. Authored-by: Yikun Jiang <yikunk...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/pandas/tests/test_expanding.py | 100 +++++++++++++------------- python/pyspark/pandas/tests/test_rolling.py | 94 +++++++++++++----------- python/pyspark/testing/pandasutils.py | 6 ++ 3 files changed, 110 insertions(+), 90 deletions(-) diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/test_expanding.py index 9ea8e08bb01..aeb0e9f297b 100644 --- a/python/pyspark/pandas/tests/test_expanding.py +++ b/python/pyspark/pandas/tests/test_expanding.py @@ -26,37 +26,37 @@ from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils class ExpandingTest(PandasOnSparkTestCase, TestUtils): - def _test_expanding_func(self, f): + def _test_expanding_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") psser = ps.from_pandas(pser) - self.assert_eq( - getattr(psser.expanding(2), f)(), getattr(pser.expanding(2), f)(), almost=True - ) - self.assert_eq( - getattr(psser.expanding(2), f)().sum(), - getattr(pser.expanding(2), f)().sum(), - almost=True, - ) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) # Multiindex pser = pd.Series( [1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]) ) psser = ps.from_pandas(pser) - self.assert_eq(getattr(psser.expanding(2), f)(), getattr(pser.expanding(2), f)()) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2))) pdf = pd.DataFrame( {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) ) psdf = ps.from_pandas(pdf) - self.assert_eq(getattr(psdf.expanding(2), f)(), getattr(pdf.expanding(2), f)()) - self.assert_eq(getattr(psdf.expanding(2), f)().sum(), getattr(pdf.expanding(2), f)().sum()) + self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) + self.assert_eq(ps_func(psdf.expanding(2)).sum(), pd_func(pdf.expanding(2)).sum()) # Multiindex column columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) pdf.columns = columns psdf.columns = columns - self.assert_eq(getattr(psdf.expanding(2), f)(), getattr(pdf.expanding(2), f)()) + self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) def test_expanding_error(self): with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): @@ -97,16 +97,22 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils): def test_expanding_kurt(self): self._test_expanding_func("kurt") - def _test_groupby_expanding_func(self, f): + def _test_groupby_expanding_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") psser = ps.from_pandas(pser) self.assert_eq( - getattr(psser.groupby(psser).expanding(2), f)().sort_index(), - getattr(pser.groupby(pser).expanding(2), f)().sort_index(), + ps_func(psser.groupby(psser).expanding(2)).sort_index(), + pd_func(pser.groupby(pser).expanding(2)).sort_index(), ) self.assert_eq( - getattr(psser.groupby(psser).expanding(2), f)().sum(), - getattr(pser.groupby(pser).expanding(2), f)().sum(), + ps_func(psser.groupby(psser).expanding(2)).sum(), + pd_func(pser.groupby(pser).expanding(2)).sum(), ) # Multiindex @@ -117,8 +123,8 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils): ) psser = ps.from_pandas(pser) self.assert_eq( - getattr(psser.groupby(psser).expanding(2), f)().sort_index(), - getattr(pser.groupby(pser).expanding(2), f)().sort_index(), + ps_func(psser.groupby(psser).expanding(2)).sort_index(), + pd_func(pser.groupby(pser).expanding(2)).sort_index(), ) pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) @@ -127,42 +133,42 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils): # The behavior of GroupBy.expanding is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): self.assert_eq( - getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a).expanding(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a).expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).expanding(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(), - getattr(pdf.groupby(pdf.a).expanding(2), f)().sum(), + ps_func(psdf.groupby(psdf.a).expanding(2)).sum(), + pd_func(pdf.groupby(pdf.a).expanding(2)).sum(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a + 1).expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).expanding(2)).sort_index(), ) else: self.assert_eq( - getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a).expanding(2), f)().drop("a", axis=1).sort_index(), + ps_func(psdf.groupby(psdf.a).expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).expanding(2)).drop("a", axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(), - getattr(pdf.groupby(pdf.a).expanding(2), f)().sum().drop("a"), + ps_func(psdf.groupby(psdf.a).expanding(2)).sum(), + pd_func(pdf.groupby(pdf.a).expanding(2)).sum().drop("a"), ) self.assert_eq( - getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().drop("a", axis=1).sort_index(), + ps_func(psdf.groupby(psdf.a + 1).expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).expanding(2)).drop("a", axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.b.groupby(psdf.a).expanding(2), f)().sort_index(), - getattr(pdf.b.groupby(pdf.a).expanding(2), f)().sort_index(), + ps_func(psdf.b.groupby(psdf.a).expanding(2)).sort_index(), + pd_func(pdf.b.groupby(pdf.a).expanding(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a)["b"].expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a)["b"].expanding(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a)["b"].expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)["b"].expanding(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a)[["b"]].expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a)[["b"]].expanding(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a)[["b"]].expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)[["b"]].expanding(2)).sort_index(), ) # Multiindex column @@ -173,25 +179,23 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils): # The behavior of GroupBy.expanding is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): self.assert_eq( - getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(), - getattr(pdf.groupby(("a", "x")).expanding(2), f)().sort_index(), + ps_func(psdf.groupby(("a", "x")).expanding(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).expanding(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), - getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(), ) else: self.assert_eq( - getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(), - getattr(pdf.groupby(("a", "x")).expanding(2), f)() - .drop(("a", "x"), axis=1) - .sort_index(), + ps_func(psdf.groupby(("a", "x")).expanding(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).expanding(2)).drop(("a", "x"), axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), - getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)() + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2)) .drop([("a", "x"), ("a", "y")], axis=1) .sort_index(), ) diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py index bf793765655..3f92eba79ce 100644 --- a/python/pyspark/pandas/tests/test_rolling.py +++ b/python/pyspark/pandas/tests/test_rolling.py @@ -36,11 +36,17 @@ class RollingTest(PandasOnSparkTestCase, TestUtils): ): Rolling(1, 2) - def _test_rolling_func(self, f): + def _test_rolling_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") psser = ps.from_pandas(pser) - self.assert_eq(getattr(psser.rolling(2), f)(), getattr(pser.rolling(2), f)()) - self.assert_eq(getattr(psser.rolling(2), f)().sum(), getattr(pser.rolling(2), f)().sum()) + self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2))) + self.assert_eq(ps_func(psser.rolling(2)).sum(), pd_func(pser.rolling(2)).sum()) # Multiindex pser = pd.Series( @@ -49,20 +55,20 @@ class RollingTest(PandasOnSparkTestCase, TestUtils): name="a", ) psser = ps.from_pandas(pser) - self.assert_eq(getattr(psser.rolling(2), f)(), getattr(pser.rolling(2), f)()) + self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2))) pdf = pd.DataFrame( {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) ) psdf = ps.from_pandas(pdf) - self.assert_eq(getattr(psdf.rolling(2), f)(), getattr(pdf.rolling(2), f)()) - self.assert_eq(getattr(psdf.rolling(2), f)().sum(), getattr(pdf.rolling(2), f)().sum()) + self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2))) + self.assert_eq(ps_func(psdf.rolling(2)).sum(), pd_func(pdf.rolling(2)).sum()) # Multiindex column columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) pdf.columns = columns psdf.columns = columns - self.assert_eq(getattr(psdf.rolling(2), f)(), getattr(pdf.rolling(2), f)()) + self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2))) def test_rolling_min(self): self._test_rolling_func("min") @@ -91,16 +97,22 @@ class RollingTest(PandasOnSparkTestCase, TestUtils): def test_rolling_kurt(self): self._test_rolling_func("kurt") - def _test_groupby_rolling_func(self, f): + def _test_groupby_rolling_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") psser = ps.from_pandas(pser) self.assert_eq( - getattr(psser.groupby(psser).rolling(2), f)().sort_index(), - getattr(pser.groupby(pser).rolling(2), f)().sort_index(), + ps_func(psser.groupby(psser).rolling(2)).sort_index(), + pd_func(pser.groupby(pser).rolling(2)).sort_index(), ) self.assert_eq( - getattr(psser.groupby(psser).rolling(2), f)().sum(), - getattr(pser.groupby(pser).rolling(2), f)().sum(), + ps_func(psser.groupby(psser).rolling(2)).sum(), + pd_func(pser.groupby(pser).rolling(2)).sum(), ) # Multiindex @@ -111,8 +123,8 @@ class RollingTest(PandasOnSparkTestCase, TestUtils): ) psser = ps.from_pandas(pser) self.assert_eq( - getattr(psser.groupby(psser).rolling(2), f)().sort_index(), - getattr(pser.groupby(pser).rolling(2), f)().sort_index(), + ps_func(psser.groupby(psser).rolling(2)).sort_index(), + pd_func(pser.groupby(pser).rolling(2)).sort_index(), ) pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) @@ -121,42 +133,42 @@ class RollingTest(PandasOnSparkTestCase, TestUtils): # The behavior of GroupBy.rolling is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): self.assert_eq( - getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a).rolling(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).rolling(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(), - getattr(pdf.groupby(pdf.a).rolling(2), f)().sum(), + ps_func(psdf.groupby(psdf.a).rolling(2)).sum(), + pd_func(pdf.groupby(pdf.a).rolling(2)).sum(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).rolling(2)).sort_index(), ) else: self.assert_eq( - getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a).rolling(2), f)().drop("a", axis=1).sort_index(), + ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).rolling(2)).drop("a", axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(), - getattr(pdf.groupby(pdf.a).rolling(2), f)().sum().drop("a"), + ps_func(psdf.groupby(psdf.a).rolling(2)).sum(), + pd_func(pdf.groupby(pdf.a).rolling(2)).sum().drop("a"), ) self.assert_eq( - getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().drop("a", axis=1).sort_index(), + ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).rolling(2)).drop("a", axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.b.groupby(psdf.a).rolling(2), f)().sort_index(), - getattr(pdf.b.groupby(pdf.a).rolling(2), f)().sort_index(), + ps_func(psdf.b.groupby(psdf.a).rolling(2)).sort_index(), + pd_func(pdf.b.groupby(pdf.a).rolling(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a)["b"].rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a)["b"].rolling(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a)["b"].rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)["b"].rolling(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a)[["b"]].rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a)[["b"]].rolling(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a)[["b"]].rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)[["b"]].rolling(2)).sort_index(), ) # Multiindex column @@ -167,25 +179,23 @@ class RollingTest(PandasOnSparkTestCase, TestUtils): # The behavior of GroupBy.rolling is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): self.assert_eq( - getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(), - getattr(pdf.groupby(("a", "x")).rolling(2), f)().sort_index(), + ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).rolling(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), - getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), ) else: self.assert_eq( - getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(), - getattr(pdf.groupby(("a", "x")).rolling(2), f)() - .drop(("a", "x"), axis=1) - .sort_index(), + ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).rolling(2)).drop(("a", "x"), axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), - getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)() + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)) .drop([("a", "x"), ("a", "y")], axis=1) .sort_index(), ) diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index baa43e5b9d5..ad2f74e8af4 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -65,6 +65,12 @@ class PandasOnSparkTestCase(ReusedSQLTestCase): super(PandasOnSparkTestCase, cls).setUpClass() cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True) + def convert_str_to_lambda(self, func): + """ + This function coverts `func` str to lambda call + """ + return lambda x: getattr(x, func)() + def assertPandasEqual(self, left, right, check_exact=True): if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): try: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org