This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7ef0440ef221 [SPARK-48146][SQL] Fix aggregate function in With expression child assertion 7ef0440ef221 is described below commit 7ef0440ef22161a6160f7b9000c70b26c84eecf7 Author: Kelvin Jiang <kelvin.ji...@databricks.com> AuthorDate: Fri May 10 22:39:15 2024 +0800 [SPARK-48146][SQL] Fix aggregate function in With expression child assertion ### What changes were proposed in this pull request? In https://github.com/apache/spark/pull/46034, there was a complicated edge case where common expression references in aggregate functions in the child of a `With` expression could become dangling. An assertion was added to avoid that case from happening, but the assertion wasn't fully accurate as a query like: ``` select id between max(if(id between 1 and 2, 2, 1)) over () and id from range(10) ``` would fail the assertion. This PR fixes the assertion to be more accurate. ### Why are the changes needed? This addresses a regression in https://github.com/apache/spark/pull/46034. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46443 from kelvinjian-db/SPARK-48146-agg. Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/expressions/With.scala | 26 +++++++++++++++++---- .../optimizer/RewriteWithExpressionSuite.scala | 27 +++++++++++++++++++++- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala index 14deedd9c70f..29794b33641c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} import org.apache.spark.sql.types.DataType /** @@ -27,9 +28,11 @@ import org.apache.spark.sql.types.DataType */ case class With(child: Expression, defs: Seq[CommonExpressionDef]) extends Expression with Unevaluable { - // We do not allow With to be created with an AggregateExpression in the child, as this would - // create a dangling CommonExpressionRef after rewriting it in RewriteWithExpression. - assert(!child.containsPattern(AGGREGATE_EXPRESSION)) + // We do not allow creating a With expression with an AggregateExpression that contains a + // reference to a common expression defined in that scope (note that it can contain another With + // expression with a common expression ref of the inner With). This is to prevent the creation of + // a dangling CommonExpressionRef after rewriting it in RewriteWithExpression. + assert(!With.childContainsUnsupportedAggExpr(this)) override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION) override def dataType: DataType = child.dataType @@ -92,6 +95,21 @@ object With { val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_)) With(replaced(commonExprRefs), commonExprDefs) } + + private[sql] def childContainsUnsupportedAggExpr(withExpr: With): Boolean = { + lazy val commonExprIds = withExpr.defs.map(_.id).toSet + withExpr.child.exists { + case agg: AggregateExpression => + // Check that the aggregate expression does not contain a reference to a common expression + // in the outer With expression (it is ok if it contains a reference to a common expression + // for a nested With expression). + agg.exists { + case r: CommonExpressionRef => commonExprIds.contains(r.id) + case _ => false + } + case _ => false + } + } } case class CommonExpressionId(id: Long = CommonExpressionId.newId, canonicalized: Boolean = false) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index d482b18d9331..8f023fa4156b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -353,7 +353,7 @@ class RewriteWithExpressionSuite extends PlanTest { ) } - test("aggregate functions in child of WITH expression is not supported") { + test("aggregate functions in child of WITH expression with ref is not supported") { val a = testRelation.output.head intercept[java.lang.AssertionError] { val expr = With(a - 1) { case Seq(ref) => @@ -366,4 +366,29 @@ class RewriteWithExpressionSuite extends PlanTest { Optimizer.execute(plan) } } + + test("WITH expression nested in aggregate function") { + val a = testRelation.output.head + val expr = With(a + 1) { case Seq(ref) => + ref * ref + } + val nestedExpr = With(a - 1) { case Seq(ref) => + ref * max(expr) + ref + } + val plan = testRelation.groupBy(a)(nestedExpr.as("col")).analyze + val commonExpr1Id = expr.defs.head.id.id + val commonExpr1Name = s"_common_expr_$commonExpr1Id" + val commonExpr2Id = nestedExpr.defs.head.id.id + val commonExpr2Name = s"_common_expr_$commonExpr2Id" + val aggExprName = "_aggregateexpression" + comparePlans( + Optimizer.execute(plan), + testRelation + .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .groupBy(a)(a, max($"$commonExpr1Name" * $"$commonExpr1Name").as(aggExprName)) + .select($"a", $"$aggExprName", (a - 1).as(commonExpr2Name)) + .select(($"$commonExpr2Name" * $"$aggExprName" + $"$commonExpr2Name").as("col")) + .analyze + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org