This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 42f2132d1fc9 [SPARK-48206][SQL][TESTS] Add tests for window rewrites with RewriteWithExpression 42f2132d1fc9 is described below commit 42f2132d1fc99bf2ec5bd398d21dcbdbd5cbde47 Author: Kelvin Jiang <kelvin.ji...@databricks.com> AuthorDate: Mon May 13 22:28:27 2024 +0800 [SPARK-48206][SQL][TESTS] Add tests for window rewrites with RewriteWithExpression ### What changes were proposed in this pull request? This PR adds more testing for `RewriteWithExpression` around `Window` operators. ### Why are the changes needed? Adds more testing for `RewriteWithExpression`, which can be fragile around `WindowExpressions`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46492 from kelvinjian-db/SPARK-48206-window. Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../optimizer/RewriteWithExpressionSuite.scala | 223 +++++++++++++-------- 1 file changed, 135 insertions(+), 88 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index 8f023fa4156b..aa8ffb2b0454 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.IntegerType class RewriteWithExpressionSuite extends PlanTest { @@ -37,6 +36,20 @@ class RewriteWithExpressionSuite extends PlanTest { private val testRelation = LocalRelation($"a".int, $"b".int) private val testRelation2 = LocalRelation($"x".int, $"y".int) + private def normalizeCommonExpressionIds(plan: LogicalPlan): LogicalPlan = { + plan.transformAllExpressions { + case a: Alias if a.name.startsWith("_common_expr") => + a.withName("_common_expr_0") + case a: AttributeReference if a.name.startsWith("_common_expr") => + a.withName("_common_expr_0") + } + } + + override def comparePlans( + plan1: LogicalPlan, plan2: LogicalPlan, checkAnalysis: Boolean = true): Unit = { + super.comparePlans(normalizeCommonExpressionIds(plan1), normalizeCommonExpressionIds(plan2)) + } + test("simple common expression") { val a = testRelation.output.head val expr = With(a) { case Seq(ref) => @@ -52,65 +65,48 @@ class RewriteWithExpressionSuite extends PlanTest { ref * ref } val plan = testRelation.select(expr.as("col")) - val commonExprId = expr.defs.head.id.id - val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation - .select((testRelation.output :+ (a + a).as(commonExprName)): _*) - .select(($"$commonExprName" * $"$commonExprName").as("col")) + .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*) + .select(($"_common_expr_0" * $"_common_expr_0").as("col")) .analyze ) } test("nested WITH expression in the definition expression") { - val a = testRelation.output.head + val Seq(a, b) = testRelation.output val innerExpr = With(a + a) { case Seq(ref) => ref + ref } - val innerCommonExprId = innerExpr.defs.head.id.id - val innerCommonExprName = s"_common_expr_$innerCommonExprId" - - val b = testRelation.output.last val outerExpr = With(innerExpr + b) { case Seq(ref) => ref * ref } - val outerCommonExprId = outerExpr.defs.head.id.id - val outerCommonExprName = s"_common_expr_$outerCommonExprId" val plan = testRelation.select(outerExpr.as("col")) - val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b) - .as(outerCommonExprName) - val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)( - exprId = rewrittenOuterExpr.exprId) comparePlans( Optimizer.execute(plan), testRelation - .select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*) - .select((testRelation.output :+ $"$innerCommonExprName" :+ rewrittenOuterExpr): _*) - .select((outerExprAttr * outerExprAttr).as("col")) + .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*) + .select((testRelation.output ++ Seq($"_common_expr_0", + ($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))): _*) + .select(($"_common_expr_1" * $"_common_expr_1").as("col")) .analyze ) } test("nested WITH expression in the main expression") { - val a = testRelation.output.head + val Seq(a, b) = testRelation.output val innerExpr = With(a + a) { case Seq(ref) => ref + ref } - val innerCommonExprId = innerExpr.defs.head.id.id - val innerCommonExprName = s"_common_expr_$innerCommonExprId" - - val b = testRelation.output.last val outerExpr = With(b + b) { case Seq(ref) => ref * ref + innerExpr } - val outerCommonExprId = outerExpr.defs.head.id.id - val outerCommonExprName = s"_common_expr_$outerCommonExprId" val plan = testRelation.select(outerExpr.as("col")) - val rewrittenInnerExpr = (a + a).as(innerCommonExprName) - val rewrittenOuterExpr = (b + b).as(outerCommonExprName) + val rewrittenInnerExpr = (a + a).as("_common_expr_0") + val rewrittenOuterExpr = (b + b).as("_common_expr_1") val finalExpr = rewrittenOuterExpr.toAttribute * rewrittenOuterExpr.toAttribute + (rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute) comparePlans( @@ -124,11 +120,10 @@ class RewriteWithExpressionSuite extends PlanTest { } test("correlated nested WITH expression is not supported") { - val b = testRelation.output.last + val Seq(a, b) = testRelation.output val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0)) val outerRef = new CommonExpressionRef(outerCommonExprDef) - val a = testRelation.output.head // The inner expression definition references the outer expression val commonExprDef1 = CommonExpressionDef(a + a + outerRef, CommonExpressionId(1)) val ref1 = new CommonExpressionRef(commonExprDef1) @@ -152,13 +147,11 @@ class RewriteWithExpressionSuite extends PlanTest { ref < 10 && ref > 0 } val plan = testRelation.where(condition) - val commonExprId = condition.defs.head.id.id - val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation - .select((testRelation.output :+ (a + a).as(commonExprName)): _*) - .where($"$commonExprName" < 10 && $"$commonExprName" > 0) + .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*) + .where($"_common_expr_0" < 10 && $"_common_expr_0" > 0) .select(testRelation.output: _*) .analyze ) @@ -170,13 +163,11 @@ class RewriteWithExpressionSuite extends PlanTest { ref < 10 && ref > 0 } val plan = testRelation.join(testRelation2, condition = Some(condition)) - val commonExprId = condition.defs.head.id.id - val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation - .select((testRelation.output :+ (a + a).as(commonExprName)): _*) - .join(testRelation2, condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0)) + .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*) + .join(testRelation2, condition = Some($"_common_expr_0" < 10 && $"_common_expr_0" > 0)) .select((testRelation.output ++ testRelation2.output): _*) .analyze ) @@ -188,14 +179,12 @@ class RewriteWithExpressionSuite extends PlanTest { ref < 10 && ref > 0 } val plan = testRelation.join(testRelation2, condition = Some(condition)) - val commonExprId = condition.defs.head.id.id - val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation .join( - testRelation2.select((testRelation2.output :+ (x + x).as(commonExprName)): _*), - condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0) + testRelation2.select((testRelation2.output :+ (x + x).as("_common_expr_0")): _*), + condition = Some($"_common_expr_0" < 10 && $"_common_expr_0" > 0) ) .select((testRelation.output ++ testRelation2.output): _*) .analyze @@ -234,14 +223,12 @@ class RewriteWithExpressionSuite extends PlanTest { ref * ref }, a)) val plan2 = testRelation.select(expr2.as("col")) - val commonExprId = expr2.children.head.asInstanceOf[With].defs.head.id.id - val commonExprName = s"_common_expr_$commonExprId" // With in the always-evaluated branches can still be optimized. comparePlans( Optimizer.execute(plan2), testRelation - .select((testRelation.output :+ (a + a).as(commonExprName)): _*) - .select(Coalesce(Seq(($"$commonExprName" * $"$commonExprName"), a)).as("col")) + .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*) + .select(Coalesce(Seq(($"_common_expr_0" * $"_common_expr_0"), a)).as("col")) .analyze ) } @@ -261,38 +248,32 @@ class RewriteWithExpressionSuite extends PlanTest { (expr2 + 2).as("col1"), count(expr3 - 3).as("col2") ) - val commonExpr1Id = expr1.defs.head.id.id - val commonExpr1Name = s"_common_expr_$commonExpr1Id" - // Note that the common expression in expr2 gets de-duplicated by PullOutGroupingExpressions. - val commonExpr3Id = expr3.defs.head.id.id - val commonExpr3Name = s"_common_expr_$commonExpr3Id" - val groupingExprName = "_groupingexpression" - val aggExprName = "_aggregateexpression" comparePlans( Optimizer.execute(plan), testRelation - .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*) .select(testRelation.output :+ - ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*) - .select(testRelation.output ++ Seq($"$groupingExprName", (a + 1).as(commonExpr3Name)): _*) - .groupBy($"$groupingExprName")( - $"$groupingExprName", - count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as(aggExprName) + ($"_common_expr_0" * $"_common_expr_0").as("_groupingexpression"): _*) + .select(testRelation.output ++ Seq($"_groupingexpression", + (a + 1).as("_common_expr_1")): _*) + .groupBy($"_groupingexpression")( + $"_groupingexpression", + count($"_common_expr_1" * $"_common_expr_1" - 3).as("_aggregateexpression") ) - .select(($"$groupingExprName" + 2).as("col1"), $"`$aggExprName`".as("col2")) + .select(($"_groupingexpression" + 2).as("col1"), $"_aggregateexpression".as("col2")) .analyze ) // Running CollapseProject after the rule cleans up the unnecessary projections. comparePlans( CollapseProject(Optimizer.execute(plan)), testRelation - .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*) .select(testRelation.output ++ Seq( - ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName), - (a + 1).as(commonExpr3Name)): _*) - .groupBy($"$groupingExprName")( - ($"$groupingExprName" + 2).as("col1"), - count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as("col2") + ($"_common_expr_0" * $"_common_expr_0").as("_groupingexpression"), + (a + 1).as("_common_expr_1")): _*) + .groupBy($"_groupingexpression")( + ($"_groupingexpression" + 2).as("col1"), + count($"_common_expr_1" * $"_common_expr_1" - 3).as("col2") ) .analyze ) @@ -311,21 +292,16 @@ class RewriteWithExpressionSuite extends PlanTest { expr1.as("col2"), max(expr2).as("col3") ) - val commonExpr1Id = expr1.defs.head.id.id - val commonExpr1Name = s"_common_expr_$commonExpr1Id" - val commonExpr2Id = expr2.defs.head.id.id - val commonExpr2Name = s"_common_expr_$commonExpr2Id" - val aggExprName = "_aggregateexpression" comparePlans( Optimizer.execute(plan), testRelation - .select(testRelation.output :+ (b + 2).as(commonExpr2Name): _*) - .groupBy(a)(a, max($"$commonExpr2Name" * $"$commonExpr2Name").as(aggExprName)) - .select(a, $"`$aggExprName`", (a + 1).as(commonExpr1Name)) + .select(testRelation.output :+ (b + 2).as("_common_expr_0"): _*) + .groupBy(a)(a, max($"_common_expr_0" * $"_common_expr_0").as("_aggregateexpression")) + .select(a, $"_aggregateexpression", (a + 1).as("_common_expr_1")) .select( (a + 3).as("col1"), - ($"$commonExpr1Name" * $"$commonExpr1Name").as("col2"), - $"`$aggExprName`".as("col3") + ($"_common_expr_1" * $"_common_expr_1").as("col2"), + $"_aggregateexpression".as("col3") ) .analyze ) @@ -340,14 +316,13 @@ class RewriteWithExpressionSuite extends PlanTest { (a - 1).as("col1"), expr.as("col2") ) - val aggExprName = "_aggregateexpression" comparePlans( Optimizer.execute(plan), testRelation - .groupBy(a)(a, count(a - 1).as(aggExprName)) + .groupBy(a)(a, count(a - 1).as("_aggregateexpression")) .select( (a - 1).as("col1"), - ($"$aggExprName" * $"$aggExprName").as("col2") + ($"_aggregateexpression" * $"_aggregateexpression").as("col2") ) .analyze ) @@ -376,19 +351,91 @@ class RewriteWithExpressionSuite extends PlanTest { ref * max(expr) + ref } val plan = testRelation.groupBy(a)(nestedExpr.as("col")).analyze - val commonExpr1Id = expr.defs.head.id.id - val commonExpr1Name = s"_common_expr_$commonExpr1Id" - val commonExpr2Id = nestedExpr.defs.head.id.id - val commonExpr2Name = s"_common_expr_$commonExpr2Id" - val aggExprName = "_aggregateexpression" comparePlans( Optimizer.execute(plan), testRelation - .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) - .groupBy(a)(a, max($"$commonExpr1Name" * $"$commonExpr1Name").as(aggExprName)) - .select($"a", $"$aggExprName", (a - 1).as(commonExpr2Name)) - .select(($"$commonExpr2Name" * $"$aggExprName" + $"$commonExpr2Name").as("col")) + .select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*) + .groupBy(a)(a, max($"_common_expr_0" * $"_common_expr_0").as("_aggregateexpression")) + .select($"a", $"_aggregateexpression", (a - 1).as("_common_expr_1")) + .select(($"_common_expr_1" * $"_aggregateexpression" + $"_common_expr_1").as("col")) + .analyze + ) + } + + test("WITH expression in window exprs") { + val Seq(a, b) = testRelation.output + val expr1 = With(a + 1) { case Seq(ref) => + ref * ref + } + val expr2 = With(b + 2) { case Seq(ref) => + ref * ref + } + val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + val plan = testRelation + .window( + Seq(windowExpr(count(a), windowSpec(Seq(expr2), Nil, frame)).as("col2")), + Seq(expr2), + Nil + ) + .window( + Seq(windowExpr(sum(expr1), windowSpec(Seq(a), Nil, frame)).as("col3")), + Seq(a), + Nil + ) + .select((a - 1).as("col1"), $"col2", $"col3") + .analyze + comparePlans( + Optimizer.execute(plan), + testRelation + .select(a, b, (b + 2).as("_common_expr_0")) + .select(a, b, $"_common_expr_0", (b + 2).as("_common_expr_1")) + .window( + Seq(windowExpr(count(a), windowSpec(Seq($"_common_expr_0" * $"_common_expr_0"), Nil, + frame)).as("col2")), + Seq($"_common_expr_1" * $"_common_expr_1"), + Nil + ) + .select(a, b, $"col2") + .select(a, b, $"col2", (a + 1).as("_common_expr_2")) + .window( + Seq(windowExpr(sum($"_common_expr_2" * $"_common_expr_2"), + windowSpec(Seq(a), Nil, frame)).as("col3")), + Seq(a), + Nil + ) + .select(a, b, $"col2", $"col3") + .select((a - 1).as("col1"), $"col2", $"col3") + .analyze + ) + } + + test("WITH common expression is window function") { + val a = testRelation.output.head + val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + val winExpr = windowExpr(sum(a), windowSpec(Seq(a), Nil, frame)) + val expr = With(winExpr) { + case Seq(ref) => ref * ref + } + val plan = testRelation.select(expr.as("col")).analyze + comparePlans( + Optimizer.execute(plan), + testRelation + .select(a) + .window(Seq(winExpr.as("_we0")), Seq(a), Nil) + .select(a, $"_we0", ($"_we0" * $"_we0").as("col")) + .select($"col") .analyze ) } + + test("window functions in child of WITH expression with ref is not supported") { + val a = testRelation.output.head + intercept[java.lang.AssertionError] { + val expr = With(a - 1) { case Seq(ref) => + ref + windowExpr(sum(ref), windowSpec(Seq(a), Nil, UnspecifiedFrame)) + } + val plan = testRelation.window(Seq(expr.as("col")), Seq(a), Nil) + Optimizer.execute(plan) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org