This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new dfc83e6508c [SPARK-44640][PYTHON][3.5] Improve error messages for 
Python UDTF returning non Iterable
dfc83e6508c is described below

commit dfc83e6508c75f9aef2bec7a52d098ac6ba90c9b
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Fri Aug 4 15:07:07 2023 +0800

    [SPARK-44640][PYTHON][3.5] Improve error messages for Python UDTF returning 
non Iterable
    
    ### What changes were proposed in this pull request?
    
    This PR cherry-picks 
https://github.com/apache/spark/commit/380c0f2033fb83b5e4f13693d2576d72c5cc01f2.
 It improves error messages when the result of a Python UDTF is not an 
Iterable. It also improves the error messages when a UDTF encounters an 
exception when executing `eval`.
    
    ### Why are the changes needed?
    
    To make Python UDTFs more user-friendly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. For example this UDTF:
    ```
    udtf(returnType="x: int")
    class TestUDTF:
        def eval(self, a):
            return a
    ```
    Before this PR, it fails with this error for regular UDTFs:
    ```
        return tuple(map(verify_and_convert_result, res))
    TypeError: 'int' object is not iterable
    ```
    And this error for arrow-optimized UDTFs:
    ```
        raise ValueError("DataFrame constructor not properly called!")
    ValueError: DataFrame constructor not properly called!
    ```
    
    After this PR, the error message will be:
    `pyspark.errors.exceptions.base.PySparkRuntimeError: 
[UDTF_RETURN_NOT_ITERABLE] The return value of the UDTF is invalid. It should 
be an iterable (e.g., generator or list), but got 'int'. Please make sure that 
the UDTF returns one of these types.`
    
    ### How was this patch tested?
    
    New UTs.
    
    Closes #42337 from allisonwang-db/spark-44640-3.5.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/errors/error_classes.py |  5 ++++
 python/pyspark/sql/tests/test_udtf.py  | 42 +++++++++++++++++++++++++--
 python/pyspark/sql/udtf.py             | 40 ++++++++++++++++++++-----
 python/pyspark/worker.py               | 53 ++++++++++++++++++----------------
 4 files changed, 105 insertions(+), 35 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 2a3f454452e..4ea3e678810 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -728,6 +728,11 @@ ERROR_CLASSES_JSON = """
       "User defined table function encountered an error in the '<method_name>' 
method: <error>"
     ]
   },
+  "UDTF_RETURN_NOT_ITERABLE" : {
+    "message" : [
+      "The return value of the UDTF is invalid. It should be an iterable 
(e.g., generator or list), but got '<type>'. Please make sure that the UDTF 
returns one of these types."
+    ]
+  },
   "UDTF_RETURN_SCHEMA_MISMATCH" : {
     "message" : [
       "The number of columns in the result does not match the specified 
schema. Expected column count: <expected>, Actual column count: <actual>. 
Please make sure the values returned by the function have the same number of 
columns as specified in the output schema."
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 5c33cb14834..4bab77038e0 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -155,6 +155,15 @@ class BaseUDTFTestsMixin:
         with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
             func(lit(1)).collect()
 
+    def test_udtf_with_invalid_return_value(self):
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self, a):
+                return a
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_NOT_ITERABLE"):
+            TestUDTF(lit(1)).collect()
+
     def test_udtf_eval_with_no_return(self):
         @udtf(returnType="a: int")
         class TestUDTF:
@@ -350,6 +359,35 @@ class BaseUDTFTestsMixin:
             ],
         )
 
+    def test_init_with_exception(self):
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def __init__(self):
+                raise Exception("error")
+
+            def eval(self):
+                yield 1,
+
+        with self.assertRaisesRegex(
+            PythonException,
+            r"\[UDTF_EXEC_ERROR\] User defined table function encountered an 
error "
+            r"in the '__init__' method: error",
+        ):
+            TestUDTF().show()
+
+    def test_eval_with_exception(self):
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self):
+                raise Exception("error")
+
+        with self.assertRaisesRegex(
+            PythonException,
+            r"\[UDTF_EXEC_ERROR\] User defined table function encountered an 
error "
+            r"in the 'eval' method: error",
+        ):
+            TestUDTF().show()
+
     def test_terminate_with_exceptions(self):
         @udtf(returnType="a: int, b: int")
         class TestUDTF:
@@ -361,8 +399,8 @@ class BaseUDTFTestsMixin:
 
         with self.assertRaisesRegex(
             PythonException,
-            "User defined table function encountered an error in the 
'terminate' "
-            "method: terminate error",
+            r"\[UDTF_EXEC_ERROR\] User defined table function encountered an 
error "
+            r"in the 'terminate' method: terminate error",
         ):
             TestUDTF(lit(1)).collect()
 
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 50bba56880c..c2830d56db5 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -19,11 +19,12 @@ User-defined table function related classes and functions
 """
 import sys
 import warnings
