This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fa3ef03a0734 [SPARK-47463][SQL] Use V2Predicate to wrap expression 
with return type of boolean
fa3ef03a0734 is described below

commit fa3ef03a073407966765544c936a9c65401e955a
Author: Zhen Wang <643348...@qq.com>
AuthorDate: Tue Apr 16 13:41:53 2024 +0800

    [SPARK-47463][SQL] Use V2Predicate to wrap expression with return type of 
boolean
    
    ### What changes were proposed in this pull request?
    
    Use V2Predicate to wrap If expr when building v2 expressions.
    
    ### Why are the changes needed?
    
    The `PushFoldableIntoBranches` optimizer may fold predicate into (if / 
case) branches and `V2ExpressionBuilder` wraps `If` as 
`GeneralScalarExpression`, which causes the assertion in 
`PushablePredicate.unapply` to fail.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    added unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45589 from wForget/SPARK-47463.
    
    Authored-by: Zhen Wang <643348...@qq.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/util/V2ExpressionBuilder.scala    | 159 +++++++++++----------
 .../spark/sql/connector/DataSourceV2Suite.scala    |  10 ++
 2 files changed, 97 insertions(+), 72 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 3942d193a328..398f21e01b80 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => 
V2Cast, Expression =>
 import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, 
UserDefinedAggregateFunc}
 import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, 
AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
 import org.apache.spark.sql.execution.datasources.PushableExpression
