This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d0605bf3bf7b [SPARK-47001][SQL] Pushdown verification in optimizer d0605bf3bf7b is described below commit d0605bf3bf7baf4e00924923cee70f729f3aa635 Author: Holden Karau <hka...@netflix.com> AuthorDate: Thu Apr 11 10:38:39 2024 +0800 [SPARK-47001][SQL] Pushdown verification in optimizer ### What changes were proposed in this pull request? Changes how we evaluate & candidate elements for filter pushdown past unions. ### Why are the changes needed? Unions type promotion combined with a reference to the head child dataframe can result in errors. ### Does this PR introduce _any_ user-facing change? Yes: slightly more filters will be pushed down (these would have previously thrown an exception). ### How was this patch tested? New test added. ### Was this patch authored or co-authored using generative AI tooling? No Closes #45146 from holdenk/SPARK-47001-pushdown-verification-in-optimizer. Authored-by: Holden Karau <hka...@netflix.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 10 +++++++-- .../catalyst/optimizer/FilterPushdownSuite.scala | 26 +++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3a4002127df1..cacde9f5a712 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1824,22 +1824,28 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe if (pushDown.nonEmpty) { val pushDownCond = pushDown.reduceLeft(And) + // The union is the child of the filter so it's children are grandchildren. + // Moves filters down to the grandchild if there is an element in the grand child's + // output which is semantically equal to the filter being evaluated. val output = union.output val newGrandChildren = union.children.map { grandchild => val newCond = pushDownCond transform { - case e if output.exists(_.semanticEquals(e)) => - grandchild.output(output.indexWhere(_.semanticEquals(e))) + case a: Attribute if output.exists(_.exprId == a.exprId) => + grandchild.output(output.indexWhere(_.exprId == a.exprId)) } assert(newCond.references.subsetOf(grandchild.outputSet)) Filter(newCond, grandchild) } val newUnion = union.withNewChildren(newGrandChildren) if (stayUp.nonEmpty) { + // If there is any filter we can't push evaluate them post union Filter(stayUp.reduceLeft(And), newUnion) } else { + // If we pushed all filters then just return the new union. newUnion } } else { + // If we can't push anything just return the initial filter. filter } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index bd2ac28a049f..03e65412d166 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { @@ -882,6 +882,30 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("union filter pushdown w/reference to grand-child field") { + val nonNullableArray = StructField("a", ArrayType(IntegerType, false)) + val bField = StructField("b", IntegerType) + val testRelationNonNull = LocalRelation(nonNullableArray, bField) + val testRelationNull = LocalRelation($"c".array(IntegerType), $"d".int) + + val nonNullArrayRef = AttributeReference("a", ArrayType(IntegerType, false))( + testRelationNonNull.output(0).exprId, List()) + + + val originalQuery = Union(Seq(testRelationNonNull, testRelationNull)) + .where(IsNotNull(nonNullArrayRef)) + + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Union(Seq( + testRelationNonNull.where(IsNotNull($"a")), + testRelationNull.where(IsNotNull($"c")))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("expand") { val agg = testRelation .groupBy(Cube(Seq(Seq($"a"), Seq($"b"))))($"a", $"b", sum($"c")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org