This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5cca6ba82950 [SPARK-46451][PS][TESTS] Reorganize `GroupbyStatTests`
5cca6ba82950 is described below

commit 5cca6ba82950903b3ba20c7ec34cc442e3eb2e9d
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Dec 19 19:57:20 2023 +0800

    [SPARK-46451][PS][TESTS] Reorganize `GroupbyStatTests`
    
    ### What changes were proposed in this pull request?
    Reorganize `GroupbyStatTests`
    
    ### Why are the changes needed?
    `GroupbyStatTests` and its parity test are slow, factor the slow tests out 
to make them more suiable for parallelism
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44409 from zhengruifeng/ps_test_group_stat.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 dev/sparktestsupport/modules.py                    |  10 +-
 .../tests/connect/groupby/test_parity_stat.py      |   6 +-
 ...test_parity_stat.py => test_parity_stat_adv.py} |  10 +-
 ...est_parity_stat.py => test_parity_stat_ddof.py} |  10 +-
 ...est_parity_stat.py => test_parity_stat_func.py} |  10 +-
 ...est_parity_stat.py => test_parity_stat_prod.py} |  10 +-
 python/pyspark/pandas/tests/groupby/test_stat.py   | 263 ++-------------------
 .../pyspark/pandas/tests/groupby/test_stat_adv.py  | 161 +++++++++++++
 .../pyspark/pandas/tests/groupby/test_stat_ddof.py |  91 +++++++
 .../pyspark/pandas/tests/groupby/test_stat_func.py | 119 ++++++++++
 .../pyspark/pandas/tests/groupby/test_stat_prod.py |  86 +++++++
 11 files changed, 519 insertions(+), 257 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 38d3d42b658c..1c632871ba61 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -852,6 +852,10 @@ pyspark_pandas_slow = Module(
         "pyspark.pandas.tests.groupby.test_size",
         "pyspark.pandas.tests.groupby.test_split_apply",
         "pyspark.pandas.tests.groupby.test_stat",
+        "pyspark.pandas.tests.groupby.test_stat_adv",
+        "pyspark.pandas.tests.groupby.test_stat_ddof",
+        "pyspark.pandas.tests.groupby.test_stat_func",
+        "pyspark.pandas.tests.groupby.test_stat_prod",
         "pyspark.pandas.tests.groupby.test_value_counts",
         "pyspark.pandas.tests.test_indexing",
         "pyspark.pandas.tests.test_ops_on_diff_frames",
@@ -1060,7 +1064,6 @@ pyspark_pandas_connect_part0 = Module(
         "pyspark.pandas.tests.connect.computation.test_parity_describe",
         "pyspark.pandas.tests.connect.computation.test_parity_eval",
         "pyspark.pandas.tests.connect.computation.test_parity_melt",
-        "pyspark.pandas.tests.connect.groupby.test_parity_stat",
         "pyspark.pandas.tests.connect.frame.test_parity_attrs",
         "pyspark.pandas.tests.connect.frame.test_parity_axis",
         "pyspark.pandas.tests.connect.diff_frames_ops.test_parity_dot_frame",
@@ -1197,6 +1200,11 @@ pyspark_pandas_connect_part3 = Module(
         "pyspark.pandas.tests.connect.io.test_parity_dataframe_conversion",
         "pyspark.pandas.tests.connect.io.test_parity_dataframe_spark_io",
         "pyspark.pandas.tests.connect.io.test_parity_series_conversion",
+        "pyspark.pandas.tests.connect.groupby.test_parity_stat",
+        "pyspark.pandas.tests.connect.groupby.test_parity_stat_adv",
+        "pyspark.pandas.tests.connect.groupby.test_parity_stat_ddof",
+        "pyspark.pandas.tests.connect.groupby.test_parity_stat_func",
+        "pyspark.pandas.tests.connect.groupby.test_parity_stat_prod",
         "pyspark.pandas.tests.connect.indexes.test_parity_datetime",
         "pyspark.pandas.tests.connect.indexes.test_parity_datetime_at",
         "pyspark.pandas.tests.connect.indexes.test_parity_datetime_between",
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py 
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
index a7c2e10dc3f5..c0f64a990e6f 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
@@ -21,7 +21,11 @@ from pyspark.testing.connectutils import 
ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
+class GroupbyParityStatTests(
+    GroupbyStatMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
     pass
 
 
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py 
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py
index a7c2e10dc3f5..2138e44263b2 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py
@@ -16,17 +16,21 @@
 #
 import unittest
 
-from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin
+from pyspark.pandas.tests.groupby.test_stat_adv import GroupbyStatAdvMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
+class GroupbyStatAdvParityTests(
+    GroupbyStatAdvMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
     pass
 
 
 if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.groupby.test_parity_stat import *  # 
noqa: F401
+    from pyspark.pandas.tests.connect.groupby.test_parity_stat_adv import *  # 
noqa: F401
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py 
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py
index a7c2e10dc3f5..3abec1d7abd7 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py
@@ -16,17 +16,21 @@
 #
 import unittest
 
-from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin
+from pyspark.pandas.tests.groupby.test_stat_ddof import DdofTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
+class GroupbyStatDdofParityTests(
+    DdofTestsMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
     pass
 
 
 if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.groupby.test_parity_stat import *  # 
noqa: F401
+    from pyspark.pandas.tests.connect.groupby.test_parity_stat_ddof import *  
# noqa: F401
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py 
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py
index a7c2e10dc3f5..41a7944aa49f 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py
@@ -16,17 +16,21 @@
 #
 import unittest
 
-from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin
+from pyspark.pandas.tests.groupby.test_stat_func import FuncTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
+class GroupbyStatFuncParityTests(
+    FuncTestsMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
     pass
 
 
 if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.groupby.test_parity_stat import *  # 
noqa: F401
+    from pyspark.pandas.tests.connect.groupby.test_parity_stat_func import *  
# noqa: F401
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py 
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py
index a7c2e10dc3f5..f0abab3a2142 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py
@@ -16,17 +16,21 @@
 #
 import unittest
 
-from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin
+from pyspark.pandas.tests.groupby.test_stat_prod import ProdTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
+class GroupbyStatProdParityTests(
+    ProdTestsMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
     pass
 
 
 if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.groupby.test_parity_stat import *  # 
noqa: F401
+    from pyspark.pandas.tests.connect.groupby.test_parity_stat_prod import *  
# noqa: F401
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/groupby/test_stat.py 
b/python/pyspark/pandas/tests/groupby/test_stat.py
index 29991ae1d54c..61d8cc5357f6 100644
--- a/python/pyspark/pandas/tests/groupby/test_stat.py
+++ b/python/pyspark/pandas/tests/groupby/test_stat.py
@@ -20,26 +20,11 @@ import numpy as np
 import pandas as pd
 
 from pyspark import pandas as ps
-from pyspark.testing.pandasutils import ComparisonTestBase
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
 
 
-class GroupbyStatMixin:
-    @property
-    def pdf(self):
-        return pd.DataFrame(
-            {
-                "A": [1, 2, 1, 2],
-                "B": [3.1, 4.1, 4.1, 3.1],
-                "C": ["a", "b", "b", "a"],
-                "D": [True, False, False, True],
-            }
-        )
-
-    @property
-    def psdf(self):
-        return ps.from_pandas(self.pdf)
-
+class GroupbyStatTestingFuncMixin:
     # TODO: All statistical functions should leverage this utility
     def _test_stat_func(self, func, check_exact=True):
         pdf, psdf = self.pdf, self.psdf
@@ -57,64 +42,22 @@ class GroupbyStatMixin:
                 check_exact=check_exact,
             )
 
-    def test_basic_stat_funcs(self):
-        self._test_stat_func(
-            lambda groupby_obj: groupby_obj.var(numeric_only=True), 
check_exact=False
-        )
-
-        pdf, psdf = self.pdf, self.psdf
 
-        # Unlike pandas', the median in pandas-on-Spark is an approximated 
median based upon
-        # approximate percentile computation because computing median across a 
large dataset
-        # is extremely expensive.
-        expected = ps.DataFrame({"B": [3.1, 3.1], "D": [0, 0]}, 
index=pd.Index([1, 2], name="A"))
-        self.assert_eq(
-            psdf.groupby("A").median().sort_index(),
-            expected,
-        )
-        self.assert_eq(
-            psdf.groupby("A").median(numeric_only=None).sort_index(),
-            expected,
-        )
-        self.assert_eq(
-            psdf.groupby("A").median(numeric_only=False).sort_index(),
-            expected,
-        )
-        self.assert_eq(
-            psdf.groupby("A")["B"].median().sort_index(),
-            expected.B,
-        )
-        with self.assertRaises(TypeError):
-            psdf.groupby("A")["C"].mean()
-
-        with self.assertRaisesRegex(
-            TypeError, "Unaccepted data types of aggregation columns; numeric 
or bool expected."
-        ):
-            psdf.groupby("A")[["C"]].std()
-
-        with self.assertRaisesRegex(
-            TypeError, "Unaccepted data types of aggregation columns; numeric 
or bool expected."
-        ):
-            psdf.groupby("A")[["C"]].sem()
-
-        self.assert_eq(
-            psdf.groupby("A").std().sort_index(),
-            pdf.groupby("A").std(numeric_only=True).sort_index(),
-            check_exact=False,
-        )
-        self.assert_eq(
-            psdf.groupby("A").sem().sort_index(),
-            pdf.groupby("A").sem(numeric_only=True).sort_index(),
-            check_exact=False,
+class GroupbyStatMixin(GroupbyStatTestingFuncMixin):
+    @property
+    def pdf(self):
+        return pd.DataFrame(
+            {
+                "A": [1, 2, 1, 2],
+                "B": [3.1, 4.1, 4.1, 3.1],
+                "C": ["a", "b", "b", "a"],
+                "D": [True, False, False, True],
+            }
         )
 
-        # TODO: fix bug of `sum` and re-enable the test below
-        # self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), 
check_exact=False)
-        self.assert_eq(
-            psdf.groupby("A").sum().sort_index(),
-            pdf.groupby("A").sum().sort_index(),
-            check_exact=False,
-        )
+    @property
+    def psdf(self):
+        return ps.from_pandas(self.pdf)
 
     def test_mean(self):
         self._test_stat_func(lambda groupby_obj: 
groupby_obj.mean(numeric_only=True))
@@ -122,48 +65,6 @@ class GroupbyStatMixin:
         with self.assertRaises(TypeError):
             psdf.groupby("A")["C"].mean()
 
-    def test_quantile(self):
-        dfs = [
-            pd.DataFrame(
-                [["a", 1], ["a", 2], ["a", 3], ["b", 1], ["b", 3], ["b", 5]], 
columns=["key", "val"]
-            ),
-            pd.DataFrame(
-                [["a", True], ["a", True], ["a", False], ["b", True], ["b", 
True], ["b", False]],
-                columns=["key", "val"],
-            ),
-        ]
-        for df in dfs:
-            psdf = ps.from_pandas(df)
-            # q accept float and int between 0 and 1
-            for i in [0, 0.1, 0.5, 1]:
-                self.assert_eq(
-                    df.groupby("key").quantile(q=i, interpolation="lower"),
-                    psdf.groupby("key").quantile(q=i),
-                    almost=True,
-                )
-                self.assert_eq(
-                    df.groupby("key")["val"].quantile(q=i, 
interpolation="lower"),
-                    psdf.groupby("key")["val"].quantile(q=i),
-                    almost=True,
-                )
-            # raise ValueError when q not in [0, 1]
-            with self.assertRaises(ValueError):
-                psdf.groupby("key").quantile(q=1.1)
-            with self.assertRaises(ValueError):
-                psdf.groupby("key").quantile(q=-0.1)
-            with self.assertRaises(ValueError):
-                psdf.groupby("key").quantile(q=2)
-            with self.assertRaises(ValueError):
-                psdf.groupby("key").quantile(q=np.nan)
-            # raise TypeError when q type mismatch
-            with self.assertRaises(TypeError):
-                psdf.groupby("key").quantile(q="0.1")
-            # raise NotImplementedError when q is list like type
-            with self.assertRaises(NotImplementedError):
-                psdf.groupby("key").quantile(q=(0.1, 0.5))
-            with self.assertRaises(NotImplementedError):
-                psdf.groupby("key").quantile(q=[0.1, 0.5])
-
     def test_min(self):
         self._test_stat_func(lambda groupby_obj: groupby_obj.min())
         self._test_stat_func(lambda groupby_obj: groupby_obj.min(min_count=2))
@@ -197,88 +98,6 @@ class GroupbyStatMixin:
             psdf.groupby("A").sum(min_count=3).sort_index(),
         )
 
-    def test_first(self):
-        self._test_stat_func(lambda groupby_obj: groupby_obj.first())
-        self._test_stat_func(lambda groupby_obj: 
groupby_obj.first(numeric_only=None))
-        self._test_stat_func(lambda groupby_obj: 
groupby_obj.first(numeric_only=True))
-
-        pdf = pd.DataFrame(
-            {
-                "A": [1, 2, 1, 2],
-                "B": [-1.5, np.nan, -3.2, 0.1],
-            }
-        )
-        psdf = ps.from_pandas(pdf)
-        self.assert_eq(
-            pdf.groupby("A").first().sort_index(), 
psdf.groupby("A").first().sort_index()
-        )
-        self.assert_eq(
-            pdf.groupby("A").first(min_count=1).sort_index(),
-            psdf.groupby("A").first(min_count=1).sort_index(),
-        )
-        self.assert_eq(
-            pdf.groupby("A").first(min_count=2).sort_index(),
-            psdf.groupby("A").first(min_count=2).sort_index(),
-        )
-
-    def test_last(self):
-        self._test_stat_func(lambda groupby_obj: groupby_obj.last())
-        self._test_stat_func(lambda groupby_obj: 
groupby_obj.last(numeric_only=None))
-        self._test_stat_func(lambda groupby_obj: 
groupby_obj.last(numeric_only=True))
-
-        pdf = pd.DataFrame(
-            {
-                "A": [1, 2, 1, 2],
-                "B": [-1.5, np.nan, -3.2, 0.1],
-            }
-        )
-        psdf = ps.from_pandas(pdf)
-        self.assert_eq(pdf.groupby("A").last().sort_index(), 
psdf.groupby("A").last().sort_index())
-        self.assert_eq(
-            pdf.groupby("A").last(min_count=1).sort_index(),
-            psdf.groupby("A").last(min_count=1).sort_index(),
-        )
-        self.assert_eq(
-            pdf.groupby("A").last(min_count=2).sort_index(),
-            psdf.groupby("A").last(min_count=2).sort_index(),
-        )
-
-    def test_nth(self):
-        for n in [0, 1, 2, 128, -1, -2, -128]:
-            self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))
-
-        with self.assertRaisesRegex(NotImplementedError, "slice or list"):
-            self.psdf.groupby("B").nth(slice(0, 2))
-        with self.assertRaisesRegex(NotImplementedError, "slice or list"):
-            self.psdf.groupby("B").nth([0, 1, -1])
-        with self.assertRaisesRegex(TypeError, "Invalid index"):
-            self.psdf.groupby("B").nth("x")
-
-    def test_prod(self):
-        pdf = pd.DataFrame(
-            {
-                "A": [1, 2, 1, 2, 1],
-                "B": [3.1, 4.1, 4.1, 3.1, 0.1],
-                "C": ["a", "b", "b", "a", "c"],
-                "D": [True, False, False, True, False],
-                "E": [-1, -2, 3, -4, -2],
-                "F": [-1.5, np.nan, -3.2, 0.1, 0],
-                "G": [np.nan, np.nan, np.nan, np.nan, np.nan],
-            }
-        )
-        psdf = ps.from_pandas(pdf)
-
-        for n in [0, 1, 2, 128, -1, -2, -128]:
-            self._test_stat_func(
-                lambda groupby_obj: groupby_obj.prod(numeric_only=True, 
min_count=n),
-                check_exact=False,
-            )
-            self.assert_eq(
-                pdf.groupby("A").prod(min_count=n, 
numeric_only=True).sort_index(),
-                psdf.groupby("A").prod(min_count=n).sort_index(),
-                almost=True,
-            )
-
     def test_median(self):
         psdf = ps.DataFrame(
             {
@@ -303,54 +122,12 @@ class GroupbyStatMixin:
         with self.assertRaisesRegex(TypeError, "accuracy must be an integer; 
however"):
             psdf.groupby("a").median(accuracy="a")
 
-    def test_ddof(self):
-        pdf = pd.DataFrame(
-            {
-                "a": [1, 1, 1, 1, 2, 2, 2, 3, 3, 3] * 3,
-                "b": [2, 3, 1, 4, 6, 9, 8, 10, 7, 5] * 3,
-                "c": [3, 5, 2, 5, 1, 2, 6, 4, 3, 6] * 3,
-            },
-            index=np.random.rand(10 * 3),
-        )
-        psdf = ps.from_pandas(pdf)
-
-        for ddof in [-1, 0, 1, 2, 3]:
-            # std
-            self.assert_eq(
-                pdf.groupby("a").std(ddof=ddof).sort_index(),
-                psdf.groupby("a").std(ddof=ddof).sort_index(),
-                check_exact=False,
-            )
-            self.assert_eq(
-                pdf.groupby("a")["b"].std(ddof=ddof).sort_index(),
-                psdf.groupby("a")["b"].std(ddof=ddof).sort_index(),
-                check_exact=False,
-            )
-            # var
-            self.assert_eq(
-                pdf.groupby("a").var(ddof=ddof).sort_index(),
-                psdf.groupby("a").var(ddof=ddof).sort_index(),
-                check_exact=False,
-            )
-            self.assert_eq(
-                pdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
-                psdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
-                check_exact=False,
-            )
-            # sem
-            self.assert_eq(
-                pdf.groupby("a").sem(ddof=ddof).sort_index(),
-                psdf.groupby("a").sem(ddof=ddof).sort_index(),
-                check_exact=False,
-            )
-            self.assert_eq(
-                pdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
-                psdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
-                check_exact=False,
-            )
-
 
-class GroupbyStatTests(GroupbyStatMixin, ComparisonTestBase, SQLTestUtils):
+class GroupbyStatTests(
+    GroupbyStatMixin,
+    PandasOnSparkTestCase,
+    SQLTestUtils,
+):
     pass
 
 
diff --git a/python/pyspark/pandas/tests/groupby/test_stat_adv.py 
b/python/pyspark/pandas/tests/groupby/test_stat_adv.py
new file mode 100644
index 000000000000..2b124ada85b4
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_stat_adv.py
@@ -0,0 +1,161 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin
+
+
+class GroupbyStatAdvMixin(GroupbyStatTestingFuncMixin):
+    @property
+    def pdf(self):
+        return pd.DataFrame(
+            {
+                "A": [1, 2, 1, 2],
+                "B": [3.1, 4.1, 4.1, 3.1],
+                "C": ["a", "b", "b", "a"],
+                "D": [True, False, False, True],
+            }
+        )
+
+    @property
+    def psdf(self):
+        return ps.from_pandas(self.pdf)
+
+    def test_quantile(self):
+        dfs = [
+            pd.DataFrame(
+                [["a", 1], ["a", 2], ["a", 3], ["b", 1], ["b", 3], ["b", 5]], 
columns=["key", "val"]
+            ),
+            pd.DataFrame(
+                [["a", True], ["a", True], ["a", False], ["b", True], ["b", 
True], ["b", False]],
+                columns=["key", "val"],
+            ),
+        ]
+        for df in dfs:
+            psdf = ps.from_pandas(df)
+            # q accept float and int between 0 and 1
+            for i in [0, 0.1, 0.5, 1]:
+                self.assert_eq(
+                    df.groupby("key").quantile(q=i, interpolation="lower"),
+                    psdf.groupby("key").quantile(q=i),
+                    almost=True,
+                )
+                self.assert_eq(
+                    df.groupby("key")["val"].quantile(q=i, 
interpolation="lower"),
+                    psdf.groupby("key")["val"].quantile(q=i),
+                    almost=True,
+                )
+            # raise ValueError when q not in [0, 1]
+            with self.assertRaises(ValueError):
+                psdf.groupby("key").quantile(q=1.1)
+            with self.assertRaises(ValueError):
+                psdf.groupby("key").quantile(q=-0.1)
+            with self.assertRaises(ValueError):
+                psdf.groupby("key").quantile(q=2)
+            with self.assertRaises(ValueError):
+                psdf.groupby("key").quantile(q=np.nan)
+            # raise TypeError when q type mismatch
+            with self.assertRaises(TypeError):
+                psdf.groupby("key").quantile(q="0.1")
+            # raise NotImplementedError when q is list like type
+            with self.assertRaises(NotImplementedError):
+                psdf.groupby("key").quantile(q=(0.1, 0.5))
+            with self.assertRaises(NotImplementedError):
+                psdf.groupby("key").quantile(q=[0.1, 0.5])
+
+    def test_first(self):
+        self._test_stat_func(lambda groupby_obj: groupby_obj.first())
+        self._test_stat_func(lambda groupby_obj: 
groupby_obj.first(numeric_only=None))
+        self._test_stat_func(lambda groupby_obj: 
groupby_obj.first(numeric_only=True))
+
+        pdf = pd.DataFrame(
+            {
+                "A": [1, 2, 1, 2],
+                "B": [-1.5, np.nan, -3.2, 0.1],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+        self.assert_eq(
+            pdf.groupby("A").first().sort_index(), 
psdf.groupby("A").first().sort_index()
+        )
+        self.assert_eq(
+            pdf.groupby("A").first(min_count=1).sort_index(),
+            psdf.groupby("A").first(min_count=1).sort_index(),
+        )
+        self.assert_eq(
+            pdf.groupby("A").first(min_count=2).sort_index(),
+            psdf.groupby("A").first(min_count=2).sort_index(),
+        )
+
+    def test_last(self):
+        self._test_stat_func(lambda groupby_obj: groupby_obj.last())
+        self._test_stat_func(lambda groupby_obj: 
groupby_obj.last(numeric_only=None))
+        self._test_stat_func(lambda groupby_obj: 
groupby_obj.last(numeric_only=True))
+
+        pdf = pd.DataFrame(
+            {
+                "A": [1, 2, 1, 2],
+                "B": [-1.5, np.nan, -3.2, 0.1],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+        self.assert_eq(pdf.groupby("A").last().sort_index(), 
psdf.groupby("A").last().sort_index())
+        self.assert_eq(
+            pdf.groupby("A").last(min_count=1).sort_index(),
+            psdf.groupby("A").last(min_count=1).sort_index(),
+        )
+        self.assert_eq(
+            pdf.groupby("A").last(min_count=2).sort_index(),
+            psdf.groupby("A").last(min_count=2).sort_index(),
+        )
+
+    def test_nth(self):
+        for n in [0, 1, 2, 128, -1, -2, -128]:
+            self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))
+
+        with self.assertRaisesRegex(NotImplementedError, "slice or list"):
+            self.psdf.groupby("B").nth(slice(0, 2))
+        with self.assertRaisesRegex(NotImplementedError, "slice or list"):
+            self.psdf.groupby("B").nth([0, 1, -1])
+        with self.assertRaisesRegex(TypeError, "Invalid index"):
+            self.psdf.groupby("B").nth("x")
+
+
+class GroupbyStatAdvTests(
+    GroupbyStatAdvMixin,
+    PandasOnSparkTestCase,
+    SQLTestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.groupby.test_stat_adv import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/groupby/test_stat_ddof.py 
b/python/pyspark/pandas/tests/groupby/test_stat_ddof.py
new file mode 100644
index 000000000000..63e974bd69fc
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_stat_ddof.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class DdofTestsMixin:
+    def test_ddof(self):
+        pdf = pd.DataFrame(
+            {
+                "a": [1, 1, 1, 1, 2, 2, 2, 3, 3, 3] * 3,
+                "b": [2, 3, 1, 4, 6, 9, 8, 10, 7, 5] * 3,
+                "c": [3, 5, 2, 5, 1, 2, 6, 4, 3, 6] * 3,
+            },
+            index=np.random.rand(10 * 3),
+        )
+        psdf = ps.from_pandas(pdf)
+
+        for ddof in [-1, 0, 1, 2, 3]:
+            # std
+            self.assert_eq(
+                pdf.groupby("a").std(ddof=ddof).sort_index(),
+                psdf.groupby("a").std(ddof=ddof).sort_index(),
+                check_exact=False,
+            )
+            self.assert_eq(
+                pdf.groupby("a")["b"].std(ddof=ddof).sort_index(),
+                psdf.groupby("a")["b"].std(ddof=ddof).sort_index(),
+                check_exact=False,
+            )
+            # var
+            self.assert_eq(
+                pdf.groupby("a").var(ddof=ddof).sort_index(),
+                psdf.groupby("a").var(ddof=ddof).sort_index(),
+                check_exact=False,
+            )
+            self.assert_eq(
+                pdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
+                psdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
+                check_exact=False,
+            )
+            # sem
+            self.assert_eq(
+                pdf.groupby("a").sem(ddof=ddof).sort_index(),
+                psdf.groupby("a").sem(ddof=ddof).sort_index(),
+                check_exact=False,
+            )
+            self.assert_eq(
+                pdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
+                psdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
+                check_exact=False,
+            )
+
+
+class DdofTests(
+    DdofTestsMixin,
+    PandasOnSparkTestCase,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.groupby.test_stat_ddof import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/groupby/test_stat_func.py 
b/python/pyspark/pandas/tests/groupby/test_stat_func.py
new file mode 100644
index 000000000000..257394b59f51
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_stat_func.py
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin
+
+
+class FuncTestsMixin(GroupbyStatTestingFuncMixin):
+    @property
+    def pdf(self):
+        return pd.DataFrame(
+            {
+                "A": [1, 2, 1, 2],
+                "B": [3.1, 4.1, 4.1, 3.1],
+                "C": ["a", "b", "b", "a"],
+                "D": [True, False, False, True],
+            }
+        )
+
+    @property
+    def psdf(self):
+        return ps.from_pandas(self.pdf)
+
+    def test_basic_stat_funcs(self):
+        self._test_stat_func(
+            lambda groupby_obj: groupby_obj.var(numeric_only=True), 
check_exact=False
+        )
+
+        pdf, psdf = self.pdf, self.psdf
+
+        # Unlike pandas', the median in pandas-on-Spark is an approximated 
median based upon
+        # approximate percentile computation because computing median across a 
large dataset
+        # is extremely expensive.
+        expected = ps.DataFrame({"B": [3.1, 3.1], "D": [0, 0]}, 
index=pd.Index([1, 2], name="A"))
+        self.assert_eq(
+            psdf.groupby("A").median().sort_index(),
+            expected,
+        )
+        self.assert_eq(
+            psdf.groupby("A").median(numeric_only=None).sort_index(),
+            expected,
+        )
+        self.assert_eq(
+            psdf.groupby("A").median(numeric_only=False).sort_index(),
+            expected,
+        )
+        self.assert_eq(
+            psdf.groupby("A")["B"].median().sort_index(),
+            expected.B,
+        )
+        with self.assertRaises(TypeError):
+            psdf.groupby("A")["C"].mean()
+
+        with self.assertRaisesRegex(
+            TypeError, "Unaccepted data types of aggregation columns; numeric 
or bool expected."
+        ):
+            psdf.groupby("A")[["C"]].std()
+
+        with self.assertRaisesRegex(
+            TypeError, "Unaccepted data types of aggregation columns; numeric 
or bool expected."
+        ):
+            psdf.groupby("A")[["C"]].sem()
+
+        self.assert_eq(
+            psdf.groupby("A").std().sort_index(),
+            pdf.groupby("A").std(numeric_only=True).sort_index(),
+            check_exact=False,
+        )
+        self.assert_eq(
+            psdf.groupby("A").sem().sort_index(),
+            pdf.groupby("A").sem(numeric_only=True).sort_index(),
+            check_exact=False,
+        )
+
+        # TODO: fix bug of `sum` and re-enable the test below
+        # self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), 
check_exact=False)
+        self.assert_eq(
+            psdf.groupby("A").sum().sort_index(),
+            pdf.groupby("A").sum().sort_index(),
+            check_exact=False,
+        )
+
+
+class FuncTests(
+    FuncTestsMixin,
+    PandasOnSparkTestCase,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.groupby.test_stat_func import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/groupby/test_stat_prod.py 
b/python/pyspark/pandas/tests/groupby/test_stat_prod.py
new file mode 100644
index 000000000000..31da55d26018
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_stat_prod.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin
+
+
+class ProdTestsMixin(GroupbyStatTestingFuncMixin):
+    @property
+    def pdf(self):
+        return pd.DataFrame(
+            {
+                "A": [1, 2, 1, 2],
+                "B": [3.1, 4.1, 4.1, 3.1],
+                "C": ["a", "b", "b", "a"],
+                "D": [True, False, False, True],
+            }
+        )
+
+    @property
+    def psdf(self):
+        return ps.from_pandas(self.pdf)
+
+    def test_prod(self):
+        pdf = pd.DataFrame(
+            {
+                "A": [1, 2, 1, 2, 1],
+                "B": [3.1, 4.1, 4.1, 3.1, 0.1],
+                "C": ["a", "b", "b", "a", "c"],
+                "D": [True, False, False, True, False],
+                "E": [-1, -2, 3, -4, -2],
+                "F": [-1.5, np.nan, -3.2, 0.1, 0],
+                "G": [np.nan, np.nan, np.nan, np.nan, np.nan],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        for n in [0, 1, 2, 128, -1, -2, -128]:
+            self._test_stat_func(
+                lambda groupby_obj: groupby_obj.prod(numeric_only=True, 
min_count=n),
+                check_exact=False,
+            )
+            self.assert_eq(
+                pdf.groupby("A").prod(min_count=n, 
numeric_only=True).sort_index(),
+                psdf.groupby("A").prod(min_count=n).sort_index(),
+                almost=True,
+            )
+
+
+class ProdTests(
+    ProdTestsMixin,
+    PandasOnSparkTestCase,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.groupby.test_stat_prod import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to