kelvinjian-db commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1566233143


##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##########
@@ -229,4 +236,85 @@ class RewriteWithExpressionSuite extends PlanTest {
         .analyze
     )
   }
+
+  test("WITH expression in grouping exprs") {
+    val a = testRelation.output.head
+    val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+      ref * ref
+    }
+    val expr2 = With.create((a + 1, 1)) { case Seq(ref) =>
+      ref * ref
+    }
+    val expr3 = With.create((a + 1, 2)) { case Seq(ref) =>
+      ref * ref
+    }
+    val plan = testRelation.groupBy(expr1)(
+      (expr2 + 2).as("col1"),
+      count(expr3 - 3).as("col2")
+    )
+    val commonExpr1Name = "_common_expr_0"
+    // Note that _common_expr_1 gets deduplicated by 
PullOutGroupingExpressions.
+    val commonExpr2Name = "_common_expr_2"
+    val groupingExprName = "_groupingexpression"
+    val countAlias = count(expr3 - 3).toString
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .select(testRelation.output :+
+          ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
+        .select(testRelation.output ++ Seq($"$groupingExprName", (a + 
1).as(commonExpr2Name)): _*)
+        .groupBy($"$groupingExprName")(
+          $"$groupingExprName",
+          count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias)
+        )
+        .select(($"$groupingExprName" + 2).as("col1"), 
$"`$countAlias`".as("col2"))
+        .analyze
+    )
+    // Running CollapseProject after the rule cleans up the unnecessary 
projections.
+    comparePlans(
+      CollapseProject(Optimizer.execute(plan)),
+      testRelation
+        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .select(testRelation.output ++ Seq(
+          ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
+          (a + 1).as(commonExpr2Name)): _*)
+        .groupBy($"$groupingExprName")(
+          ($"$groupingExprName" + 2).as("col1"),
+          count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2")
+        )
+        .analyze
+    )
+  }
+
+  test("WITH expression in aggregate exprs") {

Review Comment:
   doesn't the test above test WITH in both grouping and aggregate expressions? 
the test here is for testing the motivating example mentioned in 
https://issues.apache.org/jira/browse/SPARK-47839



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to