This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new df08177de2fd [SPARK-48416][SQL] Support nested correlated With
expression
df08177de2fd is described below
commit df08177de2fd2b177caf79ca533eb0cd2c6a4ba6
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Dec 12 15:36:09 2024 -0800
[SPARK-48416][SQL] Support nested correlated With expression
### What changes were proposed in this pull request?
The inner `With` may reference common expressions of an outer `With`. This
PR supports this case by making the rule `RewriteWithExpression` only rewrite
top-level `With` expressions, and run the rule repeatedly so that the inner
`With` expression becomes top-level `With` after one iteration, and gets
rewritten in the next iteration.
### Why are the changes needed?
To support optimized filter pushdown with `With` expression.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
updated the unit test
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49093 from cloud-fan/with.
Lead-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +-
.../catalyst/optimizer/RewriteWithExpression.scala | 25 ++++-----
.../optimizer/RewriteWithExpressionSuite.scala | 61 ++++++++++++++--------
3 files changed, 50 insertions(+), 38 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 7ec467badce5..31c1f8917763 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -160,7 +160,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
// We must run this batch after `ReplaceExpressions`, as
`RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
- Batch("Rewrite With expression", Once, RewriteWithExpression) ::
+ Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
index 393a66f7c1e4..d0c5d8158644 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -85,21 +85,19 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
private def rewriteWithExprAndInputPlans(
e: Expression,
- inputPlans: Array[LogicalPlan]): Expression = {
+ inputPlans: Array[LogicalPlan],
+ isNestedWith: Boolean = false): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
- case w: With =>
+ // Do not handle nested With in one pass. Leave it to the next rule
executor batch.
+ case w: With if !isNestedWith =>
// Rewrite nested With expressions first
- val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
- val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
+ val child = rewriteWithExprAndInputPlans(w.child, inputPlans,
isNestedWith = true)
+ val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans,
isNestedWith = true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
val childProjections =
Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id),
index) =>
- if (child.containsPattern(COMMON_EXPR_REF)) {
- throw SparkException.internalError(
- "Common expression definition cannot reference other Common
expression definitions")
- }
if (id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression definitions")
@@ -148,10 +146,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
- case ref: CommonExpressionRef =>
- if (!refToExpr.contains(ref.id)) {
- throw SparkException.internalError("Undefined common expression
id " + ref.id)
- }
+ // `child` may contain nested With and we only replace
`CommonExpressionRef` that
+ // references common expressions in the current `With`.
+ case ref: CommonExpressionRef if refToExpr.contains(ref.id) =>
if (ref.id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression references")
@@ -161,7 +158,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
- rewriteWithExprAndInputPlans(_, inputPlans))
+ rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith))
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
@@ -174,7 +171,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}
}
- case other => other.mapChildren(rewriteWithExprAndInputPlans(_,
inputPlans))
+ case other => other.mapChildren(rewriteWithExprAndInputPlans(_,
inputPlans, isNestedWith))
}
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
index 0aeca961aa51..0be6ae649464 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.TempResolvedColumn
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
class RewriteWithExpressionSuite extends PlanTest {
object Optimizer extends RuleExecutor[LogicalPlan] {
- val batches = Batch("Rewrite With expression", Once,
+ val batches = Batch("Rewrite With expression", FixedPoint(5),
PullOutGroupingExpressions,
RewriteWithExpression) :: Nil
}
@@ -84,13 +83,11 @@ class RewriteWithExpressionSuite extends PlanTest {
ref * ref
}
- val plan = testRelation.select(outerExpr.as("col"))
comparePlans(
- Optimizer.execute(plan),
+ Optimizer.execute(testRelation.select(outerExpr.as("col"))),
testRelation
- .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
- .select((testRelation.output ++ Seq($"_common_expr_0",
- ($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))):
_*)
+ .select(star(), (a + a).as("_common_expr_0"))
+ .select(a, b, ($"_common_expr_0" + $"_common_expr_0" +
b).as("_common_expr_1"))
.select(($"_common_expr_1" * $"_common_expr_1").as("col"))
.analyze
)
@@ -104,42 +101,60 @@ class RewriteWithExpressionSuite extends PlanTest {
val outerExpr = With(b + b) { case Seq(ref) =>
ref * ref + innerExpr
}
-
- val plan = testRelation.select(outerExpr.as("col"))
- val rewrittenInnerExpr = (a + a).as("_common_expr_0")
- val rewrittenOuterExpr = (b + b).as("_common_expr_1")
- val finalExpr = rewrittenOuterExpr.toAttribute *
rewrittenOuterExpr.toAttribute +
- (rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute)
+ val finalExpr = $"_common_expr_1" * $"_common_expr_1" + ($"_common_expr_0"
+ $"_common_expr_0")
comparePlans(
- Optimizer.execute(plan),
+ Optimizer.execute(testRelation.select(outerExpr.as("col"))),
testRelation
- .select((testRelation.output :+ rewrittenInnerExpr): _*)
- .select((testRelation.output :+ rewrittenInnerExpr.toAttribute :+
rewrittenOuterExpr): _*)
+ .select(star(), (b + b).as("_common_expr_1"))
+ .select(star(), (a + a).as("_common_expr_0"))
.select(finalExpr.as("col"))
.analyze
)
}
- test("correlated nested WITH expression is not supported") {
+ test("correlated nested WITH expression is supported") {
val Seq(a, b) = testRelation.output
val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0))
val outerRef = new CommonExpressionRef(outerCommonExprDef)
+ val rewrittenOuterExpr = (b + b).as("_common_expr_0")
// The inner expression definition references the outer expression
val commonExprDef1 = CommonExpressionDef(a + a + outerRef,
CommonExpressionId(1))
val ref1 = new CommonExpressionRef(commonExprDef1)
val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1))
-
val outerExpr1 = With(outerRef + innerExpr1, Seq(outerCommonExprDef))
-
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr1.as("col"))))
+ comparePlans(
+ Optimizer.execute(testRelation.select(outerExpr1.as("col"))),
+ testRelation
+ // The first Project contains the common expression of the outer With
+ .select(star(), rewrittenOuterExpr)
+ // The second Project contains the common expression of the inner
With, which references
+ // the common expression of the outer With.
+ .select(star(), (a + a + $"_common_expr_0").as("_common_expr_1"))
+ // The final Project contains the final result expression, which
references both common
+ // expressions.
+ .select(($"_common_expr_0" + ($"_common_expr_1" +
$"_common_expr_1")).as("col"))
+ .analyze
+ )
- val commonExprDef2 = CommonExpressionDef(a + a)
+ val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
- val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef1))
-
+ val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
-
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr2.as("col"))))
+ comparePlans(
+ Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
+ testRelation
+ // The first Project contains the common expression of the outer With
+ .select(star(), rewrittenOuterExpr)
+ // The second Project contains the common expression of the inner
With, which does not
+ // reference the common expression of the outer With.
+ .select(star(), (a + a).as("_common_expr_2"))
+ // The final Project contains the final result expression, which
references both common
+ // expressions.
+ .select(($"_common_expr_0" + ($"_common_expr_2" +
$"_common_expr_0")).as("col"))
+ .analyze
+ )
}
test("WITH expression in filter") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]