This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new dfc83e6508c [SPARK-44640][PYTHON][3.5] Improve error messages for Python UDTF returning non Iterable dfc83e6508c is described below commit dfc83e6508c75f9aef2bec7a52d098ac6ba90c9b Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Fri Aug 4 15:07:07 2023 +0800 [SPARK-44640][PYTHON][3.5] Improve error messages for Python UDTF returning non Iterable ### What changes were proposed in this pull request? This PR cherry-picks https://github.com/apache/spark/commit/380c0f2033fb83b5e4f13693d2576d72c5cc01f2. It improves error messages when the result of a Python UDTF is not an Iterable. It also improves the error messages when a UDTF encounters an exception when executing `eval`. ### Why are the changes needed? To make Python UDTFs more user-friendly. ### Does this PR introduce _any_ user-facing change? Yes. For example this UDTF: ``` udtf(returnType="x: int") class TestUDTF: def eval(self, a): return a ``` Before this PR, it fails with this error for regular UDTFs: ``` return tuple(map(verify_and_convert_result, res)) TypeError: 'int' object is not iterable ``` And this error for arrow-optimized UDTFs: ``` raise ValueError("DataFrame constructor not properly called!") ValueError: DataFrame constructor not properly called! ``` After this PR, the error message will be: `pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_RETURN_NOT_ITERABLE] The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got 'int'. Please make sure that the UDTF returns one of these types.` ### How was this patch tested? New UTs. Closes #42337 from allisonwang-db/spark-44640-3.5. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/errors/error_classes.py | 5 ++++ python/pyspark/sql/tests/test_udtf.py | 42 +++++++++++++++++++++++++-- python/pyspark/sql/udtf.py | 40 ++++++++++++++++++++----- python/pyspark/worker.py | 53 ++++++++++++++++++---------------- 4 files changed, 105 insertions(+), 35 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 2a3f454452e..4ea3e678810 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -728,6 +728,11 @@ ERROR_CLASSES_JSON = """ "User defined table function encountered an error in the '<method_name>' method: <error>" ] }, + "UDTF_RETURN_NOT_ITERABLE" : { + "message" : [ + "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types." + ] + }, "UDTF_RETURN_SCHEMA_MISMATCH" : { "message" : [ "The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the function have the same number of columns as specified in the output schema." diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 5c33cb14834..4bab77038e0 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -155,6 +155,15 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): func(lit(1)).collect() + def test_udtf_with_invalid_return_value(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self, a): + return a + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF(lit(1)).collect() + def test_udtf_eval_with_no_return(self): @udtf(returnType="a: int") class TestUDTF: @@ -350,6 +359,35 @@ class BaseUDTFTestsMixin: ], ) + def test_init_with_exception(self): + @udtf(returnType="x: int") + class TestUDTF: + def __init__(self): + raise Exception("error") + + def eval(self): + yield 1, + + with self.assertRaisesRegex( + PythonException, + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the '__init__' method: error", + ): + TestUDTF().show() + + def test_eval_with_exception(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + raise Exception("error") + + with self.assertRaisesRegex( + PythonException, + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the 'eval' method: error", + ): + TestUDTF().show() + def test_terminate_with_exceptions(self): @udtf(returnType="a: int, b: int") class TestUDTF: @@ -361,8 +399,8 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex( PythonException, - "User defined table function encountered an error in the 'terminate' " - "method: terminate error", + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the 'terminate' method: terminate error", ): TestUDTF(lit(1)).collect() diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 50bba56880c..c2830d56db5 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -19,11 +19,12 @@ User-defined table function related classes and functions """ import sys import warnings -from typing import Any, Iterator, Type, TYPE_CHECKING, Optional, Union +from functools import wraps +from typing import Any, Iterable, Iterator, Type, TYPE_CHECKING, Optional, Union, Callable from py4j.java_gateway import JavaObject -from pyspark.errors import PySparkAttributeError, PySparkTypeError +from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError from pyspark.rdd import PythonEvalType from pyspark.sql.column import _to_java_column, _to_seq from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -107,23 +108,46 @@ def _vectorize_udtf(cls: Type) -> Type: """Vectorize a Python UDTF handler class.""" import pandas as pd + # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. + def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def evaluate(*a: Any) -> Any: + try: + return f(*a) + except Exception as e: + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={"method_name": f.__name__, "error": str(e)}, + ) + + return evaluate + class VectorizedUDTF: def __init__(self) -> None: self.func = cls() def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]: if len(args) == 0: - yield pd.DataFrame(self.func.eval()) + yield pd.DataFrame(wrap_func(self.func.eval)()) else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. row_tuples = zip(*args) for row in row_tuples: - yield pd.DataFrame(self.func.eval(*row)) - - def terminate(self) -> Iterator[pd.DataFrame]: - if hasattr(self.func, "terminate"): - yield pd.DataFrame(self.func.terminate()) + res = wrap_func(self.func.eval)(*row) + if res is not None and not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + }, + ) + yield pd.DataFrame(res) + + if hasattr(cls, "terminate"): + + def terminate(self) -> Iterator[pd.DataFrame]: + yield pd.DataFrame(wrap_func(self.func.terminate)()) vectorized_udtf = VectorizedUDTF vectorized_udtf.__name__ = cls.__name__ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3dffdf2c642..c9eedd43b29 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -24,7 +24,7 @@ import time from inspect import currentframe, getframeinfo, getfullargspec import importlib import json -from typing import Iterator +from typing import Iterable, Iterator # 'resource' is a Unix specific module. has_resource_module = True @@ -602,6 +602,7 @@ def read_udtf(pickleSer, infile, eval_type): def wrap_arrow_udtf(f, return_type): arrow_return_type = to_arrow_type(return_type) + return_type_size = len(return_type) def verify_result(result): import pandas as pd @@ -610,7 +611,7 @@ def read_udtf(pickleSer, infile, eval_type): raise PySparkTypeError( error_class="INVALID_ARROW_UDTF_RETURN_TYPE", message_parameters={ - "type_name": type(result).__name_, + "type_name": type(result).__name__, "value": str(result), }, ) @@ -620,11 +621,11 @@ def read_udtf(pickleSer, infile, eval_type): # result dataframe may contain an empty row. For example, when a UDTF is # defined as follows: def eval(self): yield tuple(). if len(result) > 0 or len(result.columns) > 0: - if len(result.columns) != len(return_type): + if len(result.columns) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ - "expected": str(len(return_type)), + "expected": str(return_type_size), "actual": str(len(result.columns)), }, ) @@ -652,13 +653,7 @@ def read_udtf(pickleSer, infile, eval_type): yield from eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: - try: - yield from terminate() - except BaseException as e: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={"method_name": "terminate", "error": str(e)}, - ) + yield from terminate() return mapper, None, ser, ser @@ -667,15 +662,16 @@ def read_udtf(pickleSer, infile, eval_type): def wrap_udtf(f, return_type): assert return_type.needConversion() toInternal = return_type.toInternal + return_type_size = len(return_type) def verify_and_convert_result(result): # TODO(SPARK-44005): support returning non-tuple values if result is not None and hasattr(result, "__len__"): - if len(result) != len(return_type): + if len(result) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ - "expected": str(len(return_type)), + "expected": str(return_type_size), "actual": str(len(result)), }, ) @@ -683,16 +679,29 @@ def read_udtf(pickleSer, infile, eval_type): # Evaluate the function and return a tuple back to the executor. def evaluate(*a) -> tuple: - res = f(*a) + try: + res = f(*a) + except Exception as e: + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={"method_name": f.__name__, "error": str(e)}, + ) + if res is None: # If the function returns None or does not have an explicit return statement, # an empty tuple is returned to the executor. # This is because directly constructing tuple(None) results in an exception. return tuple() - else: - # If the function returns a result, we map it to the internal representation and - # returns the results as a tuple. - return tuple(map(verify_and_convert_result, res)) + + if not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={"type": type(res).__name__}, + ) + + # If the function returns a result, we map it to the internal representation and + # returns the results as a tuple. + return tuple(map(verify_and_convert_result, res)) return evaluate @@ -710,13 +719,7 @@ def read_udtf(pickleSer, infile, eval_type): yield eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: - try: - yield terminate() - except BaseException as e: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={"method_name": "terminate", "error": str(e)}, - ) + yield terminate() return mapper, None, ser, ser --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org