cloud-fan commented on code in PR #46034: URL: https://github.com/apache/spark/pull/46034#discussion_r1565741611
########## sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala: ########## @@ -229,4 +236,85 @@ class RewriteWithExpressionSuite extends PlanTest { .analyze ) } + + test("WITH expression in grouping exprs") { + val a = testRelation.output.head + val expr1 = With.create((a + 1, 0)) { case Seq(ref) => + ref * ref + } + val expr2 = With.create((a + 1, 1)) { case Seq(ref) => + ref * ref + } + val expr3 = With.create((a + 1, 2)) { case Seq(ref) => + ref * ref + } + val plan = testRelation.groupBy(expr1)( + (expr2 + 2).as("col1"), + count(expr3 - 3).as("col2") + ) + val commonExpr1Name = "_common_expr_0" + // Note that _common_expr_1 gets deduplicated by PullOutGroupingExpressions. + val commonExpr2Name = "_common_expr_2" + val groupingExprName = "_groupingexpression" + val countAlias = count(expr3 - 3).toString + comparePlans( + Optimizer.execute(plan), + testRelation + .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .select(testRelation.output :+ + ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*) + .select(testRelation.output ++ Seq($"$groupingExprName", (a + 1).as(commonExpr2Name)): _*) + .groupBy($"$groupingExprName")( + $"$groupingExprName", + count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias) + ) + .select(($"$groupingExprName" + 2).as("col1"), $"`$countAlias`".as("col2")) + .analyze + ) + // Running CollapseProject after the rule cleans up the unnecessary projections. + comparePlans( + CollapseProject(Optimizer.execute(plan)), + testRelation + .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .select(testRelation.output ++ Seq( + ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName), + (a + 1).as(commonExpr2Name)): _*) + .groupBy($"$groupingExprName")( + ($"$groupingExprName" + 2).as("col1"), + count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2") + ) + .analyze + ) + } + + test("WITH expression in aggregate exprs") { + val Seq(a, b) = testRelation.output + val expr1 = With.create((a + 1, 0)) { case Seq(ref) => + ref * ref Review Comment: can we test aggregate function as the common expression? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org