This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1790508f6740 [SPARK-55788][PYTHON] Support ExtensionDType for integers 
in Pandas UDF
1790508f6740 is described below

commit 1790508f6740e1fd412d3103b00be89ddb540ea3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Mar 4 20:21:12 2026 -0800

    [SPARK-55788][PYTHON] Support ExtensionDType for integers in Pandas UDF
    
    ### What changes were proposed in this pull request?
    Always use ExtensionDType for integers in Pandas UDF
    
    ### Why are the changes needed?
    Current DType for integers are not predictable: it depends on the 
nullability of **current batch**
    
    ### Does this PR introduce _any_ user-facing change?
    yes, controlled by new config 
`spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype`
    
    ### How was this patch tested?
    Added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #54568 from zhengruifeng/add_config.
    
    Lead-authored-by: Ruifeng Zheng <[email protected]>
    Co-authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Takuya Ueshin <[email protected]>
---
 python/pyspark/sql/conversion.py                   | 18 +++++++++---
 python/pyspark/sql/pandas/serializers.py           | 34 +++++++++++++++++++++-
 .../golden_pandas_udf_input_type_coercion_base.csv |  8 ++---
 .../golden_pandas_udf_input_type_coercion_base.md  |  8 ++---
 ...f_input_type_coercion_with_arrow_and_pandas.csv |  8 ++---
 ...df_input_type_coercion_with_arrow_and_pandas.md |  8 ++---
 python/pyspark/sql/tests/pandas/test_pandas_udf.py |  9 +++---
 python/pyspark/worker.py                           | 33 +++++++++++++++++++--
 .../org/apache/spark/sql/internal/SQLConf.scala    | 11 +++++++
 .../sql/execution/python/ArrowPythonRunner.scala   |  5 +++-
 10 files changed, 113 insertions(+), 29 deletions(-)

diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index ca6d6c61f1ef..aa22594585bb 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -115,6 +115,7 @@ class ArrowBatchTransformer:
         schema: Optional["StructType"] = None,
         struct_in_pandas: str = "dict",
         ndarray_as_list: bool = False,
+        prefer_int_ext_dtype: bool = False,
         df_for_struct: bool = False,
     ) -> List[Union["pd.Series", "pd.DataFrame"]]:
         """
@@ -132,6 +133,8 @@ class ArrowBatchTransformer:
             How to represent struct in pandas ("dict", "row", etc.)
         ndarray_as_list : bool
             Whether to convert ndarray as list.
+        prefer_int_ext_dtype : bool, optional
+            Whether to convert integers to Pandas ExtensionDType.
         df_for_struct : bool
             If True, convert struct columns to DataFrame instead of Series.
 
@@ -156,6 +159,7 @@ class ArrowBatchTransformer:
                 timezone=timezone,
                 struct_in_pandas=struct_in_pandas,
                 ndarray_as_list=ndarray_as_list,
+                prefer_int_ext_dtype=prefer_int_ext_dtype,
                 df_for_struct=df_for_struct,
             )
             for i in range(batch.num_columns)
@@ -1427,6 +1431,7 @@ class ArrowArrayToPandasConversion:
         timezone: Optional[str] = None,
         struct_in_pandas: str = "dict",
         ndarray_as_list: bool = False,
+        prefer_int_ext_dtype: bool = False,
         df_for_struct: bool = False,
     ) -> Union["pd.Series", "pd.DataFrame"]:
         """
@@ -1447,6 +1452,8 @@ class ArrowArrayToPandasConversion:
             Default is "dict".
         ndarray_as_list : bool, optional
             Whether to convert numpy ndarrays to Python lists. Default is 
False.
+        prefer_int_ext_dtype : bool, optional
+            Whether to convert integers to Pandas ExtensionDType.
         df_for_struct : bool, optional
             If True, convert struct columns to a DataFrame with columns 
corresponding
             to struct fields instead of a Series. Default is False.
@@ -1465,6 +1472,7 @@ class ArrowArrayToPandasConversion:
                 timezone=timezone,
                 struct_in_pandas=struct_in_pandas,
                 ndarray_as_list=ndarray_as_list,
