This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1790508f6740 [SPARK-55788][PYTHON] Support ExtensionDType for integers
in Pandas UDF
1790508f6740 is described below
commit 1790508f6740e1fd412d3103b00be89ddb540ea3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Mar 4 20:21:12 2026 -0800
[SPARK-55788][PYTHON] Support ExtensionDType for integers in Pandas UDF
### What changes were proposed in this pull request?
Always use ExtensionDType for integers in Pandas UDF
### Why are the changes needed?
Current DType for integers are not predictable: it depends on the
nullability of **current batch**
### Does this PR introduce _any_ user-facing change?
yes, controlled by new config
`spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype`
### How was this patch tested?
Added tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #54568 from zhengruifeng/add_config.
Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
python/pyspark/sql/conversion.py | 18 +++++++++---
python/pyspark/sql/pandas/serializers.py | 34 +++++++++++++++++++++-
.../golden_pandas_udf_input_type_coercion_base.csv | 8 ++---
.../golden_pandas_udf_input_type_coercion_base.md | 8 ++---
...f_input_type_coercion_with_arrow_and_pandas.csv | 8 ++---
...df_input_type_coercion_with_arrow_and_pandas.md | 8 ++---
python/pyspark/sql/tests/pandas/test_pandas_udf.py | 9 +++---
python/pyspark/worker.py | 33 +++++++++++++++++++--
.../org/apache/spark/sql/internal/SQLConf.scala | 11 +++++++
.../sql/execution/python/ArrowPythonRunner.scala | 5 +++-
10 files changed, 113 insertions(+), 29 deletions(-)
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index ca6d6c61f1ef..aa22594585bb 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -115,6 +115,7 @@ class ArrowBatchTransformer:
schema: Optional["StructType"] = None,
struct_in_pandas: str = "dict",
ndarray_as_list: bool = False,
+ prefer_int_ext_dtype: bool = False,
df_for_struct: bool = False,
) -> List[Union["pd.Series", "pd.DataFrame"]]:
"""
@@ -132,6 +133,8 @@ class ArrowBatchTransformer:
How to represent struct in pandas ("dict", "row", etc.)
ndarray_as_list : bool
Whether to convert ndarray as list.
+ prefer_int_ext_dtype : bool, optional
+ Whether to convert integers to Pandas ExtensionDType.
df_for_struct : bool
If True, convert struct columns to DataFrame instead of Series.
@@ -156,6 +159,7 @@ class ArrowBatchTransformer:
timezone=timezone,
struct_in_pandas=struct_in_pandas,
ndarray_as_list=ndarray_as_list,
+ prefer_int_ext_dtype=prefer_int_ext_dtype,
df_for_struct=df_for_struct,
)
for i in range(batch.num_columns)
@@ -1427,6 +1431,7 @@ class ArrowArrayToPandasConversion:
timezone: Optional[str] = None,
struct_in_pandas: str = "dict",
ndarray_as_list: bool = False,
+ prefer_int_ext_dtype: bool = False,
df_for_struct: bool = False,
) -> Union["pd.Series", "pd.DataFrame"]:
"""
@@ -1447,6 +1452,8 @@ class ArrowArrayToPandasConversion:
Default is "dict".
ndarray_as_list : bool, optional
Whether to convert numpy ndarrays to Python lists. Default is
False.
+ prefer_int_ext_dtype : bool, optional
+ Whether to convert integers to Pandas ExtensionDType.
df_for_struct : bool, optional
If True, convert struct columns to a DataFrame with columns
corresponding
to struct fields instead of a Series. Default is False.
@@ -1465,6 +1472,7 @@ class ArrowArrayToPandasConversion:
timezone=timezone,
struct_in_pandas=struct_in_pandas,
ndarray_as_list=ndarray_as_list,
+ prefer_int_ext_dtype=prefer_int_ext_dtype,
df_for_struct=df_for_struct,
)
@@ -1615,6 +1623,7 @@ class ArrowArrayToPandasConversion:
timezone: Optional[str] = None,
struct_in_pandas: Optional[str] = None,
ndarray_as_list: bool = False,
+ prefer_int_ext_dtype: bool = False,
df_for_struct: bool = False,
) -> Union["pd.Series", "pd.DataFrame"]:
import pyarrow as pa
@@ -1637,6 +1646,7 @@ class ArrowArrayToPandasConversion:
timezone=timezone,
struct_in_pandas=struct_in_pandas,
ndarray_as_list=ndarray_as_list,
+ prefer_int_ext_dtype=prefer_int_ext_dtype,
df_for_struct=False, # always False for child fields
)
for field_arr, field in zip(arr.flatten(), spark_type)
@@ -1657,22 +1667,22 @@ class ArrowArrayToPandasConversion:
# conversion methods are selected based on benchmark
python/benchmarks/bench_arrow.py
if isinstance(spark_type, ByteType):
- if arr.null_count > 0:
+ if prefer_int_ext_dtype:
series =
arr.to_pandas(types_mapper=pd.ArrowDtype).astype(pd.Int8Dtype())
else:
series = arr.to_pandas()
elif isinstance(spark_type, ShortType):
- if arr.null_count > 0:
+ if prefer_int_ext_dtype:
series =
arr.to_pandas(types_mapper=pd.ArrowDtype).astype(pd.Int16Dtype())
else:
series = arr.to_pandas()
elif isinstance(spark_type, IntegerType):
- if arr.null_count > 0:
+ if prefer_int_ext_dtype:
series =
arr.to_pandas(types_mapper=pd.ArrowDtype).astype(pd.Int32Dtype())
else:
series = arr.to_pandas()
elif isinstance(spark_type, LongType):
- if arr.null_count > 0:
+ if prefer_int_ext_dtype:
series =
arr.to_pandas(types_mapper=pd.ArrowDtype).astype(pd.Int64Dtype())
else:
series = arr.to_pandas()
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index aac6df47a3b8..019a88472648 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -419,6 +419,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
How to represent struct in pandas ("dict", "row", etc.). Default is
"dict".
ndarray_as_list : bool, optional
Whether to convert ndarray as list. Default is False.
+ prefer_int_ext_dtype : bool, optional
+ Whether to convert integers to Pandas ExtensionDType. Default is False.
df_for_struct : bool, optional
If True, convert struct columns to DataFrame instead of Series.
Default is False.
"""
@@ -431,6 +433,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
prefers_large_types: bool = False,
struct_in_pandas: str = "dict",
ndarray_as_list: bool = False,
+ prefer_int_ext_dtype: bool = False,
df_for_struct: bool = False,
input_type: Optional["StructType"] = None,
arrow_cast: bool = False,
@@ -442,6 +445,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
self._prefers_large_types = prefers_large_types
self._struct_in_pandas = struct_in_pandas
self._ndarray_as_list = ndarray_as_list
+ self._prefer_int_ext_dtype = prefer_int_ext_dtype
self._df_for_struct = df_for_struct
if input_type is not None:
assert isinstance(input_type, StructType)
@@ -486,6 +490,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
schema=self._input_type,
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=self._ndarray_as_list,
+ prefer_int_ext_dtype=self._prefer_int_ext_dtype,
df_for_struct=self._df_for_struct,
),
super().load_stream(stream),
@@ -508,6 +513,7 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
df_for_struct: bool = False,
struct_in_pandas: str = "dict",
ndarray_as_list: bool = False,
+ prefer_int_ext_dtype: bool = False,
arrow_cast: bool = False,
input_type: Optional[StructType] = None,
int_to_decimal_coercion_enabled: bool = False,
@@ -522,6 +528,7 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
prefers_large_types,
struct_in_pandas,
ndarray_as_list,
+ prefer_int_ext_dtype,
df_for_struct,
input_type,
arrow_cast,
@@ -742,7 +749,14 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
"""
- def __init__(self, timezone, safecheck, input_type,
int_to_decimal_coercion_enabled):
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ input_type,
+ prefer_int_ext_dtype,
+ int_to_decimal_coercion_enabled,
+ ):
super().__init__(
timezone=timezone,
safecheck=safecheck,
@@ -760,6 +774,7 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
# To ensure consistency across regular and arrow-optimized UDTFs,
we further
# convert these numpy.ndarrays into Python lists.
ndarray_as_list=True,
+ prefer_int_ext_dtype=prefer_int_ext_dtype,
# Enables explicit casting for mismatched return types of Arrow
Python UDTFs.
arrow_cast=True,
input_type=input_type,
@@ -806,6 +821,7 @@ class
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
timezone,
safecheck,
assign_cols_by_name,
+ prefer_int_ext_dtype,
int_to_decimal_coercion_enabled,
):
super().__init__(
@@ -815,6 +831,7 @@ class
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
df_for_struct=False,
struct_in_pandas="dict",
ndarray_as_list=False,
+ prefer_int_ext_dtype=prefer_int_ext_dtype,
arrow_cast=True,
input_type=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
@@ -837,6 +854,7 @@ class
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
schema=self._input_type,
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=self._ndarray_as_list,
+ prefer_int_ext_dtype=self._prefer_int_ext_dtype,
df_for_struct=self._df_for_struct,
)
),
@@ -858,6 +876,7 @@ class
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
timezone,
safecheck,
assign_cols_by_name,
+ prefer_int_ext_dtype,
int_to_decimal_coercion_enabled,
):
super().__init__(
@@ -867,6 +886,7 @@ class
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
df_for_struct=False,
struct_in_pandas="dict",
ndarray_as_list=False,
+ prefer_int_ext_dtype=prefer_int_ext_dtype,
arrow_cast=True,
input_type=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
@@ -934,6 +954,7 @@ class
CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
schema=from_arrow_schema(left_table.schema),
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=self._ndarray_as_list,
+ prefer_int_ext_dtype=self._prefer_int_ext_dtype,
df_for_struct=self._df_for_struct,
),
ArrowBatchTransformer.to_pandas(
@@ -942,6 +963,7 @@ class
CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
schema=from_arrow_schema(right_table.schema),
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=self._ndarray_as_list,
+ prefer_int_ext_dtype=self._prefer_int_ext_dtype,
df_for_struct=self._df_for_struct,
),
)
@@ -970,6 +992,7 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
timezone,
safecheck,
assign_cols_by_name,
+ prefer_int_ext_dtype,
state_object_schema,
arrow_max_records_per_batch,
prefers_large_var_types,
@@ -982,6 +1005,7 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
df_for_struct=False,
struct_in_pandas="dict",
ndarray_as_list=False,
+ prefer_int_ext_dtype=prefer_int_ext_dtype,
arrow_cast=True,
input_type=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
@@ -1114,6 +1138,7 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
schema=None,
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=self._ndarray_as_list,
+ prefer_int_ext_dtype=self._prefer_int_ext_dtype,
df_for_struct=self._df_for_struct,
)[0]
@@ -1151,6 +1176,7 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
schema=None,
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=self._ndarray_as_list,
+ prefer_int_ext_dtype=self._prefer_int_ext_dtype,
df_for_struct=self._df_for_struct,
)
@@ -1365,6 +1391,7 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
timezone,
safecheck,
assign_cols_by_name,
+ prefer_int_ext_dtype,
arrow_max_records_per_batch,
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
@@ -1376,6 +1403,7 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
df_for_struct=False,
struct_in_pandas="dict",
ndarray_as_list=False,
+ prefer_int_ext_dtype=prefer_int_ext_dtype,
arrow_cast=True,
input_type=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
@@ -1436,6 +1464,7 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
schema=self._input_type,
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=self._ndarray_as_list,
+ prefer_int_ext_dtype=self._prefer_int_ext_dtype,
df_for_struct=self._df_for_struct,
)
for row in pd.concat(data_pandas,
axis=1).itertuples(index=False):
@@ -1497,6 +1526,7 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
timezone,
safecheck,
assign_cols_by_name,
+ prefer_int_ext_dtype,
arrow_max_records_per_batch,
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
@@ -1505,6 +1535,7 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
timezone,
safecheck,
assign_cols_by_name,
+ prefer_int_ext_dtype,
arrow_max_records_per_batch,
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
@@ -1566,6 +1597,7 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
schema=self._input_type,
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=self._ndarray_as_list,
+ prefer_int_ext_dtype=self._prefer_int_ext_dtype,
df_for_struct=self._df_for_struct,
)
diff --git
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
index 965213ba4820..4c2817525801 100644
---
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
+++
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
@@ -1,12 +1,12 @@
Test Case Spark Type Spark Value Python Type Python
Value
0 byte_values tinyint [-128, 127, 0] ['int8', 'int8', 'int8']
[-128, 127, 0]
-1 byte_null tinyint [None, 42] ['Int8', 'Int8'] [None,
42]
+1 byte_null tinyint [None, 42] ['float64', 'float64'] [None,
42]
2 short_values smallint [-32768, 32767, 0] ['int16',
'int16', 'int16'] [-32768, 32767, 0]
-3 short_null smallint [None, 123] ['Int16', 'Int16']
[None, 123]
+3 short_null smallint [None, 123] ['float64', 'float64']
[None, 123]
4 int_values int [-2147483648, 2147483647, 0] ['int32',
'int32', 'int32'] [-2147483648, 2147483647, 0]
-5 int_null int [None, 456] ['Int32', 'Int32'] [None,
456]
+5 int_null int [None, 456] ['float64', 'float64'] [None,
456]
6 long_values bigint [-9223372036854775808, 9223372036854775807, 0]
['int64', 'int64', 'int64'] [-9223372036854775808, 9223372036854775807, 0]
-7 long_null bigint [None, 789] ['Int64', 'Int64'] [None,
789]
+7 long_null bigint [None, 789] ['float64', 'float64'] [None,
789]
8 float_values float [0.0, 1.0, 3.140000104904175] ['float32',
'float32', 'float32'] [0.0, 1.0, 3.140000104904175]
9 float_null float [None, 3.140000104904175] ['float32',
'float32'] [None, 3.140000104904175]
10 double_values double [0.0, 1.0, 0.3333333333333333] ['float64',
'float64', 'float64'] [0.0, 1.0, 0.3333333333333333]
diff --git
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
index 5240057fbfd8..6a028a978fe4 100644
---
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
+++
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
@@ -1,13 +1,13 @@
| | Test Case | Spark Type |
Spark Value |
Python Type | Python Value
|
|----|---------------------------|------------------------------------------|---------------------------------------------------------------------------|--------------------------------------|---------------------------------------------------------------------------|
| 0 | byte_values | tinyint |
[-128, 127, 0] |
['int8', 'int8', 'int8'] | [-128, 127, 0]
|
-| 1 | byte_null | tinyint |
[None, 42] |
['Int8', 'Int8'] | [None, 42]
|
+| 1 | byte_null | tinyint |
[None, 42] |
['float64', 'float64'] | [None, 42]
|
| 2 | short_values | smallint |
[-32768, 32767, 0] |
['int16', 'int16', 'int16'] | [-32768, 32767, 0]
|
-| 3 | short_null | smallint |
[None, 123] |
['Int16', 'Int16'] | [None, 123]
|
+| 3 | short_null | smallint |
[None, 123] |
['float64', 'float64'] | [None, 123]
|
| 4 | int_values | int |
[-2147483648, 2147483647, 0] |
['int32', 'int32', 'int32'] | [-2147483648, 2147483647, 0]
|
-| 5 | int_null | int |
[None, 456] |
['Int32', 'Int32'] | [None, 456]
|
+| 5 | int_null | int |
[None, 456] |
['float64', 'float64'] | [None, 456]
|
| 6 | long_values | bigint |
[-9223372036854775808, 9223372036854775807, 0] |
['int64', 'int64', 'int64'] | [-9223372036854775808,
9223372036854775807, 0] |
-| 7 | long_null | bigint |
[None, 789] |
['Int64', 'Int64'] | [None, 789]
|
+| 7 | long_null | bigint |
[None, 789] |
['float64', 'float64'] | [None, 789]
|
| 8 | float_values | float |
[0.0, 1.0, 3.140000104904175] |
['float32', 'float32', 'float32'] | [0.0, 1.0, 3.140000104904175]
|
| 9 | float_null | float |
[None, 3.140000104904175] |
['float32', 'float32'] | [None, 3.140000104904175]
|
| 10 | double_values | double |
[0.0, 1.0, 0.3333333333333333] |
['float64', 'float64', 'float64'] | [0.0, 1.0, 0.3333333333333333]
|
diff --git
a/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.csv
b/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.csv
index 9ed7dcba95c8..be52b9885f76 100644
---
a/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.csv
+++
b/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.csv
@@ -1,12 +1,12 @@
Test Case Spark Type Spark Value Python Type Python
Value
0 byte_values tinyint [-128, 127, 0] ['int', 'int', 'int']
['-128', '127', '0']
-1 byte_null tinyint [None, 42] ['NAType', 'int8']
['<NA>', '42']
+1 byte_null tinyint [None, 42] ['float', 'float'] ['nan',
'42.0']
2 short_values smallint [-32768, 32767, 0] ['int', 'int',
'int'] ['-32768', '32767', '0']
-3 short_null smallint [None, 123] ['NAType', 'int16']
['<NA>', '123']
+3 short_null smallint [None, 123] ['float', 'float']
['nan', '123.0']
4 int_values int [-2147483648, 2147483647, 0] ['int', 'int',
'int'] ['-2147483648', '2147483647', '0']
-5 int_null int [None, 456] ['NAType', 'int32']
['<NA>', '456']
+5 int_null int [None, 456] ['float', 'float'] ['nan',
'456.0']
6 long_values bigint [-9223372036854775808, 9223372036854775807, 0]
['int', 'int', 'int'] ['-9223372036854775808', '9223372036854775807', '0']
-7 long_null bigint [None, 789] ['NAType', 'int64']
['<NA>', '789']
+7 long_null bigint [None, 789] ['float', 'float'] ['nan',
'789.0']
8 float_values float [0.0, 1.0, 3.140000104904175] ['float',
'float', 'float'] ['0.0', '1.0', '3.140000104904175']
9 float_null float [None, 3.140000104904175] ['float',
'float'] ['nan', '3.140000104904175']
10 double_values double [0.0, 1.0, 0.3333333333333333] ['float',
'float', 'float'] ['0.0', '1.0', '0.3333333333333333']
diff --git
a/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.md
b/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.md
index 9b70503e679a..8bd2ffd5e1f1 100644
---
a/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.md
+++
b/python/pyspark/sql/tests/coercion/golden_python_udf_input_type_coercion_with_arrow_and_pandas.md
@@ -1,13 +1,13 @@
| | Test Case | Spark Type |
Spark Value |
Python Type | Python Value
|
|----|---------------------------|------------------------------------------|---------------------------------------------------------------------------|-----------------------------|-------------------------------------------------------------------------------------------------------|
| 0 | byte_values | tinyint |
[-128, 127, 0] |
['int', 'int', 'int'] | ['-128', '127', '0']
|
-| 1 | byte_null | tinyint |
[None, 42] |
['NAType', 'int8'] | ['<NA>', '42']
|
+| 1 | byte_null | tinyint |
[None, 42] |
['float', 'float'] | ['nan', '42.0']
|
| 2 | short_values | smallint |
[-32768, 32767, 0] |
['int', 'int', 'int'] | ['-32768', '32767', '0']
|
-| 3 | short_null | smallint |
[None, 123] |
['NAType', 'int16'] | ['<NA>', '123']
|
+| 3 | short_null | smallint |
[None, 123] |
['float', 'float'] | ['nan', '123.0']
|
| 4 | int_values | int |
[-2147483648, 2147483647, 0] |
['int', 'int', 'int'] | ['-2147483648', '2147483647', '0']
|
-| 5 | int_null | int |
[None, 456] |
['NAType', 'int32'] | ['<NA>', '456']
|
+| 5 | int_null | int |
[None, 456] |
['float', 'float'] | ['nan', '456.0']
|
| 6 | long_values | bigint |
[-9223372036854775808, 9223372036854775807, 0] |
['int', 'int', 'int'] | ['-9223372036854775808', '9223372036854775807',
'0'] |
-| 7 | long_null | bigint |
[None, 789] |
['NAType', 'int64'] | ['<NA>', '789']
|
+| 7 | long_null | bigint |
[None, 789] |
['float', 'float'] | ['nan', '789.0']
|
| 8 | float_values | float |
[0.0, 1.0, 3.140000104904175] |
['float', 'float', 'float'] | ['0.0', '1.0', '3.140000104904175']
|
| 9 | float_null | float |
[None, 3.140000104904175] |
['float', 'float'] | ['nan', '3.140000104904175']
|
| 10 | double_values | double |
[0.0, 1.0, 0.3333333333333333] |
['float', 'float', 'float'] | ['0.0', '1.0', '0.3333333333333333']
|
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 370629696807..db5d2072a4bf 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -472,10 +472,11 @@ class PandasUDFTestsMixin:
AS tab(a, b)
"""
- df = self.spark.sql(query).repartition(1).sortWithinPartitions("b")
- expected = df.select("a").collect()
- results = df.select(identity("a").alias("a")).collect()
- self.assertEqual(results, expected)
+ with
self.sql_conf({"spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype":
True}):
+ df = self.spark.sql(query).repartition(1).sortWithinPartitions("b")
+ expected = df.select("a").collect()
+ results = df.select(identity("a").alias("a")).collect()
+ self.assertEqual(results, expected)
class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 380cfb96db48..1892dcbf3bf6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -156,6 +156,13 @@ class RunnerConf(Conf):
== "true"
)
+ @property
+ def prefer_int_ext_dtype(self) -> bool:
+ return (
+
self.get("spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype",
"false")
+ == "true"
+ )
+
@property
def timezone(self) -> Optional[str]:
return self.get("spark.sql.session.timeZone", None, lower_str=False)
@@ -1489,6 +1496,7 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf):
runner_conf.timezone,
runner_conf.safecheck,
input_type=input_type,
+ prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
else:
@@ -2693,6 +2701,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
runner_conf.timezone,
runner_conf.safecheck,
runner_conf.assign_cols_by_name,
+ runner_conf.prefer_int_ext_dtype,
runner_conf.int_to_decimal_coercion_enabled,
)
elif (
@@ -2703,6 +2712,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
runner_conf.timezone,
runner_conf.safecheck,
runner_conf.assign_cols_by_name,
+ runner_conf.prefer_int_ext_dtype,
runner_conf.int_to_decimal_coercion_enabled,
)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
@@ -2712,6 +2722,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
runner_conf.timezone,
runner_conf.safecheck,
runner_conf.assign_cols_by_name,
+ prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
arrow_cast=True,
)
@@ -2720,6 +2731,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
runner_conf.timezone,
runner_conf.safecheck,
runner_conf.assign_cols_by_name,
+ runner_conf.prefer_int_ext_dtype,
eval_conf.state_value_schema,
runner_conf.arrow_max_records_per_batch,
runner_conf.use_large_var_types,
@@ -2730,6 +2742,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
runner_conf.timezone,
runner_conf.safecheck,
runner_conf.assign_cols_by_name,
+ runner_conf.prefer_int_ext_dtype,
runner_conf.arrow_max_records_per_batch,
runner_conf.arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
@@ -2739,6 +2752,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
runner_conf.timezone,
runner_conf.safecheck,
runner_conf.assign_cols_by_name,
+ runner_conf.prefer_int_ext_dtype,
runner_conf.arrow_max_records_per_batch,
runner_conf.arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
@@ -2795,6 +2809,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
df_for_struct,
struct_in_pandas,
ndarray_as_list,
+ runner_conf.prefer_int_ext_dtype,
True,
input_type,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
@@ -2955,7 +2970,11 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
else:
table = pa.table({})
# Convert to pandas once for the entire group
- all_series = ArrowBatchTransformer.to_pandas(table,
timezone=ser._timezone)
+ all_series = ArrowBatchTransformer.to_pandas(
+ table,
+ timezone=ser._timezone,
+ prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
+ )
key_series = [all_series[o] for o in key_offsets]
value_series = [all_series[o] for o in value_offsets]
yield from f(key_series, value_series)
@@ -2974,14 +2993,22 @@ def read_udfs(pickleSer, infile, eval_type,
runner_conf, eval_conf):
def mapper(batch_iter):
# Convert first Arrow batch to pandas to extract keys
- first_series = ArrowBatchTransformer.to_pandas(next(batch_iter),
timezone=ser._timezone)
+ first_series = ArrowBatchTransformer.to_pandas(
+ next(batch_iter),
+ timezone=ser._timezone,
+ prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
+ )
key_series = [first_series[o] for o in parsed_offsets[0][0]]
# Lazily convert remaining Arrow batches to pandas Series
def value_series_gen():
yield [first_series[o] for o in parsed_offsets[0][1]]
for batch in batch_iter:
- series = ArrowBatchTransformer.to_pandas(batch,
timezone=ser._timezone)
+ series = ArrowBatchTransformer.to_pandas(
+ batch,
+ timezone=ser._timezone,
+ prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
+ )
yield [series[o] for o in parsed_offsets[0][1]]
yield from f(key_series, value_series_gen())
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 0a0a448ecd2d..060ee811dd9c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4469,6 +4469,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val PYTHON_UDF_PANDAS_PREFER_INT_EXTENSION_DTYPE =
+ buildConf("spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype")
+ .doc("When true, convert integers to Pandas ExtensionDtype (e.g.
pandas.Int64Dtype) " +
+ "for Pandas UDF execution. Otherwise, depends on the behavior of " +
+ "pyarrow.Array.to_pandas on each input arrow batch.")
+ .version("4.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
val PYTHON_TABLE_UDF_ARROW_ENABLED =
buildConf("spark.sql.execution.pythonUDTF.arrow.enabled")
.doc("Enable Arrow optimization for Python UDTFs.")
@@ -7886,6 +7895,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def arrowSafeTypeConversion: Boolean =
getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION)
+ def preferIntExtDtype: Boolean =
getConf(SQLConf.PYTHON_UDF_PANDAS_PREFER_INT_EXTENSION_DTYPE)
+
def pysparkWorkerPythonExecutable: Option[String] =
getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 94354e815ad3..9af4299c12f3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -165,6 +165,9 @@ object ArrowPythonRunner {
val intToDecimalCoercion = Seq(
SQLConf.PYTHON_UDF_PANDAS_INT_TO_DECIMAL_COERCION_ENABLED.key ->
conf.getConf(SQLConf.PYTHON_UDF_PANDAS_INT_TO_DECIMAL_COERCION_ENABLED,
false).toString)
+ val preferIntExtDtype = Seq(
+ SQLConf.PYTHON_UDF_PANDAS_PREFER_INT_EXTENSION_DTYPE.key ->
+ conf.preferIntExtDtype.toString)
val binaryAsBytes = Seq(
SQLConf.PYSPARK_BINARY_AS_BYTES.key ->
conf.pysparkBinaryAsBytes.toString)
@@ -176,7 +179,7 @@ object ArrowPythonRunner {
).getOrElse(Seq.empty)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++
arrowAyncParallelism ++ useLargeVarTypes ++
- intToDecimalCoercion ++ binaryAsBytes ++
+ intToDecimalCoercion ++ preferIntExtDtype ++ binaryAsBytes ++
legacyPandasConversion ++ legacyPandasConversionUDF ++
udfProfiler ++ dataSourceProfiler: _*)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]