This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5cca6ba82950 [SPARK-46451][PS][TESTS] Reorganize `GroupbyStatTests` 5cca6ba82950 is described below commit 5cca6ba82950903b3ba20c7ec34cc442e3eb2e9d Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Dec 19 19:57:20 2023 +0800 [SPARK-46451][PS][TESTS] Reorganize `GroupbyStatTests` ### What changes were proposed in this pull request? Reorganize `GroupbyStatTests` ### Why are the changes needed? `GroupbyStatTests` and its parity test are slow, factor the slow tests out to make them more suiable for parallelism ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #44409 from zhengruifeng/ps_test_group_stat. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- dev/sparktestsupport/modules.py | 10 +- .../tests/connect/groupby/test_parity_stat.py | 6 +- ...test_parity_stat.py => test_parity_stat_adv.py} | 10 +- ...est_parity_stat.py => test_parity_stat_ddof.py} | 10 +- ...est_parity_stat.py => test_parity_stat_func.py} | 10 +- ...est_parity_stat.py => test_parity_stat_prod.py} | 10 +- python/pyspark/pandas/tests/groupby/test_stat.py | 263 ++------------------- .../pyspark/pandas/tests/groupby/test_stat_adv.py | 161 +++++++++++++ .../pyspark/pandas/tests/groupby/test_stat_ddof.py | 91 +++++++ .../pyspark/pandas/tests/groupby/test_stat_func.py | 119 ++++++++++ .../pyspark/pandas/tests/groupby/test_stat_prod.py | 86 +++++++ 11 files changed, 519 insertions(+), 257 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 38d3d42b658c..1c632871ba61 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -852,6 +852,10 @@ pyspark_pandas_slow = Module( "pyspark.pandas.tests.groupby.test_size", "pyspark.pandas.tests.groupby.test_split_apply", "pyspark.pandas.tests.groupby.test_stat", + "pyspark.pandas.tests.groupby.test_stat_adv", + "pyspark.pandas.tests.groupby.test_stat_ddof", + "pyspark.pandas.tests.groupby.test_stat_func", + "pyspark.pandas.tests.groupby.test_stat_prod", "pyspark.pandas.tests.groupby.test_value_counts", "pyspark.pandas.tests.test_indexing", "pyspark.pandas.tests.test_ops_on_diff_frames", @@ -1060,7 +1064,6 @@ pyspark_pandas_connect_part0 = Module( "pyspark.pandas.tests.connect.computation.test_parity_describe", "pyspark.pandas.tests.connect.computation.test_parity_eval", "pyspark.pandas.tests.connect.computation.test_parity_melt", - "pyspark.pandas.tests.connect.groupby.test_parity_stat", "pyspark.pandas.tests.connect.frame.test_parity_attrs", "pyspark.pandas.tests.connect.frame.test_parity_axis", "pyspark.pandas.tests.connect.diff_frames_ops.test_parity_dot_frame", @@ -1197,6 +1200,11 @@ pyspark_pandas_connect_part3 = Module( "pyspark.pandas.tests.connect.io.test_parity_dataframe_conversion", "pyspark.pandas.tests.connect.io.test_parity_dataframe_spark_io", "pyspark.pandas.tests.connect.io.test_parity_series_conversion", + "pyspark.pandas.tests.connect.groupby.test_parity_stat", + "pyspark.pandas.tests.connect.groupby.test_parity_stat_adv", + "pyspark.pandas.tests.connect.groupby.test_parity_stat_ddof", + "pyspark.pandas.tests.connect.groupby.test_parity_stat_func", + "pyspark.pandas.tests.connect.groupby.test_parity_stat_prod", "pyspark.pandas.tests.connect.indexes.test_parity_datetime", "pyspark.pandas.tests.connect.indexes.test_parity_datetime_at", "pyspark.pandas.tests.connect.indexes.test_parity_datetime_between", diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py index a7c2e10dc3f5..c0f64a990e6f 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py @@ -21,7 +21,11 @@ from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): +class GroupbyParityStatTests( + GroupbyStatMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py similarity index 81% copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py index a7c2e10dc3f5..2138e44263b2 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py @@ -16,17 +16,21 @@ # import unittest -from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin +from pyspark.pandas.tests.groupby.test_stat_adv import GroupbyStatAdvMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): +class GroupbyStatAdvParityTests( + GroupbyStatAdvMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.groupby.test_parity_stat import * # noqa: F401 + from pyspark.pandas.tests.connect.groupby.test_parity_stat_adv import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py similarity index 81% copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py index a7c2e10dc3f5..3abec1d7abd7 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py @@ -16,17 +16,21 @@ # import unittest -from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin +from pyspark.pandas.tests.groupby.test_stat_ddof import DdofTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): +class GroupbyStatDdofParityTests( + DdofTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.groupby.test_parity_stat import * # noqa: F401 + from pyspark.pandas.tests.connect.groupby.test_parity_stat_ddof import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py similarity index 81% copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py index a7c2e10dc3f5..41a7944aa49f 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py @@ -16,17 +16,21 @@ # import unittest -from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin +from pyspark.pandas.tests.groupby.test_stat_func import FuncTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): +class GroupbyStatFuncParityTests( + FuncTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.groupby.test_parity_stat import * # noqa: F401 + from pyspark.pandas.tests.connect.groupby.test_parity_stat_func import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py similarity index 81% copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py index a7c2e10dc3f5..f0abab3a2142 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py @@ -16,17 +16,21 @@ # import unittest -from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin +from pyspark.pandas.tests.groupby.test_stat_prod import ProdTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): +class GroupbyStatProdParityTests( + ProdTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.groupby.test_parity_stat import * # noqa: F401 + from pyspark.pandas.tests.connect.groupby.test_parity_stat_prod import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/groupby/test_stat.py b/python/pyspark/pandas/tests/groupby/test_stat.py index 29991ae1d54c..61d8cc5357f6 100644 --- a/python/pyspark/pandas/tests/groupby/test_stat.py +++ b/python/pyspark/pandas/tests/groupby/test_stat.py @@ -20,26 +20,11 @@ import numpy as np import pandas as pd from pyspark import pandas as ps -from pyspark.testing.pandasutils import ComparisonTestBase +from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.sqlutils import SQLTestUtils -class GroupbyStatMixin: - @property - def pdf(self): - return pd.DataFrame( - { - "A": [1, 2, 1, 2], - "B": [3.1, 4.1, 4.1, 3.1], - "C": ["a", "b", "b", "a"], - "D": [True, False, False, True], - } - ) - - @property - def psdf(self): - return ps.from_pandas(self.pdf) - +class GroupbyStatTestingFuncMixin: # TODO: All statistical functions should leverage this utility def _test_stat_func(self, func, check_exact=True): pdf, psdf = self.pdf, self.psdf @@ -57,64 +42,22 @@ class GroupbyStatMixin: check_exact=check_exact, ) - def test_basic_stat_funcs(self): - self._test_stat_func( - lambda groupby_obj: groupby_obj.var(numeric_only=True), check_exact=False - ) - - pdf, psdf = self.pdf, self.psdf - # Unlike pandas', the median in pandas-on-Spark is an approximated median based upon - # approximate percentile computation because computing median across a large dataset - # is extremely expensive. - expected = ps.DataFrame({"B": [3.1, 3.1], "D": [0, 0]}, index=pd.Index([1, 2], name="A")) - self.assert_eq( - psdf.groupby("A").median().sort_index(), - expected, - ) - self.assert_eq( - psdf.groupby("A").median(numeric_only=None).sort_index(), - expected, - ) - self.assert_eq( - psdf.groupby("A").median(numeric_only=False).sort_index(), - expected, - ) - self.assert_eq( - psdf.groupby("A")["B"].median().sort_index(), - expected.B, - ) - with self.assertRaises(TypeError): - psdf.groupby("A")["C"].mean() - - with self.assertRaisesRegex( - TypeError, "Unaccepted data types of aggregation columns; numeric or bool expected." - ): - psdf.groupby("A")[["C"]].std() - - with self.assertRaisesRegex( - TypeError, "Unaccepted data types of aggregation columns; numeric or bool expected." - ): - psdf.groupby("A")[["C"]].sem() - - self.assert_eq( - psdf.groupby("A").std().sort_index(), - pdf.groupby("A").std(numeric_only=True).sort_index(), - check_exact=False, - ) - self.assert_eq( - psdf.groupby("A").sem().sort_index(), - pdf.groupby("A").sem(numeric_only=True).sort_index(), - check_exact=False, +class GroupbyStatMixin(GroupbyStatTestingFuncMixin): + @property + def pdf(self): + return pd.DataFrame( + { + "A": [1, 2, 1, 2], + "B": [3.1, 4.1, 4.1, 3.1], + "C": ["a", "b", "b", "a"], + "D": [True, False, False, True], + } ) - # TODO: fix bug of `sum` and re-enable the test below - # self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), check_exact=False) - self.assert_eq( - psdf.groupby("A").sum().sort_index(), - pdf.groupby("A").sum().sort_index(), - check_exact=False, - ) + @property + def psdf(self): + return ps.from_pandas(self.pdf) def test_mean(self): self._test_stat_func(lambda groupby_obj: groupby_obj.mean(numeric_only=True)) @@ -122,48 +65,6 @@ class GroupbyStatMixin: with self.assertRaises(TypeError): psdf.groupby("A")["C"].mean() - def test_quantile(self): - dfs = [ - pd.DataFrame( - [["a", 1], ["a", 2], ["a", 3], ["b", 1], ["b", 3], ["b", 5]], columns=["key", "val"] - ), - pd.DataFrame( - [["a", True], ["a", True], ["a", False], ["b", True], ["b", True], ["b", False]], - columns=["key", "val"], - ), - ] - for df in dfs: - psdf = ps.from_pandas(df) - # q accept float and int between 0 and 1 - for i in [0, 0.1, 0.5, 1]: - self.assert_eq( - df.groupby("key").quantile(q=i, interpolation="lower"), - psdf.groupby("key").quantile(q=i), - almost=True, - ) - self.assert_eq( - df.groupby("key")["val"].quantile(q=i, interpolation="lower"), - psdf.groupby("key")["val"].quantile(q=i), - almost=True, - ) - # raise ValueError when q not in [0, 1] - with self.assertRaises(ValueError): - psdf.groupby("key").quantile(q=1.1) - with self.assertRaises(ValueError): - psdf.groupby("key").quantile(q=-0.1) - with self.assertRaises(ValueError): - psdf.groupby("key").quantile(q=2) - with self.assertRaises(ValueError): - psdf.groupby("key").quantile(q=np.nan) - # raise TypeError when q type mismatch - with self.assertRaises(TypeError): - psdf.groupby("key").quantile(q="0.1") - # raise NotImplementedError when q is list like type - with self.assertRaises(NotImplementedError): - psdf.groupby("key").quantile(q=(0.1, 0.5)) - with self.assertRaises(NotImplementedError): - psdf.groupby("key").quantile(q=[0.1, 0.5]) - def test_min(self): self._test_stat_func(lambda groupby_obj: groupby_obj.min()) self._test_stat_func(lambda groupby_obj: groupby_obj.min(min_count=2)) @@ -197,88 +98,6 @@ class GroupbyStatMixin: psdf.groupby("A").sum(min_count=3).sort_index(), ) - def test_first(self): - self._test_stat_func(lambda groupby_obj: groupby_obj.first()) - self._test_stat_func(lambda groupby_obj: groupby_obj.first(numeric_only=None)) - self._test_stat_func(lambda groupby_obj: groupby_obj.first(numeric_only=True)) - - pdf = pd.DataFrame( - { - "A": [1, 2, 1, 2], - "B": [-1.5, np.nan, -3.2, 0.1], - } - ) - psdf = ps.from_pandas(pdf) - self.assert_eq( - pdf.groupby("A").first().sort_index(), psdf.groupby("A").first().sort_index() - ) - self.assert_eq( - pdf.groupby("A").first(min_count=1).sort_index(), - psdf.groupby("A").first(min_count=1).sort_index(), - ) - self.assert_eq( - pdf.groupby("A").first(min_count=2).sort_index(), - psdf.groupby("A").first(min_count=2).sort_index(), - ) - - def test_last(self): - self._test_stat_func(lambda groupby_obj: groupby_obj.last()) - self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=None)) - self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=True)) - - pdf = pd.DataFrame( - { - "A": [1, 2, 1, 2], - "B": [-1.5, np.nan, -3.2, 0.1], - } - ) - psdf = ps.from_pandas(pdf) - self.assert_eq(pdf.groupby("A").last().sort_index(), psdf.groupby("A").last().sort_index()) - self.assert_eq( - pdf.groupby("A").last(min_count=1).sort_index(), - psdf.groupby("A").last(min_count=1).sort_index(), - ) - self.assert_eq( - pdf.groupby("A").last(min_count=2).sort_index(), - psdf.groupby("A").last(min_count=2).sort_index(), - ) - - def test_nth(self): - for n in [0, 1, 2, 128, -1, -2, -128]: - self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n)) - - with self.assertRaisesRegex(NotImplementedError, "slice or list"): - self.psdf.groupby("B").nth(slice(0, 2)) - with self.assertRaisesRegex(NotImplementedError, "slice or list"): - self.psdf.groupby("B").nth([0, 1, -1]) - with self.assertRaisesRegex(TypeError, "Invalid index"): - self.psdf.groupby("B").nth("x") - - def test_prod(self): - pdf = pd.DataFrame( - { - "A": [1, 2, 1, 2, 1], - "B": [3.1, 4.1, 4.1, 3.1, 0.1], - "C": ["a", "b", "b", "a", "c"], - "D": [True, False, False, True, False], - "E": [-1, -2, 3, -4, -2], - "F": [-1.5, np.nan, -3.2, 0.1, 0], - "G": [np.nan, np.nan, np.nan, np.nan, np.nan], - } - ) - psdf = ps.from_pandas(pdf) - - for n in [0, 1, 2, 128, -1, -2, -128]: - self._test_stat_func( - lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n), - check_exact=False, - ) - self.assert_eq( - pdf.groupby("A").prod(min_count=n, numeric_only=True).sort_index(), - psdf.groupby("A").prod(min_count=n).sort_index(), - almost=True, - ) - def test_median(self): psdf = ps.DataFrame( { @@ -303,54 +122,12 @@ class GroupbyStatMixin: with self.assertRaisesRegex(TypeError, "accuracy must be an integer; however"): psdf.groupby("a").median(accuracy="a") - def test_ddof(self): - pdf = pd.DataFrame( - { - "a": [1, 1, 1, 1, 2, 2, 2, 3, 3, 3] * 3, - "b": [2, 3, 1, 4, 6, 9, 8, 10, 7, 5] * 3, - "c": [3, 5, 2, 5, 1, 2, 6, 4, 3, 6] * 3, - }, - index=np.random.rand(10 * 3), - ) - psdf = ps.from_pandas(pdf) - - for ddof in [-1, 0, 1, 2, 3]: - # std - self.assert_eq( - pdf.groupby("a").std(ddof=ddof).sort_index(), - psdf.groupby("a").std(ddof=ddof).sort_index(), - check_exact=False, - ) - self.assert_eq( - pdf.groupby("a")["b"].std(ddof=ddof).sort_index(), - psdf.groupby("a")["b"].std(ddof=ddof).sort_index(), - check_exact=False, - ) - # var - self.assert_eq( - pdf.groupby("a").var(ddof=ddof).sort_index(), - psdf.groupby("a").var(ddof=ddof).sort_index(), - check_exact=False, - ) - self.assert_eq( - pdf.groupby("a")["b"].var(ddof=ddof).sort_index(), - psdf.groupby("a")["b"].var(ddof=ddof).sort_index(), - check_exact=False, - ) - # sem - self.assert_eq( - pdf.groupby("a").sem(ddof=ddof).sort_index(), - psdf.groupby("a").sem(ddof=ddof).sort_index(), - check_exact=False, - ) - self.assert_eq( - pdf.groupby("a")["b"].sem(ddof=ddof).sort_index(), - psdf.groupby("a")["b"].sem(ddof=ddof).sort_index(), - check_exact=False, - ) - -class GroupbyStatTests(GroupbyStatMixin, ComparisonTestBase, SQLTestUtils): +class GroupbyStatTests( + GroupbyStatMixin, + PandasOnSparkTestCase, + SQLTestUtils, +): pass diff --git a/python/pyspark/pandas/tests/groupby/test_stat_adv.py b/python/pyspark/pandas/tests/groupby/test_stat_adv.py new file mode 100644 index 000000000000..2b124ada85b4 --- /dev/null +++ b/python/pyspark/pandas/tests/groupby/test_stat_adv.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import numpy as np +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils +from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin + + +class GroupbyStatAdvMixin(GroupbyStatTestingFuncMixin): + @property + def pdf(self): + return pd.DataFrame( + { + "A": [1, 2, 1, 2], + "B": [3.1, 4.1, 4.1, 3.1], + "C": ["a", "b", "b", "a"], + "D": [True, False, False, True], + } + ) + + @property + def psdf(self): + return ps.from_pandas(self.pdf) + + def test_quantile(self): + dfs = [ + pd.DataFrame( + [["a", 1], ["a", 2], ["a", 3], ["b", 1], ["b", 3], ["b", 5]], columns=["key", "val"] + ), + pd.DataFrame( + [["a", True], ["a", True], ["a", False], ["b", True], ["b", True], ["b", False]], + columns=["key", "val"], + ), + ] + for df in dfs: + psdf = ps.from_pandas(df) + # q accept float and int between 0 and 1 + for i in [0, 0.1, 0.5, 1]: + self.assert_eq( + df.groupby("key").quantile(q=i, interpolation="lower"), + psdf.groupby("key").quantile(q=i), + almost=True, + ) + self.assert_eq( + df.groupby("key")["val"].quantile(q=i, interpolation="lower"), + psdf.groupby("key")["val"].quantile(q=i), + almost=True, + ) + # raise ValueError when q not in [0, 1] + with self.assertRaises(ValueError): + psdf.groupby("key").quantile(q=1.1) + with self.assertRaises(ValueError): + psdf.groupby("key").quantile(q=-0.1) + with self.assertRaises(ValueError): + psdf.groupby("key").quantile(q=2) + with self.assertRaises(ValueError): + psdf.groupby("key").quantile(q=np.nan) + # raise TypeError when q type mismatch + with self.assertRaises(TypeError): + psdf.groupby("key").quantile(q="0.1") + # raise NotImplementedError when q is list like type + with self.assertRaises(NotImplementedError): + psdf.groupby("key").quantile(q=(0.1, 0.5)) + with self.assertRaises(NotImplementedError): + psdf.groupby("key").quantile(q=[0.1, 0.5]) + + def test_first(self): + self._test_stat_func(lambda groupby_obj: groupby_obj.first()) + self._test_stat_func(lambda groupby_obj: groupby_obj.first(numeric_only=None)) + self._test_stat_func(lambda groupby_obj: groupby_obj.first(numeric_only=True)) + + pdf = pd.DataFrame( + { + "A": [1, 2, 1, 2], + "B": [-1.5, np.nan, -3.2, 0.1], + } + ) + psdf = ps.from_pandas(pdf) + self.assert_eq( + pdf.groupby("A").first().sort_index(), psdf.groupby("A").first().sort_index() + ) + self.assert_eq( + pdf.groupby("A").first(min_count=1).sort_index(), + psdf.groupby("A").first(min_count=1).sort_index(), + ) + self.assert_eq( + pdf.groupby("A").first(min_count=2).sort_index(), + psdf.groupby("A").first(min_count=2).sort_index(), + ) + + def test_last(self): + self._test_stat_func(lambda groupby_obj: groupby_obj.last()) + self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=None)) + self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=True)) + + pdf = pd.DataFrame( + { + "A": [1, 2, 1, 2], + "B": [-1.5, np.nan, -3.2, 0.1], + } + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.groupby("A").last().sort_index(), psdf.groupby("A").last().sort_index()) + self.assert_eq( + pdf.groupby("A").last(min_count=1).sort_index(), + psdf.groupby("A").last(min_count=1).sort_index(), + ) + self.assert_eq( + pdf.groupby("A").last(min_count=2).sort_index(), + psdf.groupby("A").last(min_count=2).sort_index(), + ) + + def test_nth(self): + for n in [0, 1, 2, 128, -1, -2, -128]: + self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n)) + + with self.assertRaisesRegex(NotImplementedError, "slice or list"): + self.psdf.groupby("B").nth(slice(0, 2)) + with self.assertRaisesRegex(NotImplementedError, "slice or list"): + self.psdf.groupby("B").nth([0, 1, -1]) + with self.assertRaisesRegex(TypeError, "Invalid index"): + self.psdf.groupby("B").nth("x") + + +class GroupbyStatAdvTests( + GroupbyStatAdvMixin, + PandasOnSparkTestCase, + SQLTestUtils, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.groupby.test_stat_adv import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/groupby/test_stat_ddof.py b/python/pyspark/pandas/tests/groupby/test_stat_ddof.py new file mode 100644 index 000000000000..63e974bd69fc --- /dev/null +++ b/python/pyspark/pandas/tests/groupby/test_stat_ddof.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import numpy as np +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +class DdofTestsMixin: + def test_ddof(self): + pdf = pd.DataFrame( + { + "a": [1, 1, 1, 1, 2, 2, 2, 3, 3, 3] * 3, + "b": [2, 3, 1, 4, 6, 9, 8, 10, 7, 5] * 3, + "c": [3, 5, 2, 5, 1, 2, 6, 4, 3, 6] * 3, + }, + index=np.random.rand(10 * 3), + ) + psdf = ps.from_pandas(pdf) + + for ddof in [-1, 0, 1, 2, 3]: + # std + self.assert_eq( + pdf.groupby("a").std(ddof=ddof).sort_index(), + psdf.groupby("a").std(ddof=ddof).sort_index(), + check_exact=False, + ) + self.assert_eq( + pdf.groupby("a")["b"].std(ddof=ddof).sort_index(), + psdf.groupby("a")["b"].std(ddof=ddof).sort_index(), + check_exact=False, + ) + # var + self.assert_eq( + pdf.groupby("a").var(ddof=ddof).sort_index(), + psdf.groupby("a").var(ddof=ddof).sort_index(), + check_exact=False, + ) + self.assert_eq( + pdf.groupby("a")["b"].var(ddof=ddof).sort_index(), + psdf.groupby("a")["b"].var(ddof=ddof).sort_index(), + check_exact=False, + ) + # sem + self.assert_eq( + pdf.groupby("a").sem(ddof=ddof).sort_index(), + psdf.groupby("a").sem(ddof=ddof).sort_index(), + check_exact=False, + ) + self.assert_eq( + pdf.groupby("a")["b"].sem(ddof=ddof).sort_index(), + psdf.groupby("a")["b"].sem(ddof=ddof).sort_index(), + check_exact=False, + ) + + +class DdofTests( + DdofTestsMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.groupby.test_stat_ddof import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/groupby/test_stat_func.py b/python/pyspark/pandas/tests/groupby/test_stat_func.py new file mode 100644 index 000000000000..257394b59f51 --- /dev/null +++ b/python/pyspark/pandas/tests/groupby/test_stat_func.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin + + +class FuncTestsMixin(GroupbyStatTestingFuncMixin): + @property + def pdf(self): + return pd.DataFrame( + { + "A": [1, 2, 1, 2], + "B": [3.1, 4.1, 4.1, 3.1], + "C": ["a", "b", "b", "a"], + "D": [True, False, False, True], + } + ) + + @property + def psdf(self): + return ps.from_pandas(self.pdf) + + def test_basic_stat_funcs(self): + self._test_stat_func( + lambda groupby_obj: groupby_obj.var(numeric_only=True), check_exact=False + ) + + pdf, psdf = self.pdf, self.psdf + + # Unlike pandas', the median in pandas-on-Spark is an approximated median based upon + # approximate percentile computation because computing median across a large dataset + # is extremely expensive. + expected = ps.DataFrame({"B": [3.1, 3.1], "D": [0, 0]}, index=pd.Index([1, 2], name="A")) + self.assert_eq( + psdf.groupby("A").median().sort_index(), + expected, + ) + self.assert_eq( + psdf.groupby("A").median(numeric_only=None).sort_index(), + expected, + ) + self.assert_eq( + psdf.groupby("A").median(numeric_only=False).sort_index(), + expected, + ) + self.assert_eq( + psdf.groupby("A")["B"].median().sort_index(), + expected.B, + ) + with self.assertRaises(TypeError): + psdf.groupby("A")["C"].mean() + + with self.assertRaisesRegex( + TypeError, "Unaccepted data types of aggregation columns; numeric or bool expected." + ): + psdf.groupby("A")[["C"]].std() + + with self.assertRaisesRegex( + TypeError, "Unaccepted data types of aggregation columns; numeric or bool expected." + ): + psdf.groupby("A")[["C"]].sem() + + self.assert_eq( + psdf.groupby("A").std().sort_index(), + pdf.groupby("A").std(numeric_only=True).sort_index(), + check_exact=False, + ) + self.assert_eq( + psdf.groupby("A").sem().sort_index(), + pdf.groupby("A").sem(numeric_only=True).sort_index(), + check_exact=False, + ) + + # TODO: fix bug of `sum` and re-enable the test below + # self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), check_exact=False) + self.assert_eq( + psdf.groupby("A").sum().sort_index(), + pdf.groupby("A").sum().sort_index(), + check_exact=False, + ) + + +class FuncTests( + FuncTestsMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.groupby.test_stat_func import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/groupby/test_stat_prod.py b/python/pyspark/pandas/tests/groupby/test_stat_prod.py new file mode 100644 index 000000000000..31da55d26018 --- /dev/null +++ b/python/pyspark/pandas/tests/groupby/test_stat_prod.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import numpy as np +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin + + +class ProdTestsMixin(GroupbyStatTestingFuncMixin): + @property + def pdf(self): + return pd.DataFrame( + { + "A": [1, 2, 1, 2], + "B": [3.1, 4.1, 4.1, 3.1], + "C": ["a", "b", "b", "a"], + "D": [True, False, False, True], + } + ) + + @property + def psdf(self): + return ps.from_pandas(self.pdf) + + def test_prod(self): + pdf = pd.DataFrame( + { + "A": [1, 2, 1, 2, 1], + "B": [3.1, 4.1, 4.1, 3.1, 0.1], + "C": ["a", "b", "b", "a", "c"], + "D": [True, False, False, True, False], + "E": [-1, -2, 3, -4, -2], + "F": [-1.5, np.nan, -3.2, 0.1, 0], + "G": [np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + psdf = ps.from_pandas(pdf) + + for n in [0, 1, 2, 128, -1, -2, -128]: + self._test_stat_func( + lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n), + check_exact=False, + ) + self.assert_eq( + pdf.groupby("A").prod(min_count=n, numeric_only=True).sort_index(), + psdf.groupby("A").prod(min_count=n).sort_index(), + almost=True, + ) + + +class ProdTests( + ProdTestsMixin, + PandasOnSparkTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.groupby.test_stat_prod import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org