This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f6cd38598078 [SPARK-52943][PYTHON] Enable arrow_cast for all pandas UDF eval types f6cd38598078 is described below commit f6cd38598078511fdafd969d31ced14bc4811bb0 Author: Ben Hurdelhey <ben.hurdel...@databricks.com> AuthorDate: Wed Aug 6 21:01:07 2025 +0800 [SPARK-52943][PYTHON] Enable arrow_cast for all pandas UDF eval types ### What changes were proposed in this pull request? - this enables arrow_cast for all pandas_udfs - arrow_cast=True provides a coherent type coercion behavior, it is a bit more lenient for mismatched types - arrow_cast was originally introduced in https://github.com/apache/spark/pull/41800, but up until now it only applied to a subset of `udf` and `pandas_udf` eval types (see below) - this should have no performance impact as the cast is only done in a second attempt when the pandas->arrow conversion fails. ### Why are the changes needed? - this aligns `pandas_udf()` behavior with `udf(useArrow=True)` behavior, it makes PySpark more consistent ### Does this PR introduce _any_ user-facing change? - Yes, see the updated table in [functions.py](https://github.com/apache/spark/compare/benrobby:enable-arrow-cast). TLDR: this change is additive, it does not break workloads. It makes some pandas -> arrow conversions more lenient. We now support: - int <-> decimal - float <-> decimal - string with numbers <-> int,uint,float Affected UDF types: - Eval types that already had arrow_cast enabled before this PR: - `SQL_ARROW_TABLE_UDF` - `SQL_ARROW_BATCHED_UDF` - All pandas_udf eval types adopt arrow_cast=True with this PR: - `SQL_SCALAR_PANDAS_UDF` - `SQL_SCALAR_PANDAS_ITER_UDF` - `SQL_GROUPED_MAP_PANDAS_UDF` - `SQL_MAP_PANDAS_ITER_UDF` - `SQL_GROUPED_AGG_PANDAS_UDF` - `SQL_WINDOW_AGG_PANDAS_UDF` - `SQL_ARROW_TABLE_UDF` - `SQL_COGROUPED_MAP_PANDAS_UDF` - `SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` - `SQL_TRANSFORM_WITH_STATE_PANDAS_UDF` - unaffected: - Batched UDFs (useArrow=False) - All other pure arrow UDFs (`SQL_SCALAR_ARROW_UDF`, `SQL_SCALAR_ARROW_ITER_UDF`, `SQL_GROUPED_AGG_ARROW_UDF`, `SQL_GROUPED_MAP_ARROW_UDF`, `SQL_COGROUPED_MAP_ARROW_UDF`). For UDFs returning arrow data directly, the expectation is that users supply exactly the right types. ### How was this patch tested? - added unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #51635 from benrobby/enable-arrow-cast. Authored-by: Ben Hurdelhey <ben.hurdel...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/pandas/functions.py | 40 +++++++++--------- python/pyspark/sql/pandas/serializers.py | 2 + .../sql/tests/pandas/test_pandas_cogrouped_map.py | 2 +- .../sql/tests/pandas/test_pandas_grouped_map.py | 47 +++++++++++++++++++++- python/pyspark/sql/tests/pandas/test_pandas_map.py | 46 +++++++++++++-------- python/pyspark/sql/tests/pandas/test_pandas_udf.py | 35 +++++++--------- .../tests/pandas/test_pandas_udf_grouped_agg.py | 43 ++++++++++++++++++++ .../sql/tests/pandas/test_pandas_udf_scalar.py | 30 ++++++++++++++ .../sql/tests/pandas/test_pandas_udf_window.py | 45 +++++++++++++++++++++ python/pyspark/worker.py | 5 +-- 10 files changed, 233 insertions(+), 62 deletions(-) diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 09e283ba21da..e1caf61d3b10 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -634,27 +634,27 @@ def pandas_udf(f=None, returnType=None, functionType=None): # The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that # are not yet visible to the user. Some of behaviors are buggy and might be changed in the near # future. The table might have to be eventually documented externally. - # Please see SPARK-28132's PR to see the codes in order to generate the table below. + # Please see SPARK-52943's PR to see the codes in order to generate the table below. # - # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+- [...] - # |SQL Type \ Pandas Value(Type)|None(object(NoneType))| True(bool)| 1(int8)| 1(int16)| 1(int32)| 1(int64)| 1(uint8)| 1(uint16)| 1(uint32)| 1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)| [...] - # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+- [...] - # | boolean| None| True| True| True| True| True| True| True| True| True| True| True| True| X| X| X| X| X| X| X| X| [...] - # | tinyint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| [...] - # | smallint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| [...] - # | int| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| [...] - # | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 0| 18000000000000| X| 1| X| X| X| X| [...] - # | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| [...] - # | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| [...] - # | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| datetime.date(197...| X|datetime.date(197...| X| X| X| X| [...] - # | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X|datetime.datetime...| X| X| X| X| [...] - # | string| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| 'a'| X| X| X| X| X| [...] - # | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| [...] - # | array<int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| [...] - # | map<string,int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [...] - # | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [...] - # | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| X| X|bytearray(b'')| bytearray(b'')| bytearray(b'')|b [...] - # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+- [...] + # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+- [...] + # |SQL Type \ Pandas Value(Type)|None(object(NoneType))| True(bool)| 1(int8)| 1(int16)| 1(int32)| 1(int64)| 1(uint8)| 1(uint16)| 1(uint32)| 1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))|12(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|( [...] + # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+- [...] + # | boolean| None| True| True| True| True| True| True| True| True| True| True| True| True| X| X| X| X| X| X| X| X| [...] + # | tinyint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 12| 1| X| X| X| [...] + # | smallint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 12| 1| X| X| X| [...] + # | int| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 12| 1| X| X| X| [...] + # | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 0| 18000000000000| X| 12| 1| X| X| X| [...] + # | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| 12.0| 1.0| X| X| X| [...] + # | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| 12.0| 1.0| X| X| X| [...] + # | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| datetime.date(197...| X| X|datetime.date(197...| X| X| X| [...] + # | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X| X|datetime.datetime...| X| X| X| [...] + # | string| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| 'a'| '12'| X| X| X| X| [...] + # | decimal(10,0)| None| X| Decimal('1')| Decimal('1')| Decimal('1')| X| Decimal('1')| Decimal('1')| Decimal('1')| X| Decimal('1')| Decimal('1')| Decimal('1')| X| X| X| X| Decimal('1')| X| X| X| [...] + # | array<int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2]| X| [1, 2, 3]| X| X| [...] + # | map<string,int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [...] + # | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [...] + # | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| bytearray(b'12')| X| X|bytearray(b'')| bytearray(b'')| [...] + # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+- [...] # # Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be # used in `returnType`. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index aefeea226596..769f5e043a77 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1103,6 +1103,7 @@ class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): safecheck, assign_cols_by_name, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + arrow_cast=True, ) self.pickleSer = CPickleSerializer() self.utf8_deserializer = UTF8Deserializer() @@ -1483,6 +1484,7 @@ class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): safecheck, assign_cols_by_name, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + arrow_cast=True, ) self.arrow_max_records_per_batch = arrow_max_records_per_batch self.key_offsets = None diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py index 67398a46bce8..d23252abf6a9 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py @@ -262,7 +262,7 @@ class CogroupedApplyInPandasTestsMixin: "`spark.sql.execution.pandas.convertToArrowArraySafely`." ) self._test_merge_error( - fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["2.0"]}), + fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["test_string"]}), output_schema="id long, k double", errorClass=PythonException, error_message_regex=expected, diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 9965e2acc4b5..a7516bdd22b0 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -371,7 +371,7 @@ class GroupedApplyInPandasTestsMixin: ) with self.assertRaisesRegex(PythonException, expected + "\n"): self._test_apply_in_pandas( - lambda key, pdf: pd.DataFrame([key + (str(pdf.v.mean()),)]), + lambda key, pdf: pd.DataFrame([key + ("test_string",)]), output_schema="id long, mean double", ) @@ -900,6 +900,51 @@ class GroupedApplyInPandasTestsMixin: with self.assertRaisesRegex(PythonException, error): self._test_apply_in_pandas_returning_empty_dataframe(empty_df) + def test_arrow_cast_enabled_numeric_to_decimal(self): + import numpy as np + + columns = [ + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + "float64", + ] + + pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns}) + df = self.spark.range(2).repartition(1) + + for column in columns: + with self.subTest(column=column): + v = pdf[column].iloc[:1] + schema_str = "id long, value decimal(10,0)" + + @pandas_udf(schema_str, PandasUDFType.GROUPED_MAP) + def test(pdf): + return pdf.assign(**{"value": v}) + + row = df.groupby("id").apply(test).first() + res = row[1] + self.assertEqual(res, Decimal("1")) + + def test_arrow_cast_enabled_str_to_numeric(self): + df = self.spark.range(2).repartition(1) + + types = ["int", "long", "float", "double"] + + for type_str in types: + with self.subTest(type=type_str): + schema_str = "id long, value " + type_str + + @pandas_udf(schema_str, PandasUDFType.GROUPED_MAP) + def test(pdf): + return pdf.assign(value=pd.Series(["123"])) + + row = df.groupby("id").apply(test).first() + self.assertEqual(row[1], 123) + class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index 7debe8035f61..b241b91e02a2 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -276,16 +276,17 @@ class MapInPandasTestsMixin: self.check_dataframes_with_incompatible_types() def check_dataframes_with_incompatible_types(self): - def func(iterator): - for pdf in iterator: - yield pdf.assign(id=pdf["id"].apply(str)) - for safely in [True, False]: with self.subTest(convertToArrowArraySafely=safely), self.sql_conf( {"spark.sql.execution.pandas.convertToArrowArraySafely": safely} ): # sometimes we see ValueErrors with self.subTest(convert="string to double"): + + def func(iterator): + for pdf in iterator: + yield pdf.assign(id="test_string") + expected = ( r"ValueError: Exception thrown when converting pandas.Series " r"\(object\) with name 'id' to Arrow Array \(double\)." @@ -304,18 +305,31 @@ class MapInPandasTestsMixin: .collect() ) - # sometimes we see TypeErrors - with self.subTest(convert="double to string"): - with self.assertRaisesRegex( - PythonException, - r"TypeError: Exception thrown when converting pandas.Series " - r"\(float64\) with name 'id' to Arrow Array \(string\).\n", - ): - ( - self.spark.range(10, numPartitions=3) - .select(col("id").cast("double")) - .mapInPandas(self.identity_dataframes_iter("id"), "id string") - .collect() + with self.subTest(convert="float to int precision loss"): + + def func(iterator): + for pdf in iterator: + yield pdf.assign(id=pdf["id"] + 0.1) + + df = ( + self.spark.range(10, numPartitions=3) + .select(col("id").cast("double")) + .mapInPandas(func, "id int") + ) + if safely: + expected = ( + r"ValueError: Exception thrown when converting pandas.Series " + r"\(float64\) with name 'id' to Arrow Array \(int32\)." + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + with self.assertRaisesRegex(PythonException, expected + "\n"): + df.collect() + else: + self.assertEqual( + df.collect(), self.spark.range(10, numPartitions=3).collect() ) def test_empty_iterator(self): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index d67bed462345..2fb802a60df0 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -376,27 +376,20 @@ class PandasUDFTestsMixin: values = [1, 2, 3] return pd.Series([values[int(val) % len(values)] for val in column]) - with self.sql_conf( - {"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True} - ): - result = df.withColumn("decimal_val", high_precision_udf("id")).collect() - self.assertEqual(len(result), 3) - self.assertEqual(result[0]["decimal_val"], Decimal("1.0")) - self.assertEqual(result[1]["decimal_val"], Decimal("2.0")) - self.assertEqual(result[2]["decimal_val"], Decimal("3.0")) - - with self.sql_conf( - {"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False} - ): - # Also not supported. - # This can be fixed by enabling arrow_cast - # This is currently not the case for SQL_SCALAR_PANDAS_UDF and - # SQL_SCALAR_PANDAS_ITER_UDF. - self.assertRaisesRegex( - PythonException, - "Exception thrown when converting pandas.Series", - df.withColumn("decimal_val", high_precision_udf("id")).collect, - ) + for intToDecimalCoercionEnabled in [True, False]: + # arrow_cast is enabled by default for SQL_SCALAR_PANDAS_UDF and + # and SQL_SCALAR_PANDAS_ITER_UDF, arrow can do this cast safely. + # intToDecimalCoercionEnabled is not required for this case + with self.sql_conf( + { + "spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": intToDecimalCoercionEnabled # noqa: E501 + } + ): + result = df.withColumn("decimal_val", high_precision_udf("id")).collect() + self.assertEqual(len(result), 3) + self.assertEqual(result[0]["decimal_val"], Decimal("1.0")) + self.assertEqual(result[1]["decimal_val"], Decimal("2.0")) + self.assertEqual(result[2]["decimal_val"], Decimal("3.0")) def test_pandas_udf_timestamp_ntz(self): # SPARK-36626: Test TimestampNTZ in pandas UDF diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py index e22b8f9ccacc..1059af59f4a8 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py @@ -718,6 +718,49 @@ class GroupedAggPandasUDFTestsMixin: aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s")) ) + def test_arrow_cast_enabled_numeric_to_decimal(self): + import numpy as np + from decimal import Decimal + + columns = [ + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + "float64", + ] + + pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns}) + df = self.spark.range(2).repartition(1) + + for column in columns: + with self.subTest(column=column): + + @pandas_udf("decimal(10,0)", PandasUDFType.GROUPED_AGG) + def test(series): + return pdf[column].iloc[0] + + row = df.groupby("id").agg(test(df.id)).first() + res = row[1] + self.assertEqual(res, Decimal("1")) + + def test_arrow_cast_enabled_str_to_numeric(self): + df = self.spark.range(2).repartition(1) + + types = ["int", "long", "float", "double"] + + for type_str in types: + with self.subTest(type=type_str): + + @pandas_udf(type_str, PandasUDFType.GROUPED_AGG) + def test(series): + return 123 + + row = df.groupby("id").agg(test(df.id)).first() + self.assertEqual(row[1], 123) + class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index efd9ae0eb9f5..e614d9039b61 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -1875,6 +1875,36 @@ class ScalarPandasUDFTestsMixin: with self.subTest(with_b=True, query_no=i): assertDataFrameEqual(df, [Row(0), Row(101)]) + def test_arrow_cast_enabled_numeric_to_decimal(self): + import numpy as np + + columns = [ + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + "float64", + ] + + pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns}) + df = self.spark.range(2).repartition(1) + + t = DecimalType(10, 0) + for column in columns: + with self.subTest(column=column): + v = pdf[column].iloc[:1] + row = df.select(pandas_udf(lambda _: v, t)(df.id)).first() + assert (row[0] == v).all() + + def test_arrow_cast_enabled_str_to_numeric(self): + df = self.spark.range(2).repartition(1) + for t in [IntegerType(), LongType(), FloatType(), DoubleType()]: + with self.subTest(type=t): + row = df.select(pandas_udf(lambda _: pd.Series(["123"]), t)(df.id)).first() + assert row[0] == 123 + class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py index 9b3673d80d22..2f534b811b34 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py @@ -17,6 +17,7 @@ import unittest from typing import cast +from decimal import Decimal from pyspark.errors import AnalysisException, PythonException from pyspark.sql.functions import ( @@ -33,6 +34,13 @@ from pyspark.sql.functions import ( PandasUDFType, ) from pyspark.sql.window import Window +from pyspark.sql.types import ( + DecimalType, + IntegerType, + LongType, + FloatType, + DoubleType, +) from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -563,6 +571,43 @@ class WindowPandasUDFTestsMixin: ) ).show() + def test_arrow_cast_numeric_to_decimal(self): + import numpy as np + import pandas as pd + + columns = [ + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + "float64", + ] + + pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns}) + df = self.data + w = self.unbounded_window + + t = DecimalType(10, 0) + for column in columns: + with self.subTest(column=column): + value = pdf[column].iloc[0] + mean_udf = pandas_udf(lambda v: value, t, PandasUDFType.GROUPED_AGG) + result = df.select(mean_udf(df["v"]).over(w)).first()[0] + assert result == Decimal("1.0") + assert type(result) == Decimal + + def test_arrow_cast_str_to_numeric(self): + df = self.data + w = self.unbounded_window + + for t in [IntegerType(), LongType(), FloatType(), DoubleType()]: + with self.subTest(type=t): + mean_udf = pandas_udf(lambda v: "123", t, PandasUDFType.GROUPED_AGG) + result = df.select(mean_udf(df["v"]).over(w)).first()[0] + assert result == 123 + class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9517a72a7a70..207c2a999571 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2285,6 +2285,7 @@ def read_udfs(pickleSer, infile, eval_type): safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + arrow_cast=True, ) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: arrow_max_records_per_batch = runner_conf.get( @@ -2374,8 +2375,6 @@ def read_udfs(pickleSer, infile, eval_type): "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict" ) ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF - # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion - arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF # Arrow-optimized Python UDF takes input types input_types = ( [f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile))] @@ -2390,7 +2389,7 @@ def read_udfs(pickleSer, infile, eval_type): df_for_struct, struct_in_pandas, ndarray_as_list, - arrow_cast, + True, input_types, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org