-from typing import Any, Iterator, Type, TYPE_CHECKING, Optional, Union
+from functools import wraps
+from typing import Any, Iterable, Iterator, Type, TYPE_CHECKING, Optional, 
Union, Callable
 
 from py4j.java_gateway import JavaObject
 
-from pyspark.errors import PySparkAttributeError, PySparkTypeError
+from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, 
PySparkTypeError
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.column import _to_java_column, _to_seq
 from pyspark.sql.pandas.utils import require_minimum_pandas_version, 
require_minimum_pyarrow_version
@@ -107,23 +108,46 @@ def _vectorize_udtf(cls: Type) -> Type:
     """Vectorize a Python UDTF handler class."""
     import pandas as pd
 
+    # Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
+    def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]:
+        @wraps(f)
+        def evaluate(*a: Any) -> Any:
+            try:
+                return f(*a)
+            except Exception as e:
+                raise PySparkRuntimeError(
+                    error_class="UDTF_EXEC_ERROR",
+                    message_parameters={"method_name": f.__name__, "error": 
str(e)},
+                )
+
+        return evaluate
+
     class VectorizedUDTF:
         def __init__(self) -> None:
             self.func = cls()
 
         def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]:
             if len(args) == 0:
-                yield pd.DataFrame(self.func.eval())
+                yield pd.DataFrame(wrap_func(self.func.eval)())
             else:
                 # Create tuples from the input pandas Series, each tuple
                 # represents a row across all Series.
                 row_tuples = zip(*args)
                 for row in row_tuples:
-                    yield pd.DataFrame(self.func.eval(*row))
-
-        def terminate(self) -> Iterator[pd.DataFrame]:
-            if hasattr(self.func, "terminate"):
-                yield pd.DataFrame(self.func.terminate())
+                    res = wrap_func(self.func.eval)(*row)
+                    if res is not None and not isinstance(res, Iterable):
+                        raise PySparkRuntimeError(
+                            error_class="UDTF_RETURN_NOT_ITERABLE",
+                            message_parameters={
+                                "type": type(res).__name__,
+                            },
+                        )
+                    yield pd.DataFrame(res)
+
+        if hasattr(cls, "terminate"):
+
+            def terminate(self) -> Iterator[pd.DataFrame]:
+                yield pd.DataFrame(wrap_func(self.func.terminate)())
 
     vectorized_udtf = VectorizedUDTF
     vectorized_udtf.__name__ = cls.__name__
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 3dffdf2c642..c9eedd43b29 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -24,7 +24,7 @@ import time
 from inspect import currentframe, getframeinfo, getfullargspec
 import importlib
 import json
-from typing import Iterator
+from typing import Iterable, Iterator
 
 # 'resource' is a Unix specific module.
 has_resource_module = True
@@ -602,6 +602,7 @@ def read_udtf(pickleSer, infile, eval_type):
 
         def wrap_arrow_udtf(f, return_type):
             arrow_return_type = to_arrow_type(return_type)
+            return_type_size = len(return_type)
 
             def verify_result(result):
                 import pandas as pd
