This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 981af6029fe3 [SPARK-55160][PYTHON] Directly pass input schema to 
serializers
981af6029fe3 is described below

commit 981af6029fe3e03c8ece9950c03fcc34a3a1f480
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 26 09:58:32 2026 +0800

    [SPARK-55160][PYTHON] Directly pass input schema to serializers
    
    ### What changes were proposed in this pull request?
    Directly pass input schema to serializers
    
    ### Why are the changes needed?
    informations of `spark field name` and `nullability` are always dropped in 
existing implementation, keep them so that we can make use of them later.
    E.g. we are always treating data conversion as `nullable=True`, we can 
optimize it for cases `nullable=False` when the `nullability` is available
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #53944 from zhengruifeng/pass_raw_type.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/pandas/serializers.py | 49 +++++++++++++++++---------------
 python/pyspark/worker.py                 | 24 +++++++---------
 2 files changed, 37 insertions(+), 36 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 3fbda556fffc..a8dba37b7da3 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -542,12 +542,12 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
         timezone,
         safecheck,
         assign_cols_by_name,
-        df_for_struct=False,
-        struct_in_pandas="dict",
-        ndarray_as_list=False,
-        arrow_cast=False,
-        input_types=None,
-        int_to_decimal_coercion_enabled=False,
+        df_for_struct: bool = False,
+        struct_in_pandas: str = "dict",
+        ndarray_as_list: bool = False,
+        arrow_cast: bool = False,
+        input_type: Optional[StructType] = None,
+        int_to_decimal_coercion_enabled: bool = False,
     ):
         super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled)
         self._assign_cols_by_name = assign_cols_by_name
@@ -555,7 +555,9 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
         self._struct_in_pandas = struct_in_pandas
         self._ndarray_as_list = ndarray_as_list
         self._arrow_cast = arrow_cast
-        self._input_types = input_types
+        if input_type is not None:
+            assert isinstance(input_type, StructType)
+        self._input_type = input_type
 
     def arrow_to_pandas(self, arrow_column, idx):
         import pyarrow.types as types
@@ -579,8 +581,8 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
                     self._struct_in_pandas,
                     self._ndarray_as_list,
                     spark_type=(
-                        self._input_types[idx][i].dataType
-                        if self._input_types is not None
+                        self._input_type[idx].dataType[i].dataType
+                        if self._input_type is not None
                         else None
                     ),
                 )
@@ -594,7 +596,7 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
                 idx,
                 self._struct_in_pandas,
                 self._ndarray_as_list,
-                spark_type=self._input_types[idx] if self._input_types is not 
None else None,
+                spark_type=self._input_type[idx].dataType if self._input_type 
is not None else None,
             )
         return s
 
@@ -807,8 +809,8 @@ class 
ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
     ----------
     safecheck : bool
         If True, conversion from Arrow to Pandas checks for overflow/truncation
-    input_types : list
-        List of input data types for the UDF
+    input_type : spark data type
+        input data type for the UDF, must be a StructType
     int_to_decimal_coercion_enabled : bool
         If True, applies additional coercions in Python before converting to 
Arrow
         This has performance penalties.
@@ -818,16 +820,17 @@ class 
ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
 
     def __init__(
         self,
-        safecheck,
-        input_types,
-        int_to_decimal_coercion_enabled,
-        binary_as_bytes,
+        safecheck: bool,
+        input_type: StructType,
+        int_to_decimal_coercion_enabled: bool,
+        binary_as_bytes: bool,
     ):
         super().__init__(
             safecheck=safecheck,
             arrow_cast=True,
         )
-        self._input_types = input_types
+        assert isinstance(input_type, StructType)
+        self._input_type = input_type
         self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
         self._binary_as_bytes = binary_as_bytes
 
@@ -847,9 +850,9 @@ class 
ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
         """
         converters = [
             ArrowTableToRowsConversion._create_converter(
-                dt, none_on_identity=True, 
binary_as_bytes=self._binary_as_bytes
+                f.dataType, none_on_identity=True, 
binary_as_bytes=self._binary_as_bytes
             )
-            for dt in self._input_types
+            for f in self._input_type
         ]
 
         for batch in super().load_stream(stream):
@@ -910,7 +913,7 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
     Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
     """
 
-    def __init__(self, timezone, safecheck, input_types, 
int_to_decimal_coercion_enabled):
+    def __init__(self, timezone, safecheck, input_type, 
int_to_decimal_coercion_enabled):
         super().__init__(
             timezone=timezone,
             safecheck=safecheck,
@@ -930,7 +933,7 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
             ndarray_as_list=True,
             # Enables explicit casting for mismatched return types of Arrow 
Python UDTFs.
             arrow_cast=True,
-            input_types=input_types,
+            input_type=input_type,
             # Enable additional coercions for UDTF serialization
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
         )
@@ -1118,7 +1121,7 @@ class 
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
             struct_in_pandas="dict",
             ndarray_as_list=False,
             arrow_cast=True,
-            input_types=None,
+            input_type=None,
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
         )
 
@@ -1161,7 +1164,7 @@ class 
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
             struct_in_pandas="dict",
             ndarray_as_list=False,
             arrow_cast=True,
-            input_types=None,
+            input_type=None,
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
         )
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d093beffda95..dfb2a2d12c6d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1527,15 +1527,13 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
 # the UDTF logic to input rows.
 def read_udtf(pickleSer, infile, eval_type, runner_conf):
     if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
-        input_types = [
-            field.dataType for field in 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
-        ]
+        input_type = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
         if runner_conf.use_legacy_pandas_udtf_conversion:
             # NOTE: if timezone is set here, that implies 
respectSessionTimeZone is True
             ser = ArrowStreamPandasUDTFSerializer(
                 runner_conf.timezone,
                 runner_conf.safecheck,
-                input_types=input_types,
+                input_type=input_type,
                 
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
             )
         else:
@@ -2458,9 +2456,11 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf):
             try:
                 converters = [
                     ArrowTableToRowsConversion._create_converter(
-                        dt, none_on_identity=True, 
binary_as_bytes=runner_conf.binary_as_bytes
+                        f.dataType,
+                        none_on_identity=True,
+                        binary_as_bytes=runner_conf.binary_as_bytes,
                     )
-                    for dt in input_types
+                    for f in input_type
                 ]
                 for a in it:
                     pylist = [
@@ -2827,12 +2827,10 @@ def read_udfs(pickleSer, infile, eval_type, 
runner_conf):
             eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
             and not runner_conf.use_legacy_pandas_udf_conversion
         ):
-            input_types = [
-                f.dataType for f in 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
-            ]
+            input_type = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
             ser = ArrowBatchUDFSerializer(
                 runner_conf.safecheck,
-                input_types,
+                input_type,
                 runner_conf.int_to_decimal_coercion_enabled,
                 runner_conf.binary_as_bytes,
             )
@@ -2850,8 +2848,8 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
             )
             ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
             # Arrow-optimized Python UDF takes input types
-            input_types = (
-                [f.dataType for f in 
_parse_datatype_json_string(utf8_deserializer.loads(infile))]
+            input_type = (
+                _parse_datatype_json_string(utf8_deserializer.loads(infile))
                 if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
                 else None
             )
@@ -2864,7 +2862,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
                 struct_in_pandas,
                 ndarray_as_list,
                 True,
-                input_types,
+                input_type,
                 
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
             )
     else:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to