zhengruifeng commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2471493597
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1245,6 +1245,88 @@ def __repr__(self):
return "GroupPandasUDFSerializer"
+class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
+ """
+ Serializer for grouped map Pandas iterator UDFs.
+
+ Loads grouped data as pandas.Series and serializes results from iterator
UDFs.
+ Flattens the (dataframes_generator, arrow_type) tuple by iterating over
the generator.
+ """
+
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ int_to_decimal_coercion_enabled,
+ ):
+ super(GroupPandasIterUDFSerializer, self).__init__(
+ timezone=timezone,
+ safecheck=safecheck,
+ assign_cols_by_name=assign_cols_by_name,
+ df_for_struct=False,
+ struct_in_pandas="dict",
+ ndarray_as_list=False,
+ arrow_cast=True,
+ input_types=None,
+ int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ )
+
+ def load_stream(self, stream):
+ """
+ Deserialize Grouped ArrowRecordBatches and yield a generator of
pandas.Series lists
+ (one list per batch), allowing the iterator UDF to process data
batch-by-batch.
+ """
+ import pyarrow as pa
+
+ dataframes_in_group = None
+
+ while dataframes_in_group is None or dataframes_in_group > 0:
+ dataframes_in_group = read_int(stream)
+
+ if dataframes_in_group == 1:
+ # Lazily read and convert Arrow batches one at a time from the
stream
+ # This avoids loading all batches into memory for the group
+ batch_iter = ArrowStreamSerializer.load_stream(self, stream)
+
+ # Convert each Arrow batch to pandas Series list on-demand
+ series_batches_gen = (
+ [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(pa.Table.from_batches([batch]).itercolumns())
+ ]
+ for batch in batch_iter
+ )
+
+ # Yield the generator for this group
+ # The generator must be fully consumed before the next group
is processed
+ yield series_batches_gen
Review Comment:
let's keep in line with
https://github.com/apache/spark/blob/7bd18e3852f1a2160fcd0838f7d9937ea34926b4/python/pyspark/sql/pandas/serializers.py#L1123-L1146
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]