This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f6cd38598078 [SPARK-52943][PYTHON] Enable arrow_cast for all pandas 
UDF eval types
f6cd38598078 is described below

commit f6cd38598078511fdafd969d31ced14bc4811bb0
Author: Ben Hurdelhey <ben.hurdel...@databricks.com>
AuthorDate: Wed Aug 6 21:01:07 2025 +0800

    [SPARK-52943][PYTHON] Enable arrow_cast for all pandas UDF eval types
    
    ### What changes were proposed in this pull request?
    - this enables arrow_cast for all pandas_udfs
    - arrow_cast=True provides a coherent type coercion behavior, it is a bit 
more lenient for mismatched types
    - arrow_cast was originally introduced in 
https://github.com/apache/spark/pull/41800, but up until now it only applied to 
a subset of `udf` and `pandas_udf` eval types (see below)
    - this should have no performance impact as the cast is only done in a 
second attempt when the pandas->arrow conversion fails.
    
    ### Why are the changes needed?
    - this aligns `pandas_udf()` behavior with `udf(useArrow=True)` behavior, 
it makes PySpark more consistent
    
    ### Does this PR introduce _any_ user-facing change?
    
    - Yes, see the updated table in 
[functions.py](https://github.com/apache/spark/compare/benrobby:enable-arrow-cast).
 TLDR: this change is additive, it does not break workloads. It makes some 
pandas -> arrow conversions more lenient. We now support:
      - int <-> decimal
      - float <-> decimal
      - string with numbers <-> int,uint,float
    
    Affected UDF types:
    - Eval types that already had arrow_cast enabled before this PR:
      - `SQL_ARROW_TABLE_UDF`
      - `SQL_ARROW_BATCHED_UDF`
    - All pandas_udf eval types adopt arrow_cast=True with this PR:
      - `SQL_SCALAR_PANDAS_UDF`
      - `SQL_SCALAR_PANDAS_ITER_UDF`
      - `SQL_GROUPED_MAP_PANDAS_UDF`
      - `SQL_MAP_PANDAS_ITER_UDF`
      - `SQL_GROUPED_AGG_PANDAS_UDF`
      - `SQL_WINDOW_AGG_PANDAS_UDF`
      - `SQL_ARROW_TABLE_UDF`
      - `SQL_COGROUPED_MAP_PANDAS_UDF`
      - `SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE`
      - `SQL_TRANSFORM_WITH_STATE_PANDAS_UDF`
    - unaffected:
      - Batched UDFs (useArrow=False)
      - All other pure arrow UDFs (`SQL_SCALAR_ARROW_UDF`, 
`SQL_SCALAR_ARROW_ITER_UDF`, `SQL_GROUPED_AGG_ARROW_UDF`, 
`SQL_GROUPED_MAP_ARROW_UDF`, `SQL_COGROUPED_MAP_ARROW_UDF`). For UDFs returning 
arrow data directly, the expectation is that users supply exactly the right 
types.
    
    ### How was this patch tested?
    
    - added unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51635 from benrobby/enable-arrow-cast.
    
    Authored-by: Ben Hurdelhey <ben.hurdel...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/pandas/functions.py             | 40 +++++++++---------
 python/pyspark/sql/pandas/serializers.py           |  2 +
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  |  2 +-
 .../sql/tests/pandas/test_pandas_grouped_map.py    | 47 +++++++++++++++++++++-
 python/pyspark/sql/tests/pandas/test_pandas_map.py | 46 +++++++++++++--------
 python/pyspark/sql/tests/pandas/test_pandas_udf.py | 35 +++++++---------
 .../tests/pandas/test_pandas_udf_grouped_agg.py    | 43 ++++++++++++++++++++
 .../sql/tests/pandas/test_pandas_udf_scalar.py     | 30 ++++++++++++++
 .../sql/tests/pandas/test_pandas_udf_window.py     | 45 +++++++++++++++++++++
 python/pyspark/worker.py                           |  5 +--
 10 files changed, 233 insertions(+), 62 deletions(-)

diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index 09e283ba21da..e1caf61d3b10 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -634,27 +634,27 @@ def pandas_udf(f=None, returnType=None, 
functionType=None):
     # The following table shows most of Pandas data and SQL type conversions 
in Pandas UDFs that
     # are not yet visible to the user. Some of behaviors are buggy and might 
be changed in the near
     # future. The table might have to be eventually documented externally.
-    # Please see SPARK-28132's PR to see the codes in order to generate the 
table below.
+    # Please see SPARK-52943's PR to see the codes in order to generate the 
table below.
     #
-    # 
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-
 [...]
-    # |SQL Type \ Pandas Value(Type)|None(object(NoneType))|        
True(bool)|           1(int8)|          1(int16)|            1(int32)|          
  1(int64)|          1(uint8)|         1(uint16)|         1(uint32)|         
1(uint64)|  1.0(float16)|  1.0(float32)|  1.0(float64)|1970-01-01 
00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, 
US/Eastern])|a(object(string))|  1(object(Decimal))|[1 2 
3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|  
[...]
-    # 
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-
 [...]
-    # |                      boolean|                  None|              
True|              True|              True|                True|                
True|              True|              True|              True|              
True|          True|          True|          True|                              
    X|                                                    X|                X|  
                 X|                            X|             X|                
X|                 X|  [...]
-    # |                      tinyint|                  None|                 
1|                 1|                 1|                   1|                   
1|                 1|                 1|                 1|                 1|  
           1|             1|             1|                                  X| 
                                                   X|                X|         
          1|                            X|             X|                X|     
            X|  [...]
-    # |                     smallint|                  None|                 
1|                 1|                 1|                   1|                   
1|                 1|                 1|                 1|                 1|  
           1|             1|             1|                                  X| 
                                                   X|                X|         
          1|                            X|             X|                X|     
            X|  [...]
-    # |                          int|                  None|                 
1|                 1|                 1|                   1|                   
1|                 1|                 1|                 1|                 1|  
           1|             1|             1|                                  X| 
                                                   X|                X|         
          1|                            X|             X|                X|     
            X|  [...]
-    # |                       bigint|                  None|                 
1|                 1|                 1|                   1|                   
1|                 1|                 1|                 1|                 1|  
           1|             1|             1|                                  0| 
                                      18000000000000|                X|         
          1|                            X|             X|                X|     
            X|  [...]
-    # |                        float|                  None|               
1.0|               1.0|               1.0|                 1.0|                 
1.0|               1.0|               1.0|               1.0|               
1.0|           1.0|           1.0|           1.0|                               
   X|                                                    X|                X|   
                X|                            X|             X|                
X|                 X|  [...]
-    # |                       double|                  None|               
1.0|               1.0|               1.0|                 1.0|                 
1.0|               1.0|               1.0|               1.0|               
1.0|           1.0|           1.0|           1.0|                               
   X|                                                    X|                X|   
                X|                            X|             X|                
X|                 X|  [...]
-    # |                         date|                  None|                 
X|                 X|                 X|datetime.date(197...|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|               datetime.date(197...| 
                                datetime.date(197...|                
X|datetime.date(197...|                            X|             X|            
    X|                 X|  [...]
-    # |                    timestamp|                  None|                 
X|                 X|                 X|                   
X|datetime.datetime...|                 X|                 X|                 
X|                 X|             X|             X|             X|              
 datetime.datetime...|                                 datetime.datetime...|    
            X|datetime.datetime...|                            X|             
X|                X|                 X|  [...]
-    # |                       string|                  None|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|              'a'|         
          X|                            X|             X|                X|     
            X|  [...]
-    # |                decimal(10,0)|                  None|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|                X|        
Decimal('1')|                            X|             X|                X|    
             X|  [...]
-    # |                   array<int>|                  None|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|                X|         
          X|                    [1, 2, 3]|             X|                X|     
            X|  [...]
-    # |              map<string,int>|                  None|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|                X|         
          X|                            X|             X|                X|     
            X|  [...]
-    # |               struct<_1:int>|                     X|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|                X|         
          X|                            X|             X|                X|     
            X|  [...]
-    # |                       binary|                  
None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|  
bytearray(b'\x01')|  
bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')|
                     bytearray(b'')|                                       
bytearray(b'')|  bytearray(b'a')|                   X|                          
  X|bytearray(b'')|   bytearray(b'')|    bytearray(b'')|b [...]
-    # 
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-
 [...]
+    # 
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+-
 [...]
+    # |SQL Type \ Pandas Value(Type)|None(object(NoneType))|        
True(bool)|           1(int8)|          1(int16)|            1(int32)|          
  1(int64)|          1(uint8)|         1(uint16)|         1(uint32)|         
1(uint64)|  1.0(float16)|  1.0(float32)|  1.0(float64)|1970-01-01 
00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, 
US/Eastern])|a(object(string))|12(object(string))|  1(object(Decimal))|[1 2 
3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|( [...]
+    # 
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+-
 [...]
+    # |                      boolean|                  None|              
True|              True|              True|                True|                
True|              True|              True|              True|              
True|          True|          True|          True|                              
    X|                                                    X|                X|  
               X|                   X|                            X|            
 X|                X|  [...]
+    # |                      tinyint|                  None|                 
1|                 1|                 1|                   1|                   
1|                 1|                 1|                 1|                 1|  
           1|             1|             1|                                  X| 
                                                   X|                X|         
       12|                   1|                            X|             X|    
            X|  [...]
+    # |                     smallint|                  None|                 
1|                 1|                 1|                   1|                   
1|                 1|                 1|                 1|                 1|  
           1|             1|             1|                                  X| 
                                                   X|                X|         
       12|                   1|                            X|             X|    
            X|  [...]
+    # |                          int|                  None|                 
1|                 1|                 1|                   1|                   
1|                 1|                 1|                 1|                 1|  
           1|             1|             1|                                  X| 
                                                   X|                X|         
       12|                   1|                            X|             X|    
            X|  [...]
+    # |                       bigint|                  None|                 
1|                 1|                 1|                   1|                   
1|                 1|                 1|                 1|                 1|  
           1|             1|             1|                                  0| 
                                      18000000000000|                X|         
       12|                   1|                            X|             X|    
            X|  [...]
+    # |                        float|                  None|               
1.0|               1.0|               1.0|                 1.0|                 
1.0|               1.0|               1.0|               1.0|               
1.0|           1.0|           1.0|           1.0|                               
   X|                                                    X|                X|   
           12.0|                 1.0|                            X|             
X|                X|  [...]
+    # |                       double|                  None|               
1.0|               1.0|               1.0|                 1.0|                 
1.0|               1.0|               1.0|               1.0|               
1.0|           1.0|           1.0|           1.0|                               
   X|                                                    X|                X|   
           12.0|                 1.0|                            X|             
X|                X|  [...]
+    # |                         date|                  None|                 
X|                 X|                 X|datetime.date(197...|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|               datetime.date(197...| 
                                datetime.date(197...|                X|         
        X|datetime.date(197...|                            X|             X|    
            X|  [...]
+    # |                    timestamp|                  None|                 
X|                 X|                 X|                   
X|datetime.datetime...|                 X|                 X|                 
X|                 X|             X|             X|             X|              
 datetime.datetime...|                                 datetime.datetime...|    
            X|                 X|datetime.datetime...|                          
  X|             X|                X|  [...]
+    # |                       string|                  None|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|              'a'|         
     '12'|                   X|                            X|             X|    
            X|  [...]
+    # |                decimal(10,0)|                  None|                 
X|      Decimal('1')|      Decimal('1')|        Decimal('1')|                   
X|      Decimal('1')|      Decimal('1')|      Decimal('1')|                 X|  
Decimal('1')|  Decimal('1')|  Decimal('1')|                                  X| 
                                                   X|                X|         
        X|        Decimal('1')|                            X|             X|    
            X|  [...]
+    # |                   array<int>|                  None|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|                X|         
   [1, 2]|                   X|                    [1, 2, 3]|             X|    
            X|  [...]
+    # |              map<string,int>|                  None|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|                X|         
        X|                   X|                            X|             X|    
            X|  [...]
+    # |               struct<_1:int>|                     X|                 
X|                 X|                 X|                   X|                   
X|                 X|                 X|                 X|                 X|  
           X|             X|             X|                                  X| 
                                                   X|                X|         
        X|                   X|                            X|             X|    
            X|  [...]
+    # |                       binary|                  
None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|  
bytearray(b'\x01')|  
bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')|
                     bytearray(b'')|                                       
bytearray(b'')|  bytearray(b'a')|  bytearray(b'12')|                   X|       
                     X|bytearray(b'')|   bytearray(b'')|  [...]
+    # 
+-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+------------------+--------------------+-----------------------------+--------------+-----------------+-
 [...]
     #
     # Note: DDL formatted string is used for 'SQL Type' for simplicity. This 
string can be
     #       used in `returnType`.
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index aefeea226596..769f5e043a77 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1103,6 +1103,7 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
             safecheck,
             assign_cols_by_name,
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+            arrow_cast=True,
         )
         self.pickleSer = CPickleSerializer()
         self.utf8_deserializer = UTF8Deserializer()
@@ -1483,6 +1484,7 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
             safecheck,
             assign_cols_by_name,
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+            arrow_cast=True,
         )
         self.arrow_max_records_per_batch = arrow_max_records_per_batch
         self.key_offsets = None
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 67398a46bce8..d23252abf6a9 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -262,7 +262,7 @@ class CogroupedApplyInPandasTestsMixin:
                             
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
                         )
                     self._test_merge_error(
-                        fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": 
["2.0"]}),
+                        fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": 
["test_string"]}),
                         output_schema="id long, k double",
                         errorClass=PythonException,
                         error_message_regex=expected,
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 9965e2acc4b5..a7516bdd22b0 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -371,7 +371,7 @@ class GroupedApplyInPandasTestsMixin:
                         )
                     with self.assertRaisesRegex(PythonException, expected + 
"\n"):
                         self._test_apply_in_pandas(
-                            lambda key, pdf: pd.DataFrame([key + 
(str(pdf.v.mean()),)]),
+                            lambda key, pdf: pd.DataFrame([key + 
("test_string",)]),
                             output_schema="id long, mean double",
                         )
 
@@ -900,6 +900,51 @@ class GroupedApplyInPandasTestsMixin:
             with self.assertRaisesRegex(PythonException, error):
                 self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
 
+    def test_arrow_cast_enabled_numeric_to_decimal(self):
+        import numpy as np
+
+        columns = [
+            "int8",
+            "int16",
+            "int32",
+            "uint8",
+            "uint16",
+            "uint32",
+            "float64",
+        ]
+
+        pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in 
columns})
+        df = self.spark.range(2).repartition(1)
+
+        for column in columns:
+            with self.subTest(column=column):
+                v = pdf[column].iloc[:1]
+                schema_str = "id long, value decimal(10,0)"
+
+                @pandas_udf(schema_str, PandasUDFType.GROUPED_MAP)
+                def test(pdf):
+                    return pdf.assign(**{"value": v})
+
+                row = df.groupby("id").apply(test).first()
+                res = row[1]
+                self.assertEqual(res, Decimal("1"))
+
+    def test_arrow_cast_enabled_str_to_numeric(self):
+        df = self.spark.range(2).repartition(1)
+
+        types = ["int", "long", "float", "double"]
+
+        for type_str in types:
+            with self.subTest(type=type_str):
+                schema_str = "id long, value " + type_str
+
+                @pandas_udf(schema_str, PandasUDFType.GROUPED_MAP)
+                def test(pdf):
+                    return pdf.assign(value=pd.Series(["123"]))
+
+                row = df.groupby("id").apply(test).first()
+                self.assertEqual(row[1], 123)
+
 
 class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, 
ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 7debe8035f61..b241b91e02a2 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -276,16 +276,17 @@ class MapInPandasTestsMixin:
             self.check_dataframes_with_incompatible_types()
 
     def check_dataframes_with_incompatible_types(self):
