This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new cf3d10156ff [SPARK-44930][SQL] Deterministic ApplyFunctionExpression
should be foldable
cf3d10156ff is described below
commit cf3d10156ff86e1b0c27cdc5706b345e84867cf9
Author: xianyangliu <[email protected]>
AuthorDate: Fri Aug 25 15:01:35 2023 +0800
[SPARK-44930][SQL] Deterministic ApplyFunctionExpression should be foldable
### What changes were proposed in this pull request?
Currently, ApplyFunctionExpression is unfoldable because inherits the
default value from Expression. However, it should be foldable for a
deterministic ApplyFunctionExpression.
### Why are the changes needed?
This could help optimize the usage for V2 UDF applying to constant
expressions.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New UT.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #42629 from ConeyLiu/constant-fold-v2-udf.
Authored-by: xianyangliu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 994389f42a40d292a72482e3d76d29bada82d8ec)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/ApplyFunctionExpression.scala | 1 +
.../sql/connector/DataSourceV2FunctionSuite.scala | 22 ++++++++++++----------
2 files changed, 13 insertions(+), 10 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
index da4000f53e3..a1815cf3b3d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
@@ -33,6 +33,7 @@ case class ApplyFunctionExpression(
override def inputTypes: Seq[AbstractDataType] = function.inputTypes().toSeq
override lazy val deterministic: Boolean = function.isDeterministic &&
children.forall(_.deterministic)
+ override def foldable: Boolean = deterministic && children.forall(_.foldable)
private lazy val reusedRow = new SpecificInternalRow(function.inputTypes())
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index 32391eac9a8..b74d7318a92 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -24,7 +24,6 @@ import
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd._
import test.org.apache.spark.sql.connector.catalog.functions.JavaRandomAdd._
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
-import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import
org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK,
NO_CODEGEN}
@@ -322,14 +321,7 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
test("scalar function: bad magic method") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagic))
- // TODO assign a error-classes name
- checkError(
- exception = intercept[SparkException] {
- sql("SELECT testcat.ns.strlen('abc')").collect()
- },
- errorClass = null,
- parameters = Map.empty
- )
+ intercept[UnsupportedOperationException](sql("SELECT
testcat.ns.strlen('abc')").collect())
}
test("scalar function: bad magic method with default impl") {
@@ -341,7 +333,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase
{
test("scalar function: no implementation found") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenNoImpl))
- intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect())
+ intercept[UnsupportedOperationException](sql("SELECT
testcat.ns.strlen('abc')").collect())
}
test("scalar function: invalid parameter type or length") {
@@ -688,6 +680,16 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
}
}
+ test("SPARK-44930: Fold deterministic ApplyFunctionExpression") {
+
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
+
+ val df1 = sql("SELECT testcat.ns.strlen('abc') as col1")
+ val df2 = sql("SELECT 3 as col1")
+ comparePlans(df1.queryExecution.optimizedPlan,
df2.queryExecution.optimizedPlan)
+ checkAnswer(df1, Row(3) :: Nil)
+ }
+
private case object StrLenDefault extends ScalarFunction[Int] {
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]