This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new a9d601cf357 [SPARK-44640][PYTHON][FOLLOW-UP][3.5] Update UDTF error messages to include method name a9d601cf357 is described below commit a9d601cf35706c61e30ef1f1daae34a51e6bb3b0 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Thu Sep 7 10:42:20 2023 +0900 [SPARK-44640][PYTHON][FOLLOW-UP][3.5] Update UDTF error messages to include method name (cherry picked from commit 3e22c8653d728a6b8523051faddcca437accfc22) ### What changes were proposed in this pull request? This PR is a follow-up for SPARK-44640 to make the error message of a few UDTF errors more informative by including the method name in the error message (`eval` or `terminate`). ### Why are the changes needed? To improve error messages. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42840 from allisonwang-db/spark-44640-follow-up-3.5. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/errors/error_classes.py | 8 ++++---- python/pyspark/sql/tests/test_udtf.py | 21 +++++++++++++++++++ python/pyspark/worker.py | 37 +++++++++++++++++++++++++--------- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 4709f01ba06..0fbe489f623 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -244,7 +244,7 @@ ERROR_CLASSES_JSON = """ }, "INVALID_ARROW_UDTF_RETURN_TYPE" : { "message" : [ - "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the function returned a value of type <type_name> with value: <value>." + "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '<func>' method returned a value of type <type_name> with value: <value>." ] }, "INVALID_BROADCAST_OPERATION": { @@ -730,17 +730,17 @@ ERROR_CLASSES_JSON = """ }, "UDTF_INVALID_OUTPUT_ROW_TYPE" : { "message" : [ - "The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type." + "The type of an individual output row in the '<func>' method of the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type." ] }, "UDTF_RETURN_NOT_ITERABLE" : { "message" : [ - "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types." + "The return value of the '<func>' method of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types." ] }, "UDTF_RETURN_SCHEMA_MISMATCH" : { "message" : [ - "The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the function have the same number of columns as specified in the output schema." + "The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the '<func>' method have the same number of columns as specified in the output schema." ] }, "UDTF_RETURN_TYPE_MISMATCH" : { diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 1ff9e55dd78..944ce6d85b8 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -164,6 +164,27 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): TestUDTF(lit(1)).collect() + def test_udtf_with_zero_arg_and_invalid_return_value(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + return 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF().collect() + + def test_udtf_with_invalid_return_value_in_terminate(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self, a): + ... + + def terminate(self): + return 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF(lit(1)).collect() + def test_udtf_eval_with_no_return(self): @udtf(returnType="a: int") class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d2ea18c45c9..90b11d06231 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -655,6 +655,7 @@ def read_udtf(pickleSer, infile, eval_type): message_parameters={ "type_name": type(result).__name__, "value": str(result), + "func": f.__name__, }, ) @@ -669,6 +670,7 @@ def read_udtf(pickleSer, infile, eval_type): message_parameters={ "expected": str(return_type_size), "actual": str(len(result.columns)), + "func": f.__name__, }, ) @@ -688,22 +690,30 @@ def read_udtf(pickleSer, infile, eval_type): message_parameters={"method_name": f.__name__, "error": str(e)}, ) + def check_return_value(res): + # Check whether the result of an arrow UDTF is iterable before + # using it to construct a pandas DataFrame. + if res is not None and not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, + ) + def evaluate(*args: pd.Series): if len(args) == 0: - yield verify_result(pd.DataFrame(func())), arrow_return_type + res = func() + check_return_value(res) + yield verify_result(pd.DataFrame(res)), arrow_return_type else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. row_tuples = zip(*args) for row in row_tuples: res = func(*row) - if res is not None and not isinstance(res, Iterable): - raise PySparkRuntimeError( - error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={ - "type": type(res).__name__, - }, - ) + check_return_value(res) yield verify_result(pd.DataFrame(res)), arrow_return_type return evaluate @@ -742,13 +752,17 @@ def read_udtf(pickleSer, infile, eval_type): message_parameters={ "expected": str(return_type_size), "actual": str(len(result)), + "func": f.__name__, }, ) if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): raise PySparkRuntimeError( error_class="UDTF_INVALID_OUTPUT_ROW_TYPE", - message_parameters={"type": type(result).__name__}, + message_parameters={ + "type": type(result).__name__, + "func": f.__name__, + }, ) return toInternal(result) @@ -772,7 +786,10 @@ def read_udtf(pickleSer, infile, eval_type): if not isinstance(res, Iterable): raise PySparkRuntimeError( error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={"type": type(res).__name__}, + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, ) # If the function returns a result, we map it to the internal representation and --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org