-        def func(iterator):
-            for pdf in iterator:
-                yield pdf.assign(id=pdf["id"].apply(str))
-
         for safely in [True, False]:
             with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
                 {"spark.sql.execution.pandas.convertToArrowArraySafely": 
safely}
             ):
                 # sometimes we see ValueErrors
                 with self.subTest(convert="string to double"):
+
+                    def func(iterator):
+                        for pdf in iterator:
+                            yield pdf.assign(id="test_string")
+
                     expected = (
                         r"ValueError: Exception thrown when converting 
pandas.Series "
                         r"\(object\) with name 'id' to Arrow Array \(double\)."
@@ -304,18 +305,31 @@ class MapInPandasTestsMixin:
                             .collect()
                         )
 
-                # sometimes we see TypeErrors
-                with self.subTest(convert="double to string"):
-                    with self.assertRaisesRegex(
-                        PythonException,
-                        r"TypeError: Exception thrown when converting 
pandas.Series "
-                        r"\(float64\) with name 'id' to Arrow Array 
\(string\).\n",
-                    ):
-                        (
-                            self.spark.range(10, numPartitions=3)
-                            .select(col("id").cast("double"))
-                            .mapInPandas(self.identity_dataframes_iter("id"), 
"id string")
-                            .collect()
+                with self.subTest(convert="float to int precision loss"):
+
+                    def func(iterator):
+                        for pdf in iterator:
+                            yield pdf.assign(id=pdf["id"] + 0.1)
+
+                    df = (
+                        self.spark.range(10, numPartitions=3)
+                        .select(col("id").cast("double"))
+                        .mapInPandas(func, "id int")
+                    )
+                    if safely:
+                        expected = (
+                            r"ValueError: Exception thrown when converting 
pandas.Series "
+                            r"\(float64\) with name 'id' to Arrow Array 
\(int32\)."
+                            " It can be caused by overflows or other "
+                            "unsafe conversions warned by Arrow. Arrow safe 
type check "
+                            "can be disabled by using SQL config "
+                            
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
+                        )
+                        with self.assertRaisesRegex(PythonException, expected 
+ "\n"):
+                            df.collect()
+                    else:
+                        self.assertEqual(
+                            df.collect(), self.spark.range(10, 
numPartitions=3).collect()
                         )
 
     def test_empty_iterator(self):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index d67bed462345..2fb802a60df0 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -376,27 +376,20 @@ class PandasUDFTestsMixin:
             values = [1, 2, 3]
             return pd.Series([values[int(val) % len(values)] for val in 
column])
 
-        with self.sql_conf(
-            
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
-        ):
-            result = df.withColumn("decimal_val", 
high_precision_udf("id")).collect()
-            self.assertEqual(len(result), 3)
-            self.assertEqual(result[0]["decimal_val"], Decimal("1.0"))
-            self.assertEqual(result[1]["decimal_val"], Decimal("2.0"))
-            self.assertEqual(result[2]["decimal_val"], Decimal("3.0"))
-
-        with self.sql_conf(
-            
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
-        ):
-            # Also not supported.
-            # This can be fixed by enabling arrow_cast
-            # This is currently not the case for SQL_SCALAR_PANDAS_UDF and
-            # SQL_SCALAR_PANDAS_ITER_UDF.
-            self.assertRaisesRegex(
-                PythonException,
-                "Exception thrown when converting pandas.Series",
-                df.withColumn("decimal_val", high_precision_udf("id")).collect,
-            )
+        for intToDecimalCoercionEnabled in [True, False]:
+            # arrow_cast is enabled by default for SQL_SCALAR_PANDAS_UDF and
+            # and SQL_SCALAR_PANDAS_ITER_UDF, arrow can do this cast safely.
+            # intToDecimalCoercionEnabled is not required for this case
+            with self.sql_conf(
+                {
+                    
"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": 
intToDecimalCoercionEnabled  # noqa: E501
+                }
+            ):
+                result = df.withColumn("decimal_val", 
high_precision_udf("id")).collect()
+                self.assertEqual(len(result), 3)
+                self.assertEqual(result[0]["decimal_val"], Decimal("1.0"))
+                self.assertEqual(result[1]["decimal_val"], Decimal("2.0"))
+                self.assertEqual(result[2]["decimal_val"], Decimal("3.0"))
 
     def test_pandas_udf_timestamp_ntz(self):
         # SPARK-36626: Test TimestampNTZ in pandas UDF
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index e22b8f9ccacc..1059af59f4a8 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -718,6 +718,49 @@ class GroupedAggPandasUDFTestsMixin:
                         aggregated, df.groupby("id").agg((sum(df.v) + 
sum(df.w)).alias("s"))
                     )
 