@@ -610,7 +611,7 @@ def read_udtf(pickleSer, infile, eval_type):
                     raise PySparkTypeError(
                         error_class="INVALID_ARROW_UDTF_RETURN_TYPE",
                         message_parameters={
-                            "type_name": type(result).__name_,
+                            "type_name": type(result).__name__,
                             "value": str(result),
                         },
                     )
@@ -620,11 +621,11 @@ def read_udtf(pickleSer, infile, eval_type):
                 # result dataframe may contain an empty row. For example, when 
a UDTF is
                 # defined as follows: def eval(self): yield tuple().
                 if len(result) > 0 or len(result.columns) > 0:
-                    if len(result.columns) != len(return_type):
+                    if len(result.columns) != return_type_size:
                         raise PySparkRuntimeError(
                             error_class="UDTF_RETURN_SCHEMA_MISMATCH",
                             message_parameters={
-                                "expected": str(len(return_type)),
+                                "expected": str(return_type_size),
                                 "actual": str(len(result.columns)),
                             },
                         )
@@ -652,13 +653,7 @@ def read_udtf(pickleSer, infile, eval_type):
                     yield from eval(*[a[o] for o in arg_offsets])
             finally:
                 if terminate is not None:
-                    try:
-                        yield from terminate()
-                    except BaseException as e:
-                        raise PySparkRuntimeError(
-                            error_class="UDTF_EXEC_ERROR",
-                            message_parameters={"method_name": "terminate", 
"error": str(e)},
-                        )
+                    yield from terminate()
 
         return mapper, None, ser, ser
 
@@ -667,15 +662,16 @@ def read_udtf(pickleSer, infile, eval_type):
         def wrap_udtf(f, return_type):
             assert return_type.needConversion()
             toInternal = return_type.toInternal
+            return_type_size = len(return_type)
 
             def verify_and_convert_result(result):
                 # TODO(SPARK-44005): support returning non-tuple values
                 if result is not None and hasattr(result, "__len__"):
-                    if len(result) != len(return_type):
+                    if len(result) != return_type_size:
                         raise PySparkRuntimeError(
                             error_class="UDTF_RETURN_SCHEMA_MISMATCH",
                             message_parameters={
-                                "expected": str(len(return_type)),
+                                "expected": str(return_type_size),
                                 "actual": str(len(result)),
                             },
                         )
@@ -683,16 +679,29 @@ def read_udtf(pickleSer, infile, eval_type):
 
             # Evaluate the function and return a tuple back to the executor.
             def evaluate(*a) -> tuple:
-                res = f(*a)
+                try:
+                    res = f(*a)
+                except Exception as e:
+                    raise PySparkRuntimeError(
+                        error_class="UDTF_EXEC_ERROR",
+                        message_parameters={"method_name": f.__name__, 
"error": str(e)},
+                    )
+
                 if res is None:
                     # If the function returns None or does not have an 
explicit return statement,
                     # an empty tuple is returned to the executor.
                     # This is because directly constructing tuple(None) 
results in an exception.
                     return tuple()
-                else:
-                    # If the function returns a result, we map it to the 
internal representation and
-                    # returns the results as a tuple.
-                    return tuple(map(verify_and_convert_result, res))
+
+                if not isinstance(res, Iterable):
+                    raise PySparkRuntimeError(
+                        error_class="UDTF_RETURN_NOT_ITERABLE",
+                        message_parameters={"type": type(res).__name__},
+                    )
+
+                # If the function returns a result, we map it to the internal 
representation and
+                # returns the results as a tuple.
+                return tuple(map(verify_and_convert_result, res))
 
             return evaluate
 
@@ -710,13 +719,7 @@ def read_udtf(pickleSer, infile, eval_type):
                     yield eval(*[a[o] for o in arg_offsets])
             finally:
                 if terminate is not None:
-                    try:
-                        yield terminate()
-                    except BaseException as e:
-                        raise PySparkRuntimeError(
-                            error_class="UDTF_EXEC_ERROR",
-                            message_parameters={"method_name": "terminate", 
"error": str(e)},
-                        )
+                    yield terminate()
 
         return mapper, None, ser, ser
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to