+                prefer_int_ext_dtype=prefer_int_ext_dtype,
                 df_for_struct=df_for_struct,
             )
 
@@ -1615,6 +1623,7 @@ class ArrowArrayToPandasConversion:
         timezone: Optional[str] = None,
         struct_in_pandas: Optional[str] = None,
         ndarray_as_list: bool = False,
+        prefer_int_ext_dtype: bool = False,
         df_for_struct: bool = False,
     ) -> Union["pd.Series", "pd.DataFrame"]:
         import pyarrow as pa
@@ -1637,6 +1646,7 @@ class ArrowArrayToPandasConversion:
                         timezone=timezone,
                         struct_in_pandas=struct_in_pandas,
                         ndarray_as_list=ndarray_as_list,
+                        prefer_int_ext_dtype=prefer_int_ext_dtype,
                         df_for_struct=False,  # always False for child fields
                     )
                     for field_arr, field in zip(arr.flatten(), spark_type)
@@ -1657,22 +1667,22 @@ class ArrowArrayToPandasConversion:
 
         # conversion methods are selected based on benchmark 
python/benchmarks/bench_arrow.py
         if isinstance(spark_type, ByteType):
-            if arr.null_count > 0:
+            if prefer_int_ext_dtype:
                 series = 
arr.to_pandas(types_mapper=pd.ArrowDtype).astype(pd.Int8Dtype())
             else:
                 series = arr.to_pandas()
         elif isinstance(spark_type, ShortType):
-            if arr.null_count > 0:
+            if prefer_int_ext_dtype:
                 series = 
arr.to_pandas(types_mapper=pd.ArrowDtype).astype(pd.Int16Dtype())
             else:
                 series = arr.to_pandas()
         elif isinstance(spark_type, IntegerType):
-            if arr.null_count > 0:
+            if prefer_int_ext_dtype:
                 series = 
arr.to_pandas(types_mapper=pd.ArrowDtype).astype(pd.Int32Dtype())
             else:
                 series = arr.to_pandas()
         elif isinstance(spark_type, LongType):
-            if arr.null_count > 0:
+            if prefer_int_ext_dtype:
                 series = 
arr.to_pandas(types_mapper=pd.ArrowDtype).astype(pd.Int64Dtype())
             else:
                 series = arr.to_pandas()
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index aac6df47a3b8..019a88472648 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -419,6 +419,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         How to represent struct in pandas ("dict", "row", etc.). Default is 
"dict".
     ndarray_as_list : bool, optional
         Whether to convert ndarray as list. Default is False.
+    prefer_int_ext_dtype : bool, optional
+        Whether to convert integers to Pandas ExtensionDType. Default is False.
     df_for_struct : bool, optional
         If True, convert struct columns to DataFrame instead of Series. 
Default is False.
     """
@@ -431,6 +433,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         prefers_large_types: bool = False,
         struct_in_pandas: str = "dict",
         ndarray_as_list: bool = False,
+        prefer_int_ext_dtype: bool = False,
         df_for_struct: bool = False,
         input_type: Optional["StructType"] = None,
         arrow_cast: bool = False,
@@ -442,6 +445,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         self._prefers_large_types = prefers_large_types
         self._struct_in_pandas = struct_in_pandas
         self._ndarray_as_list = ndarray_as_list
+        self._prefer_int_ext_dtype = prefer_int_ext_dtype
         self._df_for_struct = df_for_struct
         if input_type is not None:
             assert isinstance(input_type, StructType)
@@ -486,6 +490,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
                 schema=self._input_type,
                 struct_in_pandas=self._struct_in_pandas,
                 ndarray_as_list=self._ndarray_as_list,
+                prefer_int_ext_dtype=self._prefer_int_ext_dtype,
                 df_for_struct=self._df_for_struct,
             ),
             super().load_stream(stream),
@@ -508,6 +513,7 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
         df_for_struct: bool = False,
         struct_in_pandas: str = "dict",
         ndarray_as_list: bool = False,
+        prefer_int_ext_dtype: bool = False,
         arrow_cast: bool = False,
         input_type: Optional[StructType] = None,
         int_to_decimal_coercion_enabled: bool = False,
@@ -522,6 +528,7 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
             prefers_large_types,
             struct_in_pandas,
             ndarray_as_list,
+            prefer_int_ext_dtype,
             df_for_struct,
             input_type,
             arrow_cast,
@@ -742,7 +749,14 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
     Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
     """
 