+    def test_arrow_cast_enabled_numeric_to_decimal(self):
+        import numpy as np
+        from decimal import Decimal
+
+        columns = [
+            "int8",
+            "int16",
+            "int32",
+            "uint8",
+            "uint16",
+            "uint32",
+            "float64",
+        ]
+
+        pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in 
columns})
+        df = self.spark.range(2).repartition(1)
+
+        for column in columns:
+            with self.subTest(column=column):
+
+                @pandas_udf("decimal(10,0)", PandasUDFType.GROUPED_AGG)
+                def test(series):
+                    return pdf[column].iloc[0]
+
+                row = df.groupby("id").agg(test(df.id)).first()
+                res = row[1]
+                self.assertEqual(res, Decimal("1"))
+
+    def test_arrow_cast_enabled_str_to_numeric(self):
+        df = self.spark.range(2).repartition(1)
+
+        types = ["int", "long", "float", "double"]
+
+        for type_str in types:
+            with self.subTest(type=type_str):
+
+                @pandas_udf(type_str, PandasUDFType.GROUPED_AGG)
+                def test(series):
+                    return 123
+
+                row = df.groupby("id").agg(test(df.id)).first()
+                self.assertEqual(row[1], 123)
+
 
 class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, 
ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index efd9ae0eb9f5..e614d9039b61 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -1875,6 +1875,36 @@ class ScalarPandasUDFTestsMixin:
             with self.subTest(with_b=True, query_no=i):
                 assertDataFrameEqual(df, [Row(0), Row(101)])
 
