This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3f41adcc579e [SPARK-54186][PYTHON][TESTS] Fix doctests for
`PandasCogroupedOps.applyInPandas`
3f41adcc579e is described below
commit 3f41adcc579e92749356ffcff4337ae443b094a9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Nov 5 15:50:02 2025 +0800
[SPARK-54186][PYTHON][TESTS] Fix doctests for
`PandasCogroupedOps.applyInPandas`
### What changes were proposed in this pull request?
Enable doctests for `PandasCogroupedOps.applyInPandas`
### Why are the changes needed?
to improve test coverage and make sure the examples are correct
### Does this PR introduce _any_ user-facing change?
yes, doc-only changes
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52885 from zhengruifeng/enable_apply_in_pandas.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/group_ops.py | 89 +++++++++++++++++++---------------
1 file changed, 50 insertions(+), 39 deletions(-)
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index 1b4aa8798727..ddad0450ec89 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -73,20 +73,20 @@ class PandasGroupedOpsMixin:
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
- >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) #
doctest: +SKIP
+ >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
...
- >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
+ >>> df.groupby("id").apply(normalize).sort("id", "v").show()
+---+-------------------+
| id| v|
+---+-------------------+
- | 1|-0.7071067811865475|
- | 1| 0.7071067811865475|
- | 2|-0.8320502943378437|
- | 2|-0.2773500981126146|
- | 2| 1.1094003924504583|
+ | 1|-0.7071067811865...|
+ | 1| 0.7071067811865...|
+ | 2|-0.8320502943378...|
+ | 2|-0.2773500981126...|
+ | 2| 1.1094003924504...|
+---+-------------------+
See Also
@@ -159,25 +159,26 @@ class PandasGroupedOpsMixin:
Examples
--------
- >>> import pandas as pd # doctest: +SKIP
- >>> from pyspark.sql.functions import ceil
+ >>> import pandas as pd
+ >>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
- ... ("id", "v")) # doctest: +SKIP
+ ... ("id", "v"))
>>> def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
...
>>> df.groupby("id").applyInPandas(
- ... normalize, schema="id long, v double").show() # doctest: +SKIP
+ ... normalize, schema="id long, v double"
+ ... ).sort("id", "v").show()
+---+-------------------+
| id| v|
+---+-------------------+
- | 1|-0.7071067811865475|
- | 1| 0.7071067811865475|
- | 2|-0.8320502943378437|
- | 2|-0.2773500981126146|
- | 2| 1.1094003924504583|
+ | 1|-0.7071067811865...|
+ | 1| 0.7071067811865...|
+ | 2|-0.8320502943378...|
+ | 2|-0.2773500981126...|
+ | 2| 1.1094003924504...|
+---+-------------------+
Alternatively, the user can pass a function that takes two arguments.
@@ -189,14 +190,15 @@ class PandasGroupedOpsMixin:
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
- ... ("id", "v")) # doctest: +SKIP
+ ... ("id", "v"))
>>> def mean_func(key, pdf):
... # key is a tuple of one numpy.int64, which is the value
... # of 'id' for the current group
... return pd.DataFrame([key + (pdf.v.mean(),)])
...
- >>> df.groupby('id').applyInPandas(
- ... mean_func, schema="id long, v double").show() # doctest: +SKIP
+ >>> df.groupby("id").applyInPandas(
+ ... mean_func, schema="id long, v double"
+ ... ).sort("id").show()
+---+---+
| id| v|
+---+---+
@@ -209,34 +211,36 @@ class PandasGroupedOpsMixin:
... # of 'id' and 'ceil(df.v / 2)' for the current group
... return pd.DataFrame([key + (pdf.v.sum(),)])
...
- >>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas(
- ... sum_func, schema="id long, `ceil(v / 2)` long, v
double").show() # doctest: +SKIP
+ >>> df.groupby(df.id, sf.ceil(df.v / 2)).applyInPandas(
+ ... sum_func, schema="id long, `ceil(v / 2)` long, v double"
+ ... ).sort("id", "v").show()
+---+-----------+----+
| id|ceil(v / 2)| v|
+---+-----------+----+
- | 2| 5|10.0|
| 1| 1| 3.0|
- | 2| 3| 5.0|
| 2| 2| 3.0|
+ | 2| 3| 5.0|
+ | 2| 5|10.0|
+---+-----------+----+
The function can also take and return an iterator of
`pandas.DataFrame` using type
hints.
- >>> from typing import Iterator # doctest: +SKIP
+ >>> from typing import Iterator
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
- ... ("id", "v")) # doctest: +SKIP
+ ... ("id", "v"))
>>> def filter_func(
... batches: Iterator[pd.DataFrame]
- ... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
+ ... ) -> Iterator[pd.DataFrame]:
... for batch in batches:
... # Process and yield each batch independently
... filtered = batch[batch['v'] > 2.0]
... if not filtered.empty:
... yield filtered[['v']]
>>> df.groupby("id").applyInPandas(
- ... filter_func, schema="v double").show() # doctest: +SKIP
+ ... filter_func, schema="v double"
+ ... ).sort("v").show()
+----+
| v|
+----+
@@ -250,25 +254,26 @@ class PandasGroupedOpsMixin:
be passed as the second argument. The grouping key(s) will be passed
as a tuple of numpy
data types. The data will still be passed in as an iterator of
`pandas.DataFrame`.
- >>> from typing import Iterator, Tuple, Any # doctest: +SKIP
+ >>> from typing import Iterator, Tuple, Any
>>> def transform_func(
... key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
- ... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
+ ... ) -> Iterator[pd.DataFrame]:
... for batch in batches:
... # Yield transformed results for each batch
... result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
... yield result[['id', 'v_doubled']]
>>> df.groupby("id").applyInPandas(
- ... transform_func, schema="id long, v_doubled double").show() #
doctest: +SKIP
- +---+----------+
- | id|v_doubled |
- +---+----------+
- | 1| 2.0|
- | 1| 4.0|
- | 2| 6.0|
- | 2| 10.0|
- | 2| 20.0|
- +---+----------+
+ ... transform_func, schema="id long, v_doubled double"
+ ... ).sort("id", "v_doubled").show()
+ +---+---------+
+ | id|v_doubled|
+ +---+---------+
+ | 1| 2.0|
+ | 1| 4.0|
+ | 2| 6.0|
+ | 2| 10.0|
+ | 2| 20.0|
+ +---+---------+
Notes
-----
@@ -1187,8 +1192,14 @@ def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.group_ops
+ from pyspark.testing.utils import have_pandas, have_pyarrow
globs = pyspark.sql.pandas.group_ops.__dict__.copy()
+
+ if not have_pandas or not have_pyarrow:
+ del pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.apply.__doc__
+ del
pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.applyInPandas.__doc__
+
spark = SparkSession.builder.master("local[4]").appName("sql.pandas.group
tests").getOrCreate()
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]