This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7d564b7d1e53 [SPARK-49477][PYTHON] Improve pandas udf invalid return
type error message
7d564b7d1e53 is described below
commit 7d564b7d1e535d1c6c1a828ce35411dfda1037ec
Author: allisonwang-db <[email protected]>
AuthorDate: Wed Sep 4 09:26:20 2024 +0900
[SPARK-49477][PYTHON] Improve pandas udf invalid return type error message
### What changes were proposed in this pull request?
This PR improves the error message when the specified return type of a
pandas udf mismatch the actual return type.
### Why are the changes needed?
To improve the error message.
Before this PR:
`pyspark.errors.exceptions.base.PySparkValueError: A field of type
StructType expects a pandas.DataFrame, but got: <class
'pandas.core.series.Series'>`
After this PR:
`pyspark.errors.exceptions.base.PySparkValueError: Invalid return type.
Please make sure that the UDF returns a pandas.DataFrame when the specified
return type is StructType.`
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47942 from allisonwang-db/spark-49477-pandas-udf-err-msg.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 4 ++--
python/pyspark/sql/tests/pandas/test_pandas_udf.py | 13 +++++++++++++
2 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 6203d4d19d86..076226865f3a 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -510,8 +510,8 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
# If it returns a pd.Series, it should throw an error.
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
- "A field of type StructType expects a
pandas.DataFrame, "
- "but got: %s" % str(type(s))
+ "Invalid return type. Please make sure that the UDF
returns a "
+ "pandas.DataFrame when the specified return type is
StructType."
)
arrs.append(self._create_struct_array(s, t))
else:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 6720dfc37d0c..228fc30b497c 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -339,6 +339,19 @@ class PandasUDFTestsMixin:
self.assertEqual(df.schema[0].dataType.simpleString(), "interval day
to second")
self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))
+ def test_pandas_udf_return_type_error(self):
+ import pandas as pd
+
+ @pandas_udf("s string")
+ def upper(s: pd.Series) -> pd.Series:
+ return s.str.upper()
+
+ df = self.spark.createDataFrame([("a",)], schema="s string")
+
+ self.assertRaisesRegex(
+ PythonException, "Invalid return type",
df.select(upper("s")).collect
+ )
+
class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase):
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]