This is an automated email from the ASF dual-hosted git repository.
yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 642a62b6970f [SPARK-50224][SQL] The replacements of
IsValidUTF8|ValidateUTF8|TryValidateUTF8|MakeValidUTF8 shall be NullIntolerant
642a62b6970f is described below
commit 642a62b6970fb30dba6f8d7d09eae4ea305bdb17
Author: Kent Yao <[email protected]>
AuthorDate: Tue Nov 5 19:57:52 2024 +0800
[SPARK-50224][SQL] The replacements of
IsValidUTF8|ValidateUTF8|TryValidateUTF8|MakeValidUTF8 shall be NullIntolerant
### What changes were proposed in this pull request?
This PR makes replacements of
IsValidUTF8|ValidateUTF8|TryValidateUTF8|MakeValidUTF8 functions to be
NullIntolerant deriving from their origins so that we can actually construct
IsNotNull constraints for them.
This is also a common issue for other RuntimeReplaceable expressions, I
will revisit them in groups. SPARK-50223.
### Why are the changes needed?
Common strategy for performance improvement.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
new tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48758 from yaooqinn/SPARK-50224.
Authored-by: Kent Yao <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
---
.../sql/catalyst/expressions/objects/objects.scala | 43 ++++++++++++++++++++++
.../catalyst/expressions/stringExpressions.scala | 32 +++++++++-------
.../apache/spark/sql/StringFunctionsSuite.scala | 30 ++++++++++++++-
3 files changed, 90 insertions(+), 15 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 5c786bc5ddbf..f49fd697492a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -412,6 +412,28 @@ case class StaticInvoke(
}.$functionName(${arguments.mkString(", ")}))"
}
+object StaticInvoke {
+ def withNullIntolerant(
+ staticObject: Class[_],
+ dataType: DataType,
+ functionName: String,
+ arguments: Seq[Expression] = Nil,
+ inputTypes: Seq[AbstractDataType] = Nil,
+ propagateNull: Boolean = true,
+ returnNullable: Boolean = true,
+ isDeterministic: Boolean = true,
+ scalarFunction: Option[ScalarFunction[_]] = None): StaticInvoke =
+ new StaticInvoke(
+ staticObject,
+ dataType,
+ functionName,
+ arguments,
+ inputTypes,
+ propagateNull,
+ returnNullable,
+ isDeterministic, scalarFunction) with NullIntolerant
+}
+
/**
* Calls the specified function on an object, optionally passing arguments.
If the `targetObject`
* expression evaluates to null then null will be returned.
@@ -555,6 +577,27 @@ case class Invoke(
copy(targetObject = newChildren.head, arguments = newChildren.tail)
}
+object Invoke {
+ def withNullIntolerant(
+ targetObject: Expression,
+ functionName: String,
+ dataType: DataType,
+ arguments: Seq[Expression] = Nil,
+ methodInputTypes: Seq[AbstractDataType] = Nil,
+ propagateNull: Boolean = true,
+ returnNullable: Boolean = true,
+ isDeterministic: Boolean = true): Invoke =
+ new Invoke(
+ targetObject,
+ functionName,
+ dataType,
+ arguments,
+ methodInputTypes,
+ propagateNull,
+ returnNullable,
+ isDeterministic) with NullIntolerant
+}
+
object NewInstance {
def apply(
cls: Class[_],
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 2452da5d6968..8e8d3a957466 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -747,7 +747,8 @@ case class EndsWith(left: Expression, right: Expression)
extends StringPredicate
case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with
ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {
- override lazy val replacement: Expression = Invoke(input, "isValid",
BooleanType)
+ override lazy val replacement: Expression =
+ Invoke.withNullIntolerant(input, "isValid", BooleanType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
@@ -795,7 +796,8 @@ case class IsValidUTF8(input: Expression) extends
RuntimeReplaceable with Implic
case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with
ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {
- override lazy val replacement: Expression = Invoke(input, "makeValid",
input.dataType)
+ override lazy val replacement: Expression =
+ Invoke.withNullIntolerant(input, "makeValid", input.dataType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
@@ -836,12 +838,13 @@ case class MakeValidUTF8(input: Expression) extends
RuntimeReplaceable with Impl
case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with
ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {
- override lazy val replacement: Expression = StaticInvoke(
- classOf[ExpressionImplUtils],
- input.dataType,
- "validateUTF8String",
- Seq(input),
- inputTypes)
+ override lazy val replacement: Expression =
+ StaticInvoke.withNullIntolerant(
+ classOf[ExpressionImplUtils],
+ input.dataType,
+ "validateUTF8String",
+ Seq(input),
+ inputTypes)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
@@ -886,12 +889,13 @@ case class ValidateUTF8(input: Expression) extends
RuntimeReplaceable with Impli
case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with
ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {
- override lazy val replacement: Expression = StaticInvoke(
- classOf[ExpressionImplUtils],
- input.dataType,
- "tryValidateUTF8String",
- Seq(input),
- inputTypes)
+ override lazy val replacement: Expression =
+ StaticInvoke.withNullIntolerant(
+ classOf[ExpressionImplUtils],
+ input.dataType,
+ "tryValidateUTF8String",
+ Seq(input),
+ inputTypes)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 54d9fdbf8c23..2e91d60e4ba0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException,
SparkRuntimeException}
import org.apache.spark.sql.catalyst.expressions.Cast._
-import org.apache.spark.sql.execution.{FormattedMode, WholeStageCodegenExec}
+import org.apache.spark.sql.catalyst.expressions.IsNotNull
+import org.apache.spark.sql.execution.{FilterExec, FormattedMode,
WholeStageCodegenExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -1424,4 +1425,31 @@ class StringFunctionsSuite extends QueryTest with
SharedSparkSession {
}
}
}
+
+ test("SPARK-50224: The replacement of validate utf8 functions should be
NullIntolerant") {
+ def check(df: DataFrame, expected: Seq[Row]): Unit = {
+ val filter = df.queryExecution
+ .sparkPlan
+ .find(_.isInstanceOf[FilterExec])
+ .get.asInstanceOf[FilterExec]
+ assert(filter.condition.find(_.isInstanceOf[IsNotNull]).nonEmpty)
+ checkAnswer(df, expected)
+ }
+ withTable("test_table") {
+ sql("CREATE TABLE test_table" +
+ " AS SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl'), ('mno',
NULL) T(a, b)")
+ check(
+ sql("SELECT * FROM test_table WHERE is_valid_utf8(b)"),
+ Seq(Row("abc", "def"), Row("ghi", "jkl")))
+ check(
+ sql("SELECT * FROM test_table WHERE make_valid_utf8(b) = 'def'"),
+ Seq(Row("abc", "def")))
+ check(
+ sql("SELECT * FROM test_table WHERE validate_utf8(b) = 'jkl'"),
+ Seq(Row("ghi", "jkl")))
+ check(
+ sql("SELECT * FROM test_table WHERE try_validate_utf8(b) = 'def'"),
+ Seq(Row("abc", "def")))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]