This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f8a67849e398 [SPARK-55194][PYTHON] Remove GroupArrowUDFSerializer by 
moving flatten logic to mapper
f8a67849e398 is described below

commit f8a67849e3989c235201b839c6cfb3b27a06d2fe
Author: Yicong-Huang <[email protected]>
AuthorDate: Wed Jan 28 08:24:50 2026 +0800

    [SPARK-55194][PYTHON] Remove GroupArrowUDFSerializer by moving flatten 
logic to mapper
    
    ### What changes were proposed in this pull request?
    
    This PR removes `GroupArrowUDFSerializer` by moving the `flatten_struct` 
call from the serializer to the mapper in `worker.py`.
    
    ### Why are the changes needed?
    
    This is part of 
[SPARK-55159](https://issues.apache.org/jira/browse/SPARK-55159) Phase 3: 
simplifying serializer hierarchy by moving transformations to other layer.
    
    `GroupArrowUDFSerializer` existed only to add a `flatten_struct` call in 
`load_stream`, inheriting everything else from `ArrowStreamGroupUDFSerializer`. 
This created an unnecessary inheritance layer.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53974 from 
Yicong-Huang/SPARK-55194/refactor/remove-group-arrow-udf-serializer.
    
    Lead-authored-by: Yicong-Huang 
<[email protected]>
    Co-authored-by: Yicong Huang 
<[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/conversion.py         |  4 +--
 python/pyspark/sql/pandas/serializers.py | 43 +++++++++++++++++---------------
 python/pyspark/worker.py                 | 22 ++++++++++------
 3 files changed, 39 insertions(+), 30 deletions(-)

diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index 31c43ddd2797..99a3b3893410 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -69,8 +69,8 @@ class ArrowBatchTransformer:
 
         Used by:
             - ArrowStreamUDFSerializer.load_stream
-            - GroupArrowUDFSerializer.load_stream
-            - ArrowStreamArrowUDTFSerializer.load_stream
+            - SQL_GROUPED_MAP_ARROW_UDF mapper
+            - SQL_GROUPED_MAP_ARROW_ITER_UDF mapper
         """
         import pyarrow as pa
 
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index fc0496a5bcc3..563521a41bae 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -334,21 +334,40 @@ class 
ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer):
 
 class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
     """
-    Serializes pyarrow.RecordBatch data with Arrow streaming format.
+    Serializer for grouped Arrow UDFs.
+
+    Deserializes:
+        ``Iterator[Iterator[pa.RecordBatch]]`` - one inner iterator per group.
+        Each batch contains a single struct column.
+
+    Serializes:
+        ``Iterator[Tuple[Iterator[pa.RecordBatch], pa.DataType]]``
+        Each tuple contains iterator of flattened batches and their Arrow type.
 
-    Loads Arrow record batches as ``[[pyarrow.RecordBatch]]`` (one 
``[pyarrow.RecordBatch]`` per
-    group) and serializes ``[([pyarrow.RecordBatch], arrow_type)]``.
+    Used by:
+        - SQL_GROUPED_MAP_ARROW_UDF
+        - SQL_GROUPED_MAP_ARROW_ITER_UDF
 
     Parameters
     ----------
     assign_cols_by_name : bool
-        If True, then DataFrames will get columns by name
+        If True, reorder serialized columns by schema name.
     """
 
     def __init__(self, assign_cols_by_name):
         super().__init__()
         self._assign_cols_by_name = assign_cols_by_name
 
+    def load_stream(self, stream):
+        """
+        Load grouped Arrow record batches from stream.
+        """
+        for (batches,) in self._load_group_dataframes(stream, num_dfs=1):
+            yield batches
+            # Make sure the batches are fully iterated before getting the next 
group
+            for _ in batches:
+                pass
+
     def dump_stream(self, iterator, stream):
         import pyarrow as pa
 
@@ -1056,22 +1075,6 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
         return "ArrowStreamPandasUDTFSerializer"
 
 
-class GroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
-    def load_stream(self, stream):
-        """
-        Flatten the struct into Arrow's record batches.
-        """
-        for (batches,) in self._load_group_dataframes(stream, num_dfs=1):
-            batch_iter = map(ArrowBatchTransformer.flatten_struct, batches)
-            yield batch_iter
-            # Make sure the batches are fully iterated before getting the next 
group
-            for _ in batch_iter:
-                pass
-
-    def __repr__(self):
-        return "GroupArrowUDFSerializer"
-
-
 # Serializer for SQL_GROUPED_AGG_ARROW_UDF, SQL_WINDOW_AGG_ARROW_UDF,
 # and SQL_GROUPED_AGG_ARROW_ITER_UDF
 class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index dfb2a2d12c6d..03bc1366e875 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -47,12 +47,16 @@ from pyspark.serializers import (
     CPickleSerializer,
     BatchedSerializer,
 )
-from pyspark.sql.conversion import LocalDataToArrowConversion, 
ArrowTableToRowsConversion
+from pyspark.sql.conversion import (
+    LocalDataToArrowConversion,
+    ArrowTableToRowsConversion,
+    ArrowBatchTransformer,
+)
 from pyspark.sql.functions import SkipRestOfInputTableException
 from pyspark.sql.pandas.serializers import (
     ArrowStreamPandasUDFSerializer,
     ArrowStreamPandasUDTFSerializer,
-    GroupArrowUDFSerializer,
+    ArrowStreamGroupUDFSerializer,
     GroupPandasUDFSerializer,
     CogroupArrowUDFSerializer,
     CogroupPandasUDFSerializer,
@@ -2743,7 +2747,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
             eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
             or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
         ):
-            ser = GroupArrowUDFSerializer(runner_conf.assign_cols_by_name)
+            ser = 
ArrowStreamGroupUDFSerializer(runner_conf.assign_cols_by_name)
         elif eval_type in (
             PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
             PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
@@ -3149,15 +3153,17 @@ def read_udfs(pickleSer, infile, eval_type, 
runner_conf):
                 names=[batch.schema.names[o] for o in offsets],
             )
 
-        def mapper(a):
-            batch_iter = iter(a)
+        def mapper(batches):
+            # Flatten struct column into separate columns
+            flattened = map(ArrowBatchTransformer.flatten_struct, batches)
+
             # Need to materialize the first batch to get the keys
-            first_batch = next(batch_iter)
+            first_batch = next(flattened)
 
             keys = batch_from_offset(first_batch, parsed_offsets[0][0])
             value_batches = (
-                batch_from_offset(b, parsed_offsets[0][1])
-                for b in itertools.chain((first_batch,), batch_iter)
+                batch_from_offset(batch, parsed_offsets[0][1])
+                for batch in itertools.chain((first_batch,), flattened)
             )
 
             return f(keys, value_batches)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to