This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 464f8b576aa8 [SPARK-55156][PS] Deal with `include_groups` for 
`groupby.apply`
464f8b576aa8 is described below

commit 464f8b576aa8147a94afca4fc97029a2fc63b2dd
Author: Takuya Ueshin <[email protected]>
AuthorDate: Fri Feb 13 10:57:31 2026 +0800

    [SPARK-55156][PS] Deal with `include_groups` for `groupby.apply`
    
    ### What changes were proposed in this pull request?
    
    Deals with `include_groups` for `groupby.apply`.
    
    - Added `include_groups` for `groupby.apply`
        - with pandas 2
          - `True` by default
          - If set to `False`, it behaves like pandas 3
        - with pandas 3
          - `False` by default
          - If set to `True`, it raises an exception
    
    ### Why are the changes needed?
    
    `df.groupby.apply()` now has `include_groups=False` by default, which 
differs from our behavior.
    
    For example:
    
    ```py
    >>> import pandas as pd
    >>> pdf = pd.DataFrame(
    ...     {"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9, 
16, 25, 36]},
    ...     columns=["a", "b", "c"],
    ... )
    ```
    
    - pandas 2
    
    ```py
    >>> pd.__version__
    '2.3.3'
    >>> pdf.groupby("b").apply(lambda x: x + x.min())
          a   b   c
    b
    1 0   2   2   2
      1   3   2   5
    2 2   6   4  18
    3 3   8   6  32
    5 4  10  10  50
    8 5  12  16  72
    
    >>> pdf.groupby("b").apply(lambda x: x + x.min(), include_groups=False)
          a   c
    b
    1 0   2   2
      1   3   5
    2 2   6  18
    3 3   8  32
    5 4  10  50
    8 5  12  72
    ```
    
    - pandas 3
    
    ```py
    >>> pd.__version__
    '3.0.0'
    >>> pdf.groupby("b").apply(lambda x: x + x.min())
          a   c
    b
    1 0   2   2
      1   3   5
    2 2   6  18
    3 3   8  32
    5 4  10  50
    8 5  12  72
    
    >>> pdf.groupby("b").apply(lambda x: x + x.min(), include_groups=True)
    Traceback (most recent call last):
    ...
    ValueError: include_groups=True is no longer allowed.
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it will behave more like pandas 3.
    
    ### How was this patch tested?
    
    Updated the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54285 from ueshin/issues/SPARK-55156/include_groups.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/pandas/groupby.py                   |  45 ++-
 .../tests/diff_frames_ops/test_groupby_apply.py    |   4 +
 .../pandas/tests/groupby/test_apply_func.py        | 363 ++++++++++++++++-----
 python/pyspark/pandas/tests/test_categorical.py    |  12 +-
 4 files changed, 337 insertions(+), 87 deletions(-)

diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index ff2b39fe0133..37993e5f2499 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -58,6 +58,7 @@ from pyspark.sql.types import (
     StringType,
 )
 from pyspark import pandas as ps  # For running doctests and reference 
resolution in PyCharm.
+from pyspark._globals import _NoValue, _NoValueType
 from pyspark.loose_version import LooseVersion
 from pyspark.pandas._typing import Axis, FrameLike, Label, Name
 from pyspark.pandas.typedef import infer_return_type, DataFrameType, 
ScalarType, SeriesType
@@ -1808,7 +1809,13 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
             numeric_only=True,
         )
 
-    def apply(self, func: Callable, *args: Any, **kwargs: Any) -> 
Union[DataFrame, Series]:
+    def apply(
+        self,
+        func: Callable,
+        *args: Any,
+        include_groups: Union[bool, _NoValueType] = _NoValue,
+        **kwargs: Any,
+    ) -> Union[DataFrame, Series]:
         """
         Apply function `func` group-wise and combine the results together.
 
@@ -1963,6 +1970,24 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
         if not callable(func):
             raise TypeError("%s object is not callable" % type(func).__name__)
 
+        if LooseVersion(pd.__version__) < "3.0.0":
+            if include_groups is _NoValue:
+                include_groups = True
+            if include_groups:
+                warnings.warn(
+                    "DataFrameGroupBy.apply operated on the grouping columns. "
+                    "This behavior is deprecated, and in a future version of 
pandas "
+                    "the grouping columns will be excluded from the operation. 
"
+                    "Either pass `include_groups=False` to exclude the 
groupings or "
+                    "explicitly select the grouping columns after groupby to 
silence this warning.",
+                    FutureWarning,
+                )
+        else:
+            if include_groups is _NoValue:
+                include_groups = False
+            if include_groups:
+                raise ValueError("include_groups=True is no longer allowed.")
+
         spec = inspect.getfullargspec(func)
         return_sig = spec.annotations.get("return", None)
         should_infer_schema = return_sig is None
@@ -1980,10 +2005,15 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
                 for label in psdf._internal.column_labels
                 if label not in self._column_labels_to_exclude
             ]
+            if not include_groups:
+                agg_columns = [
+                    col for col in agg_columns if all(col is not gkey for gkey 
in self._groupkeys)
+                ]
 
         psdf, groupkey_labels, groupkey_names = 
GroupBy._prepare_group_map_apply(
             psdf, self._groupkeys, agg_columns
         )
+        groupkey_psser_names = [psser.name for psser in self._groupkeys]
 
         if LooseVersion(pd.__version__) < "3.0.0":
             from pandas.core.common import is_builtin_func  # type: 
ignore[import-not-found]
@@ -2014,8 +2044,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
             sample_limit = limit + 1 if limit else 2
             pdf = psdf.head(sample_limit)._to_internal_pandas()
             groupkeys = [
-                pdf[groupkey_name].rename(psser.name)
-                for groupkey_name, psser in zip(groupkey_names, 
self._groupkeys)
+                pdf[groupkey_name].rename(name)
+                for groupkey_name, name in zip(groupkey_names, 
groupkey_psser_names)
             ]
             grouped = pdf.groupby(groupkeys)
             if is_series_groupby:
@@ -2086,10 +2116,15 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
                 return_schema = StructType([field.struct_field for field in 
data_fields])
 
         def pandas_groupby_apply(pdf: pd.DataFrame) -> pd.DataFrame:
+            groupkeys = [
+                pdf[groupkey_name].rename(name)
+                for groupkey_name, name in zip(groupkey_names, 
groupkey_psser_names)
+            ]
+            grouped = pdf.groupby(groupkeys)
             if is_series_groupby:
-                pdf_or_ser = 
pdf.groupby(groupkey_names)[name].apply(pandas_apply, *args, **kwargs)
+                pdf_or_ser = grouped[name].apply(pandas_apply, *args, **kwargs)
             else:
-                pdf_or_ser = pdf.groupby(groupkey_names).apply(pandas_apply, 
*args, **kwargs)
+                pdf_or_ser = grouped.apply(pandas_apply, *args, **kwargs)
                 if should_return_series and isinstance(pdf_or_ser, 
pd.DataFrame):
                     pdf_or_ser = pdf_or_ser.stack()
 
diff --git a/python/pyspark/pandas/tests/diff_frames_ops/test_groupby_apply.py 
b/python/pyspark/pandas/tests/diff_frames_ops/test_groupby_apply.py
index b5465423c6d8..552580b0ceed 100644
--- a/python/pyspark/pandas/tests/diff_frames_ops/test_groupby_apply.py
+++ b/python/pyspark/pandas/tests/diff_frames_ops/test_groupby_apply.py
@@ -60,6 +60,10 @@ class GroupByApplyMixin:
             pdf.groupby(["a", pkey]).apply(lambda x: x + x.min()).sort_index(),
         )
 
+    def test_apply_without_shortcut(self):
+        with ps.option_context("compute.shortcut_limit", 0):
+            self.test_apply()
+
 
 class GroupByApplyTests(
     GroupByApplyMixin,
diff --git a/python/pyspark/pandas/tests/groupby/test_apply_func.py 
b/python/pyspark/pandas/tests/groupby/test_apply_func.py
index 25cc21423f32..5716e574cb44 100644
--- a/python/pyspark/pandas/tests/groupby/test_apply_func.py
+++ b/python/pyspark/pandas/tests/groupby/test_apply_func.py
@@ -19,6 +19,7 @@ import numpy as np
 import pandas as pd
 
 from pyspark import pandas as ps
+from pyspark.loose_version import LooseVersion
 from pyspark.pandas.config import option_context
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
@@ -31,100 +32,180 @@ class GroupbyApplyFuncMixin:
             columns=["a", "b", "c"],
         )
         psdf = ps.from_pandas(pdf)
+
+        if LooseVersion(pd.__version__) < "3.0.0":
+            for include_groups in [True, False]:
+                with self.subTest(include_groups=include_groups):
+                    self._check_apply(psdf, pdf, include_groups)
+        else:
+            self._check_apply(psdf, pdf, include_groups=False)
+            with self.assertRaises(ValueError):
+                psdf.groupby("b").apply(lambda x: x + x.min(), 
include_groups=True)
+
+    def _check_apply(self, psdf, pdf, include_groups):
         self.assert_eq(
-            psdf.groupby("b").apply(lambda x: x + x.min()).sort_index(),
-            pdf.groupby("b").apply(lambda x: x + x.min()).sort_index(),
+            psdf.groupby("b")
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby("b")
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
-            psdf.groupby("b").apply(len).sort_index(),
-            pdf.groupby("b").apply(len).sort_index(),
+            psdf.groupby("b").apply(len, 
include_groups=include_groups).sort_index(),
+            pdf.groupby("b").apply(len, 
include_groups=include_groups).sort_index(),
         )
         self.assert_eq(
             psdf.groupby("b")["a"]
-            .apply(lambda x, y, z: x + x.min() + y * z, 10, z=20)
+            .apply(lambda x, y, z: x + x.min() + y * z, 10, z=20, 
include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby("b")["a"]
+            .apply(lambda x, y, z: x + x.min() + y * z, 10, z=20, 
include_groups=include_groups)
             .sort_index(),
-            pdf.groupby("b")["a"].apply(lambda x, y, z: x + x.min() + y * z, 
10, z=20).sort_index(),
         )
         self.assert_eq(
-            psdf.groupby("b")[["a"]].apply(lambda x: x + x.min()).sort_index(),
-            pdf.groupby("b")[["a"]].apply(lambda x: x + x.min()).sort_index(),
+            psdf.groupby("b")[["a"]]
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby("b")[["a"]]
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
             psdf.groupby(["a", "b"])
-            .apply(lambda x, y, z: x + x.min() + y + z, 1, z=2)
+            .apply(lambda x, y, z: x + x.min() + y + z, 1, z=2, 
include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby(["a", "b"])
+            .apply(lambda x, y, z: x + x.min() + y + z, 1, z=2, 
include_groups=include_groups)
             .sort_index(),
-            pdf.groupby(["a", "b"]).apply(lambda x, y, z: x + x.min() + y + z, 
1, z=2).sort_index(),
         )
         self.assert_eq(
-            psdf.groupby(["b"])["c"].apply(lambda x: 1).sort_index(),
-            pdf.groupby(["b"])["c"].apply(lambda x: 1).sort_index(),
+            psdf.groupby(["b"])["c"].apply(lambda x: 1, 
include_groups=include_groups).sort_index(),
+            pdf.groupby(["b"])["c"].apply(lambda x: 1, 
include_groups=include_groups).sort_index(),
         )
         self.assert_eq(
-            psdf.groupby(["b"])["c"].apply(len).sort_index(),
-            pdf.groupby(["b"])["c"].apply(len).sort_index(),
+            psdf.groupby(["b"])["c"].apply(len, 
include_groups=include_groups).sort_index(),
+            pdf.groupby(["b"])["c"].apply(len, 
include_groups=include_groups).sort_index(),
         )
         self.assert_eq(
-            psdf.groupby(psdf.b // 5).apply(lambda x: x + 
x.min()).sort_index(),
-            pdf.groupby(pdf.b // 5).apply(lambda x: x + x.min()).sort_index(),
+            psdf.groupby(psdf.b // 5)
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby(pdf.b // 5)
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
             almost=True,
         )
         self.assert_eq(
-            psdf.groupby(psdf.b // 5)["a"].apply(lambda x: x + 
x.min()).sort_index(),
-            pdf.groupby(pdf.b // 5)["a"].apply(lambda x: x + 
x.min()).sort_index(),
+            psdf.groupby(psdf.b // 5)["a"]
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby(pdf.b // 5)["a"]
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
             almost=True,
         )
         self.assert_eq(
-            psdf.groupby(psdf.b // 5)[["a"]].apply(lambda x: x + 
x.min()).sort_index(),
-            pdf.groupby(pdf.b // 5)[["a"]].apply(lambda x: x + 
x.min()).sort_index(),
+            psdf.groupby(psdf.b // 5)[["a"]]
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby(pdf.b // 5)[["a"]]
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
             almost=True,
         )
         self.assert_eq(
-            psdf.groupby(psdf.b // 5)[["a"]].apply(len).sort_index(),
-            pdf.groupby(pdf.b // 5)[["a"]].apply(len).sort_index(),
+            psdf.groupby(psdf.b // 5)[["a"]].apply(len, 
include_groups=include_groups).sort_index(),
+            pdf.groupby(pdf.b // 5)[["a"]].apply(len, 
include_groups=include_groups).sort_index(),
             almost=True,
         )
         self.assert_eq(
-            psdf.a.rename().groupby(psdf.b).apply(lambda x: x + 
x.min()).sort_index(),
-            pdf.a.rename().groupby(pdf.b).apply(lambda x: x + 
x.min()).sort_index(),
+            psdf.a.rename()
+            .groupby(psdf.b)
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.a.rename()
+            .groupby(pdf.b)
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
-            psdf.a.groupby(psdf.b.rename()).apply(lambda x: x + 
x.min()).sort_index(),
-            pdf.a.groupby(pdf.b.rename()).apply(lambda x: x + 
x.min()).sort_index(),
+            psdf.a.groupby(psdf.b.rename())
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.a.groupby(pdf.b.rename())
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
-            psdf.a.rename().groupby(psdf.b.rename()).apply(lambda x: x + 
x.min()).sort_index(),
-            pdf.a.rename().groupby(pdf.b.rename()).apply(lambda x: x + 
x.min()).sort_index(),
+            psdf.a.rename()
+            .groupby(psdf.b.rename())
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.a.rename()
+            .groupby(pdf.b.rename())
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
         )
 
         with self.assertRaisesRegex(TypeError, "int object is not callable"):
-            psdf.groupby("b").apply(1)
+            psdf.groupby("b").apply(1, include_groups=include_groups)
+
+    def test_apply_with_multi_index_columns(self):
+        pdf = pd.DataFrame(
+            {"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9, 
16, 25, 36]},
+            columns=["a", "b", "c"],
+        )
+        psdf = ps.from_pandas(pdf)
 
         # multi-index columns
         columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", 
"c")])
         pdf.columns = columns
         psdf.columns = columns
 
+        if LooseVersion(pd.__version__) < "3.0.0":
+            for include_groups in [True, False]:
+                with self.subTest(include_groups=include_groups):
+                    self._check_apply_with_multi_index_columns(psdf, pdf, 
include_groups)
+        else:
+            self._check_apply_with_multi_index_columns(psdf, pdf, 
include_groups=False)
+            with self.assertRaises(ValueError):
+                psdf.groupby(("x", "b")).apply(lambda x: x + x.min(), 
include_groups=True)
+
+    def _check_apply_with_multi_index_columns(self, psdf, pdf, include_groups):
         self.assert_eq(
-            psdf.groupby(("x", "b")).apply(lambda x: 1).sort_index(),
-            pdf.groupby(("x", "b")).apply(lambda x: 1).sort_index(),
+            psdf.groupby(("x", "b")).apply(lambda x: 1, 
include_groups=include_groups).sort_index(),
+            pdf.groupby(("x", "b")).apply(lambda x: 1, 
include_groups=include_groups).sort_index(),
         )
         self.assert_eq(
-            psdf.groupby([("x", "a"), ("x", "b")]).apply(lambda x: x + 
x.min()).sort_index(),
-            pdf.groupby([("x", "a"), ("x", "b")]).apply(lambda x: x + 
x.min()).sort_index(),
+            psdf.groupby([("x", "a"), ("x", "b")])
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby([("x", "a"), ("x", "b")])
+            .apply(lambda x: x + x.min(), include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
-            psdf.groupby(("x", "b")).apply(len).sort_index(),
-            pdf.groupby(("x", "b")).apply(len).sort_index(),
+            psdf.groupby(("x", "b")).apply(len, 
include_groups=include_groups).sort_index(),
+            pdf.groupby(("x", "b")).apply(len, 
include_groups=include_groups).sort_index(),
         )
         self.assert_eq(
-            psdf.groupby([("x", "a"), ("x", "b")]).apply(len).sort_index(),
-            pdf.groupby([("x", "a"), ("x", "b")]).apply(len).sort_index(),
+            psdf.groupby([("x", "a"), ("x", "b")])
+            .apply(len, include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby([("x", "a"), ("x", "b")])
+            .apply(len, include_groups=include_groups)
+            .sort_index(),
         )
 
     def test_apply_without_shortcut(self):
         with option_context("compute.shortcut_limit", 0):
             self.test_apply()
 
+    def test_apply_with_multi_index_columns_without_shortcut(self):
+        with option_context("compute.shortcut_limit", 0):
+            self.test_apply_with_multi_index_columns()
+
     def test_apply_with_type_hint(self):
         pdf = pd.DataFrame(
             {"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9, 
16, 25, 36]},
@@ -132,27 +213,61 @@ class GroupbyApplyFuncMixin:
         )
         psdf = ps.from_pandas(pdf)
 
-        def add_max1(x) -> ps.DataFrame[int, int, int]:
-            return x + x.min()
+        if LooseVersion(pd.__version__) < "3.0.0":
+            for include_groups in [True, False]:
+                with self.subTest(include_groups=include_groups):
+                    self._check_apply_with_type_hint(psdf, pdf, include_groups)
+        else:
+            self._check_apply_with_type_hint(psdf, pdf, include_groups=False)
+
+    def _check_apply_with_type_hint(self, psdf, pdf, include_groups):
+        if include_groups:
+
+            def add_max1(x) -> ps.DataFrame[int, int, int]:
+                return x + x.min()
+
+        else:
+
+            def add_max1(x) -> ps.DataFrame[int, int]:
+                return x + x.min()
 
         # Type hints set the default column names, and we use default index for
         # pandas API on Spark. Here we ignore both diff.
-        actual = psdf.groupby("b").apply(add_max1).sort_index()
-        expected = pdf.groupby("b").apply(add_max1).sort_index()
-        self.assert_eq(sorted(actual["c0"].to_numpy()), 
sorted(expected["a"].to_numpy()))
-        self.assert_eq(sorted(actual["c1"].to_numpy()), 
sorted(expected["b"].to_numpy()))
-        self.assert_eq(sorted(actual["c2"].to_numpy()), 
sorted(expected["c"].to_numpy()))
-
-        def add_max2(
-            x,
-        ) -> ps.DataFrame[slice("a", int), slice("b", int), slice("c", int)]:
-            return x + x.min()
-
-        actual = psdf.groupby("b").apply(add_max2).sort_index()
-        expected = pdf.groupby("b").apply(add_max2).sort_index()
-        self.assert_eq(sorted(actual["a"].to_numpy()), 
sorted(expected["a"].to_numpy()))
-        self.assert_eq(sorted(actual["c"].to_numpy()), 
sorted(expected["c"].to_numpy()))
-        self.assert_eq(sorted(actual["c"].to_numpy()), 
sorted(expected["c"].to_numpy()))
+        actual = psdf.groupby("b").apply(add_max1, 
include_groups=include_groups).sort_index()
+        expected = pdf.groupby("b").apply(add_max1, 
include_groups=include_groups).sort_index()
+
+        if include_groups:
+            self.assert_eq(sorted(actual["c0"].to_numpy()), 
sorted(expected["a"].to_numpy()))
+            self.assert_eq(sorted(actual["c1"].to_numpy()), 
sorted(expected["b"].to_numpy()))
+            self.assert_eq(sorted(actual["c2"].to_numpy()), 
sorted(expected["c"].to_numpy()))
+        else:
+            self.assert_eq(sorted(actual["c0"].to_numpy()), 
sorted(expected["a"].to_numpy()))
+            self.assert_eq(sorted(actual["c1"].to_numpy()), 
sorted(expected["c"].to_numpy()))
+
+        if include_groups:
+
+            def add_max2(
+                x,
+            ) -> ps.DataFrame[slice("a", int), slice("b", int), slice("c", 
int)]:
+                return x + x.min()
+
+        else:
+
+            def add_max2(
+                x,
+            ) -> ps.DataFrame[slice("a", int), slice("c", int)]:
+                return x + x.min()
+
+        actual = psdf.groupby("b").apply(add_max2, 
include_groups=include_groups).sort_index()
+        expected = pdf.groupby("b").apply(add_max2, 
include_groups=include_groups).sort_index()
+
+        if include_groups:
+            self.assert_eq(sorted(actual["a"].to_numpy()), 
sorted(expected["a"].to_numpy()))
+            self.assert_eq(sorted(actual["b"].to_numpy()), 
sorted(expected["b"].to_numpy()))
+            self.assert_eq(sorted(actual["c"].to_numpy()), 
sorted(expected["c"].to_numpy()))
+        else:
+            self.assert_eq(sorted(actual["a"].to_numpy()), 
sorted(expected["a"].to_numpy()))
+            self.assert_eq(sorted(actual["c"].to_numpy()), 
sorted(expected["c"].to_numpy()))
 
     def test_apply_negative(self):
         def func(_) -> ps.Series[int]:
@@ -245,18 +360,38 @@ class GroupbyApplyFuncMixin:
         )
         psdf = ps.from_pandas(pdf)
 
+        if LooseVersion(pd.__version__) < "3.0.0":
+            for include_groups in [True, False]:
+                with self.subTest(include_groups=include_groups):
+                    self._check_apply_with_side_effect(psdf, pdf, 
include_groups)
+        else:
+            self._check_apply_with_side_effect(psdf, pdf, include_groups=False)
+
+    def _check_apply_with_side_effect(self, psdf, pdf, include_groups):
         acc = ps.utils.default_session().sparkContext.accumulator(0)
 
-        def sum_with_acc_frame(x) -> ps.DataFrame[np.float64, np.float64]:
-            nonlocal acc
-            acc += 1
-            return np.sum(x)
+        if include_groups:
+
+            def sum_with_acc_frame(x) -> ps.DataFrame[np.float64, np.float64]:
+                nonlocal acc
+                acc += 1
+                return np.sum(x)
 
-        actual = psdf.groupby("d").apply(sum_with_acc_frame)
-        actual.columns = ["d", "v"]
+        else:
+
+            def sum_with_acc_frame(x) -> ps.DataFrame[np.float64]:
+                nonlocal acc
+                acc += 1
+                return np.sum(x)
+
+        actual = psdf.groupby("d").apply(sum_with_acc_frame, 
include_groups=include_groups)
+        actual.columns = ["d", "v"] if include_groups else ["v"]
         self.assert_eq(
             actual._to_pandas().sort_index(),
-            pdf.groupby("d").apply(sum).sort_index().reset_index(drop=True),
+            pdf.groupby("d")
+            .apply(sum, include_groups=include_groups)
+            .sort_index()
+            .reset_index(drop=True),
         )
         self.assert_eq(acc.value, 2)
 
@@ -266,8 +401,14 @@ class GroupbyApplyFuncMixin:
             return np.sum(x)
 
         self.assert_eq(
-            
psdf.groupby("d")["v"].apply(sum_with_acc_series)._to_pandas().sort_index(),
-            
pdf.groupby("d")["v"].apply(sum).sort_index().reset_index(drop=True),
+            psdf.groupby("d")["v"]
+            .apply(sum_with_acc_series, include_groups=include_groups)
+            ._to_pandas()
+            .sort_index(),
+            pdf.groupby("d")["v"]
+            .apply(sum, include_groups=include_groups)
+            .sort_index()
+            .reset_index(drop=True),
         )
         self.assert_eq(acc.value, 4)
 
@@ -279,43 +420,101 @@ class GroupbyApplyFuncMixin:
         )
         psdf = ps.from_pandas(pdf)
 
+        if LooseVersion(pd.__version__) < "3.0.0":
+            for include_groups in [True, False]:
+                with self.subTest(include_groups=include_groups):
+                    self._check_apply_return_series(psdf, pdf, include_groups)
+        else:
+            self._check_apply_return_series(psdf, pdf, include_groups=False)
+            with self.assertRaises(ValueError):
+                psdf.groupby("b").apply(lambda x: x + x.iloc[0], 
include_groups=True)
+
+    def _check_apply_return_series(self, psdf, pdf, include_groups):
         self.assert_eq(
-            psdf.groupby("b").apply(lambda x: x.iloc[0]).sort_index(),
-            pdf.groupby("b").apply(lambda x: x.iloc[0]).sort_index(),
+            psdf.groupby("b")
+            .apply(lambda x: x.iloc[0], include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby("b").apply(lambda x: x.iloc[0], 
include_groups=include_groups).sort_index(),
         )
         self.assert_eq(
-            psdf.groupby("b").apply(lambda x: x["a"]).sort_index(),
-            pdf.groupby("b").apply(lambda x: x["a"]).sort_index(),
+            psdf.groupby("b").apply(lambda x: x["a"], 
include_groups=include_groups).sort_index(),
+            pdf.groupby("b").apply(lambda x: x["a"], 
include_groups=include_groups).sort_index(),
         )
         self.assert_eq(
-            psdf.groupby(["b", "c"]).apply(lambda x: x.iloc[0]).sort_index(),
-            pdf.groupby(["b", "c"]).apply(lambda x: x.iloc[0]).sort_index(),
+            psdf.groupby(["b", "c"])
+            .apply(lambda x: x.iloc[0], include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby(["b", "c"])
+            .apply(lambda x: x.iloc[0], include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
-            psdf.groupby(["b", "c"]).apply(lambda x: x["a"]).sort_index(),
-            pdf.groupby(["b", "c"]).apply(lambda x: x["a"]).sort_index(),
+            psdf.groupby(["b", "c"])
+            .apply(lambda x: x["a"], include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby(["b", "c"])
+            .apply(lambda x: x["a"], include_groups=include_groups)
+            .sort_index(),
         )
 
+    def test_apply_return_series_with_multi_index_columns(self):
+        # SPARK-36907: Fix DataFrameGroupBy.apply without shortcut.
+        pdf = pd.DataFrame(
+            {"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9, 
16, 25, 36]},
+            columns=["a", "b", "c"],
+        )
+        psdf = ps.from_pandas(pdf)
+
         # multi-index columns
         columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", 
"c")])
         pdf.columns = columns
         psdf.columns = columns
 
+        if LooseVersion(pd.__version__) < "3.0.0":
+            for include_groups in [True, False]:
+                with self.subTest(include_groups=include_groups):
+                    self._check_apply_return_series_with_multi_index_columns(
+                        psdf, pdf, include_groups
+                    )
+        else:
+            self._check_apply_return_series_with_multi_index_columns(
+                psdf, pdf, include_groups=False
+            )
+            with self.assertRaises(ValueError):
+                psdf.groupby(("x", "b")).apply(lambda x: x + x.iloc[0], 
include_groups=True)
+
+    def _check_apply_return_series_with_multi_index_columns(self, psdf, pdf, 
include_groups):
         self.assert_eq(
-            psdf.groupby(("x", "b")).apply(lambda x: x.iloc[0]).sort_index(),
-            pdf.groupby(("x", "b")).apply(lambda x: x.iloc[0]).sort_index(),
+            psdf.groupby(("x", "b"))
+            .apply(lambda x: x.iloc[0], include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby(("x", "b"))
+            .apply(lambda x: x.iloc[0], include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
-            psdf.groupby(("x", "b")).apply(lambda x: x[("x", 
"a")]).sort_index(),
-            pdf.groupby(("x", "b")).apply(lambda x: x[("x", 
"a")]).sort_index(),
+            psdf.groupby(("x", "b"))
+            .apply(lambda x: x[("x", "a")], include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby(("x", "b"))
+            .apply(lambda x: x[("x", "a")], include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
-            psdf.groupby([("x", "b"), ("y", "c")]).apply(lambda x: 
x.iloc[0]).sort_index(),
-            pdf.groupby([("x", "b"), ("y", "c")]).apply(lambda x: 
x.iloc[0]).sort_index(),
+            psdf.groupby([("x", "b"), ("y", "c")])
+            .apply(lambda x: x.iloc[0], include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby([("x", "b"), ("y", "c")])
+            .apply(lambda x: x.iloc[0], include_groups=include_groups)
+            .sort_index(),
         )
         self.assert_eq(
-            psdf.groupby([("x", "b"), ("y", "c")]).apply(lambda x: x[("x", 
"a")]).sort_index(),
-            pdf.groupby([("x", "b"), ("y", "c")]).apply(lambda x: x[("x", 
"a")]).sort_index(),
+            psdf.groupby([("x", "b"), ("y", "c")])
+            .apply(lambda x: x[("x", "a")], include_groups=include_groups)
+            .sort_index(),
+            pdf.groupby([("x", "b"), ("y", "c")])
+            .apply(lambda x: x[("x", "a")], include_groups=include_groups)
+            .sort_index(),
         )
 
     def test_apply_return_series_without_shortcut(self):
@@ -323,6 +522,10 @@ class GroupbyApplyFuncMixin:
         with ps.option_context("compute.shortcut_limit", 2):
             self.test_apply_return_series()
 
+    def 
test_apply_return_series_with_multi_index_columns_without_shortcut(self):
+        with ps.option_context("compute.shortcut_limit", 2):
+            self.test_apply_return_series_with_multi_index_columns()
+
     def test_apply_explicitly_infer(self):
         # SPARK-39317
         from pyspark.pandas.utils import SPARK_CONF_ARROW_ENABLED
diff --git a/python/pyspark/pandas/tests/test_categorical.py 
b/python/pyspark/pandas/tests/test_categorical.py
index 04b8c3eeae1e..7a78e02e0519 100644
--- a/python/pyspark/pandas/tests/test_categorical.py
+++ b/python/pyspark/pandas/tests/test_categorical.py
@@ -20,6 +20,7 @@ import pandas as pd
 from pandas.api.types import CategoricalDtype
 
 import pyspark.pandas as ps
+from pyspark.loose_version import LooseVersion
 from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
 
 
@@ -340,8 +341,15 @@ class CategoricalTestsMixin:
 
         pdf, psdf = self.df_pair
 
-        def identity(df) -> ps.DataFrame[zip(psdf.columns, psdf.dtypes)]:
-            return df
+        if LooseVersion(pd.__version__) < "3.0.0":
+
+            def identity(df) -> ps.DataFrame[zip(psdf.columns, psdf.dtypes)]:
+                return df
+
+        else:
+
+            def identity(df) -> ps.DataFrame[zip(psdf.columns[1:], 
psdf.dtypes[1:])]:
+                return df
 
         self.assert_eq(
             
psdf.groupby("a").apply(identity).sort_values(["b"]).reset_index(drop=True),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to