zhengruifeng commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2467769396
##########
python/pyspark/worker.py:
##########
@@ -2625,6 +2664,12 @@ def read_udfs(pickleSer, infile, eval_type):
ser = GroupPandasUDFSerializer(
timezone, safecheck, _assign_cols_by_name,
int_to_decimal_coercion_enabled
)
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+ from pyspark.sql.pandas.serializers import
GroupPandasIterUDFSerializer
Review Comment:
let's put the import here
https://github.com/apache/spark/blob/57b4cd2ccc646990e22643388d512e997cd4299e/python/pyspark/worker.py#L54-L73
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -214,22 +220,86 @@ def applyInPandas(
| 2| 2| 3.0|
+---+-----------+----+
+ The function can also take and return an iterator of
`pandas.DataFrame` using type
+ hints.
+
+ >>> from typing import Iterator # doctest: +SKIP
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v")) # doctest: +SKIP
+ >>> def filter_func(
+ ... batches: Iterator[pd.DataFrame]
+ ... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
+ ... for batch in batches:
+ ... # Process and yield each batch independently
+ ... filtered = batch[batch['v'] > 2.0]
+ ... if not filtered.empty:
+ ... yield filtered[['v']]
+ >>> df.groupby("id").applyInPandas(
+ ... filter_func, schema="v double").show() # doctest: +SKIP
+ +----+
+ | v|
+ +----+
+ | 3.0|
+ | 5.0|
+ |10.0|
+ +----+
+
+ Alternatively, the user can pass a function that takes two arguments.
+ In this case, the grouping key(s) will be passed as the first argument
and the data will
+ be passed as the second argument. The grouping key(s) will be passed
as a tuple of numpy
+ data types. The data will still be passed in as an iterator of
`pandas.DataFrame`.
+
+ >>> from typing import Iterator, Tuple, Any # doctest: +SKIP
+ >>> def transform_func(
+ ... key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+ ... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
+ ... for batch in batches:
+ ... # Yield transformed results for each batch
+ ... result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
+ ... yield result[['id', 'v_doubled']]
+ >>> df.groupby("id").applyInPandas(
+ ... transform_func, schema="id long, v_doubled double").show() #
doctest: +SKIP
+ +---+----------+
+ | id|v_doubled |
+ +---+----------+
+ | 1| 2.0|
+ | 1| 4.0|
+ | 2| 6.0|
+ | 2| 10.0|
+ | 2| 20.0|
+ +---+----------+
+
Notes
-----
- This function requires a full shuffle. All the data of a group will be
loaded
- into memory, so the user should be aware of the potential OOM risk if
data is skewed
- and certain groups are too large to fit in memory.
+ This function requires a full shuffle. If using the `pandas.DataFrame`
API, all data of a
+ group will be loaded into memory, so the user should be aware of the
potential OOM risk if
+ data is skewed and certain groups are too large to fit in memory, and
can use the
+ iterator of `pandas.DataFrame` API to mitigate this.
See Also
--------
pyspark.sql.functions.pandas_udf
"""
from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf, PandasUDFType
+ from pyspark.sql.pandas.typehints import
infer_group_pandas_eval_type_from_func
+ from pyspark.sql.pandas.functions import PythonEvalType
+ import warnings
Review Comment:
not need to re-import `PythonEvalType` and `warnings`
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1245,6 +1245,89 @@ def __repr__(self):
return "GroupPandasUDFSerializer"
+class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
+ """
+ Serializer for grouped map Pandas iterator UDFs.
+
+ Loads grouped data as pandas.Series and serializes results from iterator
UDFs.
+ Flattens the (dataframes_generator, arrow_type) tuple by iterating over
the generator.
+ """
+
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ int_to_decimal_coercion_enabled,
+ ):
+ super(GroupPandasIterUDFSerializer, self).__init__(
+ timezone=timezone,
+ safecheck=safecheck,
+ assign_cols_by_name=assign_cols_by_name,
+ df_for_struct=False,
+ struct_in_pandas="dict",
+ ndarray_as_list=False,
+ arrow_cast=True,
+ input_types=None,
+ int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ )
+
+ def load_stream(self, stream):
+ """
+ Deserialize Grouped ArrowRecordBatches and yield a generator of
pandas.Series lists
+ (one list per batch), allowing the iterator UDF to process data
batch-by-batch.
+ """
+ import pyarrow as pa
+
+ dataframes_in_group = None
+
+ while dataframes_in_group is None or dataframes_in_group > 0:
+ dataframes_in_group = read_int(stream)
+
+ if dataframes_in_group == 1:
+ # Read all Arrow batches for this group first (must read from
stream synchronously)
+ batches = list(ArrowStreamSerializer.load_stream(self, stream))
Review Comment:
I think we cannot load all batches here, the iterator API is designed to
avoid loading all batches within a group so that it can migrate OOM
you can refer to
https://github.com/apache/spark/blob/7bd18e3852f1a2160fcd0838f7d9937ea34926b4/python/pyspark/sql/pandas/serializers.py#L1136-L1140
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]