+    def test_arrow_cast_enabled_numeric_to_decimal(self):
+        import numpy as np
+
+        columns = [
+            "int8",
+            "int16",
+            "int32",
+            "uint8",
+            "uint16",
+            "uint32",
+            "float64",
+        ]
+
+        pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in 
columns})
+        df = self.spark.range(2).repartition(1)
+
+        t = DecimalType(10, 0)
+        for column in columns:
+            with self.subTest(column=column):
+                v = pdf[column].iloc[:1]
+                row = df.select(pandas_udf(lambda _: v, t)(df.id)).first()
+                assert (row[0] == v).all()
+
+    def test_arrow_cast_enabled_str_to_numeric(self):
+        df = self.spark.range(2).repartition(1)
+        for t in [IntegerType(), LongType(), FloatType(), DoubleType()]:
+            with self.subTest(type=t):
+                row = df.select(pandas_udf(lambda _: pd.Series(["123"]), 
t)(df.id)).first()
+                assert row[0] == 123
+
 
 class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 9b3673d80d22..2f534b811b34 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -17,6 +17,7 @@
 
 import unittest
 from typing import cast
+from decimal import Decimal
 
 from pyspark.errors import AnalysisException, PythonException
 from pyspark.sql.functions import (
@@ -33,6 +34,13 @@ from pyspark.sql.functions import (
     PandasUDFType,
 )
 from pyspark.sql.window import Window
+from pyspark.sql.types import (
+    DecimalType,
+    IntegerType,
+    LongType,
+    FloatType,
+    DoubleType,
+)
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_pandas,
@@ -563,6 +571,43 @@ class WindowPandasUDFTestsMixin:
                             )
                         ).show()
 
