dtenedor commented on code in PR #43611:
URL: https://github.com/apache/spark/pull/43611#discussion_r1399572799


##########
python/pyspark/sql/worker/analyze_udtf.py:
##########
@@ -116,12 +118,94 @@ def main(infile: IO, outfile: IO) -> None:
         handler = read_udtf(infile)
         args, kwargs = read_arguments(infile)
 
+        error_prefix = f"Failed to evaluate the user-defined table function 
'{handler.__name__}'"
+
+        def format_error(msg: str) -> str:
+            return dedent(msg).replace("\n", " ")
+
+        # Check that the arguments provided to the UDTF call match the 
expected parameters defined
+        # in the static 'analyze' method signature.
+        try:
+            inspect.signature(handler.analyze).bind(*args, **kwargs)
+        except TypeError as e:
+            # The UDTF call's arguments did not match the expected signature.
+            raise PySparkValueError(
+                format_error(
+                    f"""
+                    {error_prefix} because the function arguments did not 
match the expected
+                    signature of the static 'analyze' method ({e}). Please 
update the query so that
+                    this table function call provides arguments matching the 
expected signature, or
+                    else update the table function so that its static 
'analyze' method accepts the
+                    provided arguments, and then try the query again."""
+                )
+            )
+
+        # Invoke the UDTF's 'analyze' method.
         result = handler.analyze(*args, **kwargs)  # type: ignore[attr-defined]
 
+        # Check invariants about the 'analyze' method after running it.
         if not isinstance(result, AnalyzeResult):
             raise PySparkValueError(
-                "Output of `analyze` static method of Python UDTFs expects "
-                f"a pyspark.sql.udtf.AnalyzeResult but got: {type(result)}"
+                format_error(
+                    f"""
+                    {error_prefix} because the static 'analyze' method expects 
a result of type
+                    pyspark.sql.udtf.AnalyzeResult, but instead this method 
returned a value of
+                    type: {type(result)}"""
+                )
+            )
+        elif not isinstance(result.schema, StructType):
+            raise PySparkValueError(
+                format_error(
+                    f"""
+                    {error_prefix} because the static 'analyze' method expects 
a result of type
+                    pyspark.sql.udtf.AnalyzeResult with a 'schema' field 
comprising a StructType,
+                    but the 'schema' field had the wrong type: 
{type(result.schema)}"""
+                )
+            )
+        has_table_arg = (
+            len([arg for arg in args if arg.isTable])
+            + len([arg for arg in kwargs.items() if arg[-1].isTable])
+        ) > 0
+        if not has_table_arg and result.withSinglePartition:
+            raise PySparkValueError(
+                format_error(
+                    f"""
+                    {error_prefix} because the static 'analyze' method 
returned an
+                    'AnalyzeResult' object with the 'withSinglePartition' 
field set to 'true', but
+                    the function call did not provide any table argument. 
Please update the query so
+                    that it provides a table argument, or else update the 
table function so that its
+                    'analyze' method returns an 'AnalyzeResult' object with the
+                    'withSinglePartition' field set to 'false', and then try 
the query again."""
+                )
+            )
+        elif not has_table_arg and len(result.partitionBy) > 0:
+            raise PySparkValueError(
+                format_error(
+                    f"""
+                    {error_prefix} because the static 'analyze' method 
returned an
+                    'AnalyzeResult' object with the 'partitionBy' list set to 
non-empty, but the
+                    function call did not provide any table argument. Please 
update the query so
+                    that it provides a table argument, or else update the 
table function so that its
+                    'analyze' method returns an 'AnalyzeResult' object with 
the 'partitionBy' list
+                    set to empty, and then try the query again."""
+                )
+            )
+        elif (
+            hasattr(result, "partitionBy")

Review Comment:
   Good point, removed this check.



##########
python/pyspark/sql/worker/analyze_udtf.py:
##########
@@ -116,12 +118,94 @@ def main(infile: IO, outfile: IO) -> None:
         handler = read_udtf(infile)
         args, kwargs = read_arguments(infile)
 
+        error_prefix = f"Failed to evaluate the user-defined table function 
'{handler.__name__}'"
+
+        def format_error(msg: str) -> str:
+            return dedent(msg).replace("\n", " ")
+
+        # Check that the arguments provided to the UDTF call match the 
expected parameters defined
+        # in the static 'analyze' method signature.
+        try:
+            inspect.signature(handler.analyze).bind(*args, **kwargs)
+        except TypeError as e:
+            # The UDTF call's arguments did not match the expected signature.
+            raise PySparkValueError(
+                format_error(
+                    f"""
+                    {error_prefix} because the function arguments did not 
match the expected
+                    signature of the static 'analyze' method ({e}). Please 
update the query so that
+                    this table function call provides arguments matching the 
expected signature, or
+                    else update the table function so that its static 
'analyze' method accepts the
+                    provided arguments, and then try the query again."""
+                )
+            )
+
+        # Invoke the UDTF's 'analyze' method.
         result = handler.analyze(*args, **kwargs)  # type: ignore[attr-defined]
 
+        # Check invariants about the 'analyze' method after running it.
         if not isinstance(result, AnalyzeResult):
             raise PySparkValueError(
-                "Output of `analyze` static method of Python UDTFs expects "
-                f"a pyspark.sql.udtf.AnalyzeResult but got: {type(result)}"
+                format_error(
+                    f"""
+                    {error_prefix} because the static 'analyze' method expects 
a result of type
+                    pyspark.sql.udtf.AnalyzeResult, but instead this method 
returned a value of
+                    type: {type(result)}"""
+                )
+            )
+        elif not isinstance(result.schema, StructType):
+            raise PySparkValueError(
+                format_error(
+                    f"""
+                    {error_prefix} because the static 'analyze' method expects 
a result of type
+                    pyspark.sql.udtf.AnalyzeResult with a 'schema' field 
comprising a StructType,
+                    but the 'schema' field had the wrong type: 
{type(result.schema)}"""
+                )
+            )
+        has_table_arg = (
+            len([arg for arg in args if arg.isTable])
+            + len([arg for arg in kwargs.items() if arg[-1].isTable])
+        ) > 0

Review Comment:
   Sounds good, done (also re-ran `dev/reformat-python` afterwards).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to