This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3f41adcc579e [SPARK-54186][PYTHON][TESTS] Fix doctests for 
`PandasCogroupedOps.applyInPandas`
3f41adcc579e is described below

commit 3f41adcc579e92749356ffcff4337ae443b094a9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Nov 5 15:50:02 2025 +0800

    [SPARK-54186][PYTHON][TESTS] Fix doctests for 
`PandasCogroupedOps.applyInPandas`
    
    ### What changes were proposed in this pull request?
    Enable doctests for `PandasCogroupedOps.applyInPandas`
    
    ### Why are the changes needed?
    to improve test coverage and make sure the examples are correct
    
    ### Does this PR introduce _any_ user-facing change?
    yes, doc-only changes
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52885 from zhengruifeng/enable_apply_in_pandas.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/pandas/group_ops.py | 89 +++++++++++++++++++---------------
 1 file changed, 50 insertions(+), 39 deletions(-)

diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 1b4aa8798727..ddad0450ec89 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -73,20 +73,20 @@ class PandasGroupedOpsMixin:
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
         ...     ("id", "v"))
-        >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)  # 
doctest: +SKIP
+        >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
         ... def normalize(pdf):
         ...     v = pdf.v
         ...     return pdf.assign(v=(v - v.mean()) / v.std())
         ...
-        >>> df.groupby("id").apply(normalize).show()  # doctest: +SKIP
+        >>> df.groupby("id").apply(normalize).sort("id", "v").show()
         +---+-------------------+
         | id|                  v|
         +---+-------------------+
-        |  1|-0.7071067811865475|
-        |  1| 0.7071067811865475|
-        |  2|-0.8320502943378437|
-        |  2|-0.2773500981126146|
-        |  2| 1.1094003924504583|
+        |  1|-0.7071067811865...|
+        |  1| 0.7071067811865...|
+        |  2|-0.8320502943378...|
+        |  2|-0.2773500981126...|
+        |  2| 1.1094003924504...|
         +---+-------------------+
 
         See Also
@@ -159,25 +159,26 @@ class PandasGroupedOpsMixin:
 
         Examples
         --------
-        >>> import pandas as pd  # doctest: +SKIP
-        >>> from pyspark.sql.functions import ceil
+        >>> import pandas as pd
+        >>> from pyspark.sql import functions as sf
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
-        ...     ("id", "v"))  # doctest: +SKIP
+        ...     ("id", "v"))
         >>> def normalize(pdf):
         ...     v = pdf.v
         ...     return pdf.assign(v=(v - v.mean()) / v.std())
         ...
         >>> df.groupby("id").applyInPandas(
-        ...     normalize, schema="id long, v double").show()  # doctest: +SKIP
+        ...     normalize, schema="id long, v double"
+        ... ).sort("id", "v").show()
         +---+-------------------+
         | id|                  v|
         +---+-------------------+
-        |  1|-0.7071067811865475|
-        |  1| 0.7071067811865475|
-        |  2|-0.8320502943378437|
-        |  2|-0.2773500981126146|
-        |  2| 1.1094003924504583|
+        |  1|-0.7071067811865...|
+        |  1| 0.7071067811865...|
+        |  2|-0.8320502943378...|
+        |  2|-0.2773500981126...|
+        |  2| 1.1094003924504...|
         +---+-------------------+
 
         Alternatively, the user can pass a function that takes two arguments.
@@ -189,14 +190,15 @@ class PandasGroupedOpsMixin:
 
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
-        ...     ("id", "v"))  # doctest: +SKIP
+        ...     ("id", "v"))
         >>> def mean_func(key, pdf):
         ...     # key is a tuple of one numpy.int64, which is the value
         ...     # of 'id' for the current group
         ...     return pd.DataFrame([key + (pdf.v.mean(),)])
         ...
-        >>> df.groupby('id').applyInPandas(
-        ...     mean_func, schema="id long, v double").show()  # doctest: +SKIP
+        >>> df.groupby("id").applyInPandas(
+        ...     mean_func, schema="id long, v double"
+        ... ).sort("id").show()
         +---+---+
         | id|  v|
         +---+---+
