This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6569f15490ec [SPARK-47002][PYTHON] Return better error message if UDTF 
'analyze' method 'orderBy' field accidentally returns a list of strings
6569f15490ec is described below

commit 6569f15490ec144b495b777d650c42ee4cc184d3
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Thu Feb 8 11:17:37 2024 -0800

    [SPARK-47002][PYTHON] Return better error message if UDTF 'analyze' method 
'orderBy' field accidentally returns a list of strings
    
    ### What changes were proposed in this pull request?
    
    This PR updates the Python UDTF API to check and return a better error 
message if the `analyze` method returns an `AnalyzeResult` object with an 
`orderBy` field erroneously set to a list of strings, rather than 
`OrderingColumn` instances.
    
    For example, this UDTF accidentally sets the `orderBy` field in this way:
    
    ```
    from pyspark.sql.functions import AnalyzeResult, OrderingColumn, 
PartitioningColumn
    from pyspark.sql.types import IntegerType, Row, StructType
    class Udtf:
        def __init__(self):
            self._partition_col = None
            self._count = 0
            self._sum = 0
            self._last = None
    
        staticmethod
        def analyze(row: Row):
            return AnalyzeResult(
                schema=StructType()
                    .add("user_id", IntegerType())
                    .add("count", IntegerType())
                    .add("total", IntegerType())
                    .add("last", IntegerType()),
                partitionBy=[
                    PartitioningColumn("user_id")
                ],
                orderBy=[
                    "timestamp"
                ],
                )
    
        def eval(self, row: Row):
            self._partition_col = row["partition_col"]
            self._count += 1
            self._last = row["input"]
            self._sum += row["input"]
    
        def terminate(self):
            yield self._partition_col, self._count, self._sum, self._last
    ```
    
    ### Why are the changes needed?
    
    This improves error messages and helps keep users from getting confused.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see above.
    
    ### How was this patch tested?
    
    This PR adds test coverage.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45062 from dtenedor/check-udtf-sort-columns.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/sql/worker/analyze_udtf.py          | 48 ++++++++++----------
 .../sql-tests/analyzer-results/udtf/udtf.sql.out   | 20 +++++++++
 .../test/resources/sql-tests/inputs/udtf/udtf.sql  |  1 +
 .../resources/sql-tests/results/udtf/udtf.sql.out  | 22 ++++++++++
 .../apache/spark/sql/IntegratedUDFTestUtils.scala  | 51 ++++++++++++----------
 5 files changed, 95 insertions(+), 47 deletions(-)

diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index f61330b806cd..a4b26e0bdc61 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -31,7 +31,7 @@ from pyspark.serializers import (
     write_with_length,
     SpecialLengths,
 )
-from pyspark.sql.functions import PartitioningColumn, SelectedColumn
+from pyspark.sql.functions import OrderingColumn, PartitioningColumn, 
SelectedColumn
 from pyspark.sql.types import _parse_datatype_json_string, StructType
 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
 from pyspark.util import handle_worker_exception
@@ -163,6 +163,18 @@ def main(infile: IO, outfile: IO) -> None:
                     but the 'schema' field had the wrong type: 
{type(result.schema)}"""
                 )
             )
+
+        def invalid_analyze_result_field(field_name: str, expected_field: str) 
-> PySparkValueError:
+            return PySparkValueError(
+                format_error(
+                    f"""
+                    {error_prefix} because the static 'analyze' method 
returned an
+                    'AnalyzeResult' object with the '{field_name}' field set 
to a value besides a
+                    list or tuple of '{expected_field}' objects. Please update 
the table function
+                    and then try the query again."""
+                )
+            )
+
         has_table_arg = any(arg.isTable for arg in args) or any(
             arg.isTable for arg in kwargs.values()
         )
@@ -190,32 +202,18 @@ def main(infile: IO, outfile: IO) -> None:
                     set to empty, and then try the query again."""
                 )
             )