+    def test_arrow_cast_numeric_to_decimal(self):
+        import numpy as np
+        import pandas as pd
+
+        columns = [
+            "int8",
+            "int16",
+            "int32",
+            "uint8",
+            "uint16",
+            "uint32",
+            "float64",
+        ]
+
+        pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in 
columns})
+        df = self.data
+        w = self.unbounded_window
+
+        t = DecimalType(10, 0)
+        for column in columns:
+            with self.subTest(column=column):
+                value = pdf[column].iloc[0]
+                mean_udf = pandas_udf(lambda v: value, t, 
PandasUDFType.GROUPED_AGG)
+                result = df.select(mean_udf(df["v"]).over(w)).first()[0]
+                assert result == Decimal("1.0")
+                assert type(result) == Decimal
+
+    def test_arrow_cast_str_to_numeric(self):
+        df = self.data
+        w = self.unbounded_window
+
+        for t in [IntegerType(), LongType(), FloatType(), DoubleType()]:
+            with self.subTest(type=t):
+                mean_udf = pandas_udf(lambda v: "123", t, 
PandasUDFType.GROUPED_AGG)
+                result = df.select(mean_udf(df["v"]).over(w)).first()[0]
+                assert result == 123
+
 
 class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 9517a72a7a70..207c2a999571 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -2285,6 +2285,7 @@ def read_udfs(pickleSer, infile, eval_type):
                 safecheck,
                 _assign_cols_by_name,
                 
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+                arrow_cast=True,
             )
         elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
             arrow_max_records_per_batch = runner_conf.get(
@@ -2374,8 +2375,6 @@ def read_udfs(pickleSer, infile, eval_type):
                 "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF 
else "dict"
             )
             ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
-            # Arrow-optimized Python UDF uses explicit Arrow cast for type 
coercion
-            arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
             # Arrow-optimized Python UDF takes input types
             input_types = (
                 [f.dataType for f in 
_parse_datatype_json_string(utf8_deserializer.loads(infile))]
@@ -2390,7 +2389,7 @@ def read_udfs(pickleSer, infile, eval_type):
                 df_for_struct,
                 struct_in_pandas,
                 ndarray_as_list,
-                arrow_cast,
+                True,
                 input_types,
                 
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
             )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to