@@ -209,34 +211,36 @@ class PandasGroupedOpsMixin:
         ...     # of 'id' and 'ceil(df.v / 2)' for the current group
         ...     return pd.DataFrame([key + (pdf.v.sum(),)])
         ...
-        >>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas(
-        ...     sum_func, schema="id long, `ceil(v / 2)` long, v 
double").show()  # doctest: +SKIP
+        >>> df.groupby(df.id, sf.ceil(df.v / 2)).applyInPandas(
+        ...     sum_func, schema="id long, `ceil(v / 2)` long, v double"
+        ... ).sort("id", "v").show()
         +---+-----------+----+
         | id|ceil(v / 2)|   v|
         +---+-----------+----+
-        |  2|          5|10.0|
         |  1|          1| 3.0|
-        |  2|          3| 5.0|
         |  2|          2| 3.0|
+        |  2|          3| 5.0|
+        |  2|          5|10.0|
         +---+-----------+----+
 
         The function can also take and return an iterator of 
`pandas.DataFrame` using type
         hints.
 
-        >>> from typing import Iterator  # doctest: +SKIP
+        >>> from typing import Iterator
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
-        ...     ("id", "v"))  # doctest: +SKIP
+        ...     ("id", "v"))
         >>> def filter_func(
         ...     batches: Iterator[pd.DataFrame]
-        ... ) -> Iterator[pd.DataFrame]:  # doctest: +SKIP
+        ... ) -> Iterator[pd.DataFrame]:
         ...     for batch in batches:
         ...         # Process and yield each batch independently
         ...         filtered = batch[batch['v'] > 2.0]
         ...         if not filtered.empty:
         ...             yield filtered[['v']]
         >>> df.groupby("id").applyInPandas(
-        ...     filter_func, schema="v double").show()  # doctest: +SKIP
+        ...     filter_func, schema="v double"
+        ... ).sort("v").show()
         +----+
         |   v|
         +----+
@@ -250,25 +254,26 @@ class PandasGroupedOpsMixin:
         be passed as the second argument. The grouping key(s) will be passed 
as a tuple of numpy
         data types. The data will still be passed in as an iterator of 
`pandas.DataFrame`.
 
-        >>> from typing import Iterator, Tuple, Any  # doctest: +SKIP
+        >>> from typing import Iterator, Tuple, Any
         >>> def transform_func(
         ...     key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
-        ... ) -> Iterator[pd.DataFrame]:  # doctest: +SKIP
+        ... ) -> Iterator[pd.DataFrame]:
         ...     for batch in batches:
         ...         # Yield transformed results for each batch
         ...         result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
         ...         yield result[['id', 'v_doubled']]
         >>> df.groupby("id").applyInPandas(
-        ...     transform_func, schema="id long, v_doubled double").show()  # 
doctest: +SKIP
-        +---+----------+
-        | id|v_doubled |
-        +---+----------+
-        |  1|       2.0|
-        |  1|       4.0|
-        |  2|       6.0|
-        |  2|      10.0|
-        |  2|      20.0|
-        +---+----------+
+        ...     transform_func, schema="id long, v_doubled double"
+        ... ).sort("id", "v_doubled").show()
+        +---+---------+
+        | id|v_doubled|
+        +---+---------+
+        |  1|      2.0|
+        |  1|      4.0|
+        |  2|      6.0|
+        |  2|     10.0|
+        |  2|     20.0|
+        +---+---------+
 
         Notes
         -----
@@ -1187,8 +1192,14 @@ def _test() -> None:
     import doctest
     from pyspark.sql import SparkSession
     import pyspark.sql.pandas.group_ops
+    from pyspark.testing.utils import have_pandas, have_pyarrow
 
     globs = pyspark.sql.pandas.group_ops.__dict__.copy()
+
+    if not have_pandas or not have_pyarrow:
+        del pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.apply.__doc__
+        del 
pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.applyInPandas.__doc__
+
     spark = SparkSession.builder.master("local[4]").appName("sql.pandas.group 
tests").getOrCreate()
     globs["spark"] = spark
     (failure_count, test_count) = doctest.testmod(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to