-import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}
+import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, 
StringType}
 
 /**
  * The builder to generate V2 expressions from catalyst expressions.
@@ -98,45 +98,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
       generateExpression(child).map(v => new V2Cast(v, dataType))
     case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) 
=>
       generateAggregateFunc(aggregateFunction, isDistinct)
-    case Abs(child, true) => generateExpressionWithName("ABS", Seq(child))
-    case Coalesce(children) => generateExpressionWithName("COALESCE", children)
-    case Greatest(children) => generateExpressionWithName("GREATEST", children)
-    case Least(children) => generateExpressionWithName("LEAST", children)
-    case Rand(child, hideSeed) =>
+    case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate)
+    case _: Coalesce => generateExpressionWithName("COALESCE", expr, 
isPredicate)
+    case _: Greatest => generateExpressionWithName("GREATEST", expr, 
isPredicate)
+    case _: Least => generateExpressionWithName("LEAST", expr, isPredicate)
+    case Rand(_, hideSeed) =>
       if (hideSeed) {
         Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression]))
       } else {
-        generateExpressionWithName("RAND", Seq(child))
+        generateExpressionWithName("RAND", expr, isPredicate)
       }
-    case log: Logarithm => generateExpressionWithName("LOG", log.children)
-    case Log10(child) => generateExpressionWithName("LOG10", Seq(child))
-    case Log2(child) => generateExpressionWithName("LOG2", Seq(child))
-    case Log(child) => generateExpressionWithName("LN", Seq(child))
-    case Exp(child) => generateExpressionWithName("EXP", Seq(child))
-    case pow: Pow => generateExpressionWithName("POWER", pow.children)
-    case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child))
-    case Floor(child) => generateExpressionWithName("FLOOR", Seq(child))
-    case Ceil(child) => generateExpressionWithName("CEIL", Seq(child))
-    case round: Round => generateExpressionWithName("ROUND", round.children)
-    case Sin(child) => generateExpressionWithName("SIN", Seq(child))
-    case Sinh(child) => generateExpressionWithName("SINH", Seq(child))
-    case Cos(child) => generateExpressionWithName("COS", Seq(child))
-    case Cosh(child) => generateExpressionWithName("COSH", Seq(child))
-    case Tan(child) => generateExpressionWithName("TAN", Seq(child))
-    case Tanh(child) => generateExpressionWithName("TANH", Seq(child))
-    case Cot(child) => generateExpressionWithName("COT", Seq(child))
-    case Asin(child) => generateExpressionWithName("ASIN", Seq(child))
-    case Asinh(child) => generateExpressionWithName("ASINH", Seq(child))
-    case Acos(child) => generateExpressionWithName("ACOS", Seq(child))
-    case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child))
-    case Atan(child) => generateExpressionWithName("ATAN", Seq(child))
-    case Atanh(child) => generateExpressionWithName("ATANH", Seq(child))
-    case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children)
-    case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child))
-    case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child))
-    case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child))
-    case Signum(child) => generateExpressionWithName("SIGN", Seq(child))
-    case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", 
wb.children)
+    case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate)
+    case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate)
+    case _: Log2 => generateExpressionWithName("LOG2", expr, isPredicate)
+    case _: Log => generateExpressionWithName("LN", expr, isPredicate)
+    case _: Exp => generateExpressionWithName("EXP", expr, isPredicate)
+    case _: Pow => generateExpressionWithName("POWER", expr, isPredicate)
+    case _: Sqrt => generateExpressionWithName("SQRT", expr, isPredicate)
+    case _: Floor => generateExpressionWithName("FLOOR", expr, isPredicate)
+    case _: Ceil => generateExpressionWithName("CEIL", expr, isPredicate)
+    case _: Round => generateExpressionWithName("ROUND", expr, isPredicate)
+    case _: Sin => generateExpressionWithName("SIN", expr, isPredicate)
+    case _: Sinh => generateExpressionWithName("SINH", expr, isPredicate)
+    case _: Cos => generateExpressionWithName("COS", expr, isPredicate)
+    case _: Cosh => generateExpressionWithName("COSH", expr, isPredicate)
+    case _: Tan => generateExpressionWithName("TAN", expr, isPredicate)
+    case _: Tanh => generateExpressionWithName("TANH", expr, isPredicate)
+    case _: Cot => generateExpressionWithName("COT", expr, isPredicate)
+    case _: Asin => generateExpressionWithName("ASIN", expr, isPredicate)
+    case _: Asinh => generateExpressionWithName("ASINH", expr, isPredicate)
+    case _: Acos => generateExpressionWithName("ACOS", expr, isPredicate)
+    case _: Acosh => generateExpressionWithName("ACOSH", expr, isPredicate)
+    case _: Atan => generateExpressionWithName("ATAN", expr, isPredicate)
+    case _: Atanh => generateExpressionWithName("ATANH", expr, isPredicate)
+    case _: Atan2 => generateExpressionWithName("ATAN2", expr, isPredicate)
+    case _: Cbrt => generateExpressionWithName("CBRT", expr, isPredicate)
+    case _: ToDegrees => generateExpressionWithName("DEGREES", expr, 
isPredicate)
+    case _: ToRadians => generateExpressionWithName("RADIANS", expr, 
isPredicate)
+    case _: Signum => generateExpressionWithName("SIGN", expr, isPredicate)
+    case _: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", expr, 
isPredicate)
     case and: And =>
       // AND expects predicate
       val l = generateExpression(and.left, true)
@@ -187,57 +187,56 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
         assert(v.isInstanceOf[V2Predicate])
         new V2Not(v.asInstanceOf[V2Predicate])
       }
-    case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child))
-    case BitwiseNot(child) => generateExpressionWithName("~", Seq(child))
-    case CaseWhen(branches, elseValue) =>
+    case UnaryMinus(_, true) => generateExpressionWithName("-", expr, 
isPredicate)
+    case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
+    case caseWhen @ CaseWhen(branches, elseValue) =>
       val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
-      val values = branches.map(_._2).flatMap(generateExpression(_, true))
-      if (conditions.length == branches.length && values.length == 
branches.length) {
+      val values = branches.map(_._2).flatMap(generateExpression(_))
+      val elseExprOpt = elseValue.flatMap(generateExpression(_))
+      if (conditions.length == branches.length && values.length == 
branches.length &&
+          elseExprOpt.size == elseValue.size) {
         val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
           Seq[V2Expression](c, v)
         }
-        if (elseValue.isDefined) {
-          elseValue.flatMap(generateExpression(_)).map { v =>
-            val children = (branchExpressions :+ v).toArray[V2Expression]
-            // The children looks like [condition1, value1, ..., conditionN, 
valueN, elseValue]
-            new V2Predicate("CASE_WHEN", children)
-          }
+        val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression]
+        // The children looks like [condition1, value1, ..., conditionN, 
valueN (, elseValue)]
+        if (isPredicate && caseWhen.dataType.isInstanceOf[BooleanType]) {
+          Some(new V2Predicate("CASE_WHEN", children))
         } else {
-          // The children looks like [condition1, value1, ..., conditionN, 
valueN]
-          Some(new V2Predicate("CASE_WHEN", 
branchExpressions.toArray[V2Expression]))
+          Some(new GeneralScalarExpression("CASE_WHEN", children))
         }
       } else {
         None
       }
-    case iff: If => generateExpressionWithName("CASE_WHEN", iff.children)
+    case _: If => generateExpressionWithName("CASE_WHEN", expr, isPredicate)
     case substring: Substring =>
       val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
         Seq(substring.str, substring.pos)
       } else {
         substring.children
       }
-      generateExpressionWithName("SUBSTRING", children)
-    case Upper(child) => generateExpressionWithName("UPPER", Seq(child))
-    case Lower(child) => generateExpressionWithName("LOWER", Seq(child))
+      generateExpressionWithNameByChildren("SUBSTRING", children, 
substring.dataType, isPredicate)
+    case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate)
+    case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate)
     case BitLength(child) if child.dataType.isInstanceOf[StringType] =>
-      generateExpressionWithName("BIT_LENGTH", Seq(child))
+      generateExpressionWithName("BIT_LENGTH", expr, isPredicate)
     case Length(child) if child.dataType.isInstanceOf[StringType] =>
-      generateExpressionWithName("CHAR_LENGTH", Seq(child))
-    case concat: Concat => generateExpressionWithName("CONCAT", 
concat.children)
-    case translate: StringTranslate => generateExpressionWithName("TRANSLATE", 
translate.children)
-    case trim: StringTrim => generateExpressionWithName("TRIM", trim.children)
-    case trim: StringTrimLeft => generateExpressionWithName("LTRIM", 
trim.children)
-    case trim: StringTrimRight => generateExpressionWithName("RTRIM", 
trim.children)
+      generateExpressionWithName("CHAR_LENGTH", expr, isPredicate)
+    case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate)
+    case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, 
isPredicate)
+    case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate)
+    case _: StringTrimLeft => generateExpressionWithName("LTRIM", expr, 
isPredicate)
+    case _: StringTrimRight => generateExpressionWithName("RTRIM", expr, 
isPredicate)
     case overlay: Overlay =>
       val children = if (overlay.len == Literal(-1)) {
         Seq(overlay.input, overlay.replace, overlay.pos)
       } else {
         overlay.children
       }
-      generateExpressionWithName("OVERLAY", children)
-    case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children)
-    case date: DateDiff => generateExpressionWithName("DATE_DIFF", 
date.children)
-    case date: TruncDate => generateExpressionWithName("TRUNC", date.children)
+      generateExpressionWithNameByChildren("OVERLAY", children, 
overlay.dataType, isPredicate)
+    case _: DateAdd => generateExpressionWithName("DATE_ADD", expr, 
isPredicate)
+    case _: DateDiff => generateExpressionWithName("DATE_DIFF", expr, 
isPredicate)
+    case _: TruncDate => generateExpressionWithName("TRUNC", expr, isPredicate)
     case Second(child, _) =>
       generateExpression(child).map(v => new V2Extract("SECOND", v))
     case Minute(child, _) =>
@@ -270,12 +269,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
       generateExpression(child).map(v => new V2Extract("WEEK", v))
     case YearOfWeek(child) =>
       generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v))
-    case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", 
encrypt.children)
-    case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", 
decrypt.children)
-    case Crc32(child) => generateExpressionWithName("CRC32", Seq(child))
-    case Md5(child) => generateExpressionWithName("MD5", Seq(child))
-    case Sha1(child) => generateExpressionWithName("SHA1", Seq(child))
-    case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children)
+    case _: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", expr, 
isPredicate)
+    case _: AesDecrypt => generateExpressionWithName("AES_DECRYPT", expr, 
isPredicate)
+    case _: Crc32 => generateExpressionWithName("CRC32", expr, isPredicate)
+    case _: Md5 => generateExpressionWithName("MD5", expr, isPredicate)
+    case _: Sha1 => generateExpressionWithName("SHA1", expr, isPredicate)
+    case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate)
     // TODO supports other expressions
     case ApplyFunctionExpression(function, children) =>
       val childrenExpressions = children.flatMap(generateExpression(_))
@@ -380,10 +379,26 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
   }
 
   private def generateExpressionWithName(
-      v2ExpressionName: String, children: Seq[Expression]): 
Option[V2Expression] = {
+      v2ExpressionName: String,
+      expr: Expression,
+      isPredicate: Boolean): Option[V2Expression] = {
+    generateExpressionWithNameByChildren(
+      v2ExpressionName, expr.children, expr.dataType, isPredicate)
+  }
+
+  private def generateExpressionWithNameByChildren(
+      v2ExpressionName: String,
+      children: Seq[Expression],
+      dataType: DataType,
+      isPredicate: Boolean): Option[V2Expression] = {
     val childrenExpressions = children.flatMap(generateExpression(_))
     if (childrenExpressions.length == children.length) {
-      Some(new GeneralScalarExpression(v2ExpressionName, 
childrenExpressions.toArray[V2Expression]))
+      if (isPredicate && dataType.isInstanceOf[BooleanType]) {
+        Some(new V2Predicate(v2ExpressionName, 
childrenExpressions.toArray[V2Expression]))
+      } else {
+        Some(new GeneralScalarExpression(
+          v2ExpressionName, childrenExpressions.toArray[V2Expression]))
+      }
     } else {
       None
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index a7fb2c054e80..1de535df246b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -966,6 +966,16 @@ class DataSourceV2Suite extends QueryTest with 
SharedSparkSession with AdaptiveS
       )
     }
   }
+
+  test("SPARK-47463: Pushed down v2 filter with if expression") {
+    withTempView("t1") {
+      
spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load()
+        .createTempView("t1")
+      val df = sql("SELECT * FROM  t1 WHERE if(i = 1, i, 0) > 0")
+      val result = df.collect()
+      assert(result.length == 1)
+    }
+  }
 }
 
 case class RangeInputPartition(start: Int, end: Int) extends InputPartition


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to