Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-17 Thread via GitHub


cloud-fan closed pull request #46034: [SPARK-47839][SQL] Fix aggregate bug in 
RewriteWithExpression
URL: https://github.com/apache/spark/pull/46034


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-17 Thread via GitHub


cloud-fan commented on PR #46034:
URL: https://github.com/apache/spark/pull/46034#issuecomment-2062852813

   thanks, merging to master!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-16 Thread via GitHub


cloud-fan commented on PR #46034:
URL: https://github.com/apache/spark/pull/46034#issuecomment-2060437664

   The test fails: `org.apache.spark.sql.connect.ProtoToParsedPlanTestSuite`


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-16 Thread via GitHub


cloud-fan commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1568259745


##
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala:
##
@@ -21,36 +21,68 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, 
Project}
+import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
PlanHelper, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
 
 /**
  * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the 
common expressions, or
  * just inline them if they are cheap.
  *
+ * Since this rule can introduce new `Project` operators, it is advised to run 
[[CollapseProject]]
+ * after this rule.
+ *
  * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. 
If we expand its
  *   usage, we should support aggregate/window functions as well.
  */
 object RewriteWithExpression extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+  // For aggregates, separate the computation of the aggregations 
themselves from the final
+  // result by moving the final result computation into a projection above 
it. This prevents
+  // this rule from producing an invalid Aggregate operator.
+  case p @ PhysicalAggregation(
+  groupingExpressions, aggregateExpressions, resultExpressions, child)
+  if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
+// There should not be dangling common expression references in the 
aggregate expressions.
+// This can happen if a With is created with an aggregate function in 
its child.
+assert(!aggregateExpressions.exists(ae =>

Review Comment:
   Shall we do the assert in the constructor of `With`, to fail earlier?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-16 Thread via GitHub


kelvinjian-db commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1567767274


##
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##
@@ -29,25 +29,29 @@ import org.apache.spark.sql.types.IntegerType
 class RewriteWithExpressionSuite extends PlanTest {
 
   object Optimizer extends RuleExecutor[LogicalPlan] {
-val batches = Batch("Rewrite With expression", Once, 
RewriteWithExpression) :: Nil
+val batches = Batch("Rewrite With expression", Once,
+  PullOutGroupingExpressions,
+  RewriteWithExpression) :: Nil
   }
 
   private val testRelation = LocalRelation($"a".int, $"b".int)
   private val testRelation2 = LocalRelation($"x".int, $"y".int)
 
   test("simple common expression") {
 val a = testRelation.output.head
-val commonExprDef = CommonExpressionDef(a)
-val ref = new CommonExpressionRef(commonExprDef)
-val plan = testRelation.select(With(ref + ref, 
Seq(commonExprDef)).as("col"))
+val expr = With.create((a, 0)) { case Seq(ref) =>

Review Comment:
   i ended up changing the tests to manually extract the common expression IDs 
so we don't need `With.create` anymore



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


kelvinjian-db commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1566230161


##
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##
@@ -29,25 +29,29 @@ import org.apache.spark.sql.types.IntegerType
 class RewriteWithExpressionSuite extends PlanTest {
 
   object Optimizer extends RuleExecutor[LogicalPlan] {
-val batches = Batch("Rewrite With expression", Once, 
RewriteWithExpression) :: Nil
+val batches = Batch("Rewrite With expression", Once,
+  PullOutGroupingExpressions,
+  RewriteWithExpression) :: Nil
   }
 
   private val testRelation = LocalRelation($"a".int, $"b".int)
   private val testRelation2 = LocalRelation($"x".int, $"y".int)
 
   test("simple common expression") {
 val a = testRelation.output.head
-val commonExprDef = CommonExpressionDef(a)
-val ref = new CommonExpressionRef(commonExprDef)
-val plan = testRelation.select(With(ref + ref, 
Seq(commonExprDef)).as("col"))
+val expr = With.create((a, 0)) { case Seq(ref) =>

Review Comment:
   ~~good point, i think it doesn't, i'll remove these~~ never mind, looks like 
it does matter since the IDs are used in the aliases due to [this 
change](https://github.com/apache/spark/pull/46034/files#diff-1eba63e1edc5cae9e89b8e326d7f5797ae5dc9314707268760484a3ed3d57038L96-R117).
 i originally made that change since it was possible to introduce duplicate 
alias names using `index`, whereas using the common expression ID was a more 
natural ID to use in the alias name. however, if the aliases are only used for 
bookkeeping, i can revert that change



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


kelvinjian-db commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1566233683


##
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##
@@ -229,4 +236,85 @@ class RewriteWithExpressionSuite extends PlanTest {
 .analyze
 )
   }
+
+  test("WITH expression in grouping exprs") {
+val a = testRelation.output.head
+val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+  ref * ref
+}
+val expr2 = With.create((a + 1, 1)) { case Seq(ref) =>
+  ref * ref
+}
+val expr3 = With.create((a + 1, 2)) { case Seq(ref) =>
+  ref * ref
+}
+val plan = testRelation.groupBy(expr1)(
+  (expr2 + 2).as("col1"),
+  count(expr3 - 3).as("col2")
+)
+val commonExpr1Name = "_common_expr_0"
+// Note that _common_expr_1 gets deduplicated by 
PullOutGroupingExpressions.
+val commonExpr2Name = "_common_expr_2"
+val groupingExprName = "_groupingexpression"
+val countAlias = count(expr3 - 3).toString
+comparePlans(
+  Optimizer.execute(plan),
+  testRelation
+.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+.select(testRelation.output :+
+  ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
+.select(testRelation.output ++ Seq($"$groupingExprName", (a + 
1).as(commonExpr2Name)): _*)
+.groupBy($"$groupingExprName")(
+  $"$groupingExprName",
+  count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias)
+)
+.select(($"$groupingExprName" + 2).as("col1"), 
$"`$countAlias`".as("col2"))
+.analyze
+)
+// Running CollapseProject after the rule cleans up the unnecessary 
projections.
+comparePlans(
+  CollapseProject(Optimizer.execute(plan)),
+  testRelation
+.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+.select(testRelation.output ++ Seq(
+  ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
+  (a + 1).as(commonExpr2Name)): _*)
+.groupBy($"$groupingExprName")(
+  ($"$groupingExprName" + 2).as("col1"),
+  count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2")
+)
+.analyze
+)
+  }
+
+  test("WITH expression in aggregate exprs") {
+val Seq(a, b) = testRelation.output
+val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+  ref * ref

Review Comment:
   yes let me add a test for that



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


kelvinjian-db commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1566233143


##
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##
@@ -229,4 +236,85 @@ class RewriteWithExpressionSuite extends PlanTest {
 .analyze
 )
   }
+
+  test("WITH expression in grouping exprs") {
+val a = testRelation.output.head
+val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+  ref * ref
+}
+val expr2 = With.create((a + 1, 1)) { case Seq(ref) =>
+  ref * ref
+}
+val expr3 = With.create((a + 1, 2)) { case Seq(ref) =>
+  ref * ref
+}
+val plan = testRelation.groupBy(expr1)(
+  (expr2 + 2).as("col1"),
+  count(expr3 - 3).as("col2")
+)
+val commonExpr1Name = "_common_expr_0"
+// Note that _common_expr_1 gets deduplicated by 
PullOutGroupingExpressions.
+val commonExpr2Name = "_common_expr_2"
+val groupingExprName = "_groupingexpression"
+val countAlias = count(expr3 - 3).toString
+comparePlans(
+  Optimizer.execute(plan),
+  testRelation
+.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+.select(testRelation.output :+
+  ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
+.select(testRelation.output ++ Seq($"$groupingExprName", (a + 
1).as(commonExpr2Name)): _*)
+.groupBy($"$groupingExprName")(
+  $"$groupingExprName",
+  count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias)
+)
+.select(($"$groupingExprName" + 2).as("col1"), 
$"`$countAlias`".as("col2"))
+.analyze
+)
+// Running CollapseProject after the rule cleans up the unnecessary 
projections.
+comparePlans(
+  CollapseProject(Optimizer.execute(plan)),
+  testRelation
+.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+.select(testRelation.output ++ Seq(
+  ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
+  (a + 1).as(commonExpr2Name)): _*)
+.groupBy($"$groupingExprName")(
+  ($"$groupingExprName" + 2).as("col1"),
+  count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2")
+)
+.analyze
+)
+  }
+
+  test("WITH expression in aggregate exprs") {

Review Comment:
   doesn't the test above test WITH in both grouping and aggregate expressions? 
the test here is for testing the motivating example mentioned in 
https://issues.apache.org/jira/browse/SPARK-47839



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


kelvinjian-db commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1566230161


##
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##
@@ -29,25 +29,29 @@ import org.apache.spark.sql.types.IntegerType
 class RewriteWithExpressionSuite extends PlanTest {
 
   object Optimizer extends RuleExecutor[LogicalPlan] {
-val batches = Batch("Rewrite With expression", Once, 
RewriteWithExpression) :: Nil
+val batches = Batch("Rewrite With expression", Once,
+  PullOutGroupingExpressions,
+  RewriteWithExpression) :: Nil
   }
 
   private val testRelation = LocalRelation($"a".int, $"b".int)
   private val testRelation2 = LocalRelation($"x".int, $"y".int)
 
   test("simple common expression") {
 val a = testRelation.output.head
-val commonExprDef = CommonExpressionDef(a)
-val ref = new CommonExpressionRef(commonExprDef)
-val plan = testRelation.select(With(ref + ref, 
Seq(commonExprDef)).as("col"))
+val expr = With.create((a, 0)) { case Seq(ref) =>

Review Comment:
   good point, i think it doesn't, i'll remove these



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


kelvinjian-db commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1566228919


##
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala:
##
@@ -21,36 +21,57 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, 
Project}
+import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
PlanHelper, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
 
 /**
  * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the 
common expressions, or
  * just inline them if they are cheap.
  *
+ * Since this rule can introduce new `Project` operators, it is advised to run 
[[CollapseProject]]
+ * after this rule.
+ *
  * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. 
If we expand its
  *   usage, we should support aggregate/window functions as well.
  */
 object RewriteWithExpression extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+  case p @ PhysicalAggregation(
+  groupingExpressions, aggregateExpressions, resultExpressions, child)
+  if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
+// For aggregates, separate computation of the aggregations themselves 
from the final
+// result by moving the final result computation into a projection 
above. This prevents
+// this rule from producing an invalid Aggregate operator.
+// TODO: the names of these aliases will become outdated after the 
rewrite
+val aggExprs = aggregateExpressions.map(ae => Alias(ae, 
ae.toString)(ae.resultId))
+// Rewrite the projection and the aggregate separately and then piece 
them together.
+val agg = Aggregate(groupingExpressions, groupingExpressions ++ 
aggExprs, child)
+val rewrittenAgg = applyInternal(agg)
+val proj = Project(resultExpressions, rewrittenAgg)
+applyInternal(proj)
   case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
-val inputPlans = p.children.toArray
-var newPlan: LogicalPlan = p.mapExpressions { expr =>
-  rewriteWithExprAndInputPlans(expr, inputPlans)
-}
-newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
-// Since we add extra Projects with extra columns to pre-evaluate the 
common expressions,
-// the current operator may have extra columns if it inherits the 
output columns from its
-// child, and we need to project away the extra columns to keep the 
plan schema unchanged.
-assert(p.output.length <= newPlan.output.length)
-if (p.output.length < newPlan.output.length) {
-  assert(p.outputSet.subsetOf(newPlan.outputSet))
-  Project(p.output, newPlan)
-} else {
-  newPlan
-}
+applyInternal(p)
+}
+  }
+
+  private def applyInternal(p: LogicalPlan): LogicalPlan = {
+val inputPlans = p.children.toArray
+var newPlan: LogicalPlan = p.mapExpressions { expr =>
+  rewriteWithExprAndInputPlans(expr, inputPlans)
+}
+newPlan = newPlan.withNewChildren(inputPlans)

Review Comment:
   my understanding is that we need to call `withNewChildren` because 
`rewriteWithExprAndInputPlans` puts the new projections in `inputPlans` and we 
need to update our current plan with them?
   
   (btw this was largely just refactored from the previous code)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


kelvinjian-db commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1566227660


##
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala:
##
@@ -21,36 +21,57 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, 
Project}
+import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
PlanHelper, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
 
 /**
  * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the 
common expressions, or
  * just inline them if they are cheap.
  *
+ * Since this rule can introduce new `Project` operators, it is advised to run 
[[CollapseProject]]
+ * after this rule.
+ *
  * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. 
If we expand its
  *   usage, we should support aggregate/window functions as well.
  */
 object RewriteWithExpression extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+  case p @ PhysicalAggregation(
+  groupingExpressions, aggregateExpressions, resultExpressions, child)
+  if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
+// For aggregates, separate computation of the aggregations themselves 
from the final
+// result by moving the final result computation into a projection 
above. This prevents
+// this rule from producing an invalid Aggregate operator.
+// TODO: the names of these aliases will become outdated after the 
rewrite
+val aggExprs = aggregateExpressions.map(ae => Alias(ae, 
ae.toString)(ae.resultId))

Review Comment:
   good idea, i'll change this



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


cloud-fan commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1565741611


##
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##
@@ -229,4 +236,85 @@ class RewriteWithExpressionSuite extends PlanTest {
 .analyze
 )
   }
+
+  test("WITH expression in grouping exprs") {
+val a = testRelation.output.head
+val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+  ref * ref
+}
+val expr2 = With.create((a + 1, 1)) { case Seq(ref) =>
+  ref * ref
+}
+val expr3 = With.create((a + 1, 2)) { case Seq(ref) =>
+  ref * ref
+}
+val plan = testRelation.groupBy(expr1)(
+  (expr2 + 2).as("col1"),
+  count(expr3 - 3).as("col2")
+)
+val commonExpr1Name = "_common_expr_0"
+// Note that _common_expr_1 gets deduplicated by 
PullOutGroupingExpressions.
+val commonExpr2Name = "_common_expr_2"
+val groupingExprName = "_groupingexpression"
+val countAlias = count(expr3 - 3).toString
+comparePlans(
+  Optimizer.execute(plan),
+  testRelation
+.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+.select(testRelation.output :+
+  ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
+.select(testRelation.output ++ Seq($"$groupingExprName", (a + 
1).as(commonExpr2Name)): _*)
+.groupBy($"$groupingExprName")(
+  $"$groupingExprName",
+  count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias)
+)
+.select(($"$groupingExprName" + 2).as("col1"), 
$"`$countAlias`".as("col2"))
+.analyze
+)
+// Running CollapseProject after the rule cleans up the unnecessary 
projections.
+comparePlans(
+  CollapseProject(Optimizer.execute(plan)),
+  testRelation
+.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+.select(testRelation.output ++ Seq(
+  ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
+  (a + 1).as(commonExpr2Name)): _*)
+.groupBy($"$groupingExprName")(
+  ($"$groupingExprName" + 2).as("col1"),
+  count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2")
+)
+.analyze
+)
+  }
+
+  test("WITH expression in aggregate exprs") {
+val Seq(a, b) = testRelation.output
+val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+  ref * ref

Review Comment:
   can we test aggregate function as the common expression?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


cloud-fan commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1565740751


##
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##
@@ -229,4 +236,85 @@ class RewriteWithExpressionSuite extends PlanTest {
 .analyze
 )
   }
+
+  test("WITH expression in grouping exprs") {
+val a = testRelation.output.head
+val expr1 = With.create((a + 1, 0)) { case Seq(ref) =>
+  ref * ref
+}
+val expr2 = With.create((a + 1, 1)) { case Seq(ref) =>
+  ref * ref
+}
+val expr3 = With.create((a + 1, 2)) { case Seq(ref) =>
+  ref * ref
+}
+val plan = testRelation.groupBy(expr1)(
+  (expr2 + 2).as("col1"),
+  count(expr3 - 3).as("col2")
+)
+val commonExpr1Name = "_common_expr_0"
+// Note that _common_expr_1 gets deduplicated by 
PullOutGroupingExpressions.
+val commonExpr2Name = "_common_expr_2"
+val groupingExprName = "_groupingexpression"
+val countAlias = count(expr3 - 3).toString
+comparePlans(
+  Optimizer.execute(plan),
+  testRelation
+.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+.select(testRelation.output :+
+  ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
+.select(testRelation.output ++ Seq($"$groupingExprName", (a + 
1).as(commonExpr2Name)): _*)
+.groupBy($"$groupingExprName")(
+  $"$groupingExprName",
+  count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias)
+)
+.select(($"$groupingExprName" + 2).as("col1"), 
$"`$countAlias`".as("col2"))
+.analyze
+)
+// Running CollapseProject after the rule cleans up the unnecessary 
projections.
+comparePlans(
+  CollapseProject(Optimizer.execute(plan)),
+  testRelation
+.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+.select(testRelation.output ++ Seq(
+  ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
+  (a + 1).as(commonExpr2Name)): _*)
+.groupBy($"$groupingExprName")(
+  ($"$groupingExprName" + 2).as("col1"),
+  count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2")
+)
+.analyze
+)
+  }
+
+  test("WITH expression in aggregate exprs") {

Review Comment:
   can we merge the two tests and have WITH in both grouping and aggregate 
expressions?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


cloud-fan commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1565739557


##
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala:
##
@@ -21,36 +21,57 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, 
Project}
+import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
PlanHelper, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
 
 /**
  * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the 
common expressions, or
  * just inline them if they are cheap.
  *
+ * Since this rule can introduce new `Project` operators, it is advised to run 
[[CollapseProject]]
+ * after this rule.
+ *
  * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. 
If we expand its
  *   usage, we should support aggregate/window functions as well.
  */
 object RewriteWithExpression extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+  case p @ PhysicalAggregation(
+  groupingExpressions, aggregateExpressions, resultExpressions, child)
+  if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
+// For aggregates, separate computation of the aggregations themselves 
from the final
+// result by moving the final result computation into a projection 
above. This prevents
+// this rule from producing an invalid Aggregate operator.
+// TODO: the names of these aliases will become outdated after the 
rewrite
+val aggExprs = aggregateExpressions.map(ae => Alias(ae, 
ae.toString)(ae.resultId))

Review Comment:
   The alias name doesn't matter as it's only for internal bookkeeping. 
`.toString` can be super long if the aggregate function input is a complex 
expression. Shall we follow `PullOutGroupingExpressions` and use consistent 
naming? like `_aggregateExpression`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


cloud-fan commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1565732326


##
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala:
##
@@ -29,25 +29,29 @@ import org.apache.spark.sql.types.IntegerType
 class RewriteWithExpressionSuite extends PlanTest {
 
   object Optimizer extends RuleExecutor[LogicalPlan] {
-val batches = Batch("Rewrite With expression", Once, 
RewriteWithExpression) :: Nil
+val batches = Batch("Rewrite With expression", Once,
+  PullOutGroupingExpressions,
+  RewriteWithExpression) :: Nil
   }
 
   private val testRelation = LocalRelation($"a".int, $"b".int)
   private val testRelation2 = LocalRelation($"x".int, $"y".int)
 
   test("simple common expression") {
 val a = testRelation.output.head
-val commonExprDef = CommonExpressionDef(a)
-val ref = new CommonExpressionRef(commonExprDef)
-val plan = testRelation.select(With(ref + ref, 
Seq(commonExprDef)).as("col"))
+val expr = With.create((a, 0)) { case Seq(ref) =>

Review Comment:
   does the id matter for this test?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-15 Thread via GitHub


cloud-fan commented on code in PR #46034:
URL: https://github.com/apache/spark/pull/46034#discussion_r1565727826


##
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala:
##
@@ -21,36 +21,57 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, 
Project}
+import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
PlanHelper, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
 
 /**
  * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the 
common expressions, or
  * just inline them if they are cheap.
  *
+ * Since this rule can introduce new `Project` operators, it is advised to run 
[[CollapseProject]]
+ * after this rule.
+ *
  * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. 
If we expand its
  *   usage, we should support aggregate/window functions as well.
  */
 object RewriteWithExpression extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+  case p @ PhysicalAggregation(
+  groupingExpressions, aggregateExpressions, resultExpressions, child)
+  if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
+// For aggregates, separate computation of the aggregations themselves 
from the final
+// result by moving the final result computation into a projection 
above. This prevents
+// this rule from producing an invalid Aggregate operator.
+// TODO: the names of these aliases will become outdated after the 
rewrite
+val aggExprs = aggregateExpressions.map(ae => Alias(ae, 
ae.toString)(ae.resultId))
+// Rewrite the projection and the aggregate separately and then piece 
them together.
+val agg = Aggregate(groupingExpressions, groupingExpressions ++ 
aggExprs, child)
+val rewrittenAgg = applyInternal(agg)
+val proj = Project(resultExpressions, rewrittenAgg)
+applyInternal(proj)
   case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
-val inputPlans = p.children.toArray
-var newPlan: LogicalPlan = p.mapExpressions { expr =>
-  rewriteWithExprAndInputPlans(expr, inputPlans)
-}
-newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
-// Since we add extra Projects with extra columns to pre-evaluate the 
common expressions,
-// the current operator may have extra columns if it inherits the 
output columns from its
-// child, and we need to project away the extra columns to keep the 
plan schema unchanged.
-assert(p.output.length <= newPlan.output.length)
-if (p.output.length < newPlan.output.length) {
-  assert(p.outputSet.subsetOf(newPlan.outputSet))
-  Project(p.output, newPlan)
-} else {
-  newPlan
-}
+applyInternal(p)
+}
+  }
+
+  private def applyInternal(p: LogicalPlan): LogicalPlan = {
+val inputPlans = p.children.toArray
+var newPlan: LogicalPlan = p.mapExpressions { expr =>
+  rewriteWithExprAndInputPlans(expr, inputPlans)
+}
+newPlan = newPlan.withNewChildren(inputPlans)

Review Comment:
   the children do not change, why calling `withNewChildren`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression [spark]

2024-04-12 Thread via GitHub


kelvinjian-db commented on PR #46034:
URL: https://github.com/apache/spark/pull/46034#issuecomment-2052701361

   cc @cloud-fan @jchen5


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org