This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3d7b5144948e [SPARK-55162][PYTHON] Extract transformers from
ArrowStreamUDFSerializer
3d7b5144948e is described below
commit 3d7b5144948e25baf6ae5c75193ff3fa3f55f677
Author: Yicong-Huang <[email protected]>
AuthorDate: Mon Jan 26 09:37:37 2026 +0800
[SPARK-55162][PYTHON] Extract transformers from ArrowStreamUDFSerializer
### What changes were proposed in this pull request?
Extract struct transformation logic from `ArrowStreamUDFSerializer` into a
new `ArrowBatchTransformer` class in `conversion.py`:
- `flatten_struct(batch)`: Flatten a struct column into separate columns
- `wrap_struct(batch)`: Wrap columns into a single struct column
### Why are the changes needed?
This is part of
[SPARK-55159](https://issues.apache.org/jira/browse/SPARK-55159) to improve the
composability of Arrow serializers by separating data transformation from
serialization.
Benefits:
- Clear separation of concerns (serialization vs transformation)
- Transformers are reusable and testable in isolation
- Easier to understand data flow as a pipeline
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added unit tests for `ArrowBatchTransformer` in `test_conversion.py`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53946 from
Yicong-Huang/SPARK-55162/refactor/extract-transformers-from-arrow-stream-udf-serializer.
Authored-by: Yicong-Huang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/conversion.py | 36 +++++++++++++
python/pyspark/sql/pandas/serializers.py | 42 +++++----------
python/pyspark/sql/tests/test_conversion.py | 80 +++++++++++++++++++++++++++++
3 files changed, 128 insertions(+), 30 deletions(-)
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index 0a6478d49431..81ceea857e19 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -56,6 +56,42 @@ if TYPE_CHECKING:
import pandas as pd
+class ArrowBatchTransformer:
+ """
+ Pure functions that transform RecordBatch -> RecordBatch.
+ They should have no side effects (no I/O, no writing to streams).
+ """
+
+ @staticmethod
+ def flatten_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch":
+ """
+ Flatten a single struct column into a RecordBatch.
+
+ Used by: ArrowStreamUDFSerializer.load_stream
+ """
+ import pyarrow as pa
+
+ struct = batch.column(0)
+ return pa.RecordBatch.from_arrays(struct.flatten(),
schema=pa.schema(struct.type))
+
+ @staticmethod
+ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch":
+ """
+ Wrap a RecordBatch's columns into a single struct column.
+
+ Used by: ArrowStreamUDFSerializer.dump_stream
+ """
+ import pyarrow as pa
+
+ if batch.num_columns == 0:
+ # When batch has no column, it should still create
+ # an empty batch with the number of rows set.
+ struct = pa.array([{}] * batch.num_rows)
+ else:
+ struct = pa.StructArray.from_arrays(batch.columns,
fields=pa.struct(list(batch.schema)))
+ return pa.RecordBatch.from_arrays([struct], ["_0"])
+
+
class LocalDataToArrowConversion:
"""
Conversion from local data (except pandas DataFrame and numpy ndarray) to
Arrow.
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index ffdd6c9901ea..703602e33d16 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -36,6 +36,7 @@ from pyspark.sql.conversion import (
LocalDataToArrowConversion,
ArrowTableToRowsConversion,
ArrowArrayToPandasConversion,
+ ArrowBatchTransformer,
)
from pyspark.sql.pandas.types import (
from_arrow_type,
@@ -149,44 +150,25 @@ class ArrowStreamUDFSerializer(ArrowStreamSerializer):
"""
Flatten the struct into Arrow's record batches.
"""
- import pyarrow as pa
-
batches = super().load_stream(stream)
- for batch in batches:
- struct = batch.column(0)
- yield [pa.RecordBatch.from_arrays(struct.flatten(),
schema=pa.schema(struct.type))]
+ flattened = map(ArrowBatchTransformer.flatten_struct, batches)
+ return map(lambda b: [b], flattened)
def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the
Arrow stream is sent.
- This should be sent after creating the first record batch so in case
of an error, it can
- be sent back to the JVM before the Arrow stream starts.
"""
- import pyarrow as pa
-
- def wrap_and_init_stream():
- should_write_start_length = True
- for batch, _ in iterator:
- assert isinstance(batch, pa.RecordBatch)
-
- # Wrap the root struct
- if batch.num_columns == 0:
- # When batch has no column, it should still create
- # an empty batch with the number of rows set.
- struct = pa.array([{}] * batch.num_rows)
- else:
- struct = pa.StructArray.from_arrays(
- batch.columns, fields=pa.struct(list(batch.schema))
- )
- batch = pa.RecordBatch.from_arrays([struct], ["_0"])
+ import itertools
- # Write the first record batch with initialization.
- if should_write_start_length:
- write_int(SpecialLengths.START_ARROW_STREAM, stream)
- should_write_start_length = False
- yield batch
+ first = next(iterator, None)
+ if first is None:
+ return
- return super().dump_stream(wrap_and_init_stream(), stream)
+ write_int(SpecialLengths.START_ARROW_STREAM, stream)
+ batches = map(
+ lambda x: ArrowBatchTransformer.wrap_struct(x[0]),
itertools.chain([first], iterator)
+ )
+ return super().dump_stream(batches, stream)
class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer):
diff --git a/python/pyspark/sql/tests/test_conversion.py
b/python/pyspark/sql/tests/test_conversion.py
index 9773b2154c63..c3fa1fd19304 100644
--- a/python/pyspark/sql/tests/test_conversion.py
+++ b/python/pyspark/sql/tests/test_conversion.py
@@ -23,6 +23,7 @@ from pyspark.sql.conversion import (
ArrowTableToRowsConversion,
LocalDataToArrowConversion,
ArrowTimestampConversion,
+ ArrowBatchTransformer,
)
from pyspark.sql.types import (
ArrayType,
@@ -64,6 +65,85 @@ class Score:
return self.score == other.score
[email protected](not have_pyarrow, pyarrow_requirement_message)
+class ArrowBatchTransformerTests(unittest.TestCase):
+ def test_flatten_struct_basic(self):
+ """Test flattening a struct column into separate columns."""
+ import pyarrow as pa
+
+ struct_array = pa.StructArray.from_arrays(
+ [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])],
+ names=["x", "y"],
+ )
+ batch = pa.RecordBatch.from_arrays([struct_array], ["_0"])
+
+ flattened = ArrowBatchTransformer.flatten_struct(batch)
+
+ self.assertEqual(flattened.num_columns, 2)
+ self.assertEqual(flattened.column(0).to_pylist(), [1, 2, 3])
+ self.assertEqual(flattened.column(1).to_pylist(), ["a", "b", "c"])
+ self.assertEqual(flattened.schema.names, ["x", "y"])
+
+ def test_flatten_struct_empty_batch(self):
+ """Test flattening an empty batch."""
+ import pyarrow as pa
+
+ struct_type = pa.struct([("x", pa.int64()), ("y", pa.string())])
+ struct_array = pa.array([], type=struct_type)
+ batch = pa.RecordBatch.from_arrays([struct_array], ["_0"])
+
+ flattened = ArrowBatchTransformer.flatten_struct(batch)
+
+ self.assertEqual(flattened.num_rows, 0)
+ self.assertEqual(flattened.num_columns, 2)
+
+ def test_wrap_struct_basic(self):
+ """Test wrapping columns into a struct."""
+ import pyarrow as pa
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])],
+ names=["x", "y"],
+ )
+
+ wrapped = ArrowBatchTransformer.wrap_struct(batch)
+
+ self.assertEqual(wrapped.num_columns, 1)
+ self.assertEqual(wrapped.schema.names, ["_0"])
+
+ struct_col = wrapped.column(0)
+ self.assertEqual(len(struct_col), 3)
+ self.assertEqual(struct_col.field(0).to_pylist(), [1, 2, 3])
+ self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b", "c"])
+
+ def test_wrap_struct_empty_columns(self):
+ """Test wrapping a batch with no columns."""
+ import pyarrow as pa
+
+ schema = pa.schema([])
+ batch = pa.RecordBatch.from_arrays([], schema=schema)
+
+ wrapped = ArrowBatchTransformer.wrap_struct(batch)
+
+ self.assertEqual(wrapped.num_columns, 1)
+ self.assertEqual(wrapped.num_rows, 0)
+
+ def test_wrap_struct_empty_batch(self):
+ """Test wrapping an empty batch with schema."""
+ import pyarrow as pa
+
+ schema = pa.schema([("x", pa.int64()), ("y", pa.string())])
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([], type=pa.int64()), pa.array([], type=pa.string())],
+ schema=schema,
+ )
+
+ wrapped = ArrowBatchTransformer.wrap_struct(batch)
+
+ self.assertEqual(wrapped.num_rows, 0)
+ self.assertEqual(wrapped.num_columns, 1)
+
+
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class ConversionTests(unittest.TestCase):
def test_conversion(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]