-    def __init__(self, timezone, safecheck, input_type, 
int_to_decimal_coercion_enabled):
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        input_type,
+        prefer_int_ext_dtype,
+        int_to_decimal_coercion_enabled,
+    ):
         super().__init__(
             timezone=timezone,
             safecheck=safecheck,
@@ -760,6 +774,7 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
             # To ensure consistency across regular and arrow-optimized UDTFs, 
we further
             # convert these numpy.ndarrays into Python lists.
             ndarray_as_list=True,
+            prefer_int_ext_dtype=prefer_int_ext_dtype,
             # Enables explicit casting for mismatched return types of Arrow 
Python UDTFs.
             arrow_cast=True,
             input_type=input_type,
@@ -806,6 +821,7 @@ class 
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
         timezone,
         safecheck,
         assign_cols_by_name,
+        prefer_int_ext_dtype,
         int_to_decimal_coercion_enabled,
     ):
         super().__init__(
@@ -815,6 +831,7 @@ class 
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
             df_for_struct=False,
             struct_in_pandas="dict",
             ndarray_as_list=False,
+            prefer_int_ext_dtype=prefer_int_ext_dtype,
             arrow_cast=True,
             input_type=None,
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
@@ -837,6 +854,7 @@ class 
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
                         schema=self._input_type,
                         struct_in_pandas=self._struct_in_pandas,
                         ndarray_as_list=self._ndarray_as_list,
+                        prefer_int_ext_dtype=self._prefer_int_ext_dtype,
                         df_for_struct=self._df_for_struct,
                     )
                 ),
@@ -858,6 +876,7 @@ class 
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
         timezone,
         safecheck,
         assign_cols_by_name,
+        prefer_int_ext_dtype,
         int_to_decimal_coercion_enabled,
     ):
         super().__init__(
@@ -867,6 +886,7 @@ class 
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
             df_for_struct=False,
             struct_in_pandas="dict",
             ndarray_as_list=False,
+            prefer_int_ext_dtype=prefer_int_ext_dtype,
             arrow_cast=True,
             input_type=None,
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
@@ -934,6 +954,7 @@ class 
CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
                     schema=from_arrow_schema(left_table.schema),
                     struct_in_pandas=self._struct_in_pandas,
                     ndarray_as_list=self._ndarray_as_list,
+                    prefer_int_ext_dtype=self._prefer_int_ext_dtype,
                     df_for_struct=self._df_for_struct,
                 ),
                 ArrowBatchTransformer.to_pandas(
@@ -942,6 +963,7 @@ class 
CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
                     schema=from_arrow_schema(right_table.schema),
                     struct_in_pandas=self._struct_in_pandas,
                     ndarray_as_list=self._ndarray_as_list,
+                    prefer_int_ext_dtype=self._prefer_int_ext_dtype,
                     df_for_struct=self._df_for_struct,
                 ),
             )
@@ -970,6 +992,7 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
         timezone,
         safecheck,
         assign_cols_by_name,
+        prefer_int_ext_dtype,
         state_object_schema,
         arrow_max_records_per_batch,
         prefers_large_var_types,
@@ -982,6 +1005,7 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
             df_for_struct=False,
             struct_in_pandas="dict",
             ndarray_as_list=False,
+            prefer_int_ext_dtype=prefer_int_ext_dtype,
             arrow_cast=True,
             input_type=None,
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
@@ -1114,6 +1138,7 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
                     schema=None,
                     struct_in_pandas=self._struct_in_pandas,
                     ndarray_as_list=self._ndarray_as_list,
