This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d7b5144948e [SPARK-55162][PYTHON] Extract transformers from 
ArrowStreamUDFSerializer
3d7b5144948e is described below

commit 3d7b5144948e25baf6ae5c75193ff3fa3f55f677
Author: Yicong-Huang <[email protected]>
AuthorDate: Mon Jan 26 09:37:37 2026 +0800

    [SPARK-55162][PYTHON] Extract transformers from ArrowStreamUDFSerializer
    
    ### What changes were proposed in this pull request?
    
    Extract struct transformation logic from `ArrowStreamUDFSerializer` into a 
new `ArrowBatchTransformer` class in `conversion.py`:
    
    - `flatten_struct(batch)`: Flatten a struct column into separate columns
    - `wrap_struct(batch)`: Wrap columns into a single struct column
    
    ### Why are the changes needed?
    
    This is part of 
[SPARK-55159](https://issues.apache.org/jira/browse/SPARK-55159) to improve the 
composability of Arrow serializers by separating data transformation from 
serialization.
    
    Benefits:
    - Clear separation of concerns (serialization vs transformation)
    - Transformers are reusable and testable in isolation
    - Easier to understand data flow as a pipeline
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added unit tests for `ArrowBatchTransformer` in `test_conversion.py`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53946 from 
Yicong-Huang/SPARK-55162/refactor/extract-transformers-from-arrow-stream-udf-serializer.
    
    Authored-by: Yicong-Huang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/conversion.py            | 36 +++++++++++++
 python/pyspark/sql/pandas/serializers.py    | 42 +++++----------
 python/pyspark/sql/tests/test_conversion.py | 80 +++++++++++++++++++++++++++++
 3 files changed, 128 insertions(+), 30 deletions(-)

diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index 0a6478d49431..81ceea857e19 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -56,6 +56,42 @@ if TYPE_CHECKING:
     import pandas as pd
 
 
+class ArrowBatchTransformer:
+    """
+    Pure functions that transform RecordBatch -> RecordBatch.
+    They should have no side effects (no I/O, no writing to streams).
+    """
+
+    @staticmethod
+    def flatten_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch":
+        """
+        Flatten a single struct column into a RecordBatch.
+
+        Used by: ArrowStreamUDFSerializer.load_stream
+        """
+        import pyarrow as pa
+
+        struct = batch.column(0)
+        return pa.RecordBatch.from_arrays(struct.flatten(), 
schema=pa.schema(struct.type))
+
+    @staticmethod
+    def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch":
+        """
+        Wrap a RecordBatch's columns into a single struct column.
+
+        Used by: ArrowStreamUDFSerializer.dump_stream
+        """
+        import pyarrow as pa
+
+        if batch.num_columns == 0:
+            # When batch has no column, it should still create
+            # an empty batch with the number of rows set.
+            struct = pa.array([{}] * batch.num_rows)
+        else:
+            struct = pa.StructArray.from_arrays(batch.columns, 
fields=pa.struct(list(batch.schema)))
+        return pa.RecordBatch.from_arrays([struct], ["_0"])
+
+
 class LocalDataToArrowConversion:
     """
     Conversion from local data (except pandas DataFrame and numpy ndarray) to 
Arrow.
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index ffdd6c9901ea..703602e33d16 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -36,6 +36,7 @@ from pyspark.sql.conversion import (
     LocalDataToArrowConversion,
     ArrowTableToRowsConversion,
     ArrowArrayToPandasConversion,
+    ArrowBatchTransformer,
 )
 from pyspark.sql.pandas.types import (
     from_arrow_type,
@@ -149,44 +150,25 @@ class ArrowStreamUDFSerializer(ArrowStreamSerializer):
         """
         Flatten the struct into Arrow's record batches.
         """
-        import pyarrow as pa
-
         batches = super().load_stream(stream)
-        for batch in batches:
-            struct = batch.column(0)
-            yield [pa.RecordBatch.from_arrays(struct.flatten(), 
schema=pa.schema(struct.type))]
+        flattened = map(ArrowBatchTransformer.flatten_struct, batches)
+        return map(lambda b: [b], flattened)
 
     def dump_stream(self, iterator, stream):
         """
         Override because Pandas UDFs require a START_ARROW_STREAM before the 
Arrow stream is sent.
-        This should be sent after creating the first record batch so in case 
of an error, it can
-        be sent back to the JVM before the Arrow stream starts.
         """
-        import pyarrow as pa
-
-        def wrap_and_init_stream():
-            should_write_start_length = True
-            for batch, _ in iterator:
-                assert isinstance(batch, pa.RecordBatch)
-
-                # Wrap the root struct
-                if batch.num_columns == 0:
-                    # When batch has no column, it should still create
-                    # an empty batch with the number of rows set.
-                    struct = pa.array([{}] * batch.num_rows)
-                else:
-                    struct = pa.StructArray.from_arrays(
-                        batch.columns, fields=pa.struct(list(batch.schema))
-                    )
-                batch = pa.RecordBatch.from_arrays([struct], ["_0"])
+        import itertools
 
-                # Write the first record batch with initialization.
-                if should_write_start_length:
-                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
-                    should_write_start_length = False
-                yield batch
+        first = next(iterator, None)
+        if first is None:
+            return
 
-        return super().dump_stream(wrap_and_init_stream(), stream)
+        write_int(SpecialLengths.START_ARROW_STREAM, stream)
+        batches = map(
+            lambda x: ArrowBatchTransformer.wrap_struct(x[0]), 
itertools.chain([first], iterator)
+        )
+        return super().dump_stream(batches, stream)
 
 
 class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer):
diff --git a/python/pyspark/sql/tests/test_conversion.py 
b/python/pyspark/sql/tests/test_conversion.py
index 9773b2154c63..c3fa1fd19304 100644
--- a/python/pyspark/sql/tests/test_conversion.py
+++ b/python/pyspark/sql/tests/test_conversion.py
@@ -23,6 +23,7 @@ from pyspark.sql.conversion import (
     ArrowTableToRowsConversion,
     LocalDataToArrowConversion,
     ArrowTimestampConversion,
+    ArrowBatchTransformer,
 )
 from pyspark.sql.types import (
     ArrayType,
@@ -64,6 +65,85 @@ class Score:
         return self.score == other.score
 
 
[email protected](not have_pyarrow, pyarrow_requirement_message)
+class ArrowBatchTransformerTests(unittest.TestCase):
+    def test_flatten_struct_basic(self):
+        """Test flattening a struct column into separate columns."""
+        import pyarrow as pa
+
+        struct_array = pa.StructArray.from_arrays(
+            [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])],
+            names=["x", "y"],
+        )
+        batch = pa.RecordBatch.from_arrays([struct_array], ["_0"])
+
+        flattened = ArrowBatchTransformer.flatten_struct(batch)
+
+        self.assertEqual(flattened.num_columns, 2)
+        self.assertEqual(flattened.column(0).to_pylist(), [1, 2, 3])
+        self.assertEqual(flattened.column(1).to_pylist(), ["a", "b", "c"])
+        self.assertEqual(flattened.schema.names, ["x", "y"])
+
+    def test_flatten_struct_empty_batch(self):
+        """Test flattening an empty batch."""
+        import pyarrow as pa
+
+        struct_type = pa.struct([("x", pa.int64()), ("y", pa.string())])
+        struct_array = pa.array([], type=struct_type)
+        batch = pa.RecordBatch.from_arrays([struct_array], ["_0"])
+
+        flattened = ArrowBatchTransformer.flatten_struct(batch)
+
+        self.assertEqual(flattened.num_rows, 0)
+        self.assertEqual(flattened.num_columns, 2)
+
+    def test_wrap_struct_basic(self):
+        """Test wrapping columns into a struct."""
+        import pyarrow as pa
+
+        batch = pa.RecordBatch.from_arrays(
+            [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])],
+            names=["x", "y"],
+        )
+
+        wrapped = ArrowBatchTransformer.wrap_struct(batch)
+
+        self.assertEqual(wrapped.num_columns, 1)
+        self.assertEqual(wrapped.schema.names, ["_0"])
+
+        struct_col = wrapped.column(0)
+        self.assertEqual(len(struct_col), 3)
+        self.assertEqual(struct_col.field(0).to_pylist(), [1, 2, 3])
+        self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b", "c"])
+
+    def test_wrap_struct_empty_columns(self):
+        """Test wrapping a batch with no columns."""
+        import pyarrow as pa
+
+        schema = pa.schema([])
+        batch = pa.RecordBatch.from_arrays([], schema=schema)
+
+        wrapped = ArrowBatchTransformer.wrap_struct(batch)
+
+        self.assertEqual(wrapped.num_columns, 1)
+        self.assertEqual(wrapped.num_rows, 0)
+
+    def test_wrap_struct_empty_batch(self):
+        """Test wrapping an empty batch with schema."""
+        import pyarrow as pa
+
+        schema = pa.schema([("x", pa.int64()), ("y", pa.string())])
+        batch = pa.RecordBatch.from_arrays(
+            [pa.array([], type=pa.int64()), pa.array([], type=pa.string())],
+            schema=schema,
+        )
+
+        wrapped = ArrowBatchTransformer.wrap_struct(batch)
+
+        self.assertEqual(wrapped.num_rows, 0)
+        self.assertEqual(wrapped.num_columns, 1)
+
+
 @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
 class ConversionTests(unittest.TestCase):
     def test_conversion(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to