This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 541e1c4da131 [SPARK-48197][SQL] Avoid assert error for invalid lambda 
function
541e1c4da131 is described below

commit 541e1c4da131ce737b9cf554028cf292bebbcf04
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Thu May 9 10:56:21 2024 +0800

    [SPARK-48197][SQL] Avoid assert error for invalid lambda function
    
    ### What changes were proposed in this pull request?
    
    `ExpressionBuilder` asserts all its input expressions to be resolved during 
lookup, which is not true as the analyzer rule `ResolveFunctions` can trigger 
function lookup even if the input expression contains unresolved lambda 
functions.
    
    This PR updates that assert to check non-lambda inputs only, and fail 
earlier if the input contains lambda functions. In the future, if we use 
`ExpressionBuilder` to register higher-order functions, we can relax it.
    
    ### Why are the changes needed?
    
    better error message
    
    ### Does this PR introduce _any_ user-facing change?
    
    no, only changes error message
    
    ### How was this patch tested?
    
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #46475 from cloud-fan/minor.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 7e79e91dc8c531ee9135f0e32a9aa2e1f80c4bbf)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/FunctionRegistry.scala   |  9 ++++++++-
 .../plans/logical/FunctionBuilderBase.scala        |  2 ++
 .../ansi/higher-order-functions.sql.out            | 20 ++++++++++++++++++++
 .../higher-order-functions.sql.out                 | 20 ++++++++++++++++++++
 .../sql-tests/inputs/higher-order-functions.sql    |  2 ++
 .../results/ansi/higher-order-functions.sql.out    | 22 ++++++++++++++++++++++
 .../results/higher-order-functions.sql.out         | 22 ++++++++++++++++++++++
 7 files changed, 96 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 558579cdb80a..aaf718fab941 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -930,7 +930,14 @@ object FunctionRegistry {
       since: Option[String] = None): (String, (ExpressionInfo, 
FunctionBuilder)) = {
     val info = FunctionRegistryBase.expressionInfo[T](name, since)
     val funcBuilder = (expressions: Seq[Expression]) => {
-      assert(expressions.forall(_.resolved), "function arguments must be 
resolved.")
+      val (lambdas, others) = 
expressions.partition(_.isInstanceOf[LambdaFunction])
+      if (lambdas.nonEmpty && !builder.supportsLambda) {
+        throw new AnalysisException(
+          errorClass = 
"INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+          messageParameters = Map(
+            "class" -> builder.getClass.getCanonicalName))
+      }
+      assert(others.forall(_.resolved), "function arguments must be resolved.")
       val rearrangedExpressions = rearrangeExpressions(name, builder, 
expressions)
       val expr = builder.build(name, rearrangedExpressions)
       if (setAlias) expr.setTagValue(FUNC_ALIAS, name)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
index 1088655f60cd..a901fa5a72c5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
@@ -69,6 +69,8 @@ trait FunctionBuilderBase[T] {
   }
 
   def build(funcName: String, expressions: Seq[Expression]): T
+
+  def supportsLambda: Boolean = false
 }
 
 object NamedParametersSupport {
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out
index 08d3be615b31..3fafb9858e5a 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out
@@ -34,6 +34,26 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+select ceil(x -> x) as v
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+  "sqlState" : "42K0D",
+  "messageParameters" : {
+    "class" : 
"org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 19,
+    "fragment" : "ceil(x -> x)"
+  } ]
+}
+
+
 -- !query
 select transform(zs, z -> z) as v from nested
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out
index f656716a843e..d9e88ac618aa 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out
@@ -34,6 +34,26 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+select ceil(x -> x) as v
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+  "sqlState" : "42K0D",
+  "messageParameters" : {
+    "class" : 
"org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 19,
+    "fragment" : "ceil(x -> x)"
+  } ]
+}
+
+
 -- !query
 select transform(zs, z -> z) as v from nested
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
index 7925a21de04c..37081de012e9 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
@@ -11,6 +11,8 @@ create or replace temporary view nested as values
 
 -- Only allow lambda's in higher order functions.
 select upper(x -> x) as v;
+-- Also test functions registered with `ExpressionBuilder`.
+select ceil(x -> x) as v;
 
 -- Identity transform an array
 select transform(zs, z -> z) as v from nested;
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
index e479b49463e7..eb9c454109f0 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
@@ -32,6 +32,28 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+select ceil(x -> x) as v
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+  "sqlState" : "42K0D",
+  "messageParameters" : {
+    "class" : 
"org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 19,
+    "fragment" : "ceil(x -> x)"
+  } ]
+}
+
+
 -- !query
 select transform(zs, z -> z) as v from nested
 -- !query schema
diff --git 
a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
index e479b49463e7..eb9c454109f0 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
@@ -32,6 +32,28 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+select ceil(x -> x) as v
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+  "sqlState" : "42K0D",
+  "messageParameters" : {
+    "class" : 
"org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 19,
+    "fragment" : "ceil(x -> x)"
+  } ]
+}
+
+
 -- !query
 select transform(zs, z -> z) as v from nested
 -- !query schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to