This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d564b7d1e53 [SPARK-49477][PYTHON] Improve pandas udf invalid return 
type error message
7d564b7d1e53 is described below

commit 7d564b7d1e535d1c6c1a828ce35411dfda1037ec
Author: allisonwang-db <[email protected]>
AuthorDate: Wed Sep 4 09:26:20 2024 +0900

    [SPARK-49477][PYTHON] Improve pandas udf invalid return type error message
    
    ### What changes were proposed in this pull request?
    
    This PR improves the error message when the specified return type of a 
pandas udf mismatch the actual return type.
    
    ### Why are the changes needed?
    
    To improve the error message.
    
    Before this PR:
    `pyspark.errors.exceptions.base.PySparkValueError: A field of type 
StructType expects a pandas.DataFrame, but got: <class 
'pandas.core.series.Series'>`
    
    After this PR:
    `pyspark.errors.exceptions.base.PySparkValueError: Invalid return type. 
Please make sure that the UDF returns a pandas.DataFrame when the specified 
return type is StructType.`
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #47942 from allisonwang-db/spark-49477-pandas-udf-err-msg.
    
    Authored-by: allisonwang-db <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/pandas/serializers.py           |  4 ++--
 python/pyspark/sql/tests/pandas/test_pandas_udf.py | 13 +++++++++++++
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 6203d4d19d86..076226865f3a 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -510,8 +510,8 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
                 # If it returns a pd.Series, it should throw an error.
                 if not isinstance(s, pd.DataFrame):
                     raise PySparkValueError(
-                        "A field of type StructType expects a 
pandas.DataFrame, "
-                        "but got: %s" % str(type(s))
+                        "Invalid return type. Please make sure that the UDF 
returns a "
+                        "pandas.DataFrame when the specified return type is 
StructType."
                     )
                 arrs.append(self._create_struct_array(s, t))
             else:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 6720dfc37d0c..228fc30b497c 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -339,6 +339,19 @@ class PandasUDFTestsMixin:
         self.assertEqual(df.schema[0].dataType.simpleString(), "interval day 
to second")
         self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))
 
+    def test_pandas_udf_return_type_error(self):
+        import pandas as pd
+
+        @pandas_udf("s string")
+        def upper(s: pd.Series) -> pd.Series:
+            return s.str.upper()
+
+        df = self.spark.createDataFrame([("a",)], schema="s string")
+
+        self.assertRaisesRegex(
+            PythonException, "Invalid return type", 
df.select(upper("s")).collect
+        )
+
 
 class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to