+                    prefer_int_ext_dtype=self._prefer_int_ext_dtype,
                     df_for_struct=self._df_for_struct,
                 )[0]
 
@@ -1151,6 +1176,7 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
                         schema=None,
                         struct_in_pandas=self._struct_in_pandas,
                         ndarray_as_list=self._ndarray_as_list,
+                        prefer_int_ext_dtype=self._prefer_int_ext_dtype,
                         df_for_struct=self._df_for_struct,
                     )
 
@@ -1365,6 +1391,7 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
         timezone,
         safecheck,
         assign_cols_by_name,
+        prefer_int_ext_dtype,
         arrow_max_records_per_batch,
         arrow_max_bytes_per_batch,
         int_to_decimal_coercion_enabled,
@@ -1376,6 +1403,7 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
             df_for_struct=False,
             struct_in_pandas="dict",
             ndarray_as_list=False,
+            prefer_int_ext_dtype=prefer_int_ext_dtype,
             arrow_cast=True,
             input_type=None,
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
@@ -1436,6 +1464,7 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
                         schema=self._input_type,
                         struct_in_pandas=self._struct_in_pandas,
                         ndarray_as_list=self._ndarray_as_list,
+                        prefer_int_ext_dtype=self._prefer_int_ext_dtype,
                         df_for_struct=self._df_for_struct,
                     )
                     for row in pd.concat(data_pandas, 
axis=1).itertuples(index=False):
@@ -1497,6 +1526,7 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
         timezone,
         safecheck,
         assign_cols_by_name,
+        prefer_int_ext_dtype,
         arrow_max_records_per_batch,
         arrow_max_bytes_per_batch,
         int_to_decimal_coercion_enabled,
@@ -1505,6 +1535,7 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
             timezone,
             safecheck,
             assign_cols_by_name,
+            prefer_int_ext_dtype,
             arrow_max_records_per_batch,
             arrow_max_bytes_per_batch,
             int_to_decimal_coercion_enabled,
@@ -1566,6 +1597,7 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
                     schema=self._input_type,
                     struct_in_pandas=self._struct_in_pandas,
                     ndarray_as_list=self._ndarray_as_list,
+                    prefer_int_ext_dtype=self._prefer_int_ext_dtype,
                     df_for_struct=self._df_for_struct,
                 )
 
diff --git 
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
 
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
index 965213ba4820..4c2817525801 100644
--- 
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
+++ 
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
@@ -1,12 +1,12 @@
        Test Case       Spark Type      Spark Value     Python Type     Python 
Value
 0      byte_values     tinyint [-128, 127, 0]  ['int8', 'int8', 'int8']        
[-128, 127, 0]
-1      byte_null       tinyint [None, 42]      ['Int8', 'Int8']        [None, 
42]
+1      byte_null       tinyint [None, 42]      ['float64', 'float64']  [None, 
42]
 2      short_values    smallint        [-32768, 32767, 0]      ['int16', 
'int16', 'int16']     [-32768, 32767, 0]
-3      short_null      smallint        [None, 123]     ['Int16', 'Int16']      
[None, 123]
+3      short_null      smallint        [None, 123]     ['float64', 'float64']  
[None, 123]
 4      int_values      int     [-2147483648, 2147483647, 0]    ['int32', 
'int32', 'int32']     [-2147483648, 2147483647, 0]
-5      int_null        int     [None, 456]     ['Int32', 'Int32']      [None, 
456]
+5      int_null        int     [None, 456]     ['float64', 'float64']  [None, 
456]
 6      long_values     bigint  [-9223372036854775808, 9223372036854775807, 0]  
