This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 154a2708e988 [SPARK-54531][PYTHON] Introduce
ArrowStreamAggPandasUDFSerializer
154a2708e988 is described below
commit 154a2708e98840ba4d69b5afdd549e695081e4ce
Author: Yicong-Huang <[email protected]>
AuthorDate: Thu Nov 27 16:44:38 2025 +0800
[SPARK-54531][PYTHON] Introduce ArrowStreamAggPandasUDFSerializer
### What changes were proposed in this pull request?
This PR separates `SQL_GROUPED_AGG_PANDAS_UDF` and
`SQL_WINDOW_AGG_PANDAS_UDF` into a dedicated serializer
`ArrowStreamAggPandasUDFSerializer`, aligning with the existing
`ArrowStreamAggArrowUDFSerializer` architecture.
### Why are the changes needed?
1. **Input/Output type differences**: Aggregation UDFs
(`SQL_GROUPED_AGG_PANDAS_UDF` and `SQL_WINDOW_AGG_PANDAS_UDF`) have different
input/output types compared to grouped map UDFs:
- Aggregation UDFs: Input is `pd.Series` (entire group/partition),
output is scalar
- Grouped map UDFs: Input is `(keys, vals)` where `vals` is
`pd.DataFrame`, output is `pd.DataFrame`
2. **Multi-UDF support**: Aggregation UDFs support multiple UDFs in a
single projection/aggregation, while grouped map UDFs do not.
### Does this PR introduce _any_ user-facing change?
No. This is an internal refactoring that does not change the public API or
behavior. The serialization logic remains functionally equivalent.
### How was this patch tested?
All existing tests continue to pass, and a new multi-UDF test
(`test_pandas_udf_window.py::WindowPandasUDFTests::test_multiple_udfs`) was
added.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53239 from
Yicong-Huang/SPARK-54531/feat/introduce-ArrowStreamAggArrowUDFSerializer.
Authored-by: Yicong-Huang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 58 ++++++++++++++++++++++
.../sql/tests/pandas/test_pandas_udf_window.py | 26 ++++++++++
python/pyspark/worker.py | 6 ++-
3 files changed, 89 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 12a1f3c288b4..f757ba4f696f 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1185,6 +1185,64 @@ class
ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
return "ArrowStreamAggArrowUDFSerializer"
+# Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF
+class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ int_to_decimal_coercion_enabled,
+ ):
+ super(ArrowStreamAggPandasUDFSerializer, self).__init__(
+ timezone=timezone,
+ safecheck=safecheck,
+ assign_cols_by_name=False,
+ df_for_struct=False,
+ struct_in_pandas="dict",
+ ndarray_as_list=False,
+ arrow_cast=True,
+ input_types=None,
+ int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ )
+ self._timezone = timezone
+ self._safecheck = safecheck
+ self._assign_cols_by_name = assign_cols_by_name
+
+ def load_stream(self, stream):
+ """
+ Deserialize Grouped ArrowRecordBatches and yield as a list of
pandas.Series.
+ """
+ import pyarrow as pa
+
+ dataframes_in_group = None
+
+ while dataframes_in_group is None or dataframes_in_group > 0:
+ dataframes_in_group = read_int(stream)
+
+ if dataframes_in_group == 1:
+ yield (
+ [
+ self.arrow_to_pandas(c, i)
+ for i, c in enumerate(
+ pa.Table.from_batches(
+ ArrowStreamSerializer.load_stream(self, stream)
+ ).itercolumns()
+ )
+ ]
+ )
+
+ elif dataframes_in_group != 0:
+ raise PySparkValueError(
+ errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
+ messageParameters={"dataframes_in_group":
str(dataframes_in_group)},
+ )
+
+ def __repr__(self):
+ return "ArrowStreamAggPandasUDFSerializer"
+
+
+# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF
class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
def __init__(
self,
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 6e1cbdaf73cf..7cfcb29f50c1 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -190,6 +190,32 @@ class WindowPandasUDFTestsMixin:
assert_frame_equal(expected1.toPandas(), result1.toPandas())
+ def test_multiple_udfs_in_single_projection(self):
+ """
+ Test multiple window aggregate pandas UDFs in a single
select/projection.
+ """
+ df = self.data
+ w = self.unbounded_window
+
+ # Use select() with multiple window UDFs in the same projection
+ result1 = df.select(
+ df["id"],
+ df["v"],
+ self.pandas_agg_mean_udf(df["v"]).over(w).alias("mean_v"),
+ self.pandas_agg_max_udf(df["v"]).over(w).alias("max_v"),
+ self.pandas_agg_min_udf(df["w"]).over(w).alias("min_w"),
+ )
+
+ expected1 = df.select(
+ df["id"],
+ df["v"],
+ sf.mean(df["v"]).over(w).alias("mean_v"),
+ sf.max(df["v"]).over(w).alias("max_v"),
+ sf.min(df["w"]).over(w).alias("min_w"),
+ )
+
+ assert_frame_equal(expected1.toPandas(), result1.toPandas())
+
def test_replace_existing(self):
df = self.data
w = self.unbounded_window
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 94e3b2728d08..96aac6083bf2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -67,6 +67,7 @@ from pyspark.sql.pandas.serializers import (
TransformWithStateInPySparkRowSerializer,
TransformWithStateInPySparkRowInitStateSerializer,
ArrowStreamArrowUDFSerializer,
+ ArrowStreamAggPandasUDFSerializer,
ArrowStreamAggArrowUDFSerializer,
ArrowBatchUDFSerializer,
ArrowStreamUDTFSerializer,
@@ -2721,10 +2722,13 @@ def read_udfs(pickleSer, infile, eval_type):
):
ser = ArrowStreamAggArrowUDFSerializer(timezone, True,
_assign_cols_by_name, True)
elif eval_type in (
- PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
):
+ ser = ArrowStreamAggPandasUDFSerializer(
+ timezone, safecheck, _assign_cols_by_name,
int_to_decimal_coercion_enabled
+ )
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
ser = GroupPandasUDFSerializer(
timezone, safecheck, _assign_cols_by_name,
int_to_decimal_coercion_enabled
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]