-        elif isinstance(result.partitionBy, (list, tuple)) and (
-            len(result.partitionBy) > 0
-            and not all([isinstance(val, PartitioningColumn) for val in 
result.partitionBy])
+        elif not isinstance(result.partitionBy, (list, tuple)) or not all(
+            isinstance(val, PartitioningColumn) for val in result.partitionBy
         ):
-            raise PySparkValueError(
-                format_error(
-                    f"""
-                    {error_prefix} because the static 'analyze' method 
returned an
-                    'AnalyzeResult' object with the 'partitionBy' field set to 
a value besides a
-                    list or tuple of 'PartitioningColumn' objects. Please 
update the table function
-                    and then try the query again."""
-                )
-            )
-        elif isinstance(result.select, (list, tuple)) and (
-            len(result.select) > 0
-            and not all([isinstance(val, SelectedColumn) for val in 
result.select])
+            raise invalid_analyze_result_field("partitionBy", 
"PartitioningColumn")
+        elif not isinstance(result.orderBy, (list, tuple)) or not all(
+            isinstance(val, OrderingColumn) for val in result.orderBy
         ):
-            raise PySparkValueError(
-                format_error(
-                    f"""
-                    {error_prefix} because the static 'analyze' method 
returned an
-                    'AnalyzeResult' object with the 'select' field set to a 
value besides a
-                    list or tuple of 'SelectedColumn' objects. Please update 
the table function
-                    and then try the query again."""
-                )
-            )
+            raise invalid_analyze_result_field("orderBy", "OrderingColumn")
+        elif not isinstance(result.select, (list, tuple)) or not all(
+            isinstance(val, SelectedColumn) for val in result.select
+        ):
+            raise invalid_analyze_result_field("select", "SelectedColumn")
 
         # Return the analyzed schema.
         write_with_length(result.schema.json().encode("utf-8"), outfile)
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
index 8cf567c14b3b..cdfa4f69f6e7 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
@@ -383,6 +383,26 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+SELECT * FROM UDTFInvalidOrderByStringList(TABLE(t2))
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON",
+  "sqlState" : "38000",
+  "messageParameters" : {
+    "msg" : "Failed to evaluate the user-defined table function 
'UDTFInvalidOrderByStringList' because the static 'analyze' method returned an 
'AnalyzeResult' object with the 'orderBy' field set to a value besides a list 
or tuple of 'OrderingColumn' objects. Please update the table function and then 
try the query again."
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 15,
+    "stopIndex" : 53,
+    "fragment" : "UDTFInvalidOrderByStringList(TABLE(t2))"
+  } ]
+}
+
+
 -- !query
 SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2))
 -- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql 
b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
index 03d5b001d102..c83481f10dca 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
@@ -83,6 +83,7 @@ SELECT * FROM UDTFInvalidSelectExprParseError(TABLE(t2));
 SELECT * FROM UDTFInvalidSelectExprStringValue(TABLE(t2));
 SELECT * FROM UDTFInvalidComplexSelectExprMissingAlias(TABLE(t2));
 SELECT * FROM UDTFInvalidOrderByAscKeyword(TABLE(t2));
+SELECT * FROM UDTFInvalidOrderByStringList(TABLE(t2));
 -- As a reminder, UDTFInvalidPartitionByAndWithSinglePartition returns this 
