cloud-fan commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1565741611


##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##########
@@ -229,4 +236,85 @@ class RewriteWithExpressionSuite extends PlanTest {
         .analyze
     )
   }
+
+  test("WITH expression in grouping exprs") {
+    val a = testRelation.output.head
+    val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+      ref * ref
+    }
+    val expr2 = With.create((a + 1, 1)) { case Seq(ref) =>
+      ref * ref
+    }
+    val expr3 = With.create((a + 1, 2)) { case Seq(ref) =>
+      ref * ref
+    }
+    val plan = testRelation.groupBy(expr1)(
+      (expr2 + 2).as("col1"),
+      count(expr3 - 3).as("col2")
+    )
+    val commonExpr1Name = "_common_expr_0"
+    // Note that _common_expr_1 gets deduplicated by 
PullOutGroupingExpressions.
+    val commonExpr2Name = "_common_expr_2"
+    val groupingExprName = "_groupingexpression"
+    val countAlias = count(expr3 - 3).toString
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .select(testRelation.output :+
+          ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
+        .select(testRelation.output ++ Seq($"$groupingExprName", (a + 
1).as(commonExpr2Name)): _*)
+        .groupBy($"$groupingExprName")(
+          $"$groupingExprName",
+          count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias)
+        )
+        .select(($"$groupingExprName" + 2).as("col1"), 
$"`$countAlias`".as("col2"))
+        .analyze
+    )
+    // Running CollapseProject after the rule cleans up the unnecessary 
projections.
+    comparePlans(
+      CollapseProject(Optimizer.execute(plan)),
+      testRelation
+        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .select(testRelation.output ++ Seq(
+          ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
+          (a + 1).as(commonExpr2Name)): _*)
+        .groupBy($"$groupingExprName")(
+          ($"$groupingExprName" + 2).as("col1"),
+          count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2")
+        )
+        .analyze
+    )
+  }
+
+  test("WITH expression in aggregate exprs") {
+    val Seq(a, b) = testRelation.output
+    val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+      ref * ref

Review Comment:
   can we test aggregate function as the common expression?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to