This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5cca6ba82950 [SPARK-46451][PS][TESTS] Reorganize `GroupbyStatTests`
5cca6ba82950 is described below
commit 5cca6ba82950903b3ba20c7ec34cc442e3eb2e9d
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Dec 19 19:57:20 2023 +0800
[SPARK-46451][PS][TESTS] Reorganize `GroupbyStatTests`
### What changes were proposed in this pull request?
Reorganize `GroupbyStatTests`
### Why are the changes needed?
`GroupbyStatTests` and its parity test are slow, factor the slow tests out
to make them more suiable for parallelism
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44409 from zhengruifeng/ps_test_group_stat.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 10 +-
.../tests/connect/groupby/test_parity_stat.py | 6 +-
...test_parity_stat.py => test_parity_stat_adv.py} | 10 +-
...est_parity_stat.py => test_parity_stat_ddof.py} | 10 +-
...est_parity_stat.py => test_parity_stat_func.py} | 10 +-
...est_parity_stat.py => test_parity_stat_prod.py} | 10 +-
python/pyspark/pandas/tests/groupby/test_stat.py | 263 ++-------------------
.../pyspark/pandas/tests/groupby/test_stat_adv.py | 161 +++++++++++++
.../pyspark/pandas/tests/groupby/test_stat_ddof.py | 91 +++++++
.../pyspark/pandas/tests/groupby/test_stat_func.py | 119 ++++++++++
.../pyspark/pandas/tests/groupby/test_stat_prod.py | 86 +++++++
11 files changed, 519 insertions(+), 257 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 38d3d42b658c..1c632871ba61 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -852,6 +852,10 @@ pyspark_pandas_slow = Module(
"pyspark.pandas.tests.groupby.test_size",
"pyspark.pandas.tests.groupby.test_split_apply",
"pyspark.pandas.tests.groupby.test_stat",
+ "pyspark.pandas.tests.groupby.test_stat_adv",
+ "pyspark.pandas.tests.groupby.test_stat_ddof",
+ "pyspark.pandas.tests.groupby.test_stat_func",
+ "pyspark.pandas.tests.groupby.test_stat_prod",
"pyspark.pandas.tests.groupby.test_value_counts",
"pyspark.pandas.tests.test_indexing",
"pyspark.pandas.tests.test_ops_on_diff_frames",
@@ -1060,7 +1064,6 @@ pyspark_pandas_connect_part0 = Module(
"pyspark.pandas.tests.connect.computation.test_parity_describe",
"pyspark.pandas.tests.connect.computation.test_parity_eval",
"pyspark.pandas.tests.connect.computation.test_parity_melt",
- "pyspark.pandas.tests.connect.groupby.test_parity_stat",
"pyspark.pandas.tests.connect.frame.test_parity_attrs",
"pyspark.pandas.tests.connect.frame.test_parity_axis",
"pyspark.pandas.tests.connect.diff_frames_ops.test_parity_dot_frame",
@@ -1197,6 +1200,11 @@ pyspark_pandas_connect_part3 = Module(
"pyspark.pandas.tests.connect.io.test_parity_dataframe_conversion",
"pyspark.pandas.tests.connect.io.test_parity_dataframe_spark_io",
"pyspark.pandas.tests.connect.io.test_parity_series_conversion",
+ "pyspark.pandas.tests.connect.groupby.test_parity_stat",
+ "pyspark.pandas.tests.connect.groupby.test_parity_stat_adv",
+ "pyspark.pandas.tests.connect.groupby.test_parity_stat_ddof",
+ "pyspark.pandas.tests.connect.groupby.test_parity_stat_func",
+ "pyspark.pandas.tests.connect.groupby.test_parity_stat_prod",
"pyspark.pandas.tests.connect.indexes.test_parity_datetime",
"pyspark.pandas.tests.connect.indexes.test_parity_datetime_at",
"pyspark.pandas.tests.connect.indexes.test_parity_datetime_between",
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
index a7c2e10dc3f5..c0f64a990e6f 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
@@ -21,7 +21,11 @@ from pyspark.testing.connectutils import
ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase):
+class GroupbyParityStatTests(
+ GroupbyStatMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py
index a7c2e10dc3f5..2138e44263b2 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_adv.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin
+from pyspark.pandas.tests.groupby.test_stat_adv import GroupbyStatAdvMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase):
+class GroupbyStatAdvParityTests(
+ GroupbyStatAdvMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.groupby.test_parity_stat import * #
noqa: F401
+ from pyspark.pandas.tests.connect.groupby.test_parity_stat_adv import * #
noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py
index a7c2e10dc3f5..3abec1d7abd7 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_ddof.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin
+from pyspark.pandas.tests.groupby.test_stat_ddof import DdofTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase):
+class GroupbyStatDdofParityTests(
+ DdofTestsMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.groupby.test_parity_stat import * #
noqa: F401
+ from pyspark.pandas.tests.connect.groupby.test_parity_stat_ddof import *
# noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py
index a7c2e10dc3f5..41a7944aa49f 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_func.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin
+from pyspark.pandas.tests.groupby.test_stat_func import FuncTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase):
+class GroupbyStatFuncParityTests(
+ FuncTestsMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.groupby.test_parity_stat import * #
noqa: F401
+ from pyspark.pandas.tests.connect.groupby.test_parity_stat_func import *
# noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
copy to python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py
index a7c2e10dc3f5..f0abab3a2142 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat_prod.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.groupby.test_stat import GroupbyStatMixin
+from pyspark.pandas.tests.groupby.test_stat_prod import ProdTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase):
+class GroupbyStatProdParityTests(
+ ProdTestsMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.groupby.test_parity_stat import * #
noqa: F401
+ from pyspark.pandas.tests.connect.groupby.test_parity_stat_prod import *
# noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/groupby/test_stat.py
b/python/pyspark/pandas/tests/groupby/test_stat.py
index 29991ae1d54c..61d8cc5357f6 100644
--- a/python/pyspark/pandas/tests/groupby/test_stat.py
+++ b/python/pyspark/pandas/tests/groupby/test_stat.py
@@ -20,26 +20,11 @@ import numpy as np
import pandas as pd
from pyspark import pandas as ps
-from pyspark.testing.pandasutils import ComparisonTestBase
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
-class GroupbyStatMixin:
- @property
- def pdf(self):
- return pd.DataFrame(
- {
- "A": [1, 2, 1, 2],
- "B": [3.1, 4.1, 4.1, 3.1],
- "C": ["a", "b", "b", "a"],
- "D": [True, False, False, True],
- }
- )
-
- @property
- def psdf(self):
- return ps.from_pandas(self.pdf)
-
+class GroupbyStatTestingFuncMixin:
# TODO: All statistical functions should leverage this utility
def _test_stat_func(self, func, check_exact=True):
pdf, psdf = self.pdf, self.psdf
@@ -57,64 +42,22 @@ class GroupbyStatMixin:
check_exact=check_exact,
)
- def test_basic_stat_funcs(self):
- self._test_stat_func(
- lambda groupby_obj: groupby_obj.var(numeric_only=True),
check_exact=False
- )
-
- pdf, psdf = self.pdf, self.psdf
- # Unlike pandas', the median in pandas-on-Spark is an approximated
median based upon
- # approximate percentile computation because computing median across a
large dataset
- # is extremely expensive.
- expected = ps.DataFrame({"B": [3.1, 3.1], "D": [0, 0]},
index=pd.Index([1, 2], name="A"))
- self.assert_eq(
- psdf.groupby("A").median().sort_index(),
- expected,
- )
- self.assert_eq(
- psdf.groupby("A").median(numeric_only=None).sort_index(),
- expected,
- )
- self.assert_eq(
- psdf.groupby("A").median(numeric_only=False).sort_index(),
- expected,
- )
- self.assert_eq(
- psdf.groupby("A")["B"].median().sort_index(),
- expected.B,
- )
- with self.assertRaises(TypeError):
- psdf.groupby("A")["C"].mean()
-
- with self.assertRaisesRegex(
- TypeError, "Unaccepted data types of aggregation columns; numeric
or bool expected."
- ):
- psdf.groupby("A")[["C"]].std()
-
- with self.assertRaisesRegex(
- TypeError, "Unaccepted data types of aggregation columns; numeric
or bool expected."
- ):
- psdf.groupby("A")[["C"]].sem()
-
- self.assert_eq(
- psdf.groupby("A").std().sort_index(),
- pdf.groupby("A").std(numeric_only=True).sort_index(),
- check_exact=False,
- )
- self.assert_eq(
- psdf.groupby("A").sem().sort_index(),
- pdf.groupby("A").sem(numeric_only=True).sort_index(),
- check_exact=False,
+class GroupbyStatMixin(GroupbyStatTestingFuncMixin):
+ @property
+ def pdf(self):
+ return pd.DataFrame(
+ {
+ "A": [1, 2, 1, 2],
+ "B": [3.1, 4.1, 4.1, 3.1],
+ "C": ["a", "b", "b", "a"],
+ "D": [True, False, False, True],
+ }
)
- # TODO: fix bug of `sum` and re-enable the test below
- # self._test_stat_func(lambda groupby_obj: groupby_obj.sum(),
check_exact=False)
- self.assert_eq(
- psdf.groupby("A").sum().sort_index(),
- pdf.groupby("A").sum().sort_index(),
- check_exact=False,
- )
+ @property
+ def psdf(self):
+ return ps.from_pandas(self.pdf)
def test_mean(self):
self._test_stat_func(lambda groupby_obj:
groupby_obj.mean(numeric_only=True))
@@ -122,48 +65,6 @@ class GroupbyStatMixin:
with self.assertRaises(TypeError):
psdf.groupby("A")["C"].mean()
- def test_quantile(self):
- dfs = [
- pd.DataFrame(
- [["a", 1], ["a", 2], ["a", 3], ["b", 1], ["b", 3], ["b", 5]],
columns=["key", "val"]
- ),
- pd.DataFrame(
- [["a", True], ["a", True], ["a", False], ["b", True], ["b",
True], ["b", False]],
- columns=["key", "val"],
- ),
- ]
- for df in dfs:
- psdf = ps.from_pandas(df)
- # q accept float and int between 0 and 1
- for i in [0, 0.1, 0.5, 1]:
- self.assert_eq(
- df.groupby("key").quantile(q=i, interpolation="lower"),
- psdf.groupby("key").quantile(q=i),
- almost=True,
- )
- self.assert_eq(
- df.groupby("key")["val"].quantile(q=i,
interpolation="lower"),
- psdf.groupby("key")["val"].quantile(q=i),
- almost=True,
- )
- # raise ValueError when q not in [0, 1]
- with self.assertRaises(ValueError):
- psdf.groupby("key").quantile(q=1.1)
- with self.assertRaises(ValueError):
- psdf.groupby("key").quantile(q=-0.1)
- with self.assertRaises(ValueError):
- psdf.groupby("key").quantile(q=2)
- with self.assertRaises(ValueError):
- psdf.groupby("key").quantile(q=np.nan)
- # raise TypeError when q type mismatch
- with self.assertRaises(TypeError):
- psdf.groupby("key").quantile(q="0.1")
- # raise NotImplementedError when q is list like type
- with self.assertRaises(NotImplementedError):
- psdf.groupby("key").quantile(q=(0.1, 0.5))
- with self.assertRaises(NotImplementedError):
- psdf.groupby("key").quantile(q=[0.1, 0.5])
-
def test_min(self):
self._test_stat_func(lambda groupby_obj: groupby_obj.min())
self._test_stat_func(lambda groupby_obj: groupby_obj.min(min_count=2))
@@ -197,88 +98,6 @@ class GroupbyStatMixin:
psdf.groupby("A").sum(min_count=3).sort_index(),
)
- def test_first(self):
- self._test_stat_func(lambda groupby_obj: groupby_obj.first())
- self._test_stat_func(lambda groupby_obj:
groupby_obj.first(numeric_only=None))
- self._test_stat_func(lambda groupby_obj:
groupby_obj.first(numeric_only=True))
-
- pdf = pd.DataFrame(
- {
- "A": [1, 2, 1, 2],
- "B": [-1.5, np.nan, -3.2, 0.1],
- }
- )
- psdf = ps.from_pandas(pdf)
- self.assert_eq(
- pdf.groupby("A").first().sort_index(),
psdf.groupby("A").first().sort_index()
- )
- self.assert_eq(
- pdf.groupby("A").first(min_count=1).sort_index(),
- psdf.groupby("A").first(min_count=1).sort_index(),
- )
- self.assert_eq(
- pdf.groupby("A").first(min_count=2).sort_index(),
- psdf.groupby("A").first(min_count=2).sort_index(),
- )
-
- def test_last(self):
- self._test_stat_func(lambda groupby_obj: groupby_obj.last())
- self._test_stat_func(lambda groupby_obj:
groupby_obj.last(numeric_only=None))
- self._test_stat_func(lambda groupby_obj:
groupby_obj.last(numeric_only=True))
-
- pdf = pd.DataFrame(
- {
- "A": [1, 2, 1, 2],
- "B": [-1.5, np.nan, -3.2, 0.1],
- }
- )
- psdf = ps.from_pandas(pdf)
- self.assert_eq(pdf.groupby("A").last().sort_index(),
psdf.groupby("A").last().sort_index())
- self.assert_eq(
- pdf.groupby("A").last(min_count=1).sort_index(),
- psdf.groupby("A").last(min_count=1).sort_index(),
- )
- self.assert_eq(
- pdf.groupby("A").last(min_count=2).sort_index(),
- psdf.groupby("A").last(min_count=2).sort_index(),
- )
-
- def test_nth(self):
- for n in [0, 1, 2, 128, -1, -2, -128]:
- self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))
-
- with self.assertRaisesRegex(NotImplementedError, "slice or list"):
- self.psdf.groupby("B").nth(slice(0, 2))
- with self.assertRaisesRegex(NotImplementedError, "slice or list"):
- self.psdf.groupby("B").nth([0, 1, -1])
- with self.assertRaisesRegex(TypeError, "Invalid index"):
- self.psdf.groupby("B").nth("x")
-
- def test_prod(self):
- pdf = pd.DataFrame(
- {
- "A": [1, 2, 1, 2, 1],
- "B": [3.1, 4.1, 4.1, 3.1, 0.1],
- "C": ["a", "b", "b", "a", "c"],
- "D": [True, False, False, True, False],
- "E": [-1, -2, 3, -4, -2],
- "F": [-1.5, np.nan, -3.2, 0.1, 0],
- "G": [np.nan, np.nan, np.nan, np.nan, np.nan],
- }
- )
- psdf = ps.from_pandas(pdf)
-
- for n in [0, 1, 2, 128, -1, -2, -128]:
- self._test_stat_func(
- lambda groupby_obj: groupby_obj.prod(numeric_only=True,
min_count=n),
- check_exact=False,
- )
- self.assert_eq(
- pdf.groupby("A").prod(min_count=n,
numeric_only=True).sort_index(),
- psdf.groupby("A").prod(min_count=n).sort_index(),
- almost=True,
- )
-
def test_median(self):
psdf = ps.DataFrame(
{
@@ -303,54 +122,12 @@ class GroupbyStatMixin:
with self.assertRaisesRegex(TypeError, "accuracy must be an integer;
however"):
psdf.groupby("a").median(accuracy="a")
- def test_ddof(self):
- pdf = pd.DataFrame(
- {
- "a": [1, 1, 1, 1, 2, 2, 2, 3, 3, 3] * 3,
- "b": [2, 3, 1, 4, 6, 9, 8, 10, 7, 5] * 3,
- "c": [3, 5, 2, 5, 1, 2, 6, 4, 3, 6] * 3,
- },
- index=np.random.rand(10 * 3),
- )
- psdf = ps.from_pandas(pdf)
-
- for ddof in [-1, 0, 1, 2, 3]:
- # std
- self.assert_eq(
- pdf.groupby("a").std(ddof=ddof).sort_index(),
- psdf.groupby("a").std(ddof=ddof).sort_index(),
- check_exact=False,
- )
- self.assert_eq(
- pdf.groupby("a")["b"].std(ddof=ddof).sort_index(),
- psdf.groupby("a")["b"].std(ddof=ddof).sort_index(),
- check_exact=False,
- )
- # var
- self.assert_eq(
- pdf.groupby("a").var(ddof=ddof).sort_index(),
- psdf.groupby("a").var(ddof=ddof).sort_index(),
- check_exact=False,
- )
- self.assert_eq(
- pdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
- psdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
- check_exact=False,
- )
- # sem
- self.assert_eq(
- pdf.groupby("a").sem(ddof=ddof).sort_index(),
- psdf.groupby("a").sem(ddof=ddof).sort_index(),
- check_exact=False,
- )
- self.assert_eq(
- pdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
- psdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
- check_exact=False,
- )
-
-class GroupbyStatTests(GroupbyStatMixin, ComparisonTestBase, SQLTestUtils):
+class GroupbyStatTests(
+ GroupbyStatMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
pass
diff --git a/python/pyspark/pandas/tests/groupby/test_stat_adv.py
b/python/pyspark/pandas/tests/groupby/test_stat_adv.py
new file mode 100644
index 000000000000..2b124ada85b4
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_stat_adv.py
@@ -0,0 +1,161 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin
+
+
+class GroupbyStatAdvMixin(GroupbyStatTestingFuncMixin):
+ @property
+ def pdf(self):
+ return pd.DataFrame(
+ {
+ "A": [1, 2, 1, 2],
+ "B": [3.1, 4.1, 4.1, 3.1],
+ "C": ["a", "b", "b", "a"],
+ "D": [True, False, False, True],
+ }
+ )
+
+ @property
+ def psdf(self):
+ return ps.from_pandas(self.pdf)
+
+ def test_quantile(self):
+ dfs = [
+ pd.DataFrame(
+ [["a", 1], ["a", 2], ["a", 3], ["b", 1], ["b", 3], ["b", 5]],
columns=["key", "val"]
+ ),
+ pd.DataFrame(
+ [["a", True], ["a", True], ["a", False], ["b", True], ["b",
True], ["b", False]],
+ columns=["key", "val"],
+ ),
+ ]
+ for df in dfs:
+ psdf = ps.from_pandas(df)
+ # q accept float and int between 0 and 1
+ for i in [0, 0.1, 0.5, 1]:
+ self.assert_eq(
+ df.groupby("key").quantile(q=i, interpolation="lower"),
+ psdf.groupby("key").quantile(q=i),
+ almost=True,
+ )
+ self.assert_eq(
+ df.groupby("key")["val"].quantile(q=i,
interpolation="lower"),
+ psdf.groupby("key")["val"].quantile(q=i),
+ almost=True,
+ )
+ # raise ValueError when q not in [0, 1]
+ with self.assertRaises(ValueError):
+ psdf.groupby("key").quantile(q=1.1)
+ with self.assertRaises(ValueError):
+ psdf.groupby("key").quantile(q=-0.1)
+ with self.assertRaises(ValueError):
+ psdf.groupby("key").quantile(q=2)
+ with self.assertRaises(ValueError):
+ psdf.groupby("key").quantile(q=np.nan)
+ # raise TypeError when q type mismatch
+ with self.assertRaises(TypeError):
+ psdf.groupby("key").quantile(q="0.1")
+ # raise NotImplementedError when q is list like type
+ with self.assertRaises(NotImplementedError):
+ psdf.groupby("key").quantile(q=(0.1, 0.5))
+ with self.assertRaises(NotImplementedError):
+ psdf.groupby("key").quantile(q=[0.1, 0.5])
+
+ def test_first(self):
+ self._test_stat_func(lambda groupby_obj: groupby_obj.first())
+ self._test_stat_func(lambda groupby_obj:
groupby_obj.first(numeric_only=None))
+ self._test_stat_func(lambda groupby_obj:
groupby_obj.first(numeric_only=True))
+
+ pdf = pd.DataFrame(
+ {
+ "A": [1, 2, 1, 2],
+ "B": [-1.5, np.nan, -3.2, 0.1],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(
+ pdf.groupby("A").first().sort_index(),
psdf.groupby("A").first().sort_index()
+ )
+ self.assert_eq(
+ pdf.groupby("A").first(min_count=1).sort_index(),
+ psdf.groupby("A").first(min_count=1).sort_index(),
+ )
+ self.assert_eq(
+ pdf.groupby("A").first(min_count=2).sort_index(),
+ psdf.groupby("A").first(min_count=2).sort_index(),
+ )
+
+ def test_last(self):
+ self._test_stat_func(lambda groupby_obj: groupby_obj.last())
+ self._test_stat_func(lambda groupby_obj:
groupby_obj.last(numeric_only=None))
+ self._test_stat_func(lambda groupby_obj:
groupby_obj.last(numeric_only=True))
+
+ pdf = pd.DataFrame(
+ {
+ "A": [1, 2, 1, 2],
+ "B": [-1.5, np.nan, -3.2, 0.1],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(pdf.groupby("A").last().sort_index(),
psdf.groupby("A").last().sort_index())
+ self.assert_eq(
+ pdf.groupby("A").last(min_count=1).sort_index(),
+ psdf.groupby("A").last(min_count=1).sort_index(),
+ )
+ self.assert_eq(
+ pdf.groupby("A").last(min_count=2).sort_index(),
+ psdf.groupby("A").last(min_count=2).sort_index(),
+ )
+
+ def test_nth(self):
+ for n in [0, 1, 2, 128, -1, -2, -128]:
+ self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))
+
+ with self.assertRaisesRegex(NotImplementedError, "slice or list"):
+ self.psdf.groupby("B").nth(slice(0, 2))
+ with self.assertRaisesRegex(NotImplementedError, "slice or list"):
+ self.psdf.groupby("B").nth([0, 1, -1])
+ with self.assertRaisesRegex(TypeError, "Invalid index"):
+ self.psdf.groupby("B").nth("x")
+
+
+class GroupbyStatAdvTests(
+ GroupbyStatAdvMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.groupby.test_stat_adv import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/groupby/test_stat_ddof.py
b/python/pyspark/pandas/tests/groupby/test_stat_ddof.py
new file mode 100644
index 000000000000..63e974bd69fc
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_stat_ddof.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class DdofTestsMixin:
+ def test_ddof(self):
+ pdf = pd.DataFrame(
+ {
+ "a": [1, 1, 1, 1, 2, 2, 2, 3, 3, 3] * 3,
+ "b": [2, 3, 1, 4, 6, 9, 8, 10, 7, 5] * 3,
+ "c": [3, 5, 2, 5, 1, 2, 6, 4, 3, 6] * 3,
+ },
+ index=np.random.rand(10 * 3),
+ )
+ psdf = ps.from_pandas(pdf)
+
+ for ddof in [-1, 0, 1, 2, 3]:
+ # std
+ self.assert_eq(
+ pdf.groupby("a").std(ddof=ddof).sort_index(),
+ psdf.groupby("a").std(ddof=ddof).sort_index(),
+ check_exact=False,
+ )
+ self.assert_eq(
+ pdf.groupby("a")["b"].std(ddof=ddof).sort_index(),
+ psdf.groupby("a")["b"].std(ddof=ddof).sort_index(),
+ check_exact=False,
+ )
+ # var
+ self.assert_eq(
+ pdf.groupby("a").var(ddof=ddof).sort_index(),
+ psdf.groupby("a").var(ddof=ddof).sort_index(),
+ check_exact=False,
+ )
+ self.assert_eq(
+ pdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
+ psdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
+ check_exact=False,
+ )
+ # sem
+ self.assert_eq(
+ pdf.groupby("a").sem(ddof=ddof).sort_index(),
+ psdf.groupby("a").sem(ddof=ddof).sort_index(),
+ check_exact=False,
+ )
+ self.assert_eq(
+ pdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
+ psdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
+ check_exact=False,
+ )
+
+
+class DdofTests(
+ DdofTestsMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.groupby.test_stat_ddof import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/groupby/test_stat_func.py
b/python/pyspark/pandas/tests/groupby/test_stat_func.py
new file mode 100644
index 000000000000..257394b59f51
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_stat_func.py
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin
+
+
+class FuncTestsMixin(GroupbyStatTestingFuncMixin):
+ @property
+ def pdf(self):
+ return pd.DataFrame(
+ {
+ "A": [1, 2, 1, 2],
+ "B": [3.1, 4.1, 4.1, 3.1],
+ "C": ["a", "b", "b", "a"],
+ "D": [True, False, False, True],
+ }
+ )
+
+ @property
+ def psdf(self):
+ return ps.from_pandas(self.pdf)
+
+ def test_basic_stat_funcs(self):
+ self._test_stat_func(
+ lambda groupby_obj: groupby_obj.var(numeric_only=True),
check_exact=False
+ )
+
+ pdf, psdf = self.pdf, self.psdf
+
+ # Unlike pandas', the median in pandas-on-Spark is an approximated
median based upon
+ # approximate percentile computation because computing median across a
large dataset
+ # is extremely expensive.
+ expected = ps.DataFrame({"B": [3.1, 3.1], "D": [0, 0]},
index=pd.Index([1, 2], name="A"))
+ self.assert_eq(
+ psdf.groupby("A").median().sort_index(),
+ expected,
+ )
+ self.assert_eq(
+ psdf.groupby("A").median(numeric_only=None).sort_index(),
+ expected,
+ )
+ self.assert_eq(
+ psdf.groupby("A").median(numeric_only=False).sort_index(),
+ expected,
+ )
+ self.assert_eq(
+ psdf.groupby("A")["B"].median().sort_index(),
+ expected.B,
+ )
+ with self.assertRaises(TypeError):
+ psdf.groupby("A")["C"].mean()
+
+ with self.assertRaisesRegex(
+ TypeError, "Unaccepted data types of aggregation columns; numeric
or bool expected."
+ ):
+ psdf.groupby("A")[["C"]].std()
+
+ with self.assertRaisesRegex(
+ TypeError, "Unaccepted data types of aggregation columns; numeric
or bool expected."
+ ):
+ psdf.groupby("A")[["C"]].sem()
+
+ self.assert_eq(
+ psdf.groupby("A").std().sort_index(),
+ pdf.groupby("A").std(numeric_only=True).sort_index(),
+ check_exact=False,
+ )
+ self.assert_eq(
+ psdf.groupby("A").sem().sort_index(),
+ pdf.groupby("A").sem(numeric_only=True).sort_index(),
+ check_exact=False,
+ )
+
+ # TODO: fix bug of `sum` and re-enable the test below
+ # self._test_stat_func(lambda groupby_obj: groupby_obj.sum(),
check_exact=False)
+ self.assert_eq(
+ psdf.groupby("A").sum().sort_index(),
+ pdf.groupby("A").sum().sort_index(),
+ check_exact=False,
+ )
+
+
+class FuncTests(
+ FuncTestsMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.groupby.test_stat_func import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/groupby/test_stat_prod.py
b/python/pyspark/pandas/tests/groupby/test_stat_prod.py
new file mode 100644
index 000000000000..31da55d26018
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_stat_prod.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.pandas.tests.groupby.test_stat import GroupbyStatTestingFuncMixin
+
+
+class ProdTestsMixin(GroupbyStatTestingFuncMixin):
+ @property
+ def pdf(self):
+ return pd.DataFrame(
+ {
+ "A": [1, 2, 1, 2],
+ "B": [3.1, 4.1, 4.1, 3.1],
+ "C": ["a", "b", "b", "a"],
+ "D": [True, False, False, True],
+ }
+ )
+
+ @property
+ def psdf(self):
+ return ps.from_pandas(self.pdf)
+
+ def test_prod(self):
+ pdf = pd.DataFrame(
+ {
+ "A": [1, 2, 1, 2, 1],
+ "B": [3.1, 4.1, 4.1, 3.1, 0.1],
+ "C": ["a", "b", "b", "a", "c"],
+ "D": [True, False, False, True, False],
+ "E": [-1, -2, 3, -4, -2],
+ "F": [-1.5, np.nan, -3.2, 0.1, 0],
+ "G": [np.nan, np.nan, np.nan, np.nan, np.nan],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ for n in [0, 1, 2, 128, -1, -2, -128]:
+ self._test_stat_func(
+ lambda groupby_obj: groupby_obj.prod(numeric_only=True,
min_count=n),
+ check_exact=False,
+ )
+ self.assert_eq(
+ pdf.groupby("A").prod(min_count=n,
numeric_only=True).sort_index(),
+ psdf.groupby("A").prod(min_count=n).sort_index(),
+ almost=True,
+ )
+
+
+class ProdTests(
+ ProdTestsMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.groupby.test_stat_prod import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]