This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fa3ef03a0734 [SPARK-47463][SQL] Use V2Predicate to wrap expression with return type of boolean fa3ef03a0734 is described below commit fa3ef03a073407966765544c936a9c65401e955a Author: Zhen Wang <643348...@qq.com> AuthorDate: Tue Apr 16 13:41:53 2024 +0800 [SPARK-47463][SQL] Use V2Predicate to wrap expression with return type of boolean ### What changes were proposed in this pull request? Use V2Predicate to wrap If expr when building v2 expressions. ### Why are the changes needed? The `PushFoldableIntoBranches` optimizer may fold predicate into (if / case) branches and `V2ExpressionBuilder` wraps `If` as `GeneralScalarExpression`, which causes the assertion in `PushablePredicate.unapply` to fail. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? added unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #45589 from wForget/SPARK-47463. Authored-by: Zhen Wang <643348...@qq.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/util/V2ExpressionBuilder.scala | 159 +++++++++++---------- .../spark/sql/connector/DataSourceV2Suite.scala | 10 ++ 2 files changed, 97 insertions(+), 72 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 3942d193a328..398f21e01b80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableExpression -import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType} +import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType} /** * The builder to generate V2 expressions from catalyst expressions. @@ -98,45 +98,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Cast(v, dataType)) case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => generateAggregateFunc(aggregateFunction, isDistinct) - case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) - case Coalesce(children) => generateExpressionWithName("COALESCE", children) - case Greatest(children) => generateExpressionWithName("GREATEST", children) - case Least(children) => generateExpressionWithName("LEAST", children) - case Rand(child, hideSeed) => + case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate) + case _: Coalesce => generateExpressionWithName("COALESCE", expr, isPredicate) + case _: Greatest => generateExpressionWithName("GREATEST", expr, isPredicate) + case _: Least => generateExpressionWithName("LEAST", expr, isPredicate) + case Rand(_, hideSeed) => if (hideSeed) { Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression])) } else { - generateExpressionWithName("RAND", Seq(child)) + generateExpressionWithName("RAND", expr, isPredicate) } - case log: Logarithm => generateExpressionWithName("LOG", log.children) - case Log10(child) => generateExpressionWithName("LOG10", Seq(child)) - case Log2(child) => generateExpressionWithName("LOG2", Seq(child)) - case Log(child) => generateExpressionWithName("LN", Seq(child)) - case Exp(child) => generateExpressionWithName("EXP", Seq(child)) - case pow: Pow => generateExpressionWithName("POWER", pow.children) - case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child)) - case Floor(child) => generateExpressionWithName("FLOOR", Seq(child)) - case Ceil(child) => generateExpressionWithName("CEIL", Seq(child)) - case round: Round => generateExpressionWithName("ROUND", round.children) - case Sin(child) => generateExpressionWithName("SIN", Seq(child)) - case Sinh(child) => generateExpressionWithName("SINH", Seq(child)) - case Cos(child) => generateExpressionWithName("COS", Seq(child)) - case Cosh(child) => generateExpressionWithName("COSH", Seq(child)) - case Tan(child) => generateExpressionWithName("TAN", Seq(child)) - case Tanh(child) => generateExpressionWithName("TANH", Seq(child)) - case Cot(child) => generateExpressionWithName("COT", Seq(child)) - case Asin(child) => generateExpressionWithName("ASIN", Seq(child)) - case Asinh(child) => generateExpressionWithName("ASINH", Seq(child)) - case Acos(child) => generateExpressionWithName("ACOS", Seq(child)) - case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child)) - case Atan(child) => generateExpressionWithName("ATAN", Seq(child)) - case Atanh(child) => generateExpressionWithName("ATANH", Seq(child)) - case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children) - case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child)) - case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child)) - case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child)) - case Signum(child) => generateExpressionWithName("SIGN", Seq(child)) - case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children) + case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate) + case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate) + case _: Log2 => generateExpressionWithName("LOG2", expr, isPredicate) + case _: Log => generateExpressionWithName("LN", expr, isPredicate) + case _: Exp => generateExpressionWithName("EXP", expr, isPredicate) + case _: Pow => generateExpressionWithName("POWER", expr, isPredicate) + case _: Sqrt => generateExpressionWithName("SQRT", expr, isPredicate) + case _: Floor => generateExpressionWithName("FLOOR", expr, isPredicate) + case _: Ceil => generateExpressionWithName("CEIL", expr, isPredicate) + case _: Round => generateExpressionWithName("ROUND", expr, isPredicate) + case _: Sin => generateExpressionWithName("SIN", expr, isPredicate) + case _: Sinh => generateExpressionWithName("SINH", expr, isPredicate) + case _: Cos => generateExpressionWithName("COS", expr, isPredicate) + case _: Cosh => generateExpressionWithName("COSH", expr, isPredicate) + case _: Tan => generateExpressionWithName("TAN", expr, isPredicate) + case _: Tanh => generateExpressionWithName("TANH", expr, isPredicate) + case _: Cot => generateExpressionWithName("COT", expr, isPredicate) + case _: Asin => generateExpressionWithName("ASIN", expr, isPredicate) + case _: Asinh => generateExpressionWithName("ASINH", expr, isPredicate) + case _: Acos => generateExpressionWithName("ACOS", expr, isPredicate) + case _: Acosh => generateExpressionWithName("ACOSH", expr, isPredicate) + case _: Atan => generateExpressionWithName("ATAN", expr, isPredicate) + case _: Atanh => generateExpressionWithName("ATANH", expr, isPredicate) + case _: Atan2 => generateExpressionWithName("ATAN2", expr, isPredicate) + case _: Cbrt => generateExpressionWithName("CBRT", expr, isPredicate) + case _: ToDegrees => generateExpressionWithName("DEGREES", expr, isPredicate) + case _: ToRadians => generateExpressionWithName("RADIANS", expr, isPredicate) + case _: Signum => generateExpressionWithName("SIGN", expr, isPredicate) + case _: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", expr, isPredicate) case and: And => // AND expects predicate val l = generateExpression(and.left, true) @@ -187,57 +187,56 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { assert(v.isInstanceOf[V2Predicate]) new V2Not(v.asInstanceOf[V2Predicate]) } - case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child)) - case BitwiseNot(child) => generateExpressionWithName("~", Seq(child)) - case CaseWhen(branches, elseValue) => + case UnaryMinus(_, true) => generateExpressionWithName("-", expr, isPredicate) + case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) + case caseWhen @ CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) - val values = branches.map(_._2).flatMap(generateExpression(_, true)) - if (conditions.length == branches.length && values.length == branches.length) { + val values = branches.map(_._2).flatMap(generateExpression(_)) + val elseExprOpt = elseValue.flatMap(generateExpression(_)) + if (conditions.length == branches.length && values.length == branches.length && + elseExprOpt.size == elseValue.size) { val branchExpressions = conditions.zip(values).flatMap { case (c, v) => Seq[V2Expression](c, v) } - if (elseValue.isDefined) { - elseValue.flatMap(generateExpression(_)).map { v => - val children = (branchExpressions :+ v).toArray[V2Expression] - // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue] - new V2Predicate("CASE_WHEN", children) - } + val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression] + // The children looks like [condition1, value1, ..., conditionN, valueN (, elseValue)] + if (isPredicate && caseWhen.dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate("CASE_WHEN", children)) } else { - // The children looks like [condition1, value1, ..., conditionN, valueN] - Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression])) + Some(new GeneralScalarExpression("CASE_WHEN", children)) } } else { None } - case iff: If => generateExpressionWithName("CASE_WHEN", iff.children) + case _: If => generateExpressionWithName("CASE_WHEN", expr, isPredicate) case substring: Substring => val children = if (substring.len == Literal(Integer.MAX_VALUE)) { Seq(substring.str, substring.pos) } else { substring.children } - generateExpressionWithName("SUBSTRING", children) - case Upper(child) => generateExpressionWithName("UPPER", Seq(child)) - case Lower(child) => generateExpressionWithName("LOWER", Seq(child)) + generateExpressionWithNameByChildren("SUBSTRING", children, substring.dataType, isPredicate) + case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate) + case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate) case BitLength(child) if child.dataType.isInstanceOf[StringType] => - generateExpressionWithName("BIT_LENGTH", Seq(child)) + generateExpressionWithName("BIT_LENGTH", expr, isPredicate) case Length(child) if child.dataType.isInstanceOf[StringType] => - generateExpressionWithName("CHAR_LENGTH", Seq(child)) - case concat: Concat => generateExpressionWithName("CONCAT", concat.children) - case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children) - case trim: StringTrim => generateExpressionWithName("TRIM", trim.children) - case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children) - case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children) + generateExpressionWithName("CHAR_LENGTH", expr, isPredicate) + case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate) + case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, isPredicate) + case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate) + case _: StringTrimLeft => generateExpressionWithName("LTRIM", expr, isPredicate) + case _: StringTrimRight => generateExpressionWithName("RTRIM", expr, isPredicate) case overlay: Overlay => val children = if (overlay.len == Literal(-1)) { Seq(overlay.input, overlay.replace, overlay.pos) } else { overlay.children } - generateExpressionWithName("OVERLAY", children) - case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children) - case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children) - case date: TruncDate => generateExpressionWithName("TRUNC", date.children) + generateExpressionWithNameByChildren("OVERLAY", children, overlay.dataType, isPredicate) + case _: DateAdd => generateExpressionWithName("DATE_ADD", expr, isPredicate) + case _: DateDiff => generateExpressionWithName("DATE_DIFF", expr, isPredicate) + case _: TruncDate => generateExpressionWithName("TRUNC", expr, isPredicate) case Second(child, _) => generateExpression(child).map(v => new V2Extract("SECOND", v)) case Minute(child, _) => @@ -270,12 +269,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Extract("WEEK", v)) case YearOfWeek(child) => generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v)) - case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", encrypt.children) - case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", decrypt.children) - case Crc32(child) => generateExpressionWithName("CRC32", Seq(child)) - case Md5(child) => generateExpressionWithName("MD5", Seq(child)) - case Sha1(child) => generateExpressionWithName("SHA1", Seq(child)) - case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children) + case _: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", expr, isPredicate) + case _: AesDecrypt => generateExpressionWithName("AES_DECRYPT", expr, isPredicate) + case _: Crc32 => generateExpressionWithName("CRC32", expr, isPredicate) + case _: Md5 => generateExpressionWithName("MD5", expr, isPredicate) + case _: Sha1 => generateExpressionWithName("SHA1", expr, isPredicate) + case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate) // TODO supports other expressions case ApplyFunctionExpression(function, children) => val childrenExpressions = children.flatMap(generateExpression(_)) @@ -380,10 +379,26 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } private def generateExpressionWithName( - v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = { + v2ExpressionName: String, + expr: Expression, + isPredicate: Boolean): Option[V2Expression] = { + generateExpressionWithNameByChildren( + v2ExpressionName, expr.children, expr.dataType, isPredicate) + } + + private def generateExpressionWithNameByChildren( + v2ExpressionName: String, + children: Seq[Expression], + dataType: DataType, + isPredicate: Boolean): Option[V2Expression] = { val childrenExpressions = children.flatMap(generateExpression(_)) if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression])) + if (isPredicate && dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } else { + Some(new GeneralScalarExpression( + v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } } else { None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index a7fb2c054e80..1de535df246b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -966,6 +966,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS ) } } + + test("SPARK-47463: Pushed down v2 filter with if expression") { + withTempView("t1") { + spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() + .createTempView("t1") + val df = sql("SELECT * FROM t1 WHERE if(i = 1, i, 0) > 0") + val result = df.collect() + assert(result.length == 1) + } + } } case class RangeInputPartition(start: Int, end: Int) extends InputPartition --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org