['int64', 'int64', 'int64']     [-9223372036854775808, 9223372036854775807, 0]
-7      long_null       bigint  [None, 789]     ['Int64', 'Int64']      [None, 
789]
+7      long_null       bigint  [None, 789]     ['float64', 'float64']  [None, 
789]
 8      float_values    float   [0.0, 1.0, 3.140000104904175]   ['float32', 
'float32', 'float32']       [0.0, 1.0, 3.140000104904175]
 9      float_null      float   [None, 3.140000104904175]       ['float32', 
'float32']  [None, 3.140000104904175]
 10     double_values   double  [0.0, 1.0, 0.3333333333333333]  ['float64', 
'float64', 'float64']       [0.0, 1.0, 0.3333333333333333]
diff --git 
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
 
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
index 5240057fbfd8..6a028a978fe4 100644
--- 
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
+++ 
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
@@ -1,13 +1,13 @@
 |    | Test Case                 | Spark Type                               | 
Spark Value                                                               | 
Python Type                          | Python Value                             
                                 |
 
|----|---------------------------|------------------------------------------|---------------------------------------------------------------------------|--------------------------------------|---------------------------------------------------------------------------|
 |  0 | byte_values               | tinyint                                  | 
[-128, 127, 0]                                                            | 
['int8', 'int8', 'int8']             | [-128, 127, 0]                           
                                 |
-|  1 | byte_null                 | tinyint                                  | 
[None, 42]                                                                | 
['Int8', 'Int8']                     | [None, 42]                               
                                 |
+|  1 | byte_null                 | tinyint                                  | 
[None, 42]                                                                | 
['float64', 'float64']               | [None, 42]                               
                                 |
 |  2 | short_values              | smallint                                 | 
[-32768, 32767, 0]                                                        | 
['int16', 'int16', 'int16']          | [-32768, 32767, 0]                       
                                 |
-|  3 | short_null                | smallint                                 | 
[None, 123]                                                               | 
['Int16', 'Int16']                   | [None, 123]                              
                                 |
+|  3 | short_null                | smallint                                 | 
[None, 123]                                                               | 
['float64', 'float64']               | [None, 123]                              
                                 |
 |  4 | int_values                | int                                      | 
[-2147483648, 2147483647, 0]                                              | 
['int32', 'int32', 'int32']          | [-2147483648, 2147483647, 0]             
                                 |
-|  5 | int_null                  | int                                      | 
[None, 456]                                                               | 
['Int32', 'Int32']                   | [None, 456]                              
                                 |
+|  5 | int_null                  | int                                      | 
[None, 456]                                                               | 
['float64', 'float64']               | [None, 456]                              
                                 |
 |  6 | long_values               | bigint                                   | 
[-9223372036854775808, 9223372036854775807, 0]                            | 
['int64', 'int64', 'int64']          | [-9223372036854775808, 
9223372036854775807, 0]                            |
-|  7 | long_null                 | bigint                                   | 
[None, 789]                                                               | 
['Int64', 'Int64']                   | [None, 789]                              
                                 |
+|  7 | long_null                 | bigint                                   | 
[None, 789]                                                               | 
['float64', 'float64']               | [None, 789]                              
                                 |
 |  8 | float_values              | float                                    | 
[0.0, 1.0, 3.140000104904175]                                             | 
['float32', 'float32', 'float32']    | [0.0, 1.0, 3.140000104904175]            
                                 |
 |  9 | float_null                | float                                    | 
[None, 3.140000104904175]                                                 | 
['float32', 'float32']               | [None, 3.140000104904175]                
                                 |
 | 10 | double_values             | double                                   | 
[0.0, 1.0, 0.3333333333333333]                                            | 
['float64', 'float64', 'float64']    | [0.0, 1.0, 0.3333333333333333]           
                                 |
diff --git 
a/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.csv
 
b/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.csv
index 9ed7dcba95c8..be52b9885f76 100644
--- 
a/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.csv
+++ 
b/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.csv
@@ -1,12 +1,12 @@
        Test Case       Spark Type      Spark Value     Python Type     Python 
Value
 0      byte_values     tinyint [-128, 127, 0]  ['int', 'int', 'int']   
