This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4faeb803670d [SPARK-46413][PYTHON] Validate returnType of Arrow Python 
UDF
4faeb803670d is described below

commit 4faeb803670d77de8a47fe1683a6680c3ee3f454
Author: Xinrong Meng <[email protected]>
AuthorDate: Thu Dec 21 08:18:13 2023 +0900

    [SPARK-46413][PYTHON] Validate returnType of Arrow Python UDF
    
    ### What changes were proposed in this pull request?
    Validate returnType of Arrow Python UDF
    
    ### Why are the changes needed?
    Better error handling and consistency with other types of UDFs.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, now we raise an error when the given returnType is not supported.
    
    ```py
    >>> udf(lambda x: x, returnType=VarcharType(10), useArrow=True)
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.base.PySparkTypeError: 
[UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] VarcharType(10) is not supported 
in conversion to Arrow.
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #44362 from xinrong-meng/test_more_udf.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/tests/connect/test_parity_arrow_python_udf.py     |  4 ++++
 python/pyspark/sql/tests/test_arrow_python_udf.py         | 15 ++++++++++++++-
 python/pyspark/sql/udf.py                                 | 14 ++++++++++++--
 3 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py 
b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
index fa329b598d98..f5bd99fa22cf 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
@@ -58,6 +58,10 @@ class ArrowPythonUDFParityTests(UDFParityTests, 
PythonUDFArrowTestsMixin):
         with self.assertRaises(PythonException):
             self.spark.sql("SELECT test_udf(id, a => id * 10) FROM 
range(2)").show()
 
+    @unittest.skip("Spark Connect does not validate return type in client.")
+    def test_err_return_type(self):
+        super.test_err_return_type()
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py 
b/python/pyspark/sql/tests/test_arrow_python_udf.py
index f853b15ce6f8..c59326edc31a 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -17,10 +17,11 @@
 
 import unittest
 
-from pyspark.errors import PythonException
+from pyspark.errors import PythonException, PySparkNotImplementedError
 from pyspark.sql import Row
 from pyspark.sql.functions import udf
 from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
+from pyspark.sql.types import VarcharType
 from pyspark.testing.sqlutils import (
     have_pandas,
     have_pyarrow,
@@ -175,6 +176,18 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
         with self.assertRaises(PythonException):
             df_floating_value.select(udf(lambda x: x, 
"decimal")("value").alias("res")).collect()
 
+    def test_err_return_type(self):
+        with self.assertRaises(PySparkNotImplementedError) as pe:
+            udf(lambda x: x, VarcharType(10), useArrow=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={
+                "feature": "Invalid return type with Arrow-optimized Python 
UDF: VarcharType(10)"
+            },
+        )
+
 
 class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index b9a00d432671..16605bc12acc 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -213,8 +213,18 @@ class UserDefinedFunction:
                 self._returnType_placeholder = self._returnType
             else:
                 self._returnType_placeholder = 
_parse_datatype_string(self._returnType)
-
-        if (
+        if self.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
+            try:
+                to_arrow_type(self._returnType_placeholder)
+            except TypeError:
+                raise PySparkNotImplementedError(
+                    error_class="NOT_IMPLEMENTED",
+                    message_parameters={
+                        "feature": f"Invalid return type with Arrow-optimized 
Python UDF: "
+                        f"{self._returnType_placeholder}"
+                    },
+                )
+        elif (
             self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
             or self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
         ):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to