This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 77628fc1d01a [SPARK-55390][PYTHON] Consolidate SQL_SCALAR_ARROW_UDF
wrapper, mapper, and serializer logic
77628fc1d01a is described below
commit 77628fc1d01a56f293185a25c52693e9b6b110a6
Author: Yicong Huang <[email protected]>
AuthorDate: Tue Mar 10 09:14:50 2026 +0800
[SPARK-55390][PYTHON] Consolidate SQL_SCALAR_ARROW_UDF wrapper, mapper, and
serializer logic
### What changes were proposed in this pull request?
This PR consolidates the `SQL_SCALAR_ARROW_UDF` execution path by:
1. Extracting `verify_scalar_result()` as a reusable helper to replace
inline `verify_result_type` and `verify_result_length` closures in
`wrap_scalar_arrow_udf`
2. Removing the dedicated `wrap_scalar_arrow_udf` wrapper and replacing it
with the general `ArrowStreamGroupSerializer`-based path
3. Adding `ArrowBatchTransformer.enforce_schema()` to handle schema
enforcement (column reordering and type coercion) in a centralized way
### Why are the changes needed?
The scalar Arrow UDF path had its own dedicated wrapper
(`wrap_scalar_arrow_udf`), mapper, and serializer logic that duplicated
patterns already available in the consolidated `ArrowStreamGroupSerializer`
infrastructure. This refactoring reduces code duplication and makes the UDF
execution paths more consistent.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests for scalar Arrow UDFs.
### Benchmark Results
ASV microbenchmark comparison (`ScalarArrowUDFTimeBench`, `repeat=(3, 5,
5.0)`):
| Scenario | UDF | Before (master) | After (PR) | Delta |
|---|---|---|---|---|
| sm_batch_few_col | identity | 71.2±0.5ms | 71.1±0.8ms | -0.1% |
| sm_batch_few_col | sort | 186±1ms | 184±0.3ms | -1.1% |
| sm_batch_few_col | nullcheck | 55.7±0.9ms | 56.4±0.4ms | +1.3% |
| sm_batch_many_col | identity | 24.2±0.2ms | 23.6±0.06ms | -2.5% |
| sm_batch_many_col | sort | 43.0±0.8ms | 42.1±0.6ms | -2.1% |
| sm_batch_many_col | nullcheck | 20.6±0.3ms | 20.6±0.1ms | 0% |
| lg_batch_few_col | identity | 465±1ms | 465±9ms | 0% |
| lg_batch_few_col | sort | 824±1ms | 825±8ms | +0.1% |
| lg_batch_few_col | nullcheck | 271±2ms | 278±3ms | +2.6% |
| lg_batch_many_col | identity | 323±0.1ms | 321±0.9ms | -0.6% |
| lg_batch_many_col | sort | 358±3ms | 362±4ms | +1.1% |
| lg_batch_many_col | nullcheck | 326±2ms | 330±0.4ms | +1.2% |
| pure_ints | identity | 112±2ms | 113±4ms | +0.9% |
| pure_ints | sort | 179±0.6ms | 174±0.3ms | -2.8% |
| pure_ints | nullcheck | 89.9±0.6ms | 91.8±0.4ms | +2.1% |
| pure_floats | identity | 108±1ms | 108±0.2ms | 0% |
| pure_floats | sort | 569±1ms | 568±1ms | -0.2% |
| pure_floats | nullcheck | 88.6±0.6ms | 90.4±0.3ms | +2.0% |
| pure_strings | identity | 120±2ms | 120±0.3ms | 0% |
| pure_strings | sort | 522±0.9ms | 516±0.7ms | -1.2% |
| pure_strings | nullcheck | 97.5±0.4ms | 100.0±0.6ms | +2.6% |
| pure_ts | identity | 110±0.5ms | 110±1ms | 0% |
| pure_ts | sort | 216±0.7ms | 215±0.9ms | -0.5% |
| pure_ts | nullcheck | 89.0±0.2ms | 89.4±0.2ms | +0.4% |
| mixed_types | identity | 105±0.4ms | 105±0.2ms | 0% |
| mixed_types | sort | 166±0.6ms | 166±0.8ms | 0% |
| mixed_types | nullcheck | 84.2±0.6ms | 83.1±0.3ms | -1.3% |
Peak memory: no change (within 1M). All deltas within ±3% noise — no
performance regression.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #54296 from Yicong-Huang/SPARK-55390/refactor/scalar-arrow-udf.
Lead-authored-by: Yicong Huang
<[email protected]>
Co-authored-by: Yicong-Huang
<[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/conversion.py | 66 ++++++++++++++++++++-
python/pyspark/worker.py | 121 ++++++++++++++++++++++++---------------
2 files changed, 139 insertions(+), 48 deletions(-)
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index 249bd2ffcb20..efd03089bbf8 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -21,7 +21,7 @@ import decimal
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence,
Union, overload
import pyspark
-from pyspark.errors import PySparkValueError
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.sql.pandas.types import (
_dedup_names,
_deduplicate_field_names,
@@ -107,6 +107,70 @@ class ArrowBatchTransformer:
struct = pa.StructArray.from_arrays(batch.columns,
fields=pa.struct(list(batch.schema)))
return pa.RecordBatch.from_arrays([struct], ["_0"])
+ @classmethod
+ def enforce_schema(
+ cls,
+ batch: "pa.RecordBatch",
+ arrow_schema: "pa.Schema",
+ safecheck: bool = True,
+ ) -> "pa.RecordBatch":
+ """
+ Enforce target schema on a RecordBatch by reordering columns and
coercing types.
+
+ .. note::
+ Currently this function is only used by UDTF. The error messages
+ are UDTF-specific (see SPARK-55723).
+
+ Parameters
+ ----------
+ batch : pa.RecordBatch
+ Input RecordBatch to transform.
+ arrow_schema : pa.Schema
+ Target Arrow schema. Callers should pre-compute this once via
+ to_arrow_schema() to avoid repeated conversion.
+ safecheck : bool, default True
+ If True, use safe casting (fails on overflow/truncation).
+
+ Returns
+ -------
+ pa.RecordBatch
+ RecordBatch with columns reordered and types coerced to match
target schema.
+ """
+ import pyarrow as pa
+
+ if batch.num_columns == 0 or len(arrow_schema) == 0:
+ return batch
+
+ # Fast path: schema already matches (ignoring metadata), no work needed
+ if batch.schema.equals(arrow_schema, check_metadata=False):
+ return batch
+
+ # Check if columns are in the same order (by name) as the target
schema.
+ # If so, use index-based access (faster than name lookup).
+ batch_names = [batch.schema.field(i).name for i in
range(batch.num_columns)]
+ target_names = [field.name for field in arrow_schema]
+ use_index = batch_names == target_names
+
+ coerced_arrays = []
+ for i, field in enumerate(arrow_schema):
+ arr = batch.column(i) if use_index else batch.column(field.name)
+ if arr.type != field.type:
+ try:
+ arr = arr.cast(target_type=field.type, safe=safecheck)
+ except (pa.ArrowInvalid, pa.ArrowTypeError):
+ # TODO(SPARK-55723): Unify error messages for all UDF
types,
+ # not just UDTF.
+ raise PySparkRuntimeError(
+ errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF",
+ messageParameters={
+ "expected": str(field.type),
+ "actual": str(arr.type),
+ },
+ )
+ coerced_arrays.append(arr)
+
+ return pa.RecordBatch.from_arrays(coerced_arrays, names=target_names)
+
@classmethod
def to_pandas(
cls,
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7fbe0849ee63..ddd0b45d9020 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -72,7 +72,7 @@ from pyspark.sql.pandas.serializers import (
ArrowStreamUDTFSerializer,
ArrowStreamArrowUDTFSerializer,
)
-from pyspark.sql.pandas.types import to_arrow_type
+from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type
from pyspark.sql.types import (
ArrayType,
BinaryType,
@@ -80,6 +80,7 @@ from pyspark.sql.types import (
MapType,
Row,
StringType,
+ StructField,
StructType,
_create_row,
_parse_datatype_json_string,
@@ -266,6 +267,39 @@ def verify_result(expected_type: type) -> Callable[[Any],
Iterator]:
return check
+def verify_scalar_result(result: Any, num_rows: int) -> Any:
+ """
+ Verify a scalar UDF result is array-like and has the expected number of
rows.
+
+ Parameters
+ ----------
+ result : Any
+ The UDF result to verify.
+ num_rows : int
+ Expected number of rows (must match input batch size).
+ """
+ try:
+ result_length = len(result)
+ except TypeError:
+ raise PySparkTypeError(
+ errorClass="UDF_RETURN_TYPE",
+ messageParameters={
+ "expected": "array-like object",
+ "actual": type(result).__name__,
+ },
+ )
+ if result_length != num_rows:
+ raise PySparkRuntimeError(
+ errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
+ messageParameters={
+ "udf_type": "arrow_udf",
+ "expected": str(num_rows),
+ "actual": str(result_length),
+ },
+ )
+ return result
+
+
def wrap_udf(f, args_offsets, kwargs_offsets, return_type):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets,
kwargs_offsets)
@@ -312,46 +346,6 @@ def wrap_scalar_pandas_udf(f, args_offsets,
kwargs_offsets, return_type, runner_
)
-def wrap_scalar_arrow_udf(f, args_offsets, kwargs_offsets, return_type,
runner_conf):
- func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets,
kwargs_offsets)
-
- arrow_return_type = to_arrow_type(
- return_type, timezone="UTC",
prefers_large_types=runner_conf.use_large_var_types
- )
-
- def verify_result_type(result):
- if not hasattr(result, "__len__"):
- pd_type = "pyarrow.Array"
- raise PySparkTypeError(
- errorClass="UDF_RETURN_TYPE",
- messageParameters={
- "expected": pd_type,
- "actual": type(result).__name__,
- },
- )
- return result
-
- def verify_result_length(result, length):
- if len(result) != length:
- raise PySparkRuntimeError(
- errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
- messageParameters={
- "udf_type": "arrow_udf",
- "expected": str(length),
- "actual": str(len(result)),
- },
- )
- return result
-
- return (
- args_kwargs_offsets,
- lambda *a: (
- verify_result_length(verify_result_type(func(*a)), len(a[0])),
- arrow_return_type,
- ),
- )
-
-
def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type,
runner_conf):
if runner_conf.use_legacy_pandas_udf_conversion:
return wrap_arrow_batch_udf_legacy(
@@ -1403,7 +1397,7 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets,
return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF:
- return wrap_scalar_arrow_udf(func, args_offsets, kwargs_offsets,
return_type, runner_conf)
+ return func, args_offsets, kwargs_offsets, return_type
elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets,
return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
@@ -1413,7 +1407,7 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
return args_offsets, wrap_pandas_batch_iter_udf(func, return_type,
runner_conf)
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
- return func
+ return func, None, None, None
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost
when wrapping it
return args_offsets, wrap_grouped_map_pandas_udf(func, return_type,
argspec, runner_conf)
@@ -2764,12 +2758,12 @@ def read_udfs(pickleSer, infile, eval_type,
runner_conf, eval_conf):
ser = TransformWithStateInPySparkRowInitStateSerializer(
arrow_max_records_per_batch=runner_conf.arrow_max_records_per_batch
)
- elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
- ser = ArrowStreamSerializer(write_start_stream=True)
elif eval_type in (
+ PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_SCALAR_ARROW_UDF,
- PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
):
+ ser = ArrowStreamSerializer(write_start_stream=True)
+ elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF:
# Arrow cast and safe check are always enabled
ser = ArrowStreamArrowUDFSerializer(safecheck=True,
arrow_cast=True)
elif (
@@ -2831,7 +2825,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
import pyarrow as pa
assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
- udf_func: Callable[[Iterator[pa.RecordBatch]],
Iterator[pa.RecordBatch]] = udfs[0]
+ udf_func: Callable[[Iterator[pa.RecordBatch]],
Iterator[pa.RecordBatch]] = udfs[0][0]
def func(split_index: int, batches: Iterator[pa.RecordBatch]) ->
Iterator[pa.RecordBatch]:
"""Apply mapInArrow UDF"""
@@ -2851,6 +2845,39 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
# profiling is not supported for UDF
return func, None, ser, ser
+ if eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF:
+ import pyarrow as pa
+
+ col_names = ["_%d" % i for i in range(len(udfs))]
+ combined_arrow_schema = to_arrow_schema(
+ StructType([StructField(n, rt) for n, (_, _, _, rt) in
zip(col_names, udfs)]),
+ timezone="UTC",
+ prefers_large_types=runner_conf.use_large_var_types,
+ )
+
+ def func(split_index: int, batches: Iterator[pa.RecordBatch]) ->
Iterator[pa.RecordBatch]:
+ """Apply scalar Arrow UDFs"""
+
+ for input_batch in batches:
+ output_batch = pa.RecordBatch.from_arrays(
+ [
+ udf_func(
+ *[input_batch.column(o) for o in args_offsets],
+ **{k: input_batch.column(v) for k, v in
kwargs_offsets.items()},
+ )
+ for udf_func, args_offsets, kwargs_offsets, _ in udfs
+ ],
+ col_names,
+ )
+ output_batch = ArrowBatchTransformer.enforce_schema(
+ output_batch, combined_arrow_schema
+ )
+ verify_scalar_result(output_batch, input_batch.num_rows)
+ yield output_batch
+
+ # profiling is not supported for UDF
+ return func, None, ser, ser
+
is_scalar_iter = eval_type in (
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]