['-128', '127', '0']
-1      byte_null       tinyint [None, 42]      ['NAType', 'int8']      
['<NA>', '42']
+1      byte_null       tinyint [None, 42]      ['float', 'float']      ['nan', 
'42.0']
 2      short_values    smallint        [-32768, 32767, 0]      ['int', 'int', 
'int']   ['-32768', '32767', '0']
-3      short_null      smallint        [None, 123]     ['NAType', 'int16']     
['<NA>', '123']
+3      short_null      smallint        [None, 123]     ['float', 'float']      
['nan', '123.0']
 4      int_values      int     [-2147483648, 2147483647, 0]    ['int', 'int', 
'int']   ['-2147483648', '2147483647', '0']
-5      int_null        int     [None, 456]     ['NAType', 'int32']     
['<NA>', '456']
+5      int_null        int     [None, 456]     ['float', 'float']      ['nan', 
'456.0']
 6      long_values     bigint  [-9223372036854775808, 9223372036854775807, 0]  
['int', 'int', 'int']   ['-9223372036854775808', '9223372036854775807', '0']
-7      long_null       bigint  [None, 789]     ['NAType', 'int64']     
['<NA>', '789']
+7      long_null       bigint  [None, 789]     ['float', 'float']      ['nan', 
'789.0']
 8      float_values    float   [0.0, 1.0, 3.140000104904175]   ['float', 
'float', 'float']     ['0.0', '1.0', '3.140000104904175']
 9      float_null      float   [None, 3.140000104904175]       ['float', 
'float']      ['nan', '3.140000104904175']
 10     double_values   double  [0.0, 1.0, 0.3333333333333333]  ['float', 
'float', 'float']     ['0.0', '1.0', '0.3333333333333333']
diff --git 
a/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.md
 
b/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.md
index 9b70503e679a..8bd2ffd5e1f1 100644
--- 
a/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.md
+++ 
b/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.md
@@ -1,13 +1,13 @@
 |    | Test Case                 | Spark Type                               | 
Spark Value                                                               | 
Python Type                 | Python Value                                      
                                                    |
 
|----|---------------------------|------------------------------------------|---------------------------------------------------------------------------|-----------------------------|-------------------------------------------------------------------------------------------------------|
 |  0 | byte_values               | tinyint                                  | 
[-128, 127, 0]                                                            | 
['int', 'int', 'int']       | ['-128', '127', '0']                              
                                                    |
-|  1 | byte_null                 | tinyint                                  | 
[None, 42]                                                                | 
['NAType', 'int8']          | ['<NA>', '42']                                    
                                                    |
+|  1 | byte_null                 | tinyint                                  | 
[None, 42]                                                                | 
['float', 'float']          | ['nan', '42.0']                                   
                                                    |
 |  2 | short_values              | smallint                                 | 
[-32768, 32767, 0]                                                        | 
['int', 'int', 'int']       | ['-32768', '32767', '0']                          
                                                    |
-|  3 | short_null                | smallint                                 | 
[None, 123]                                                               | 
['NAType', 'int16']         | ['<NA>', '123']                                   
                                                    |
+|  3 | short_null                | smallint                                 | 
[None, 123]                                                               | 
['float', 'float']          | ['nan', '123.0']                                  
                                                    |
 |  4 | int_values                | int                                      | 
[-2147483648, 2147483647, 0]                                              | 
['int', 'int', 'int']       | ['-2147483648', '2147483647', '0']                
                                                    |
-|  5 | int_null                  | int                                      | 
[None, 456]                                                               | 
['NAType', 'int32']         | ['<NA>', '456']                                   
                                                    |
+|  5 | int_null                  | int                                      | 
[None, 456]                                                               | 
['float', 'float']          | ['nan', '456.0']                                  
                                                    |
 |  6 | long_values               | bigint                                   | 
[-9223372036854775808, 9223372036854775807, 0]                            | 
['int', 'int', 'int']       | ['-9223372036854775808', '9223372036854775807', 
'0']                                                  |
-|  7 | long_null                 | bigint                                   | 
[None, 789]                                                               | 
['NAType', 'int64']         | ['<NA>', '789']                                   
                                                    |
