This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 981af6029fe3 [SPARK-55160][PYTHON] Directly pass input schema to
serializers
981af6029fe3 is described below
commit 981af6029fe3e03c8ece9950c03fcc34a3a1f480
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 26 09:58:32 2026 +0800
[SPARK-55160][PYTHON] Directly pass input schema to serializers
### What changes were proposed in this pull request?
Directly pass input schema to serializers
### Why are the changes needed?
informations of `spark field name` and `nullability` are always dropped in
existing implementation, keep them so that we can make use of them later.
E.g. we are always treating data conversion as `nullable=True`, we can
optimize it for cases `nullable=False` when the `nullability` is available
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #53944 from zhengruifeng/pass_raw_type.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 49 +++++++++++++++++---------------
python/pyspark/worker.py | 24 +++++++---------
2 files changed, 37 insertions(+), 36 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 3fbda556fffc..a8dba37b7da3 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -542,12 +542,12 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
timezone,
safecheck,
assign_cols_by_name,
- df_for_struct=False,
- struct_in_pandas="dict",
- ndarray_as_list=False,
- arrow_cast=False,
- input_types=None,
- int_to_decimal_coercion_enabled=False,
+ df_for_struct: bool = False,
+ struct_in_pandas: str = "dict",
+ ndarray_as_list: bool = False,
+ arrow_cast: bool = False,
+ input_type: Optional[StructType] = None,
+ int_to_decimal_coercion_enabled: bool = False,
):
super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled)
self._assign_cols_by_name = assign_cols_by_name
@@ -555,7 +555,9 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
self._struct_in_pandas = struct_in_pandas
self._ndarray_as_list = ndarray_as_list
self._arrow_cast = arrow_cast
- self._input_types = input_types
+ if input_type is not None:
+ assert isinstance(input_type, StructType)
+ self._input_type = input_type
def arrow_to_pandas(self, arrow_column, idx):
import pyarrow.types as types
@@ -579,8 +581,8 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
self._struct_in_pandas,
self._ndarray_as_list,
spark_type=(
- self._input_types[idx][i].dataType
- if self._input_types is not None
+ self._input_type[idx].dataType[i].dataType
+ if self._input_type is not None
else None
),
)
@@ -594,7 +596,7 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
idx,
self._struct_in_pandas,
self._ndarray_as_list,
- spark_type=self._input_types[idx] if self._input_types is not
None else None,
+ spark_type=self._input_type[idx].dataType if self._input_type
is not None else None,
)
return s
@@ -807,8 +809,8 @@ class
ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
----------
safecheck : bool
If True, conversion from Arrow to Pandas checks for overflow/truncation
- input_types : list
- List of input data types for the UDF
+ input_type : spark data type
+ input data type for the UDF, must be a StructType
int_to_decimal_coercion_enabled : bool
If True, applies additional coercions in Python before converting to
Arrow
This has performance penalties.
@@ -818,16 +820,17 @@ class
ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
def __init__(
self,
- safecheck,
- input_types,
- int_to_decimal_coercion_enabled,
- binary_as_bytes,
+ safecheck: bool,
+ input_type: StructType,
+ int_to_decimal_coercion_enabled: bool,
+ binary_as_bytes: bool,
):
super().__init__(
safecheck=safecheck,
arrow_cast=True,
)
- self._input_types = input_types
+ assert isinstance(input_type, StructType)
+ self._input_type = input_type
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
self._binary_as_bytes = binary_as_bytes
@@ -847,9 +850,9 @@ class
ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
"""
converters = [
ArrowTableToRowsConversion._create_converter(
- dt, none_on_identity=True,
binary_as_bytes=self._binary_as_bytes
+ f.dataType, none_on_identity=True,
binary_as_bytes=self._binary_as_bytes
)
- for dt in self._input_types
+ for f in self._input_type
]
for batch in super().load_stream(stream):
@@ -910,7 +913,7 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
"""
- def __init__(self, timezone, safecheck, input_types,
int_to_decimal_coercion_enabled):
+ def __init__(self, timezone, safecheck, input_type,
int_to_decimal_coercion_enabled):
super().__init__(
timezone=timezone,
safecheck=safecheck,
@@ -930,7 +933,7 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
ndarray_as_list=True,
# Enables explicit casting for mismatched return types of Arrow
Python UDTFs.
arrow_cast=True,
- input_types=input_types,
+ input_type=input_type,
# Enable additional coercions for UDTF serialization
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
@@ -1118,7 +1121,7 @@ class
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
struct_in_pandas="dict",
ndarray_as_list=False,
arrow_cast=True,
- input_types=None,
+ input_type=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
@@ -1161,7 +1164,7 @@ class
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
struct_in_pandas="dict",
ndarray_as_list=False,
arrow_cast=True,
- input_types=None,
+ input_type=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d093beffda95..dfb2a2d12c6d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1527,15 +1527,13 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
# the UDTF logic to input rows.
def read_udtf(pickleSer, infile, eval_type, runner_conf):
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
- input_types = [
- field.dataType for field in
_parse_datatype_json_string(utf8_deserializer.loads(infile))
- ]
+ input_type =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
if runner_conf.use_legacy_pandas_udtf_conversion:
# NOTE: if timezone is set here, that implies
respectSessionTimeZone is True
ser = ArrowStreamPandasUDTFSerializer(
runner_conf.timezone,
runner_conf.safecheck,
- input_types=input_types,
+ input_type=input_type,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
else:
@@ -2458,9 +2456,11 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf):
try:
converters = [
ArrowTableToRowsConversion._create_converter(
- dt, none_on_identity=True,
binary_as_bytes=runner_conf.binary_as_bytes
+ f.dataType,
+ none_on_identity=True,
+ binary_as_bytes=runner_conf.binary_as_bytes,
)
- for dt in input_types
+ for f in input_type
]
for a in it:
pylist = [
@@ -2827,12 +2827,10 @@ def read_udfs(pickleSer, infile, eval_type,
runner_conf):
eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
and not runner_conf.use_legacy_pandas_udf_conversion
):
- input_types = [
- f.dataType for f in
_parse_datatype_json_string(utf8_deserializer.loads(infile))
- ]
+ input_type =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
ser = ArrowBatchUDFSerializer(
runner_conf.safecheck,
- input_types,
+ input_type,
runner_conf.int_to_decimal_coercion_enabled,
runner_conf.binary_as_bytes,
)
@@ -2850,8 +2848,8 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
)
ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
# Arrow-optimized Python UDF takes input types
- input_types = (
- [f.dataType for f in
_parse_datatype_json_string(utf8_deserializer.loads(infile))]
+ input_type = (
+ _parse_datatype_json_string(utf8_deserializer.loads(infile))
if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
else None
)
@@ -2864,7 +2862,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
struct_in_pandas,
ndarray_as_list,
True,
- input_types,
+ input_type,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
else:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]