This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3dffd12ef594 [SPARK-54438][PYTHON] Consolidate
ArrowStreamAggArrowIterUDFSerializer into ArrowStreamAggArrowUDFSerializer
3dffd12ef594 is described below
commit 3dffd12ef594dc229c1a99bc4baa4dd8c5d8d34a
Author: Yicong-Huang <[email protected]>
AuthorDate: Tue Dec 9 14:12:12 2025 +0800
[SPARK-54438][PYTHON] Consolidate ArrowStreamAggArrowIterUDFSerializer into
ArrowStreamAggArrowUDFSerializer
### What changes were proposed in this pull request?
This PR consolidates `ArrowStreamAggArrowIterUDFSerializer` with
`ArrowStreamAggArrowUDFSerializer`.
### Why are the changes needed?
When the iterator API was added for Arrow grouped aggregation UDFs, a new
`ArrowStreamAggArrowIterUDFSerializer` class was created. However, this class
is nearly identical to `ArrowStreamAggArrowUDFSerializer`, differing only in
whether batches are processed lazily (iterator mode) or all at once (regular
mode). By consolidating these two classes, we reduce code duplication and
maintain consistency with similar serializer consolidations.
### Does this PR introduce _any_ user-facing change?
No, this is an internal refactoring that maintains backward compatibility.
The API behavior remains the same from the user's perspective.
### How was this patch tested?
Existing Tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53328 from
Yicong-Huang/SPARK-54438/refactor/consolidate-serde-for-sql-grouped-agg-arrow.
Authored-by: Yicong-Huang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 56 +++-----------------------------
python/pyspark/worker.py | 44 ++++++++++++++++++++++---
2 files changed, 44 insertions(+), 56 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index dc854cb1985d..aae1c34b47af 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1125,7 +1125,8 @@ class
GroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
return "GroupArrowUDFSerializer"
-# Serializer for SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF
+# Serializer for SQL_GROUPED_AGG_ARROW_UDF, SQL_WINDOW_AGG_ARROW_UDF,
+# and SQL_GROUPED_AGG_ARROW_ITER_UDF
class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
def __init__(
self,
@@ -1143,55 +1144,8 @@ class
ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
def load_stream(self, stream):
"""
- Flatten the struct into Arrow's record batches.
- """
- import pyarrow as pa
-
- dataframes_in_group = None
-
- while dataframes_in_group is None or dataframes_in_group > 0:
- dataframes_in_group = read_int(stream)
-
- if dataframes_in_group == 1:
- batches = ArrowStreamSerializer.load_stream(self, stream)
- if hasattr(pa, "concat_batches"):
- yield pa.concat_batches(batches)
- else:
- # pyarrow.concat_batches not supported in old versions
- yield pa.RecordBatch.from_struct_array(
- pa.concat_arrays([b.to_struct_array() for b in
batches])
- )
-
- elif dataframes_in_group != 0:
- raise PySparkValueError(
- errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
- messageParameters={"dataframes_in_group":
str(dataframes_in_group)},
- )
-
- def __repr__(self):
- return "ArrowStreamAggArrowUDFSerializer"
-
-
-# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
-class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
- def __init__(
- self,
- timezone,
- safecheck,
- assign_cols_by_name,
- arrow_cast,
- ):
- super().__init__(
- timezone=timezone,
- safecheck=safecheck,
- assign_cols_by_name=assign_cols_by_name,
- arrow_cast=arrow_cast,
- )
-
- def load_stream(self, stream):
- """
- Yield an iterator that produces one list of column arrays per batch.
- Each group yields Iterator[List[pa.Array]], allowing UDF to process
batches one by one
+ Yield an iterator that produces one tuple of column arrays per batch.
+ Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to
process batches one by one
without consuming all batches upfront.
"""
dataframes_in_group = None
@@ -1217,7 +1171,7 @@ class
ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
)
def __repr__(self):
- return "ArrowStreamAggArrowIterUDFSerializer"
+ return "ArrowStreamAggArrowUDFSerializer"
# Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 65dcbbbf23e6..8de39b973802 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -66,7 +66,6 @@ from pyspark.sql.pandas.serializers import (
ArrowStreamArrowUDFSerializer,
ArrowStreamAggPandasUDFSerializer,
ArrowStreamAggArrowUDFSerializer,
- ArrowStreamAggArrowIterUDFSerializer,
ArrowBatchUDFSerializer,
ArrowStreamUDTFSerializer,
ArrowStreamArrowUDTFSerializer,
@@ -2732,12 +2731,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
):
ser = GroupArrowUDFSerializer(runner_conf.assign_cols_by_name)
- elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF:
- ser = ArrowStreamAggArrowIterUDFSerializer(
- runner_conf.timezone, True, runner_conf.assign_cols_by_name,
True
- )
elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
):
ser = ArrowStreamAggArrowUDFSerializer(
@@ -3259,6 +3255,44 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
batch_iter = (tuple(batch_columns[o] for o in arg_offsets) for
batch_columns in a)
return f(batch_iter)
+ elif eval_type in (
+ PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
+ PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
+ ):
+ import pyarrow as pa
+
+ # For SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF,
+ # convert iterator of batch columns to a concatenated RecordBatch
+ def mapper(a):
+ # a is Iterator[Tuple[pa.Array, ...]] - convert to RecordBatch
+ batches = []
+ for batch_columns in a:
+ # batch_columns is Tuple[pa.Array, ...] - convert to
RecordBatch
+ batch = pa.RecordBatch.from_arrays(
+ batch_columns, names=["_%d" % i for i in
range(len(batch_columns))]
+ )
+ batches.append(batch)
+
+ # Concatenate all batches into one
+ if hasattr(pa, "concat_batches"):
+ concatenated_batch = pa.concat_batches(batches)
+ else:
+ # pyarrow.concat_batches not supported in old versions
+ concatenated_batch = pa.RecordBatch.from_struct_array(
+ pa.concat_arrays([b.to_struct_array() for b in batches])
+ )
+
+ # Extract series using offsets (concatenated_batch.columns[o]
gives pa.Array)
+ result = tuple(
+ f(*[concatenated_batch.columns[o] for o in arg_offsets]) for
arg_offsets, f in udfs
+ )
+ # In the special case of a single UDF this will return a single
result rather
+ # than a tuple of results; this is the format that the JVM side
expects.
+ if len(result) == 1:
+ return result[0]
+ else:
+ return result
+
else:
def mapper(a):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]