This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d5241ff2689 [SPARK-46347][PS][TESTS] Reorganize `RollingTests ` d5241ff2689 is described below commit d5241ff26892fa615b27ae39b0be1b8907f59f29 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Dec 11 10:29:49 2023 -0800 [SPARK-46347][PS][TESTS] Reorganize `RollingTests ` ### What changes were proposed in this pull request? Reorganize `RollingTests`, break it into multiple small files ### Why are the changes needed? to be consistent with Pandas's tests ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #44281 from zhengruifeng/ps_test_rolling. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- dev/sparktestsupport/modules.py | 16 +- .../test_parity_groupby_rolling.py} | 12 +- .../test_parity_groupby_rolling_adv.py} | 12 +- .../test_parity_groupby_rolling_count.py} | 12 +- .../connect/{ => window}/test_parity_rolling.py | 10 +- .../test_parity_rolling_adv.py} | 12 +- .../test_parity_rolling_count.py} | 12 +- .../test_parity_rolling_error.py} | 12 +- python/pyspark/pandas/tests/test_rolling.py | 317 --------------------- .../pandas/tests/window/test_groupby_rolling.py | 132 +++++++++ .../tests/window/test_groupby_rolling_adv.py | 60 ++++ .../tests/window/test_groupby_rolling_count.py | 113 ++++++++ python/pyspark/pandas/tests/window/test_rolling.py | 91 ++++++ .../test_rolling_adv.py} | 33 ++- .../pandas/tests/window/test_rolling_count.py | 72 +++++ .../test_rolling_error.py} | 31 +- 16 files changed, 578 insertions(+), 369 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 68e9ed8101d..c77a34f1d22 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -750,7 +750,13 @@ pyspark_pandas = Module( "pyspark.pandas.tests.resample.test_series", "pyspark.pandas.tests.resample.test_timezone", "pyspark.pandas.tests.test_reshape", - "pyspark.pandas.tests.test_rolling", + "pyspark.pandas.tests.window.test_rolling", + "pyspark.pandas.tests.window.test_rolling_adv", + "pyspark.pandas.tests.window.test_rolling_count", + "pyspark.pandas.tests.window.test_rolling_error", + "pyspark.pandas.tests.window.test_groupby_rolling", + "pyspark.pandas.tests.window.test_groupby_rolling_adv", + "pyspark.pandas.tests.window.test_groupby_rolling_count", "pyspark.pandas.tests.test_scalars", "pyspark.pandas.tests.test_series_conversion", "pyspark.pandas.tests.test_series_datetime", @@ -1120,7 +1126,13 @@ pyspark_pandas_connect_part2 = Module( "pyspark.pandas.tests.connect.window.test_parity_ewm_error", "pyspark.pandas.tests.connect.window.test_parity_ewm_mean", "pyspark.pandas.tests.connect.window.test_parity_groupby_ewm_mean", - "pyspark.pandas.tests.connect.test_parity_rolling", + "pyspark.pandas.tests.connect.window.test_parity_rolling", + "pyspark.pandas.tests.connect.window.test_parity_rolling_adv", + "pyspark.pandas.tests.connect.window.test_parity_rolling_count", + "pyspark.pandas.tests.connect.window.test_parity_rolling_error", + "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling", + "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_adv", + "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_count", "pyspark.pandas.tests.connect.test_parity_expanding", "pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling", "pyspark.pandas.tests.connect.computation.test_parity_missing_data", diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling.py similarity index 76% copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling.py index 8318bed24f0..0a3e0b1358f 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin +from pyspark.pandas.tests.window.test_groupby_rolling import GroupByRollingMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class RollingParityGroupTests( + GroupByRollingMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_groupby_rolling import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_adv.py similarity index 75% copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_adv.py index 8318bed24f0..774f8dd9e75 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_adv.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin +from pyspark.pandas.tests.window.test_groupby_rolling_adv import GroupByRollingAdvMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class RollingParityGroupAdvTests( + GroupByRollingAdvMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_adv import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_count.py similarity index 75% copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_count.py index 8318bed24f0..89dc851b32c 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_count.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin +from pyspark.pandas.tests.window.test_groupby_rolling_count import GroupByRollingCountMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class RollingParityGroupCountTests( + GroupByRollingCountMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_count import * # noqa try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_rolling.py similarity index 79% copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py copy to python/pyspark/pandas/tests/connect/window/test_parity_rolling.py index 8318bed24f0..9dc3d9dcd4c 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_rolling.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin +from pyspark.pandas.tests.window.test_rolling import RollingMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase + RollingMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_rolling import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_adv.py similarity index 77% copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py copy to python/pyspark/pandas/tests/connect/window/test_parity_rolling_adv.py index 8318bed24f0..ae0d9e0ba11 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_adv.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin +from pyspark.pandas.tests.window.test_rolling_adv import RollingAdvMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class RollingParityAdvTests( + RollingAdvMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_rolling_adv import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_count.py similarity index 77% copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py copy to python/pyspark/pandas/tests/connect/window/test_parity_rolling_count.py index 8318bed24f0..7bbe31bc303 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_count.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin +from pyspark.pandas.tests.window.test_rolling_count import RollingCountMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class RollingParityCountTests( + RollingCountMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_rolling_count import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_error.py similarity index 77% copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py copy to python/pyspark/pandas/tests/connect/window/test_parity_rolling_error.py index 8318bed24f0..dc4ecb321d7 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_error.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin +from pyspark.pandas.tests.window.test_rolling_error import RollingErrorMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class RollingParityErrorTests( + RollingErrorMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_rolling_error import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py deleted file mode 100644 index c7e49eab5bb..00000000000 --- a/python/pyspark/pandas/tests/test_rolling.py +++ /dev/null @@ -1,317 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import unittest - -import numpy as np -import pandas as pd - -import pyspark.pandas as ps -from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -from pyspark.pandas.window import Rolling - - -class RollingTestsMixin: - def test_rolling_error(self): - with self.assertRaisesRegex(ValueError, "window must be >= 0"): - ps.range(10).rolling(window=-1) - with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): - ps.range(10).rolling(window=1, min_periods=-1) - - with self.assertRaisesRegex( - TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int" - ): - Rolling(1, 2) - - def _test_rolling_func(self, ps_func, pd_func=None): - if not pd_func: - pd_func = ps_func - if isinstance(pd_func, str): - pd_func = self.convert_str_to_lambda(pd_func) - if isinstance(ps_func, str): - ps_func = self.convert_str_to_lambda(ps_func) - pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") - psser = ps.from_pandas(pser) - self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2))) - self.assert_eq(ps_func(psser.rolling(2)).sum(), pd_func(pser.rolling(2)).sum()) - - # Multiindex - pser = pd.Series( - [1, 2, 3], - index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]), - name="a", - ) - psser = ps.from_pandas(pser) - self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2))) - - pdf = pd.DataFrame( - {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) - ) - psdf = ps.from_pandas(pdf) - self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2))) - self.assert_eq(ps_func(psdf.rolling(2)).sum(), pd_func(pdf.rolling(2)).sum()) - - # Multiindex column - columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) - pdf.columns = columns - psdf.columns = columns - self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2))) - - def test_rolling_min(self): - self._test_rolling_func("min") - - def test_rolling_max(self): - self._test_rolling_func("max") - - def test_rolling_mean(self): - self._test_rolling_func("mean") - - def test_rolling_quantile(self): - self._test_rolling_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")) - - def test_rolling_sum(self): - self._test_rolling_func("sum") - - def test_rolling_count(self): - pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") - psser = ps.from_pandas(pser) - self.assert_eq(psser.rolling(2).count(), pser.rolling(2, min_periods=1).count()) - self.assert_eq(psser.rolling(2).count().sum(), pser.rolling(2, min_periods=1).count().sum()) - - # TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas - # Multiindex - pser = pd.Series( - [1, 2, 3], - index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]), - name="a", - ) - psser = ps.from_pandas(pser) - self.assert_eq(psser.rolling(2).count(), pser.rolling(2, min_periods=1).count()) - - pdf = pd.DataFrame( - {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) - ) - psdf = ps.from_pandas(pdf) - self.assert_eq(psdf.rolling(2).count(), pdf.rolling(2, min_periods=1).count()) - self.assert_eq(psdf.rolling(2).count().sum(), pdf.rolling(2, min_periods=1).count().sum()) - - # Multiindex column - columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) - pdf.columns = columns - psdf.columns = columns - self.assert_eq(psdf.rolling(2).count(), pdf.rolling(2, min_periods=1).count()) - - def test_rolling_std(self): - self._test_rolling_func("std") - - def test_rolling_var(self): - self._test_rolling_func("var") - - def test_rolling_skew(self): - self._test_rolling_func("skew") - - def test_rolling_kurt(self): - self._test_rolling_func("kurt") - - def _test_groupby_rolling_func(self, ps_func, pd_func=None): - if not pd_func: - pd_func = ps_func - if isinstance(pd_func, str): - pd_func = self.convert_str_to_lambda(pd_func) - if isinstance(ps_func, str): - ps_func = self.convert_str_to_lambda(ps_func) - pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") - psser = ps.from_pandas(pser) - self.assert_eq( - ps_func(psser.groupby(psser).rolling(2)).sort_index(), - pd_func(pser.groupby(pser).rolling(2)).sort_index(), - ) - self.assert_eq( - ps_func(psser.groupby(psser).rolling(2)).sum(), - pd_func(pser.groupby(pser).rolling(2)).sum(), - ) - - # Multiindex - pser = pd.Series( - [1, 2, 3, 2], - index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z"), ("c", "z")]), - name="a", - ) - psser = ps.from_pandas(pser) - self.assert_eq( - ps_func(psser.groupby(psser).rolling(2)).sort_index(), - pd_func(pser.groupby(pser).rolling(2)).sort_index(), - ) - - pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) - psdf = ps.from_pandas(pdf) - - self.assert_eq( - ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(), - pd_func(pdf.groupby(pdf.a).rolling(2)).sort_index(), - ) - self.assert_eq( - ps_func(psdf.groupby(psdf.a).rolling(2)).sum(), - pd_func(pdf.groupby(pdf.a).rolling(2)).sum(), - ) - self.assert_eq( - ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(), - pd_func(pdf.groupby(pdf.a + 1).rolling(2)).sort_index(), - ) - - self.assert_eq( - ps_func(psdf.b.groupby(psdf.a).rolling(2)).sort_index(), - pd_func(pdf.b.groupby(pdf.a).rolling(2)).sort_index(), - ) - self.assert_eq( - ps_func(psdf.groupby(psdf.a)["b"].rolling(2)).sort_index(), - pd_func(pdf.groupby(pdf.a)["b"].rolling(2)).sort_index(), - ) - self.assert_eq( - ps_func(psdf.groupby(psdf.a)[["b"]].rolling(2)).sort_index(), - pd_func(pdf.groupby(pdf.a)[["b"]].rolling(2)).sort_index(), - ) - - # Multiindex column - columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) - pdf.columns = columns - psdf.columns = columns - - self.assert_eq( - ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(), - pd_func(pdf.groupby(("a", "x")).rolling(2)).sort_index(), - ) - - self.assert_eq( - ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), - pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), - ) - - def test_groupby_rolling_count(self): - pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") - psser = ps.from_pandas(pser) - # TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas - self.assert_eq( - psser.groupby(psser).rolling(2).count().sort_index(), - pser.groupby(pser).rolling(2, min_periods=1).count().sort_index(), - ) - self.assert_eq( - psser.groupby(psser).rolling(2).count().sum(), - pser.groupby(pser).rolling(2, min_periods=1).count().sum(), - ) - - # Multiindex - pser = pd.Series( - [1, 2, 3, 2], - index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z"), ("c", "z")]), - name="a", - ) - psser = ps.from_pandas(pser) - self.assert_eq( - psser.groupby(psser).rolling(2).count().sort_index(), - pser.groupby(pser).rolling(2, min_periods=1).count().sort_index(), - ) - - pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) - psdf = ps.from_pandas(pdf) - - self.assert_eq( - psdf.groupby(psdf.a).rolling(2).count().sort_index(), - pdf.groupby(pdf.a).rolling(2, min_periods=1).count().sort_index(), - ) - self.assert_eq( - psdf.groupby(psdf.a).rolling(2).count().sum(), - pdf.groupby(pdf.a).rolling(2, min_periods=1).count().sum(), - ) - self.assert_eq( - psdf.groupby(psdf.a + 1).rolling(2).count().sort_index(), - pdf.groupby(pdf.a + 1).rolling(2, min_periods=1).count().sort_index(), - ) - - self.assert_eq( - psdf.b.groupby(psdf.a).rolling(2).count().sort_index(), - pdf.b.groupby(pdf.a).rolling(2, min_periods=1).count().sort_index(), - ) - self.assert_eq( - psdf.groupby(psdf.a)["b"].rolling(2).count().sort_index(), - pdf.groupby(pdf.a)["b"].rolling(2, min_periods=1).count().sort_index(), - ) - self.assert_eq( - psdf.groupby(psdf.a)[["b"]].rolling(2).count().sort_index(), - pdf.groupby(pdf.a)[["b"]].rolling(2, min_periods=1).count().sort_index(), - ) - - # Multiindex column - columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) - pdf.columns = columns - psdf.columns = columns - - self.assert_eq( - psdf.groupby(("a", "x")).rolling(2).count().sort_index(), - pdf.groupby(("a", "x")).rolling(2, min_periods=1).count().sort_index(), - ) - - self.assert_eq( - psdf.groupby([("a", "x"), ("a", "y")]).rolling(2).count().sort_index(), - pdf.groupby([("a", "x"), ("a", "y")]).rolling(2, min_periods=1).count().sort_index(), - ) - - def test_groupby_rolling_min(self): - self._test_groupby_rolling_func("min") - - def test_groupby_rolling_max(self): - self._test_groupby_rolling_func("max") - - def test_groupby_rolling_mean(self): - self._test_groupby_rolling_func("mean") - - def test_groupby_rolling_quantile(self): - self._test_groupby_rolling_func( - lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") - ) - - def test_groupby_rolling_sum(self): - self._test_groupby_rolling_func("sum") - - def test_groupby_rolling_std(self): - # TODO: `std` now raise error in pandas 1.0.0 - self._test_groupby_rolling_func("std") - - def test_groupby_rolling_var(self): - self._test_groupby_rolling_func("var") - - def test_groupby_rolling_skew(self): - self._test_groupby_rolling_func("skew") - - def test_groupby_rolling_kurt(self): - self._test_groupby_rolling_func("kurt") - - -class RollingTests(RollingTestsMixin, PandasOnSparkTestCase, TestUtils): - pass - - -if __name__ == "__main__": - import unittest - from pyspark.pandas.tests.test_rolling import * # noqa: F401 - - try: - import xmlrunner - - testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_groupby_rolling.py b/python/pyspark/pandas/tests/window/test_groupby_rolling.py new file mode 100644 index 00000000000..a5bced6a8bf --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_groupby_rolling.py @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pandas as pd + +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils + + +class GroupByRollingTestingFuncMixin: + def _test_groupby_rolling_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) + pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") + psser = ps.from_pandas(pser) + self.assert_eq( + ps_func(psser.groupby(psser).rolling(2)).sort_index(), + pd_func(pser.groupby(pser).rolling(2)).sort_index(), + ) + self.assert_eq( + ps_func(psser.groupby(psser).rolling(2)).sum(), + pd_func(pser.groupby(pser).rolling(2)).sum(), + ) + + # Multiindex + pser = pd.Series( + [1, 2, 3, 2], + index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z"), ("c", "z")]), + name="a", + ) + psser = ps.from_pandas(pser) + self.assert_eq( + ps_func(psser.groupby(psser).rolling(2)).sort_index(), + pd_func(pser.groupby(pser).rolling(2)).sort_index(), + ) + + pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) + psdf = ps.from_pandas(pdf) + + self.assert_eq( + ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).rolling(2)).sort_index(), + ) + self.assert_eq( + ps_func(psdf.groupby(psdf.a).rolling(2)).sum(), + pd_func(pdf.groupby(pdf.a).rolling(2)).sum(), + ) + self.assert_eq( + ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).rolling(2)).sort_index(), + ) + + self.assert_eq( + ps_func(psdf.b.groupby(psdf.a).rolling(2)).sort_index(), + pd_func(pdf.b.groupby(pdf.a).rolling(2)).sort_index(), + ) + self.assert_eq( + ps_func(psdf.groupby(psdf.a)["b"].rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)["b"].rolling(2)).sort_index(), + ) + self.assert_eq( + ps_func(psdf.groupby(psdf.a)[["b"]].rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)[["b"]].rolling(2)).sort_index(), + ) + + # Multiindex column + columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) + pdf.columns = columns + psdf.columns = columns + + self.assert_eq( + ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).rolling(2)).sort_index(), + ) + + self.assert_eq( + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), + ) + + +class GroupByRollingMixin(GroupByRollingTestingFuncMixin): + def test_groupby_rolling_min(self): + self._test_groupby_rolling_func("min") + + def test_groupby_rolling_max(self): + self._test_groupby_rolling_func("max") + + def test_groupby_rolling_mean(self): + self._test_groupby_rolling_func("mean") + + def test_groupby_rolling_sum(self): + self._test_groupby_rolling_func("sum") + + +class GroupByRollingTests( + GroupByRollingMixin, + PandasOnSparkTestCase, + TestUtils, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_groupby_rolling import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_groupby_rolling_adv.py b/python/pyspark/pandas/tests/window/test_groupby_rolling_adv.py new file mode 100644 index 00000000000..13fa5902d2a --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_groupby_rolling_adv.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.pandas.tests.window.test_groupby_rolling import GroupByRollingTestingFuncMixin + + +class GroupByRollingAdvMixin(GroupByRollingTestingFuncMixin): + def test_groupby_rolling_quantile(self): + self._test_groupby_rolling_func( + lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") + ) + + def test_groupby_rolling_std(self): + # TODO: `std` now raise error in pandas 1.0.0 + self._test_groupby_rolling_func("std") + + def test_groupby_rolling_var(self): + self._test_groupby_rolling_func("var") + + def test_groupby_rolling_skew(self): + self._test_groupby_rolling_func("skew") + + def test_groupby_rolling_kurt(self): + self._test_groupby_rolling_func("kurt") + + +class GroupByRollingAdvTests( + GroupByRollingAdvMixin, + PandasOnSparkTestCase, + TestUtils, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_groupby_rolling_adv import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_groupby_rolling_count.py b/python/pyspark/pandas/tests/window/test_groupby_rolling_count.py new file mode 100644 index 00000000000..7499e2f821a --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_groupby_rolling_count.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pandas as pd + +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils + + +class GroupByRollingCountMixin: + def test_groupby_rolling_count(self): + pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") + psser = ps.from_pandas(pser) + # TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas + self.assert_eq( + psser.groupby(psser).rolling(2).count().sort_index(), + pser.groupby(pser).rolling(2, min_periods=1).count().sort_index(), + ) + self.assert_eq( + psser.groupby(psser).rolling(2).count().sum(), + pser.groupby(pser).rolling(2, min_periods=1).count().sum(), + ) + + # Multiindex + pser = pd.Series( + [1, 2, 3, 2], + index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z"), ("c", "z")]), + name="a", + ) + psser = ps.from_pandas(pser) + self.assert_eq( + psser.groupby(psser).rolling(2).count().sort_index(), + pser.groupby(pser).rolling(2, min_periods=1).count().sort_index(), + ) + + pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) + psdf = ps.from_pandas(pdf) + + self.assert_eq( + psdf.groupby(psdf.a).rolling(2).count().sort_index(), + pdf.groupby(pdf.a).rolling(2, min_periods=1).count().sort_index(), + ) + self.assert_eq( + psdf.groupby(psdf.a).rolling(2).count().sum(), + pdf.groupby(pdf.a).rolling(2, min_periods=1).count().sum(), + ) + self.assert_eq( + psdf.groupby(psdf.a + 1).rolling(2).count().sort_index(), + pdf.groupby(pdf.a + 1).rolling(2, min_periods=1).count().sort_index(), + ) + + self.assert_eq( + psdf.b.groupby(psdf.a).rolling(2).count().sort_index(), + pdf.b.groupby(pdf.a).rolling(2, min_periods=1).count().sort_index(), + ) + self.assert_eq( + psdf.groupby(psdf.a)["b"].rolling(2).count().sort_index(), + pdf.groupby(pdf.a)["b"].rolling(2, min_periods=1).count().sort_index(), + ) + self.assert_eq( + psdf.groupby(psdf.a)[["b"]].rolling(2).count().sort_index(), + pdf.groupby(pdf.a)[["b"]].rolling(2, min_periods=1).count().sort_index(), + ) + + # Multiindex column + columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) + pdf.columns = columns + psdf.columns = columns + + self.assert_eq( + psdf.groupby(("a", "x")).rolling(2).count().sort_index(), + pdf.groupby(("a", "x")).rolling(2, min_periods=1).count().sort_index(), + ) + + self.assert_eq( + psdf.groupby([("a", "x"), ("a", "y")]).rolling(2).count().sort_index(), + pdf.groupby([("a", "x"), ("a", "y")]).rolling(2, min_periods=1).count().sort_index(), + ) + + +class GroupByRollingCountTests( + GroupByRollingCountMixin, + PandasOnSparkTestCase, + TestUtils, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_groupby_rolling_count import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_rolling.py b/python/pyspark/pandas/tests/window/test_rolling.py new file mode 100644 index 00000000000..cf6903afe7c --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_rolling.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pandas as pd + +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +class RollingTestingFuncMixin: + def _test_rolling_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) + pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") + psser = ps.from_pandas(pser) + self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2))) + self.assert_eq(ps_func(psser.rolling(2)).sum(), pd_func(pser.rolling(2)).sum()) + + # Multiindex + pser = pd.Series( + [1, 2, 3], + index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]), + name="a", + ) + psser = ps.from_pandas(pser) + self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2))) + + pdf = pd.DataFrame( + {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2))) + self.assert_eq(ps_func(psdf.rolling(2)).sum(), pd_func(pdf.rolling(2)).sum()) + + # Multiindex column + columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) + pdf.columns = columns + psdf.columns = columns + self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2))) + + +class RollingMixin(RollingTestingFuncMixin): + def test_rolling_min(self): + self._test_rolling_func("min") + + def test_rolling_max(self): + self._test_rolling_func("max") + + def test_rolling_mean(self): + self._test_rolling_func("mean") + + def test_rolling_sum(self): + self._test_rolling_func("sum") + + +class RollingTests( + RollingMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_rolling import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/window/test_rolling_adv.py similarity index 55% copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py copy to python/pyspark/pandas/tests/window/test_rolling_adv.py index 8318bed24f0..6ae48dfa76d 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/window/test_rolling_adv.py @@ -14,24 +14,41 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.pandas.tests.window.test_rolling import RollingTestingFuncMixin -class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class RollingAdvMixin(RollingTestingFuncMixin): + def test_rolling_quantile(self): + self._test_rolling_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")) + + def test_rolling_std(self): + self._test_rolling_func("std") + + def test_rolling_var(self): + self._test_rolling_func("var") + + def test_rolling_skew(self): + self._test_rolling_func("skew") + + def test_rolling_kurt(self): + self._test_rolling_func("kurt") + + +class RollingAdvTests( + RollingAdvMixin, + PandasOnSparkTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + import unittest + from pyspark.pandas.tests.window.test_rolling_adv import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/window/test_rolling_count.py b/python/pyspark/pandas/tests/window/test_rolling_count.py new file mode 100644 index 00000000000..36ec8cb056a --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_rolling_count.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pandas as pd + +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +class RollingCountMixin: + def test_rolling_count(self): + pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") + psser = ps.from_pandas(pser) + self.assert_eq(psser.rolling(2).count(), pser.rolling(2, min_periods=1).count()) + self.assert_eq(psser.rolling(2).count().sum(), pser.rolling(2, min_periods=1).count().sum()) + + # TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas + # Multiindex + pser = pd.Series( + [1, 2, 3], + index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]), + name="a", + ) + psser = ps.from_pandas(pser) + self.assert_eq(psser.rolling(2).count(), pser.rolling(2, min_periods=1).count()) + + pdf = pd.DataFrame( + {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(psdf.rolling(2).count(), pdf.rolling(2, min_periods=1).count()) + self.assert_eq(psdf.rolling(2).count().sum(), pdf.rolling(2, min_periods=1).count().sum()) + + # Multiindex column + columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) + pdf.columns = columns + psdf.columns = columns + self.assert_eq(psdf.rolling(2).count(), pdf.rolling(2, min_periods=1).count()) + + +class RollingCountTests( + RollingCountMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_rolling_count import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/window/test_rolling_error.py similarity index 55% rename from python/pyspark/pandas/tests/connect/test_parity_rolling.py rename to python/pyspark/pandas/tests/window/test_rolling_error.py index 8318bed24f0..485eeb78c13 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py +++ b/python/pyspark/pandas/tests/window/test_rolling_error.py @@ -14,24 +14,39 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest -from pyspark.pandas.tests.test_rolling import RollingTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.pandas.window import Rolling -class RollingParityTests( - RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class RollingErrorMixin: + def test_rolling_error(self): + with self.assertRaisesRegex(ValueError, "window must be >= 0"): + ps.range(10).rolling(window=-1) + with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): + ps.range(10).rolling(window=1, min_periods=-1) + + with self.assertRaisesRegex( + TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int" + ): + Rolling(1, 2) + + +class RollingErrorTests( + RollingErrorMixin, + PandasOnSparkTestCase, + TestUtils, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401 + import unittest + from pyspark.pandas.tests.window.test_rolling_error import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org