This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4faeb803670d [SPARK-46413][PYTHON] Validate returnType of Arrow Python
UDF
4faeb803670d is described below
commit 4faeb803670d77de8a47fe1683a6680c3ee3f454
Author: Xinrong Meng <[email protected]>
AuthorDate: Thu Dec 21 08:18:13 2023 +0900
[SPARK-46413][PYTHON] Validate returnType of Arrow Python UDF
### What changes were proposed in this pull request?
Validate returnType of Arrow Python UDF
### Why are the changes needed?
Better error handling and consistency with other types of UDFs.
### Does this PR introduce _any_ user-facing change?
Yes, now we raise an error when the given returnType is not supported.
```py
>>> udf(lambda x: x, returnType=VarcharType(10), useArrow=True)
Traceback (most recent call last):
...
pyspark.errors.exceptions.base.PySparkTypeError:
[UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] VarcharType(10) is not supported
in conversion to Arrow.
```
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44362 from xinrong-meng/test_more_udf.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/tests/connect/test_parity_arrow_python_udf.py | 4 ++++
python/pyspark/sql/tests/test_arrow_python_udf.py | 15 ++++++++++++++-
python/pyspark/sql/udf.py | 14 ++++++++++++--
3 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
index fa329b598d98..f5bd99fa22cf 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
@@ -58,6 +58,10 @@ class ArrowPythonUDFParityTests(UDFParityTests,
PythonUDFArrowTestsMixin):
with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM
range(2)").show()
+ @unittest.skip("Spark Connect does not validate return type in client.")
+ def test_err_return_type(self):
+ super.test_err_return_type()
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py
b/python/pyspark/sql/tests/test_arrow_python_udf.py
index f853b15ce6f8..c59326edc31a 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -17,10 +17,11 @@
import unittest
-from pyspark.errors import PythonException
+from pyspark.errors import PythonException, PySparkNotImplementedError
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
+from pyspark.sql.types import VarcharType
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
@@ -175,6 +176,18 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
with self.assertRaises(PythonException):
df_floating_value.select(udf(lambda x: x,
"decimal")("value").alias("res")).collect()
+ def test_err_return_type(self):
+ with self.assertRaises(PySparkNotImplementedError) as pe:
+ udf(lambda x: x, VarcharType(10), useArrow=True)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_IMPLEMENTED",
+ message_parameters={
+ "feature": "Invalid return type with Arrow-optimized Python
UDF: VarcharType(10)"
+ },
+ )
+
class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index b9a00d432671..16605bc12acc 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -213,8 +213,18 @@ class UserDefinedFunction:
self._returnType_placeholder = self._returnType
else:
self._returnType_placeholder =
_parse_datatype_string(self._returnType)
-
- if (
+ if self.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
+ try:
+ to_arrow_type(self._returnType_placeholder)
+ except TypeError:
+ raise PySparkNotImplementedError(
+ error_class="NOT_IMPLEMENTED",
+ message_parameters={
+ "feature": f"Invalid return type with Arrow-optimized
Python UDF: "
+ f"{self._returnType_placeholder}"
+ },
+ )
+ elif (
self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
or self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]