This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f6cd38598078 [SPARK-52943][PYTHON] Enable arrow_cast for all pandas
UDF eval types
f6cd38598078 is described below
commit f6cd38598078511fdafd969d31ced14bc4811bb0
Author: Ben Hurdelhey <[email protected]>
AuthorDate: Wed Aug 6 21:01:07 2025 +0800
[SPARK-52943][PYTHON] Enable arrow_cast for all pandas UDF eval types
### What changes were proposed in this pull request?
- this enables arrow_cast for all pandas_udfs
- arrow_cast=True provides a coherent type coercion behavior, it is a bit
more lenient for mismatched types
- arrow_cast was originally introduced in
https://github.com/apache/spark/pull/41800, but up until now it only applied to
a subset of `udf` and `pandas_udf` eval types (see below)
- this should have no performance impact as the cast is only done in a
second attempt when the pandas->arrow conversion fails.
### Why are the changes needed?
- this aligns `pandas_udf()` behavior with `udf(useArrow=True)` behavior,
it makes PySpark more consistent
### Does this PR introduce _any_ user-facing change?
- Yes, see the updated table in
[functions.py](https://github.com/apache/spark/compare/benrobby:enable-arrow-cast).
TLDR: this change is additive, it does not break workloads. It makes some
pandas -> arrow conversions more lenient. We now support:
- int <-> decimal
- float <-> decimal
- string with numbers <-> int,uint,float
Affected UDF types:
- Eval types that already had arrow_cast enabled before this PR:
- `SQL_ARROW_TABLE_UDF`
- `SQL_ARROW_BATCHED_UDF`
- All pandas_udf eval types adopt arrow_cast=True with this PR:
- `SQL_SCALAR_PANDAS_UDF`
- `SQL_SCALAR_PANDAS_ITER_UDF`
- `SQL_GROUPED_MAP_PANDAS_UDF`
- `SQL_MAP_PANDAS_ITER_UDF`
- `SQL_GROUPED_AGG_PANDAS_UDF`
- `SQL_WINDOW_AGG_PANDAS_UDF`
- `SQL_ARROW_TABLE_UDF`
- `SQL_COGROUPED_MAP_PANDAS_UDF`
- `SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE`
- `SQL_TRANSFORM_WITH_STATE_PANDAS_UDF`
- unaffected:
- Batched UDFs (useArrow=False)
- All other pure arrow UDFs (`SQL_SCALAR_ARROW_UDF`,
`SQL_SCALAR_ARROW_ITER_UDF`, `SQL_GROUPED_AGG_ARROW_UDF`,
`SQL_GROUPED_MAP_ARROW_UDF`, `SQL_COGROUPED_MAP_ARROW_UDF`). For UDFs returning
arrow data directly, the expectation is that users supply exactly the right
types.
### How was this patch tested?
- added unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #51635 from benrobby/enable-arrow-cast.
Authored-by: Ben Hurdelhey <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/functions.py | 40 +++++++++---------
python/pyspark/sql/pandas/serializers.py | 2 +
.../sql/tests/pandas/test_pandas_cogrouped_map.py | 2 +-
.../sql/tests/pandas/test_pandas_grouped_map.py | 47 +++++++++++++++++++++-
python/pyspark/sql/tests/pandas/test_pandas_map.py | 46 +++++++++++++--------
python/pyspark/sql/tests/pandas/test_pandas_udf.py | 35 +++++++---------
.../tests/pandas/test_pandas_udf_grouped_agg.py | 43 ++++++++++++++++++++
.../sql/tests/pandas/test_pandas_udf_scalar.py | 30 ++++++++++++++
.../sql/tests/pandas/test_pandas_udf_window.py | 45 +++++++++++++++++++++
python/pyspark/worker.py | 5 +--
10 files changed, 233 insertions(+), 62 deletions(-)
diff --git a/python/pyspark/sql/pandas/functions.py
b/python/pyspark/sql/pandas/functions.py
index 09e283ba21da..e1caf61d3b10 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -634,27 +634,27 @@ def pandas_udf(f=None, returnType=None,
functionType=None):
# The following table shows most of Pandas data and SQL type conversions
in Pandas UDFs that
# are not yet visible to the user. Some of behaviors are buggy and might
be changed in the near
# future. The table might have to be eventually documented externally.
- # Please see SPARK-28132's PR to see the codes in order to generate the
table below.
+ # Please see SPARK-52943's PR to see the codes in order to generate the
table below.
#
- #
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-
[...]
- # |SQL Type \ Pandas Value(Type)|None(object(NoneType))|
True(bool)| 1(int8)| 1(int16)| 1(int32)|
1(int64)| 1(uint8)| 1(uint16)| 1(uint32)|
1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01
00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns,
US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2
3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|
[...]
- #
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-
[...]
- # | boolean| None|
True| True| True| True|
True| True| True| True|
True| True| True| True|
X| X| X|
X| X| X|
X| X| [...]
- # | tinyint| None|
1| 1| 1| 1|
1| 1| 1| 1| 1|
1| 1| 1| X|
X| X|
1| X| X| X|
X| [...]
- # | smallint| None|
1| 1| 1| 1|
1| 1| 1| 1| 1|
1| 1| 1| X|
X| X|
1| X| X| X|
X| [...]
- # | int| None|
1| 1| 1| 1|
1| 1| 1| 1| 1|
1| 1| 1| X|
X| X|
1| X| X| X|
X| [...]
- # | bigint| None|
1| 1| 1| 1|
1| 1| 1| 1| 1|
1| 1| 1| 0|
18000000000000| X|
1| X| X| X|
X| [...]
- # | float| None|
1.0| 1.0| 1.0| 1.0|
1.0| 1.0| 1.0| 1.0|
1.0| 1.0| 1.0| 1.0|
X| X| X|
X| X| X|
X| X| [...]
- # | double| None|
1.0| 1.0| 1.0| 1.0|
1.0| 1.0| 1.0| 1.0|
1.0| 1.0| 1.0| 1.0|
X| X| X|
X| X| X|
X| X| [...]
- # | date| None|
X| X| X|datetime.date(197...|
X| X| X| X| X|
X| X| X| datetime.date(197...|
datetime.date(197...|
X|datetime.date(197...| X| X|
X| X| [...]
- # | timestamp| None|
X| X| X|
X|datetime.datetime...| X| X|
X| X| X| X| X|
datetime.datetime...| datetime.datetime...|
X|datetime.datetime...| X|
X| X| X| [...]
- # | string| None|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| 'a'|
X| X| X| X|
X| [...]
- # | decimal(10,0)| None|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| X|
Decimal('1')| X| X| X|
X| [...]
- # | array<int>| None|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| X|
X| [1, 2, 3]| X| X|
X| [...]
- # | map<string,int>| None|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| X|
X| X| X| X|
X| [...]
- # | struct<_1:int>| X|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| X|
X| X| X| X|
X| [...]
- # | binary|
None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|
bytearray(b'\x01')|
bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')|
bytearray(b'')|
bytearray(b'')| bytearray(b'a')| X|
X|bytearray(b'')| bytearray(b'')| bytearray(b'')|b [...]
- #
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-
[...]
+ #
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+-
[...]
+ # |SQL Type \ Pandas Value(Type)|None(object(NoneType))|
True(bool)| 1(int8)| 1(int16)| 1(int32)|
1(int64)| 1(uint8)| 1(uint16)| 1(uint32)|
1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01
00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns,
US/Eastern])|a(object(string))|12(object(string))| 1(object(Decimal))|[1 2
3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|( [...]
+ #
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+-
[...]
+ # | boolean| None|
True| True| True| True|
True| True| True| True|
True| True| True| True|
X| X| X|
X| X| X|
X| X| [...]
+ # | tinyint| None|
1| 1| 1| 1|
1| 1| 1| 1| 1|
1| 1| 1| X|
X| X|
12| 1| X| X|
X| [...]
+ # | smallint| None|
1| 1| 1| 1|
1| 1| 1| 1| 1|
1| 1| 1| X|
X| X|
12| 1| X| X|
X| [...]
+ # | int| None|
1| 1| 1| 1|
1| 1| 1| 1| 1|
1| 1| 1| X|
X| X|
12| 1| X| X|
X| [...]
+ # | bigint| None|
1| 1| 1| 1|
1| 1| 1| 1| 1|
1| 1| 1| 0|
18000000000000| X|
12| 1| X| X|
X| [...]
+ # | float| None|
1.0| 1.0| 1.0| 1.0|
1.0| 1.0| 1.0| 1.0|
1.0| 1.0| 1.0| 1.0|
X| X| X|
12.0| 1.0| X|
X| X| [...]
+ # | double| None|
1.0| 1.0| 1.0| 1.0|
1.0| 1.0| 1.0| 1.0|
1.0| 1.0| 1.0| 1.0|
X| X| X|
12.0| 1.0| X|
X| X| [...]
+ # | date| None|
X| X| X|datetime.date(197...|
X| X| X| X| X|
X| X| X| datetime.date(197...|
datetime.date(197...| X|
X|datetime.date(197...| X| X|
X| [...]
+ # | timestamp| None|
X| X| X|
X|datetime.datetime...| X| X|
X| X| X| X| X|
datetime.datetime...| datetime.datetime...|
X| X|datetime.datetime...|
X| X| X| [...]
+ # | string| None|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| 'a'|
'12'| X| X| X|
X| [...]
+ # | decimal(10,0)| None|
X| Decimal('1')| Decimal('1')| Decimal('1')|
X| Decimal('1')| Decimal('1')| Decimal('1')| X|
Decimal('1')| Decimal('1')| Decimal('1')| X|
X| X|
X| Decimal('1')| X| X|
X| [...]
+ # | array<int>| None|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| X|
[1, 2]| X| [1, 2, 3]| X|
X| [...]
+ # | map<string,int>| None|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| X|
X| X| X| X|
X| [...]
+ # | struct<_1:int>| X|
X| X| X| X|
X| X| X| X| X|
X| X| X| X|
X| X|
X| X| X| X|
X| [...]
+ # | binary|
None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|
bytearray(b'\x01')|
bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')|
bytearray(b'')|
bytearray(b'')| bytearray(b'a')| bytearray(b'12')| X|
X|bytearray(b'')| bytearray(b'')| [...]
+ #
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+-
[...]
#
# Note: DDL formatted string is used for 'SQL Type' for simplicity. This
string can be
# used in `returnType`.
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index aefeea226596..769f5e043a77 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1103,6 +1103,7 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
safecheck,
assign_cols_by_name,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ arrow_cast=True,
)
self.pickleSer = CPickleSerializer()
self.utf8_deserializer = UTF8Deserializer()
@@ -1483,6 +1484,7 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
safecheck,
assign_cols_by_name,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ arrow_cast=True,
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch
self.key_offsets = None
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 67398a46bce8..d23252abf6a9 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -262,7 +262,7 @@ class CogroupedApplyInPandasTestsMixin:
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
self._test_merge_error(
- fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k":
["2.0"]}),
+ fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k":
["test_string"]}),
output_schema="id long, k double",
errorClass=PythonException,
error_message_regex=expected,
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 9965e2acc4b5..a7516bdd22b0 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -371,7 +371,7 @@ class GroupedApplyInPandasTestsMixin:
)
with self.assertRaisesRegex(PythonException, expected +
"\n"):
self._test_apply_in_pandas(
- lambda key, pdf: pd.DataFrame([key +
(str(pdf.v.mean()),)]),
+ lambda key, pdf: pd.DataFrame([key +
("test_string",)]),
output_schema="id long, mean double",
)
@@ -900,6 +900,51 @@ class GroupedApplyInPandasTestsMixin:
with self.assertRaisesRegex(PythonException, error):
self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
+ def test_arrow_cast_enabled_numeric_to_decimal(self):
+ import numpy as np
+
+ columns = [
+ "int8",
+ "int16",
+ "int32",
+ "uint8",
+ "uint16",
+ "uint32",
+ "float64",
+ ]
+
+ pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in
columns})
+ df = self.spark.range(2).repartition(1)
+
+ for column in columns:
+ with self.subTest(column=column):
+ v = pdf[column].iloc[:1]
+ schema_str = "id long, value decimal(10,0)"
+
+ @pandas_udf(schema_str, PandasUDFType.GROUPED_MAP)
+ def test(pdf):
+ return pdf.assign(**{"value": v})
+
+ row = df.groupby("id").apply(test).first()
+ res = row[1]
+ self.assertEqual(res, Decimal("1"))
+
+ def test_arrow_cast_enabled_str_to_numeric(self):
+ df = self.spark.range(2).repartition(1)
+
+ types = ["int", "long", "float", "double"]
+
+ for type_str in types:
+ with self.subTest(type=type_str):
+ schema_str = "id long, value " + type_str
+
+ @pandas_udf(schema_str, PandasUDFType.GROUPED_MAP)
+ def test(pdf):
+ return pdf.assign(value=pd.Series(["123"]))
+
+ row = df.groupby("id").apply(test).first()
+ self.assertEqual(row[1], 123)
+
class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin,
ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 7debe8035f61..b241b91e02a2 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -276,16 +276,17 @@ class MapInPandasTestsMixin:
self.check_dataframes_with_incompatible_types()
def check_dataframes_with_incompatible_types(self):
- def func(iterator):
- for pdf in iterator:
- yield pdf.assign(id=pdf["id"].apply(str))
-
for safely in [True, False]:
with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
{"spark.sql.execution.pandas.convertToArrowArraySafely":
safely}
):
# sometimes we see ValueErrors
with self.subTest(convert="string to double"):
+
+ def func(iterator):
+ for pdf in iterator:
+ yield pdf.assign(id="test_string")
+
expected = (
r"ValueError: Exception thrown when converting
pandas.Series "
r"\(object\) with name 'id' to Arrow Array \(double\)."
@@ -304,18 +305,31 @@ class MapInPandasTestsMixin:
.collect()
)
- # sometimes we see TypeErrors
- with self.subTest(convert="double to string"):
- with self.assertRaisesRegex(
- PythonException,
- r"TypeError: Exception thrown when converting
pandas.Series "
- r"\(float64\) with name 'id' to Arrow Array
\(string\).\n",
- ):
- (
- self.spark.range(10, numPartitions=3)
- .select(col("id").cast("double"))
- .mapInPandas(self.identity_dataframes_iter("id"),
"id string")
- .collect()
+ with self.subTest(convert="float to int precision loss"):
+
+ def func(iterator):
+ for pdf in iterator:
+ yield pdf.assign(id=pdf["id"] + 0.1)
+
+ df = (
+ self.spark.range(10, numPartitions=3)
+ .select(col("id").cast("double"))
+ .mapInPandas(func, "id int")
+ )
+ if safely:
+ expected = (
+ r"ValueError: Exception thrown when converting
pandas.Series "
+ r"\(float64\) with name 'id' to Arrow Array
\(int32\)."
+ " It can be caused by overflows or other "
+ "unsafe conversions warned by Arrow. Arrow safe
type check "
+ "can be disabled by using SQL config "
+
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
+ )
+ with self.assertRaisesRegex(PythonException, expected
+ "\n"):
+ df.collect()
+ else:
+ self.assertEqual(
+ df.collect(), self.spark.range(10,
numPartitions=3).collect()
)
def test_empty_iterator(self):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index d67bed462345..2fb802a60df0 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -376,27 +376,20 @@ class PandasUDFTestsMixin:
values = [1, 2, 3]
return pd.Series([values[int(val) % len(values)] for val in
column])
- with self.sql_conf(
-
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
- ):
- result = df.withColumn("decimal_val",
high_precision_udf("id")).collect()
- self.assertEqual(len(result), 3)
- self.assertEqual(result[0]["decimal_val"], Decimal("1.0"))
- self.assertEqual(result[1]["decimal_val"], Decimal("2.0"))
- self.assertEqual(result[2]["decimal_val"], Decimal("3.0"))
-
- with self.sql_conf(
-
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
- ):
- # Also not supported.
- # This can be fixed by enabling arrow_cast
- # This is currently not the case for SQL_SCALAR_PANDAS_UDF and
- # SQL_SCALAR_PANDAS_ITER_UDF.
- self.assertRaisesRegex(
- PythonException,
- "Exception thrown when converting pandas.Series",
- df.withColumn("decimal_val", high_precision_udf("id")).collect,
- )
+ for intToDecimalCoercionEnabled in [True, False]:
+ # arrow_cast is enabled by default for SQL_SCALAR_PANDAS_UDF and
+ # and SQL_SCALAR_PANDAS_ITER_UDF, arrow can do this cast safely.
+ # intToDecimalCoercionEnabled is not required for this case
+ with self.sql_conf(
+ {
+
"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled":
intToDecimalCoercionEnabled # noqa: E501
+ }
+ ):
+ result = df.withColumn("decimal_val",
high_precision_udf("id")).collect()
+ self.assertEqual(len(result), 3)
+ self.assertEqual(result[0]["decimal_val"], Decimal("1.0"))
+ self.assertEqual(result[1]["decimal_val"], Decimal("2.0"))
+ self.assertEqual(result[2]["decimal_val"], Decimal("3.0"))
def test_pandas_udf_timestamp_ntz(self):
# SPARK-36626: Test TimestampNTZ in pandas UDF
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index e22b8f9ccacc..1059af59f4a8 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -718,6 +718,49 @@ class GroupedAggPandasUDFTestsMixin:
aggregated, df.groupby("id").agg((sum(df.v) +
sum(df.w)).alias("s"))
)
+ def test_arrow_cast_enabled_numeric_to_decimal(self):
+ import numpy as np
+ from decimal import Decimal
+
+ columns = [
+ "int8",
+ "int16",
+ "int32",
+ "uint8",
+ "uint16",
+ "uint32",
+ "float64",
+ ]
+
+ pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in
columns})
+ df = self.spark.range(2).repartition(1)
+
+ for column in columns:
+ with self.subTest(column=column):
+
+ @pandas_udf("decimal(10,0)", PandasUDFType.GROUPED_AGG)
+ def test(series):
+ return pdf[column].iloc[0]
+
+ row = df.groupby("id").agg(test(df.id)).first()
+ res = row[1]
+ self.assertEqual(res, Decimal("1"))
+
+ def test_arrow_cast_enabled_str_to_numeric(self):
+ df = self.spark.range(2).repartition(1)
+
+ types = ["int", "long", "float", "double"]
+
+ for type_str in types:
+ with self.subTest(type=type_str):
+
+ @pandas_udf(type_str, PandasUDFType.GROUPED_AGG)
+ def test(series):
+ return 123
+
+ row = df.groupby("id").agg(test(df.id)).first()
+ self.assertEqual(row[1], 123)
+
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin,
ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index efd9ae0eb9f5..e614d9039b61 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -1875,6 +1875,36 @@ class ScalarPandasUDFTestsMixin:
with self.subTest(with_b=True, query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])
+ def test_arrow_cast_enabled_numeric_to_decimal(self):
+ import numpy as np
+
+ columns = [
+ "int8",
+ "int16",
+ "int32",
+ "uint8",
+ "uint16",
+ "uint32",
+ "float64",
+ ]
+
+ pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in
columns})
+ df = self.spark.range(2).repartition(1)
+
+ t = DecimalType(10, 0)
+ for column in columns:
+ with self.subTest(column=column):
+ v = pdf[column].iloc[:1]
+ row = df.select(pandas_udf(lambda _: v, t)(df.id)).first()
+ assert (row[0] == v).all()
+
+ def test_arrow_cast_enabled_str_to_numeric(self):
+ df = self.spark.range(2).repartition(1)
+ for t in [IntegerType(), LongType(), FloatType(), DoubleType()]:
+ with self.subTest(type=t):
+ row = df.select(pandas_udf(lambda _: pd.Series(["123"]),
t)(df.id)).first()
+ assert row[0] == 123
+
class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 9b3673d80d22..2f534b811b34 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -17,6 +17,7 @@
import unittest
from typing import cast
+from decimal import Decimal
from pyspark.errors import AnalysisException, PythonException
from pyspark.sql.functions import (
@@ -33,6 +34,13 @@ from pyspark.sql.functions import (
PandasUDFType,
)
from pyspark.sql.window import Window
+from pyspark.sql.types import (
+ DecimalType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType,
+)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
@@ -563,6 +571,43 @@ class WindowPandasUDFTestsMixin:
)
).show()
+ def test_arrow_cast_numeric_to_decimal(self):
+ import numpy as np
+ import pandas as pd
+
+ columns = [
+ "int8",
+ "int16",
+ "int32",
+ "uint8",
+ "uint16",
+ "uint32",
+ "float64",
+ ]
+
+ pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in
columns})
+ df = self.data
+ w = self.unbounded_window
+
+ t = DecimalType(10, 0)
+ for column in columns:
+ with self.subTest(column=column):
+ value = pdf[column].iloc[0]
+ mean_udf = pandas_udf(lambda v: value, t,
PandasUDFType.GROUPED_AGG)
+ result = df.select(mean_udf(df["v"]).over(w)).first()[0]
+ assert result == Decimal("1.0")
+ assert type(result) == Decimal
+
+ def test_arrow_cast_str_to_numeric(self):
+ df = self.data
+ w = self.unbounded_window
+
+ for t in [IntegerType(), LongType(), FloatType(), DoubleType()]:
+ with self.subTest(type=t):
+ mean_udf = pandas_udf(lambda v: "123", t,
PandasUDFType.GROUPED_AGG)
+ result = df.select(mean_udf(df["v"]).over(w)).first()[0]
+ assert result == 123
+
class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 9517a72a7a70..207c2a999571 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -2285,6 +2285,7 @@ def read_udfs(pickleSer, infile, eval_type):
safecheck,
_assign_cols_by_name,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ arrow_cast=True,
)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
arrow_max_records_per_batch = runner_conf.get(
@@ -2374,8 +2375,6 @@ def read_udfs(pickleSer, infile, eval_type):
"row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
else "dict"
)
ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
- # Arrow-optimized Python UDF uses explicit Arrow cast for type
coercion
- arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
# Arrow-optimized Python UDF takes input types
input_types = (
[f.dataType for f in
_parse_datatype_json_string(utf8_deserializer.loads(infile))]
@@ -2390,7 +2389,7 @@ def read_udfs(pickleSer, infile, eval_type):
df_for_struct,
struct_in_pandas,
ndarray_as_list,
- arrow_cast,
+ True,
input_types,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]