This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2893cd304a9b [SPARK-46391][PS][TESTS] Reorganize `ExpandingParityTests` 2893cd304a9b is described below commit 2893cd304a9b7a6782727da58204331a9083cdaf Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Dec 14 18:41:32 2023 +0800 [SPARK-46391][PS][TESTS] Reorganize `ExpandingParityTests` ### What changes were proposed in this pull request? Reorganize `ExpandingParityTests` ### Why are the changes needed? to make the test more consistent with pandas ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #44332 from zhengruifeng/ps_test_expanding. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- dev/sparktestsupport/modules.py | 12 ++- .../connect/{ => window}/test_parity_expanding.py | 10 +- .../test_parity_expanding_adv.py} | 12 ++- .../test_parity_expanding_error.py} | 12 ++- .../test_parity_groupby_expanding.py} | 12 ++- .../test_parity_groupby_expanding_adv.py} | 12 ++- .../pyspark/pandas/tests/window/test_expanding.py | 96 +++++++++++++++++++ .../test_expanding_adv.py} | 33 +++++-- .../test_expanding_error.py} | 28 ++++-- .../test_groupby_expanding.py} | 103 ++------------------- .../test_groupby_expanding_adv.py} | 35 +++++-- 11 files changed, 220 insertions(+), 145 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index fee9198dff42..22fdde139d28 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -727,7 +727,11 @@ pyspark_pandas = Module( "pyspark.pandas.tests.test_dataframe_conversion", "pyspark.pandas.tests.test_dataframe_spark_io", "pyspark.pandas.tests.test_default_index", - "pyspark.pandas.tests.test_expanding", + "pyspark.pandas.tests.window.test_expanding", + "pyspark.pandas.tests.window.test_expanding_adv", + "pyspark.pandas.tests.window.test_expanding_error", + "pyspark.pandas.tests.window.test_groupby_expanding", + "pyspark.pandas.tests.window.test_groupby_expanding_adv", "pyspark.pandas.tests.test_extension", "pyspark.pandas.tests.window.test_ewm_error", "pyspark.pandas.tests.window.test_ewm_mean", @@ -1135,7 +1139,11 @@ pyspark_pandas_connect_part2 = Module( "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling", "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_adv", "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_count", - "pyspark.pandas.tests.connect.test_parity_expanding", + "pyspark.pandas.tests.connect.window.test_parity_expanding", + "pyspark.pandas.tests.connect.window.test_parity_expanding_adv", + "pyspark.pandas.tests.connect.window.test_parity_expanding_error", + "pyspark.pandas.tests.connect.window.test_parity_groupby_expanding", + "pyspark.pandas.tests.connect.window.test_parity_groupby_expanding_adv", "pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling", "pyspark.pandas.tests.connect.computation.test_parity_missing_data", "pyspark.pandas.tests.connect.groupby.test_parity_index", diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/connect/window/test_parity_expanding.py similarity index 79% copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py copy to python/pyspark/pandas/tests/connect/window/test_parity_expanding.py index 7f8b1a3cac2f..ac83a1c3b34c 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin +from pyspark.pandas.tests.window.test_expanding import ExpandingMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase + ExpandingMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_expanding import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py similarity index 77% copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py copy to python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py index 7f8b1a3cac2f..0baec678bede 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_adv.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin +from pyspark.pandas.tests.window.test_expanding_adv import ExpandingAdvMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class ExpandingAdvParityTests( + ExpandingAdvMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_expanding_adv import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py similarity index 76% copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py copy to python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py index 7f8b1a3cac2f..a8531a02799c 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_expanding_error.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin +from pyspark.pandas.tests.window.test_expanding_error import ExpandingErrorMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class ExpandingErrorParityTests( + ExpandingErrorMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_expanding_error import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py similarity index 76% copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py index 7f8b1a3cac2f..356bc5298264 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin +from pyspark.pandas.tests.window.test_groupby_expanding import GroupByExpandingMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class GroupByExpandingParityTests( + GroupByExpandingMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_groupby_expanding import * # noqa try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py similarity index 75% copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py index 7f8b1a3cac2f..b743e335b154 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_expanding_adv.py @@ -16,19 +16,21 @@ # import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin +from pyspark.pandas.tests.window.test_groupby_expanding_adv import GroupByExpandingAdvMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class GroupByExpandingAdvParityTests( + GroupByExpandingAdvMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_groupby_expanding_adv import * # noqa try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/window/test_expanding.py b/python/pyspark/pandas/tests/window/test_expanding.py new file mode 100644 index 000000000000..ebe54ff21719 --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_expanding.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pandas as pd + +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +class ExpandingTestingFuncMixin: + def _test_expanding_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) + pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") + psser = ps.from_pandas(pser) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) + + # Multiindex + pser = pd.Series( + [1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]) + ) + psser = ps.from_pandas(pser) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2))) + + pdf = pd.DataFrame( + {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) + self.assert_eq(ps_func(psdf.expanding(2)).sum(), pd_func(pdf.expanding(2)).sum()) + + # Multiindex column + columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) + pdf.columns = columns + psdf.columns = columns + self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) + + +class ExpandingMixin(ExpandingTestingFuncMixin): + def test_expanding_repr(self): + self.assertEqual(repr(ps.range(10).expanding(5)), "Expanding [min_periods=5]") + + def test_expanding_count(self): + self._test_expanding_func("count") + + def test_expanding_min(self): + self._test_expanding_func("min") + + def test_expanding_max(self): + self._test_expanding_func("max") + + def test_expanding_mean(self): + self._test_expanding_func("mean") + + def test_expanding_sum(self): + self._test_expanding_func("sum") + + +class ExpandingTests( + ExpandingMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_expanding import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/window/test_expanding_adv.py similarity index 55% copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py copy to python/pyspark/pandas/tests/window/test_expanding_adv.py index 7f8b1a3cac2f..e537f1ecfbc0 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/window/test_expanding_adv.py @@ -14,24 +14,41 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.window.test_expanding import ExpandingTestingFuncMixin -class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class ExpandingAdvMixin(ExpandingTestingFuncMixin): + def test_expanding_quantile(self): + self._test_expanding_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")) + + def test_expanding_std(self): + self._test_expanding_func("std") + + def test_expanding_var(self): + self._test_expanding_func("var") + + def test_expanding_skew(self): + self._test_expanding_func("skew") + + def test_expanding_kurt(self): + self._test_expanding_func("kurt") + + +class ExpandingAdvTests( + ExpandingAdvMixin, + PandasOnSparkTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + import unittest + from pyspark.pandas.tests.window.test_expanding_adv import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/window/test_expanding_error.py similarity index 60% copy from python/pyspark/pandas/tests/connect/test_parity_expanding.py copy to python/pyspark/pandas/tests/window/test_expanding_error.py index 7f8b1a3cac2f..fa888f5f1696 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/window/test_expanding_error.py @@ -14,24 +14,36 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +import pyspark.pandas as ps +from pyspark.pandas.window import Expanding +from pyspark.testing.pandasutils import PandasOnSparkTestCase -class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class ExpandingErrorMixin: + def test_expanding_error(self): + with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): + ps.range(10).expanding(-1) + + with self.assertRaisesRegex( + TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int" + ): + Expanding(1, 2) + + +class ExpandingErrorTests( + ExpandingErrorMixin, + PandasOnSparkTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + import unittest + from pyspark.pandas.tests.window.test_expanding_error import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/window/test_groupby_expanding.py similarity index 56% rename from python/pyspark/pandas/tests/test_expanding.py rename to python/pyspark/pandas/tests/window/test_groupby_expanding.py index 5166f8132665..44fecd7e58eb 100644 --- a/python/pyspark/pandas/tests/test_expanding.py +++ b/python/pyspark/pandas/tests/window/test_groupby_expanding.py @@ -19,85 +19,10 @@ import numpy as np import pandas as pd import pyspark.pandas as ps -from pyspark.pandas.window import Expanding from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class ExpandingTestsMixin: - def _test_expanding_func(self, ps_func, pd_func=None): - if not pd_func: - pd_func = ps_func - if isinstance(pd_func, str): - pd_func = self.convert_str_to_lambda(pd_func) - if isinstance(ps_func, str): - ps_func = self.convert_str_to_lambda(ps_func) - pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") - psser = ps.from_pandas(pser) - self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) - self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) - - # Multiindex - pser = pd.Series( - [1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]) - ) - psser = ps.from_pandas(pser) - self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2))) - - pdf = pd.DataFrame( - {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) - ) - psdf = ps.from_pandas(pdf) - self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) - self.assert_eq(ps_func(psdf.expanding(2)).sum(), pd_func(pdf.expanding(2)).sum()) - - # Multiindex column - columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) - pdf.columns = columns - psdf.columns = columns - self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) - - def test_expanding_error(self): - with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): - ps.range(10).expanding(-1) - - with self.assertRaisesRegex( - TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int" - ): - Expanding(1, 2) - - def test_expanding_repr(self): - self.assertEqual(repr(ps.range(10).expanding(5)), "Expanding [min_periods=5]") - - def test_expanding_count(self): - self._test_expanding_func("count") - - def test_expanding_min(self): - self._test_expanding_func("min") - - def test_expanding_max(self): - self._test_expanding_func("max") - - def test_expanding_mean(self): - self._test_expanding_func("mean") - - def test_expanding_quantile(self): - self._test_expanding_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")) - - def test_expanding_sum(self): - self._test_expanding_func("sum") - - def test_expanding_std(self): - self._test_expanding_func("std") - - def test_expanding_var(self): - self._test_expanding_func("var") - - def test_expanding_skew(self): - self._test_expanding_func("skew") - - def test_expanding_kurt(self): - self._test_expanding_func("kurt") - +class GroupByExpandingTestingFuncMixin: def _test_groupby_expanding_func(self, ps_func, pd_func=None): if not pd_func: pd_func = ps_func @@ -172,6 +97,8 @@ class ExpandingTestsMixin: pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(), ) + +class GroupByExpandingMixin(GroupByExpandingTestingFuncMixin): def test_groupby_expanding_count(self): self._test_groupby_expanding_func("count") @@ -184,34 +111,20 @@ class ExpandingTestsMixin: def test_groupby_expanding_mean(self): self._test_groupby_expanding_func("mean") - def test_groupby_expanding_quantile(self): - self._test_groupby_expanding_func( - lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") - ) - def test_groupby_expanding_sum(self): self._test_groupby_expanding_func("sum") - def test_groupby_expanding_std(self): - self._test_groupby_expanding_func("std") - - def test_groupby_expanding_var(self): - self._test_groupby_expanding_func("var") - - def test_groupby_expanding_skew(self): - self._test_groupby_expanding_func("skew") - - def test_groupby_expanding_kurt(self): - self._test_groupby_expanding_func("kurt") - -class ExpandingTests(ExpandingTestsMixin, PandasOnSparkTestCase, TestUtils): +class GroupByExpandingTests( + GroupByExpandingMixin, + PandasOnSparkTestCase, +): pass if __name__ == "__main__": import unittest - from pyspark.pandas.tests.test_expanding import * # noqa: F401 + from pyspark.pandas.tests.window.test_groupby_expanding import * # noqa: F401 try: import xmlrunner diff --git a/python/pyspark/pandas/tests/connect/test_parity_expanding.py b/python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py similarity index 50% rename from python/pyspark/pandas/tests/connect/test_parity_expanding.py rename to python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py index 7f8b1a3cac2f..22cb03dc0ff3 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_expanding.py +++ b/python/pyspark/pandas/tests/window/test_groupby_expanding_adv.py @@ -14,24 +14,43 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest -from pyspark.pandas.tests.test_expanding import ExpandingTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.pandas.tests.window.test_groupby_expanding import GroupByExpandingTestingFuncMixin -class ExpandingParityTests( - ExpandingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +class GroupByExpandingAdvMixin(GroupByExpandingTestingFuncMixin): + def test_groupby_expanding_quantile(self): + self._test_groupby_expanding_func( + lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") + ) + + def test_groupby_expanding_std(self): + self._test_groupby_expanding_func("std") + + def test_groupby_expanding_var(self): + self._test_groupby_expanding_func("var") + + def test_groupby_expanding_skew(self): + self._test_groupby_expanding_func("skew") + + def test_groupby_expanding_kurt(self): + self._test_groupby_expanding_func("kurt") + + +class GroupByExpandingAdvTests( + GroupByExpandingAdvMixin, + PandasOnSparkTestCase, ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_expanding import * # noqa: F401 + import unittest + from pyspark.pandas.tests.window.test_groupby_expanding_adv import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org