+|  7 | long_null                 | bigint                                   | 
[None, 789]                                                               | 
['float', 'float']          | ['nan', '789.0']                                  
                                                    |
 |  8 | float_values              | float                                    | 
[0.0, 1.0, 3.140000104904175]                                             | 
['float', 'float', 'float'] | ['0.0', '1.0', '3.140000104904175']               
                                                    |
 |  9 | float_null                | float                                    | 
[None, 3.140000104904175]                                                 | 
['float', 'float']          | ['nan', '3.140000104904175']                      
                                                    |
 | 10 | double_values             | double                                   | 
[0.0, 1.0, 0.3333333333333333]                                            | 
['float', 'float', 'float'] | ['0.0', '1.0', '0.3333333333333333']              
                                                    |
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 370629696807..db5d2072a4bf 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -472,10 +472,11 @@ class PandasUDFTestsMixin:
             AS tab(a, b)
             """
 
-        df = self.spark.sql(query).repartition(1).sortWithinPartitions("b")
-        expected = df.select("a").collect()
-        results = df.select(identity("a").alias("a")).collect()
-        self.assertEqual(results, expected)
+        with 
self.sql_conf({"spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype": 
True}):
+            df = self.spark.sql(query).repartition(1).sortWithinPartitions("b")
+            expected = df.select("a").collect()
+            results = df.select(identity("a").alias("a")).collect()
+            self.assertEqual(results, expected)
 
 
 class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 380cfb96db48..1892dcbf3bf6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -156,6 +156,13 @@ class RunnerConf(Conf):
             == "true"
         )
 
+    @property
+    def prefer_int_ext_dtype(self) -> bool:
+        return (
+            
self.get("spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype", 
"false")
+            == "true"
+        )
+
     @property
     def timezone(self) -> Optional[str]:
         return self.get("spark.sql.session.timeZone", None, lower_str=False)
@@ -1489,6 +1496,7 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf):
                 runner_conf.timezone,
                 runner_conf.safecheck,
                 input_type=input_type,
+                prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
                 
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
             )
         else:
@@ -2693,6 +2701,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
                 runner_conf.timezone,
                 runner_conf.safecheck,
                 runner_conf.assign_cols_by_name,
+                runner_conf.prefer_int_ext_dtype,
                 runner_conf.int_to_decimal_coercion_enabled,
             )
         elif (
@@ -2703,6 +2712,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
                 runner_conf.timezone,
                 runner_conf.safecheck,
                 runner_conf.assign_cols_by_name,
+                runner_conf.prefer_int_ext_dtype,
                 runner_conf.int_to_decimal_coercion_enabled,
             )
         elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
@@ -2712,6 +2722,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
                 runner_conf.timezone,
                 runner_conf.safecheck,
                 runner_conf.assign_cols_by_name,
+                prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
                 
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
                 arrow_cast=True,
             )
@@ -2720,6 +2731,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
                 runner_conf.timezone,
                 runner_conf.safecheck,
                 runner_conf.assign_cols_by_name,
+                runner_conf.prefer_int_ext_dtype,
                 eval_conf.state_value_schema,
                 runner_conf.arrow_max_records_per_batch,
                 runner_conf.use_large_var_types,
@@ -2730,6 +2742,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
                 runner_conf.timezone,
                 runner_conf.safecheck,
                 runner_conf.assign_cols_by_name,
+                runner_conf.prefer_int_ext_dtype,
                 runner_conf.arrow_max_records_per_batch,
                 runner_conf.arrow_max_bytes_per_batch,
                 
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
@@ -2739,6 +2752,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
                 runner_conf.timezone,
                 runner_conf.safecheck,
                 runner_conf.assign_cols_by_name,
+                runner_conf.prefer_int_ext_dtype,
                 runner_conf.arrow_max_records_per_batch,
                 runner_conf.arrow_max_bytes_per_batch,
                 
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
@@ -2795,6 +2809,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
                 df_for_struct,
                 struct_in_pandas,
                 ndarray_as_list,
+                runner_conf.prefer_int_ext_dtype,
                 True,
                 input_type,
                 
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
@@ -2955,7 +2970,11 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
             else:
                 table = pa.table({})
             # Convert to pandas once for the entire group
-            all_series = ArrowBatchTransformer.to_pandas(table, 
timezone=ser._timezone)
+            all_series = ArrowBatchTransformer.to_pandas(
+                table,
+                timezone=ser._timezone,
+                prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
+            )
             key_series = [all_series[o] for o in key_offsets]
             value_series = [all_series[o] for o in value_offsets]
             yield from f(key_series, value_series)
@@ -2974,14 +2993,22 @@ def read_udfs(pickleSer, infile, eval_type, 
runner_conf, eval_conf):
 
         def mapper(batch_iter):
             # Convert first Arrow batch to pandas to extract keys
-            first_series = ArrowBatchTransformer.to_pandas(next(batch_iter), 
timezone=ser._timezone)
+            first_series = ArrowBatchTransformer.to_pandas(
+                next(batch_iter),
+                timezone=ser._timezone,
+                prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
+            )
             key_series = [first_series[o] for o in parsed_offsets[0][0]]
 
             # Lazily convert remaining Arrow batches to pandas Series
             def value_series_gen():
                 yield [first_series[o] for o in parsed_offsets[0][1]]
                 for batch in batch_iter:
-                    series = ArrowBatchTransformer.to_pandas(batch, 
timezone=ser._timezone)
+                    series = ArrowBatchTransformer.to_pandas(
+                        batch,
+                        timezone=ser._timezone,
+                        prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
+                    )
                     yield [series[o] for o in parsed_offsets[0][1]]
 
             yield from f(key_series, value_series_gen())
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 0a0a448ecd2d..060ee811dd9c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4469,6 +4469,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val PYTHON_UDF_PANDAS_PREFER_INT_EXTENSION_DTYPE =
+    buildConf("spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype")
+      .doc("When true, convert integers to Pandas ExtensionDtype (e.g. 
pandas.Int64Dtype) " +
+        "for Pandas UDF execution. Otherwise, depends on the behavior of " +
+        "pyarrow.Array.to_pandas on each input arrow batch.")
+      .version("4.2.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val PYTHON_TABLE_UDF_ARROW_ENABLED =
     buildConf("spark.sql.execution.pythonUDTF.arrow.enabled")
       .doc("Enable Arrow optimization for Python UDTFs.")
@@ -7886,6 +7895,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def arrowSafeTypeConversion: Boolean = 
getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION)
 
+  def preferIntExtDtype: Boolean = 
getConf(SQLConf.PYTHON_UDF_PANDAS_PREFER_INT_EXTENSION_DTYPE)
+
   def pysparkWorkerPythonExecutable: Option[String] =
     getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 94354e815ad3..9af4299c12f3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -165,6 +165,9 @@ object ArrowPythonRunner {
     val intToDecimalCoercion = Seq(
       SQLConf.PYTHON_UDF_PANDAS_INT_TO_DECIMAL_COERCION_ENABLED.key ->
       conf.getConf(SQLConf.PYTHON_UDF_PANDAS_INT_TO_DECIMAL_COERCION_ENABLED, 
false).toString)
+    val preferIntExtDtype = Seq(
+      SQLConf.PYTHON_UDF_PANDAS_PREFER_INT_EXTENSION_DTYPE.key ->
+        conf.preferIntExtDtype.toString)
     val binaryAsBytes = Seq(
       SQLConf.PYSPARK_BINARY_AS_BYTES.key ->
       conf.pysparkBinaryAsBytes.toString)
@@ -176,7 +179,7 @@ object ArrowPythonRunner {
     ).getOrElse(Seq.empty)
     Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++
       arrowAyncParallelism ++ useLargeVarTypes ++
-      intToDecimalCoercion ++ binaryAsBytes ++
+      intToDecimalCoercion ++ preferIntExtDtype ++ binaryAsBytes ++
       legacyPandasConversion ++ legacyPandasConversionUDF ++
       udfProfiler ++ dataSourceProfiler: _*)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to