This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 642a62b6970f [SPARK-50224][SQL] The replacements of 
IsValidUTF8|ValidateUTF8|TryValidateUTF8|MakeValidUTF8 shall be NullIntolerant
642a62b6970f is described below

commit 642a62b6970fb30dba6f8d7d09eae4ea305bdb17
Author: Kent Yao <[email protected]>
AuthorDate: Tue Nov 5 19:57:52 2024 +0800

    [SPARK-50224][SQL] The replacements of 
IsValidUTF8|ValidateUTF8|TryValidateUTF8|MakeValidUTF8 shall be NullIntolerant
    
    ### What changes were proposed in this pull request?
    
    This PR makes replacements of 
IsValidUTF8|ValidateUTF8|TryValidateUTF8|MakeValidUTF8 functions to be 
NullIntolerant deriving from their origins so that we can actually construct 
IsNotNull constraints for them.
    
    This is also a common issue for other RuntimeReplaceable expressions, I 
will revisit them in groups. SPARK-50223.
    
    ### Why are the changes needed?
    
    Common strategy for performance improvement.
    
    ### Does this PR introduce _any_ user-facing change?
    NO
    
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #48758 from yaooqinn/SPARK-50224.
    
    Authored-by: Kent Yao <[email protected]>
    Signed-off-by: Kent Yao <[email protected]>
---
 .../sql/catalyst/expressions/objects/objects.scala | 43 ++++++++++++++++++++++
 .../catalyst/expressions/stringExpressions.scala   | 32 +++++++++-------
 .../apache/spark/sql/StringFunctionsSuite.scala    | 30 ++++++++++++++-
 3 files changed, 90 insertions(+), 15 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 5c786bc5ddbf..f49fd697492a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -412,6 +412,28 @@ case class StaticInvoke(
     }.$functionName(${arguments.mkString(", ")}))"
 }
 