analyze result:
 --     AnalyzeResult(
 --         schema=StructType()
diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out 
b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
index c85eacdc348b..78ad8b7c02cd 100644
--- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
@@ -464,6 +464,28 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+SELECT * FROM UDTFInvalidOrderByStringList(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON",
+  "sqlState" : "38000",
+  "messageParameters" : {
+    "msg" : "Failed to evaluate the user-defined table function 
'UDTFInvalidOrderByStringList' because the static 'analyze' method returned an 
'AnalyzeResult' object with the 'orderBy' field set to a value besides a list 
or tuple of 'OrderingColumn' objects. Please update the table function and then 
try the query again."
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 15,
+    "stopIndex" : 53,
+    "fragment" : "UDTFInvalidOrderByStringList(TABLE(t2))"
+  } ]
+}
+
+
 -- !query
 SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2))
 -- !query schema
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index 848c38c95da5..c1ca48162d20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -609,10 +609,10 @@ object IntegratedUDFTestUtils extends SQLHelper {
         |                .add("total", IntegerType())
         |                .add("last", IntegerType()),
         |            partitionBy=[
-        |                PartitioningColumn("$partitionBy")
+        |                $partitionBy
         |            ],
         |            orderBy=[
-        |                OrderingColumn("$orderBy")
+        |                $orderBy
         |            ],
         |            select=[
         |                $select
@@ -631,65 +631,71 @@ object IntegratedUDFTestUtils extends SQLHelper {
 
   object UDTFPartitionByOrderBy
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col",
-      orderBy = "input",
+      partitionBy = "PartitioningColumn(\"partition_col\")",
+      orderBy = "OrderingColumn(\"input\")",
       select = "")
 
   object UDTFPartitionByOrderByComplexExpr
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col + 1",
-      orderBy = "RANDOM(42)",
+      partitionBy = "PartitioningColumn(\"partition_col + 1\")",
+      orderBy = "OrderingColumn(\"RANDOM(42)\")",
       select = "")
 
   object UDTFPartitionByOrderBySelectExpr
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col",
-      orderBy = "input",
+      partitionBy = "PartitioningColumn(\"partition_col\")",
+      orderBy = "OrderingColumn(\"input\")",
       select = "SelectedColumn(\"partition_col\"), SelectedColumn(\"input\")")
 
   object UDTFPartitionByOrderBySelectComplexExpr
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col + 1",
-      orderBy = "RANDOM(42)",
+      partitionBy = "PartitioningColumn(\"partition_col + 1\")",
+      orderBy = "OrderingColumn(\"RANDOM(42)\")",
       select = "SelectedColumn(\"partition_col\"), " +
         "SelectedColumn(name=\"input + 1\", alias=\"input\")")
 
   object UDTFPartitionByOrderBySelectExprOnlyPartitionColumn
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col",
-      orderBy = "input",
+      partitionBy = "PartitioningColumn(\"partition_col\")",
+      orderBy = "OrderingColumn(\"input\")",
       select = "SelectedColumn(\"partition_col\")")
 
   object UDTFInvalidPartitionByOrderByParseError
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "unparsable",
-      orderBy = "input",
+      partitionBy = "PartitioningColumn(\"unparsable\")",
+      orderBy = "OrderingColumn(\"input\")",
       select = "")
 
   object UDTFInvalidOrderByAscKeyword
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col",
-      orderBy = "partition_col ASC",
+      partitionBy = "PartitioningColumn(\"partition_col\")",
+      orderBy = "OrderingColumn(\"partition_col ASC\")",
       select = "")
 
   object UDTFInvalidSelectExprParseError
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col",
-      orderBy = "input",
+      partitionBy = "PartitioningColumn(\"partition_col\")",
+      orderBy = "OrderingColumn(\"input\")",
       select = "SelectedColumn(\"unparsable\")")
 
   object UDTFInvalidSelectExprStringValue
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col",
-      orderBy = "input",
+      partitionBy = "PartitioningColumn(\"partition_col\")",
+      orderBy = "OrderingColumn(\"input\")",
       select = "\"partition_cll\"")
 
   object UDTFInvalidComplexSelectExprMissingAlias
     extends TestPythonUDTFPartitionByOrderByBase(
-      partitionBy = "partition_col + 1",
-      orderBy = "RANDOM(42)",
+      partitionBy = "PartitioningColumn(\"partition_col + 1\")",
+      orderBy = "OrderingColumn(\"RANDOM(42)\")",
       select = "SelectedColumn(name=\"input + 1\")")
 
+  object UDTFInvalidOrderByStringList
+    extends TestPythonUDTFPartitionByOrderByBase(
+      partitionBy = "PartitioningColumn(\"partition_col\")",
+      orderBy = "\"partition_col\"",
+      select = "")
+
   object UDTFInvalidPartitionByAndWithSinglePartition extends TestUDTF {
     val pythonScript: String =
       s"""
@@ -1197,6 +1203,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
     UDTFWithSinglePartition,
     UDTFPartitionByOrderBy,
     UDTFInvalidOrderByAscKeyword,
+    UDTFInvalidOrderByStringList,
     UDTFInvalidSelectExprParseError,
     UDTFInvalidSelectExprStringValue,
     UDTFInvalidComplexSelectExprMissingAlias,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to