This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 464f8b576aa8 [SPARK-55156][PS] Deal with `include_groups` for
`groupby.apply`
464f8b576aa8 is described below
commit 464f8b576aa8147a94afca4fc97029a2fc63b2dd
Author: Takuya Ueshin <[email protected]>
AuthorDate: Fri Feb 13 10:57:31 2026 +0800
[SPARK-55156][PS] Deal with `include_groups` for `groupby.apply`
### What changes were proposed in this pull request?
Deals with `include_groups` for `groupby.apply`.
- Added `include_groups` for `groupby.apply`
- with pandas 2
- `True` by default
- If set to `False`, it behaves like pandas 3
- with pandas 3
- `False` by default
- If set to `True`, it raises an exception
### Why are the changes needed?
`df.groupby.apply()` now has `include_groups=False` by default, which
differs from our behavior.
For example:
```py
>>> import pandas as pd
>>> pdf = pd.DataFrame(
... {"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9,
16, 25, 36]},
... columns=["a", "b", "c"],
... )
```
- pandas 2
```py
>>> pd.__version__
'2.3.3'
>>> pdf.groupby("b").apply(lambda x: x + x.min())
a b c
b
1 0 2 2 2
1 3 2 5
2 2 6 4 18
3 3 8 6 32
5 4 10 10 50
8 5 12 16 72
>>> pdf.groupby("b").apply(lambda x: x + x.min(), include_groups=False)
a c
b
1 0 2 2
1 3 5
2 2 6 18
3 3 8 32
5 4 10 50
8 5 12 72
```
- pandas 3
```py
>>> pd.__version__
'3.0.0'
>>> pdf.groupby("b").apply(lambda x: x + x.min())
a c
b
1 0 2 2
1 3 5
2 2 6 18
3 3 8 32
5 4 10 50
8 5 12 72
>>> pdf.groupby("b").apply(lambda x: x + x.min(), include_groups=True)
Traceback (most recent call last):
...
ValueError: include_groups=True is no longer allowed.
```
### Does this PR introduce _any_ user-facing change?
Yes, it will behave more like pandas 3.
### How was this patch tested?
Updated the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54285 from ueshin/issues/SPARK-55156/include_groups.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/pandas/groupby.py | 45 ++-
.../tests/diff_frames_ops/test_groupby_apply.py | 4 +
.../pandas/tests/groupby/test_apply_func.py | 363 ++++++++++++++++-----
python/pyspark/pandas/tests/test_categorical.py | 12 +-
4 files changed, 337 insertions(+), 87 deletions(-)
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index ff2b39fe0133..37993e5f2499 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -58,6 +58,7 @@ from pyspark.sql.types import (
StringType,
)
from pyspark import pandas as ps # For running doctests and reference
resolution in PyCharm.
+from pyspark._globals import _NoValue, _NoValueType
from pyspark.loose_version import LooseVersion
from pyspark.pandas._typing import Axis, FrameLike, Label, Name
from pyspark.pandas.typedef import infer_return_type, DataFrameType,
ScalarType, SeriesType
@@ -1808,7 +1809,13 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
numeric_only=True,
)
- def apply(self, func: Callable, *args: Any, **kwargs: Any) ->
Union[DataFrame, Series]:
+ def apply(
+ self,
+ func: Callable,
+ *args: Any,
+ include_groups: Union[bool, _NoValueType] = _NoValue,
+ **kwargs: Any,
+ ) -> Union[DataFrame, Series]:
"""
Apply function `func` group-wise and combine the results together.
@@ -1963,6 +1970,24 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
if not callable(func):
raise TypeError("%s object is not callable" % type(func).__name__)
+ if LooseVersion(pd.__version__) < "3.0.0":
+ if include_groups is _NoValue:
+ include_groups = True
+ if include_groups:
+ warnings.warn(
+ "DataFrameGroupBy.apply operated on the grouping columns. "
+ "This behavior is deprecated, and in a future version of
pandas "
+ "the grouping columns will be excluded from the operation.
"
+ "Either pass `include_groups=False` to exclude the
groupings or "
+ "explicitly select the grouping columns after groupby to
silence this warning.",
+ FutureWarning,
+ )
+ else:
+ if include_groups is _NoValue:
+ include_groups = False
+ if include_groups:
+ raise ValueError("include_groups=True is no longer allowed.")
+
spec = inspect.getfullargspec(func)
return_sig = spec.annotations.get("return", None)
should_infer_schema = return_sig is None
@@ -1980,10 +2005,15 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
for label in psdf._internal.column_labels
if label not in self._column_labels_to_exclude
]
+ if not include_groups:
+ agg_columns = [
+ col for col in agg_columns if all(col is not gkey for gkey
in self._groupkeys)
+ ]
psdf, groupkey_labels, groupkey_names =
GroupBy._prepare_group_map_apply(
psdf, self._groupkeys, agg_columns
)
+ groupkey_psser_names = [psser.name for psser in self._groupkeys]
if LooseVersion(pd.__version__) < "3.0.0":
from pandas.core.common import is_builtin_func # type:
ignore[import-not-found]
@@ -2014,8 +2044,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
sample_limit = limit + 1 if limit else 2
pdf = psdf.head(sample_limit)._to_internal_pandas()
groupkeys = [
- pdf[groupkey_name].rename(psser.name)
- for groupkey_name, psser in zip(groupkey_names,
self._groupkeys)
+ pdf[groupkey_name].rename(name)
+ for groupkey_name, name in zip(groupkey_names,
groupkey_psser_names)
]
grouped = pdf.groupby(groupkeys)
if is_series_groupby:
@@ -2086,10 +2116,15 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
return_schema = StructType([field.struct_field for field in
data_fields])
def pandas_groupby_apply(pdf: pd.DataFrame) -> pd.DataFrame:
+ groupkeys = [
+ pdf[groupkey_name].rename(name)
+ for groupkey_name, name in zip(groupkey_names,
groupkey_psser_names)
+ ]
+ grouped = pdf.groupby(groupkeys)
if is_series_groupby:
- pdf_or_ser =
pdf.groupby(groupkey_names)[name].apply(pandas_apply, *args, **kwargs)
+ pdf_or_ser = grouped[name].apply(pandas_apply, *args, **kwargs)
else:
- pdf_or_ser = pdf.groupby(groupkey_names).apply(pandas_apply,
*args, **kwargs)
+ pdf_or_ser = grouped.apply(pandas_apply, *args, **kwargs)
if should_return_series and isinstance(pdf_or_ser,
pd.DataFrame):
pdf_or_ser = pdf_or_ser.stack()
diff --git a/python/pyspark/pandas/tests/diff_frames_ops/test_groupby_apply.py
b/python/pyspark/pandas/tests/diff_frames_ops/test_groupby_apply.py
index b5465423c6d8..552580b0ceed 100644
--- a/python/pyspark/pandas/tests/diff_frames_ops/test_groupby_apply.py
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_groupby_apply.py
@@ -60,6 +60,10 @@ class GroupByApplyMixin:
pdf.groupby(["a", pkey]).apply(lambda x: x + x.min()).sort_index(),
)
+ def test_apply_without_shortcut(self):
+ with ps.option_context("compute.shortcut_limit", 0):
+ self.test_apply()
+
class GroupByApplyTests(
GroupByApplyMixin,
diff --git a/python/pyspark/pandas/tests/groupby/test_apply_func.py
b/python/pyspark/pandas/tests/groupby/test_apply_func.py
index 25cc21423f32..5716e574cb44 100644
--- a/python/pyspark/pandas/tests/groupby/test_apply_func.py
+++ b/python/pyspark/pandas/tests/groupby/test_apply_func.py
@@ -19,6 +19,7 @@ import numpy as np
import pandas as pd
from pyspark import pandas as ps
+from pyspark.loose_version import LooseVersion
from pyspark.pandas.config import option_context
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
@@ -31,100 +32,180 @@ class GroupbyApplyFuncMixin:
columns=["a", "b", "c"],
)
psdf = ps.from_pandas(pdf)
+
+ if LooseVersion(pd.__version__) < "3.0.0":
+ for include_groups in [True, False]:
+ with self.subTest(include_groups=include_groups):
+ self._check_apply(psdf, pdf, include_groups)
+ else:
+ self._check_apply(psdf, pdf, include_groups=False)
+ with self.assertRaises(ValueError):
+ psdf.groupby("b").apply(lambda x: x + x.min(),
include_groups=True)
+
+ def _check_apply(self, psdf, pdf, include_groups):
self.assert_eq(
- psdf.groupby("b").apply(lambda x: x + x.min()).sort_index(),
- pdf.groupby("b").apply(lambda x: x + x.min()).sort_index(),
+ psdf.groupby("b")
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby("b")
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
- psdf.groupby("b").apply(len).sort_index(),
- pdf.groupby("b").apply(len).sort_index(),
+ psdf.groupby("b").apply(len,
include_groups=include_groups).sort_index(),
+ pdf.groupby("b").apply(len,
include_groups=include_groups).sort_index(),
)
self.assert_eq(
psdf.groupby("b")["a"]
- .apply(lambda x, y, z: x + x.min() + y * z, 10, z=20)
+ .apply(lambda x, y, z: x + x.min() + y * z, 10, z=20,
include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby("b")["a"]
+ .apply(lambda x, y, z: x + x.min() + y * z, 10, z=20,
include_groups=include_groups)
.sort_index(),
- pdf.groupby("b")["a"].apply(lambda x, y, z: x + x.min() + y * z,
10, z=20).sort_index(),
)
self.assert_eq(
- psdf.groupby("b")[["a"]].apply(lambda x: x + x.min()).sort_index(),
- pdf.groupby("b")[["a"]].apply(lambda x: x + x.min()).sort_index(),
+ psdf.groupby("b")[["a"]]
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby("b")[["a"]]
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
psdf.groupby(["a", "b"])
- .apply(lambda x, y, z: x + x.min() + y + z, 1, z=2)
+ .apply(lambda x, y, z: x + x.min() + y + z, 1, z=2,
include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby(["a", "b"])
+ .apply(lambda x, y, z: x + x.min() + y + z, 1, z=2,
include_groups=include_groups)
.sort_index(),
- pdf.groupby(["a", "b"]).apply(lambda x, y, z: x + x.min() + y + z,
1, z=2).sort_index(),
)
self.assert_eq(
- psdf.groupby(["b"])["c"].apply(lambda x: 1).sort_index(),
- pdf.groupby(["b"])["c"].apply(lambda x: 1).sort_index(),
+ psdf.groupby(["b"])["c"].apply(lambda x: 1,
include_groups=include_groups).sort_index(),
+ pdf.groupby(["b"])["c"].apply(lambda x: 1,
include_groups=include_groups).sort_index(),
)
self.assert_eq(
- psdf.groupby(["b"])["c"].apply(len).sort_index(),
- pdf.groupby(["b"])["c"].apply(len).sort_index(),
+ psdf.groupby(["b"])["c"].apply(len,
include_groups=include_groups).sort_index(),
+ pdf.groupby(["b"])["c"].apply(len,
include_groups=include_groups).sort_index(),
)
self.assert_eq(
- psdf.groupby(psdf.b // 5).apply(lambda x: x +
x.min()).sort_index(),
- pdf.groupby(pdf.b // 5).apply(lambda x: x + x.min()).sort_index(),
+ psdf.groupby(psdf.b // 5)
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby(pdf.b // 5)
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
almost=True,
)
self.assert_eq(
- psdf.groupby(psdf.b // 5)["a"].apply(lambda x: x +
x.min()).sort_index(),
- pdf.groupby(pdf.b // 5)["a"].apply(lambda x: x +
x.min()).sort_index(),
+ psdf.groupby(psdf.b // 5)["a"]
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby(pdf.b // 5)["a"]
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
almost=True,
)
self.assert_eq(
- psdf.groupby(psdf.b // 5)[["a"]].apply(lambda x: x +
x.min()).sort_index(),
- pdf.groupby(pdf.b // 5)[["a"]].apply(lambda x: x +
x.min()).sort_index(),
+ psdf.groupby(psdf.b // 5)[["a"]]
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby(pdf.b // 5)[["a"]]
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
almost=True,
)
self.assert_eq(
- psdf.groupby(psdf.b // 5)[["a"]].apply(len).sort_index(),
- pdf.groupby(pdf.b // 5)[["a"]].apply(len).sort_index(),
+ psdf.groupby(psdf.b // 5)[["a"]].apply(len,
include_groups=include_groups).sort_index(),
+ pdf.groupby(pdf.b // 5)[["a"]].apply(len,
include_groups=include_groups).sort_index(),
almost=True,
)
self.assert_eq(
- psdf.a.rename().groupby(psdf.b).apply(lambda x: x +
x.min()).sort_index(),
- pdf.a.rename().groupby(pdf.b).apply(lambda x: x +
x.min()).sort_index(),
+ psdf.a.rename()
+ .groupby(psdf.b)
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.a.rename()
+ .groupby(pdf.b)
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
- psdf.a.groupby(psdf.b.rename()).apply(lambda x: x +
x.min()).sort_index(),
- pdf.a.groupby(pdf.b.rename()).apply(lambda x: x +
x.min()).sort_index(),
+ psdf.a.groupby(psdf.b.rename())
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.a.groupby(pdf.b.rename())
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
- psdf.a.rename().groupby(psdf.b.rename()).apply(lambda x: x +
x.min()).sort_index(),
- pdf.a.rename().groupby(pdf.b.rename()).apply(lambda x: x +
x.min()).sort_index(),
+ psdf.a.rename()
+ .groupby(psdf.b.rename())
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.a.rename()
+ .groupby(pdf.b.rename())
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
)
with self.assertRaisesRegex(TypeError, "int object is not callable"):
- psdf.groupby("b").apply(1)
+ psdf.groupby("b").apply(1, include_groups=include_groups)
+
+ def test_apply_with_multi_index_columns(self):
+ pdf = pd.DataFrame(
+ {"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9,
16, 25, 36]},
+ columns=["a", "b", "c"],
+ )
+ psdf = ps.from_pandas(pdf)
# multi-index columns
columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y",
"c")])
pdf.columns = columns
psdf.columns = columns
+ if LooseVersion(pd.__version__) < "3.0.0":
+ for include_groups in [True, False]:
+ with self.subTest(include_groups=include_groups):
+ self._check_apply_with_multi_index_columns(psdf, pdf,
include_groups)
+ else:
+ self._check_apply_with_multi_index_columns(psdf, pdf,
include_groups=False)
+ with self.assertRaises(ValueError):
+ psdf.groupby(("x", "b")).apply(lambda x: x + x.min(),
include_groups=True)
+
+ def _check_apply_with_multi_index_columns(self, psdf, pdf, include_groups):
self.assert_eq(
- psdf.groupby(("x", "b")).apply(lambda x: 1).sort_index(),
- pdf.groupby(("x", "b")).apply(lambda x: 1).sort_index(),
+ psdf.groupby(("x", "b")).apply(lambda x: 1,
include_groups=include_groups).sort_index(),
+ pdf.groupby(("x", "b")).apply(lambda x: 1,
include_groups=include_groups).sort_index(),
)
self.assert_eq(
- psdf.groupby([("x", "a"), ("x", "b")]).apply(lambda x: x +
x.min()).sort_index(),
- pdf.groupby([("x", "a"), ("x", "b")]).apply(lambda x: x +
x.min()).sort_index(),
+ psdf.groupby([("x", "a"), ("x", "b")])
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby([("x", "a"), ("x", "b")])
+ .apply(lambda x: x + x.min(), include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
- psdf.groupby(("x", "b")).apply(len).sort_index(),
- pdf.groupby(("x", "b")).apply(len).sort_index(),
+ psdf.groupby(("x", "b")).apply(len,
include_groups=include_groups).sort_index(),
+ pdf.groupby(("x", "b")).apply(len,
include_groups=include_groups).sort_index(),
)
self.assert_eq(
- psdf.groupby([("x", "a"), ("x", "b")]).apply(len).sort_index(),
- pdf.groupby([("x", "a"), ("x", "b")]).apply(len).sort_index(),
+ psdf.groupby([("x", "a"), ("x", "b")])
+ .apply(len, include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby([("x", "a"), ("x", "b")])
+ .apply(len, include_groups=include_groups)
+ .sort_index(),
)
def test_apply_without_shortcut(self):
with option_context("compute.shortcut_limit", 0):
self.test_apply()
+ def test_apply_with_multi_index_columns_without_shortcut(self):
+ with option_context("compute.shortcut_limit", 0):
+ self.test_apply_with_multi_index_columns()
+
def test_apply_with_type_hint(self):
pdf = pd.DataFrame(
{"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9,
16, 25, 36]},
@@ -132,27 +213,61 @@ class GroupbyApplyFuncMixin:
)
psdf = ps.from_pandas(pdf)
- def add_max1(x) -> ps.DataFrame[int, int, int]:
- return x + x.min()
+ if LooseVersion(pd.__version__) < "3.0.0":
+ for include_groups in [True, False]:
+ with self.subTest(include_groups=include_groups):
+ self._check_apply_with_type_hint(psdf, pdf, include_groups)
+ else:
+ self._check_apply_with_type_hint(psdf, pdf, include_groups=False)
+
+ def _check_apply_with_type_hint(self, psdf, pdf, include_groups):
+ if include_groups:
+
+ def add_max1(x) -> ps.DataFrame[int, int, int]:
+ return x + x.min()
+
+ else:
+
+ def add_max1(x) -> ps.DataFrame[int, int]:
+ return x + x.min()
# Type hints set the default column names, and we use default index for
# pandas API on Spark. Here we ignore both diff.
- actual = psdf.groupby("b").apply(add_max1).sort_index()
- expected = pdf.groupby("b").apply(add_max1).sort_index()
- self.assert_eq(sorted(actual["c0"].to_numpy()),
sorted(expected["a"].to_numpy()))
- self.assert_eq(sorted(actual["c1"].to_numpy()),
sorted(expected["b"].to_numpy()))
- self.assert_eq(sorted(actual["c2"].to_numpy()),
sorted(expected["c"].to_numpy()))
-
- def add_max2(
- x,
- ) -> ps.DataFrame[slice("a", int), slice("b", int), slice("c", int)]:
- return x + x.min()
-
- actual = psdf.groupby("b").apply(add_max2).sort_index()
- expected = pdf.groupby("b").apply(add_max2).sort_index()
- self.assert_eq(sorted(actual["a"].to_numpy()),
sorted(expected["a"].to_numpy()))
- self.assert_eq(sorted(actual["c"].to_numpy()),
sorted(expected["c"].to_numpy()))
- self.assert_eq(sorted(actual["c"].to_numpy()),
sorted(expected["c"].to_numpy()))
+ actual = psdf.groupby("b").apply(add_max1,
include_groups=include_groups).sort_index()
+ expected = pdf.groupby("b").apply(add_max1,
include_groups=include_groups).sort_index()
+
+ if include_groups:
+ self.assert_eq(sorted(actual["c0"].to_numpy()),
sorted(expected["a"].to_numpy()))
+ self.assert_eq(sorted(actual["c1"].to_numpy()),
sorted(expected["b"].to_numpy()))
+ self.assert_eq(sorted(actual["c2"].to_numpy()),
sorted(expected["c"].to_numpy()))
+ else:
+ self.assert_eq(sorted(actual["c0"].to_numpy()),
sorted(expected["a"].to_numpy()))
+ self.assert_eq(sorted(actual["c1"].to_numpy()),
sorted(expected["c"].to_numpy()))
+
+ if include_groups:
+
+ def add_max2(
+ x,
+ ) -> ps.DataFrame[slice("a", int), slice("b", int), slice("c",
int)]:
+ return x + x.min()
+
+ else:
+
+ def add_max2(
+ x,
+ ) -> ps.DataFrame[slice("a", int), slice("c", int)]:
+ return x + x.min()
+
+ actual = psdf.groupby("b").apply(add_max2,
include_groups=include_groups).sort_index()
+ expected = pdf.groupby("b").apply(add_max2,
include_groups=include_groups).sort_index()
+
+ if include_groups:
+ self.assert_eq(sorted(actual["a"].to_numpy()),
sorted(expected["a"].to_numpy()))
+ self.assert_eq(sorted(actual["b"].to_numpy()),
sorted(expected["b"].to_numpy()))
+ self.assert_eq(sorted(actual["c"].to_numpy()),
sorted(expected["c"].to_numpy()))
+ else:
+ self.assert_eq(sorted(actual["a"].to_numpy()),
sorted(expected["a"].to_numpy()))
+ self.assert_eq(sorted(actual["c"].to_numpy()),
sorted(expected["c"].to_numpy()))
def test_apply_negative(self):
def func(_) -> ps.Series[int]:
@@ -245,18 +360,38 @@ class GroupbyApplyFuncMixin:
)
psdf = ps.from_pandas(pdf)
+ if LooseVersion(pd.__version__) < "3.0.0":
+ for include_groups in [True, False]:
+ with self.subTest(include_groups=include_groups):
+ self._check_apply_with_side_effect(psdf, pdf,
include_groups)
+ else:
+ self._check_apply_with_side_effect(psdf, pdf, include_groups=False)
+
+ def _check_apply_with_side_effect(self, psdf, pdf, include_groups):
acc = ps.utils.default_session().sparkContext.accumulator(0)
- def sum_with_acc_frame(x) -> ps.DataFrame[np.float64, np.float64]:
- nonlocal acc
- acc += 1
- return np.sum(x)
+ if include_groups:
+
+ def sum_with_acc_frame(x) -> ps.DataFrame[np.float64, np.float64]:
+ nonlocal acc
+ acc += 1
+ return np.sum(x)
- actual = psdf.groupby("d").apply(sum_with_acc_frame)
- actual.columns = ["d", "v"]
+ else:
+
+ def sum_with_acc_frame(x) -> ps.DataFrame[np.float64]:
+ nonlocal acc
+ acc += 1
+ return np.sum(x)
+
+ actual = psdf.groupby("d").apply(sum_with_acc_frame,
include_groups=include_groups)
+ actual.columns = ["d", "v"] if include_groups else ["v"]
self.assert_eq(
actual._to_pandas().sort_index(),
- pdf.groupby("d").apply(sum).sort_index().reset_index(drop=True),
+ pdf.groupby("d")
+ .apply(sum, include_groups=include_groups)
+ .sort_index()
+ .reset_index(drop=True),
)
self.assert_eq(acc.value, 2)
@@ -266,8 +401,14 @@ class GroupbyApplyFuncMixin:
return np.sum(x)
self.assert_eq(
-
psdf.groupby("d")["v"].apply(sum_with_acc_series)._to_pandas().sort_index(),
-
pdf.groupby("d")["v"].apply(sum).sort_index().reset_index(drop=True),
+ psdf.groupby("d")["v"]
+ .apply(sum_with_acc_series, include_groups=include_groups)
+ ._to_pandas()
+ .sort_index(),
+ pdf.groupby("d")["v"]
+ .apply(sum, include_groups=include_groups)
+ .sort_index()
+ .reset_index(drop=True),
)
self.assert_eq(acc.value, 4)
@@ -279,43 +420,101 @@ class GroupbyApplyFuncMixin:
)
psdf = ps.from_pandas(pdf)
+ if LooseVersion(pd.__version__) < "3.0.0":
+ for include_groups in [True, False]:
+ with self.subTest(include_groups=include_groups):
+ self._check_apply_return_series(psdf, pdf, include_groups)
+ else:
+ self._check_apply_return_series(psdf, pdf, include_groups=False)
+ with self.assertRaises(ValueError):
+ psdf.groupby("b").apply(lambda x: x + x.iloc[0],
include_groups=True)
+
+ def _check_apply_return_series(self, psdf, pdf, include_groups):
self.assert_eq(
- psdf.groupby("b").apply(lambda x: x.iloc[0]).sort_index(),
- pdf.groupby("b").apply(lambda x: x.iloc[0]).sort_index(),
+ psdf.groupby("b")
+ .apply(lambda x: x.iloc[0], include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby("b").apply(lambda x: x.iloc[0],
include_groups=include_groups).sort_index(),
)
self.assert_eq(
- psdf.groupby("b").apply(lambda x: x["a"]).sort_index(),
- pdf.groupby("b").apply(lambda x: x["a"]).sort_index(),
+ psdf.groupby("b").apply(lambda x: x["a"],
include_groups=include_groups).sort_index(),
+ pdf.groupby("b").apply(lambda x: x["a"],
include_groups=include_groups).sort_index(),
)
self.assert_eq(
- psdf.groupby(["b", "c"]).apply(lambda x: x.iloc[0]).sort_index(),
- pdf.groupby(["b", "c"]).apply(lambda x: x.iloc[0]).sort_index(),
+ psdf.groupby(["b", "c"])
+ .apply(lambda x: x.iloc[0], include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby(["b", "c"])
+ .apply(lambda x: x.iloc[0], include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
- psdf.groupby(["b", "c"]).apply(lambda x: x["a"]).sort_index(),
- pdf.groupby(["b", "c"]).apply(lambda x: x["a"]).sort_index(),
+ psdf.groupby(["b", "c"])
+ .apply(lambda x: x["a"], include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby(["b", "c"])
+ .apply(lambda x: x["a"], include_groups=include_groups)
+ .sort_index(),
)
+ def test_apply_return_series_with_multi_index_columns(self):
+ # SPARK-36907: Fix DataFrameGroupBy.apply without shortcut.
+ pdf = pd.DataFrame(
+ {"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9,
16, 25, 36]},
+ columns=["a", "b", "c"],
+ )
+ psdf = ps.from_pandas(pdf)
+
# multi-index columns
columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y",
"c")])
pdf.columns = columns
psdf.columns = columns
+ if LooseVersion(pd.__version__) < "3.0.0":
+ for include_groups in [True, False]:
+ with self.subTest(include_groups=include_groups):
+ self._check_apply_return_series_with_multi_index_columns(
+ psdf, pdf, include_groups
+ )
+ else:
+ self._check_apply_return_series_with_multi_index_columns(
+ psdf, pdf, include_groups=False
+ )
+ with self.assertRaises(ValueError):
+ psdf.groupby(("x", "b")).apply(lambda x: x + x.iloc[0],
include_groups=True)
+
+ def _check_apply_return_series_with_multi_index_columns(self, psdf, pdf,
include_groups):
self.assert_eq(
- psdf.groupby(("x", "b")).apply(lambda x: x.iloc[0]).sort_index(),
- pdf.groupby(("x", "b")).apply(lambda x: x.iloc[0]).sort_index(),
+ psdf.groupby(("x", "b"))
+ .apply(lambda x: x.iloc[0], include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby(("x", "b"))
+ .apply(lambda x: x.iloc[0], include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
- psdf.groupby(("x", "b")).apply(lambda x: x[("x",
"a")]).sort_index(),
- pdf.groupby(("x", "b")).apply(lambda x: x[("x",
"a")]).sort_index(),
+ psdf.groupby(("x", "b"))
+ .apply(lambda x: x[("x", "a")], include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby(("x", "b"))
+ .apply(lambda x: x[("x", "a")], include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
- psdf.groupby([("x", "b"), ("y", "c")]).apply(lambda x:
x.iloc[0]).sort_index(),
- pdf.groupby([("x", "b"), ("y", "c")]).apply(lambda x:
x.iloc[0]).sort_index(),
+ psdf.groupby([("x", "b"), ("y", "c")])
+ .apply(lambda x: x.iloc[0], include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby([("x", "b"), ("y", "c")])
+ .apply(lambda x: x.iloc[0], include_groups=include_groups)
+ .sort_index(),
)
self.assert_eq(
- psdf.groupby([("x", "b"), ("y", "c")]).apply(lambda x: x[("x",
"a")]).sort_index(),
- pdf.groupby([("x", "b"), ("y", "c")]).apply(lambda x: x[("x",
"a")]).sort_index(),
+ psdf.groupby([("x", "b"), ("y", "c")])
+ .apply(lambda x: x[("x", "a")], include_groups=include_groups)
+ .sort_index(),
+ pdf.groupby([("x", "b"), ("y", "c")])
+ .apply(lambda x: x[("x", "a")], include_groups=include_groups)
+ .sort_index(),
)
def test_apply_return_series_without_shortcut(self):
@@ -323,6 +522,10 @@ class GroupbyApplyFuncMixin:
with ps.option_context("compute.shortcut_limit", 2):
self.test_apply_return_series()
+ def
test_apply_return_series_with_multi_index_columns_without_shortcut(self):
+ with ps.option_context("compute.shortcut_limit", 2):
+ self.test_apply_return_series_with_multi_index_columns()
+
def test_apply_explicitly_infer(self):
# SPARK-39317
from pyspark.pandas.utils import SPARK_CONF_ARROW_ENABLED
diff --git a/python/pyspark/pandas/tests/test_categorical.py
b/python/pyspark/pandas/tests/test_categorical.py
index 04b8c3eeae1e..7a78e02e0519 100644
--- a/python/pyspark/pandas/tests/test_categorical.py
+++ b/python/pyspark/pandas/tests/test_categorical.py
@@ -20,6 +20,7 @@ import pandas as pd
from pandas.api.types import CategoricalDtype
import pyspark.pandas as ps
+from pyspark.loose_version import LooseVersion
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
@@ -340,8 +341,15 @@ class CategoricalTestsMixin:
pdf, psdf = self.df_pair
- def identity(df) -> ps.DataFrame[zip(psdf.columns, psdf.dtypes)]:
- return df
+ if LooseVersion(pd.__version__) < "3.0.0":
+
+ def identity(df) -> ps.DataFrame[zip(psdf.columns, psdf.dtypes)]:
+ return df
+
+ else:
+
+ def identity(df) -> ps.DataFrame[zip(psdf.columns[1:],
psdf.dtypes[1:])]:
+ return df
self.assert_eq(
psdf.groupby("a").apply(identity).sort_values(["b"]).reset_index(drop=True),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]