+object StaticInvoke {
+  def withNullIntolerant(
+      staticObject: Class[_],
+      dataType: DataType,
+      functionName: String,
+      arguments: Seq[Expression] = Nil,
+      inputTypes: Seq[AbstractDataType] = Nil,
+      propagateNull: Boolean = true,
+      returnNullable: Boolean = true,
+      isDeterministic: Boolean = true,
+      scalarFunction: Option[ScalarFunction[_]] = None): StaticInvoke =
+    new StaticInvoke(
+      staticObject,
+      dataType,
+      functionName,
+      arguments,
+      inputTypes,
+      propagateNull,
+      returnNullable,
+      isDeterministic, scalarFunction) with NullIntolerant
+}
+
 /**
  * Calls the specified function on an object, optionally passing arguments.  
If the `targetObject`
  * expression evaluates to null then null will be returned.
@@ -555,6 +577,27 @@ case class Invoke(
     copy(targetObject = newChildren.head, arguments = newChildren.tail)
 }
 
+object Invoke {
+  def withNullIntolerant(
+      targetObject: Expression,
+      functionName: String,
+      dataType: DataType,
+      arguments: Seq[Expression] = Nil,
+      methodInputTypes: Seq[AbstractDataType] = Nil,
+      propagateNull: Boolean = true,
+      returnNullable: Boolean = true,
+      isDeterministic: Boolean = true): Invoke =
+    new Invoke(
+      targetObject,
+      functionName,
+      dataType,
+      arguments,
+      methodInputTypes,
+      propagateNull,
+      returnNullable,
+      isDeterministic) with NullIntolerant
+}
+
 object NewInstance {
   def apply(
       cls: Class[_],
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 2452da5d6968..8e8d3a957466 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -747,7 +747,8 @@ case class EndsWith(left: Expression, right: Expression) 
extends StringPredicate
 case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with 
ImplicitCastInputTypes
   with UnaryLike[Expression] with NullIntolerant {
 
-  override lazy val replacement: Expression = Invoke(input, "isValid", 
BooleanType)
+  override lazy val replacement: Expression =
+    Invoke.withNullIntolerant(input, "isValid", BooleanType)
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(StringTypeWithCollation(supportsTrimCollation = true))
@@ -795,7 +796,8 @@ case class IsValidUTF8(input: Expression) extends 
RuntimeReplaceable with Implic
 case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with 
ImplicitCastInputTypes
   with UnaryLike[Expression] with NullIntolerant {
 
-  override lazy val replacement: Expression = Invoke(input, "makeValid", 
input.dataType)
+  override lazy val replacement: Expression =
+    Invoke.withNullIntolerant(input, "makeValid", input.dataType)
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(StringTypeWithCollation(supportsTrimCollation = true))
@@ -836,12 +838,13 @@ case class MakeValidUTF8(input: Expression) extends 
RuntimeReplaceable with Impl
 case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with 
ImplicitCastInputTypes
   with UnaryLike[Expression] with NullIntolerant {
 
-  override lazy val replacement: Expression = StaticInvoke(
-    classOf[ExpressionImplUtils],
-    input.dataType,
-    "validateUTF8String",
-    Seq(input),
-    inputTypes)
+  override lazy val replacement: Expression =
+    StaticInvoke.withNullIntolerant(
+      classOf[ExpressionImplUtils],
+      input.dataType,
+      "validateUTF8String",
+      Seq(input),
+      inputTypes)
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(StringTypeWithCollation(supportsTrimCollation = true))
@@ -886,12 +889,13 @@ case class ValidateUTF8(input: Expression) extends 
RuntimeReplaceable with Impli
 case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with 
ImplicitCastInputTypes
   with UnaryLike[Expression] with NullIntolerant {
 
-  override lazy val replacement: Expression = StaticInvoke(
-    classOf[ExpressionImplUtils],
-    input.dataType,
-    "tryValidateUTF8String",
-    Seq(input),
-    inputTypes)
+  override lazy val replacement: Expression =
+    StaticInvoke.withNullIntolerant(
+      classOf[ExpressionImplUtils],
+      input.dataType,
+      "tryValidateUTF8String",
+      Seq(input),
+      inputTypes)
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(StringTypeWithCollation(supportsTrimCollation = true))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 54d9fdbf8c23..2e91d60e4ba0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
 
 import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, 
SparkRuntimeException}
 import org.apache.spark.sql.catalyst.expressions.Cast._
-import org.apache.spark.sql.execution.{FormattedMode, WholeStageCodegenExec}
+import org.apache.spark.sql.catalyst.expressions.IsNotNull
+import org.apache.spark.sql.execution.{FilterExec, FormattedMode, 
WholeStageCodegenExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -1424,4 +1425,31 @@ class StringFunctionsSuite extends QueryTest with 
SharedSparkSession {
       }
     }
   }
+
+  test("SPARK-50224: The replacement of validate utf8 functions should be 
NullIntolerant") {
+    def check(df: DataFrame, expected: Seq[Row]): Unit = {
+      val filter = df.queryExecution
+        .sparkPlan
+        .find(_.isInstanceOf[FilterExec])
+        .get.asInstanceOf[FilterExec]
+      assert(filter.condition.find(_.isInstanceOf[IsNotNull]).nonEmpty)
+      checkAnswer(df, expected)
+    }
+    withTable("test_table") {
+      sql("CREATE TABLE test_table" +
+        " AS SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl'), ('mno', 
NULL) T(a, b)")
+      check(
+        sql("SELECT * FROM test_table WHERE is_valid_utf8(b)"),
+        Seq(Row("abc", "def"), Row("ghi", "jkl")))
+      check(
+        sql("SELECT * FROM test_table WHERE make_valid_utf8(b) = 'def'"),
+        Seq(Row("abc", "def")))
+      check(
+        sql("SELECT * FROM test_table WHERE validate_utf8(b) = 'jkl'"),
+        Seq(Row("ghi", "jkl")))
+      check(
+        sql("SELECT * FROM test_table WHERE try_validate_utf8(b) = 'def'"),
+        Seq(Row("abc", "def")))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to