This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a6cda2302c29 [SPARK-45760][SQL][FOLLOWUP] Inline With inside 
conditional branches
a6cda2302c29 is described below

commit a6cda2302c2962072af104c5d012329b06cbf166
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Tue Nov 28 12:53:13 2023 +0100

    [SPARK-45760][SQL][FOLLOWUP] Inline With inside conditional branches
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/43623 to fix a 
regression. For `With` inside conditional branches, they may not be evaluated 
at all and we should not pull out the common expressions into a `Project`, but 
just inline.
    
    ### Why are the changes needed?
    
    avoid perf regression
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43978 from cloud-fan/with.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/expressions/Expression.scala      |   5 +
 .../expressions/conditionalExpressions.scala       |  19 +++-
 .../sql/catalyst/expressions/nullExpressions.scala |   8 ++
 .../catalyst/optimizer/RewriteWithExpression.scala | 119 ++++++++++++++-------
 .../optimizer/RewriteWithExpressionSuite.scala     |  79 +++++++++++++-
 5 files changed, 185 insertions(+), 45 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 0dc70c6c3947..2cc813bd3055 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -513,6 +513,11 @@ trait ConditionalExpression extends Expression {
    */
   def alwaysEvaluatedInputs: Seq[Expression]
 
+  /**
+   * Return a copy of itself with a new `alwaysEvaluatedInputs`.
+   */
+  def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): 
ConditionalExpression
+
   /**
    * Return groups of branches. For each group, at least one branch will be 
hit at runtime,
    * so that we can eagerly evaluate the common expressions of a group.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 28a7db51621f..9ee2f2bb4141 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -56,6 +56,10 @@ case class If(predicate: Expression, trueValue: Expression, 
falseValue: Expressi
    */
   override def alwaysEvaluatedInputs: Seq[Expression] = predicate :: Nil
 
+  override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: 
Seq[Expression]): If = {
+    copy(predicate = alwaysEvaluatedInputs.head)
+  }
+
   override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, 
falseValue))
 
   final override val nodePatterns : Seq[TreePattern] = Seq(IF)
@@ -165,8 +169,15 @@ case class CaseWhen(
 
   final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)
 
-  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Expression =
-    super.legacyWithNewChildren(newChildren)
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): CaseWhen = {
+    if (newChildren.length % 2 == 0) {
+      copy(branches = newChildren.grouped(2).map { case Seq(a, b) => (a, b) 
}.toSeq)
+    } else {
+      copy(
+        branches = newChildren.dropRight(1).grouped(2).map { case Seq(a, b) => 
(a, b) }.toSeq,
+        elseValue = newChildren.lastOption)
+    }
+  }
 
   // both then and else expressions should be considered.
   @transient
@@ -213,6 +224,10 @@ case class CaseWhen(
    */
   override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil
 
+  override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: 
Seq[Expression]): CaseWhen = {
+    withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ 
children.drop(1))
+  }
+
   override def branchGroups: Seq[Seq[Expression]] = {
     // We look at subexpressions in conditions and values of `CaseWhen` 
separately. It is
     // because a subexpression in conditions will be run no matter which 
condition is matched
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 0e9e375b8acf..4ccb369f5e2b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -70,6 +70,10 @@ case class Coalesce(children: Seq[Expression])
    */
   override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil
 
+  override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: 
Seq[Expression]): Coalesce = {
+    withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ 
children.drop(1))
+  }
+
   override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) {
     // If there is only one child, the first child is already covered by
     // `alwaysEvaluatedInputs` and we should exclude it here.
@@ -290,6 +294,10 @@ case class NaNvl(left: Expression, right: Expression)
    */
   override def alwaysEvaluatedInputs: Seq[Expression] = left :: Nil
 
+  override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: 
Seq[Expression]): NaNvl = {
+    copy(left = alwaysEvaluatedInputs.head)
+  }
+
   override def branchGroups: Seq[Seq[Expression]] = Seq(children)
 
   override def eval(input: InternalRow): Any = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
index c5bd71b4a7d1..cf2c77069a19 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, 
CommonExpressionRef, Expression, With}
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
@@ -35,56 +36,92 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
     plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
       case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
-        var newChildren = p.children
-        var newPlan: LogicalPlan = p.transformExpressionsUp {
-          case With(child, defs) =>
-            val refToExpr = mutable.HashMap.empty[Long, Expression]
-            val childProjections = 
Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias])
+        val inputPlans = p.children.toArray
+        var newPlan: LogicalPlan = p.mapExpressions { expr =>
+          rewriteWithExprAndInputPlans(expr, inputPlans)
+        }
+        newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
+        if (p.output == newPlan.output) {
+          newPlan
+        } else {
+          Project(p.output, newPlan)
+        }
+    }
+  }
+
+  private def rewriteWithExprAndInputPlans(
+      e: Expression,
+      inputPlans: Array[LogicalPlan]): Expression = {
+    if (!e.containsPattern(WITH_EXPRESSION)) return e
+    e match {
+      case w: With =>
+        // Rewrite nested With expressions first
+        val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
+        val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
+        val refToExpr = mutable.HashMap.empty[Long, Expression]
+        val childProjections = 
Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])
+
+        defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), 
index) =>
+          if (child.containsPattern(COMMON_EXPR_REF)) {
+            throw SparkException.internalError(
+              "Common expression definition cannot reference other Common 
expression definitions")
+          }
 
-            defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), 
index) =>
-              if (CollapseProject.isCheap(child)) {
-                refToExpr(id) = child
-              } else {
-                val childProjectionIndex = newChildren.indexWhere(
-                  c => child.references.subsetOf(c.outputSet)
-                )
-                if (childProjectionIndex == -1) {
-                  // When we cannot rewrite the common expressions, force to 
inline them so that the
-                  // query can still run. This can happen if the join 
condition contains `With` and
-                  // the common expression references columns from both join 
sides.
-                  // TODO: things can go wrong if the common expression is 
nondeterministic. We
-                  //       don't fix it for now to match the old buggy 
behavior when certain
-                  //       `RuntimeReplaceable` did not use the `With` 
expression.
-                  // TODO: we should calculate the ref count and also inline 
the common expression
-                  //       if it's ref count is 1.
-                  refToExpr(id) = child
-                } else {
-                  val alias = Alias(child, s"_common_expr_$index")()
-                  childProjections(childProjectionIndex) += alias
-                  refToExpr(id) = alias.toAttribute
-                }
-              }
+          if (CollapseProject.isCheap(child)) {
+            refToExpr(id) = child
+          } else {
+            val childProjectionIndex = inputPlans.indexWhere(
+              c => child.references.subsetOf(c.outputSet)
+            )
+            if (childProjectionIndex == -1) {
+              // When we cannot rewrite the common expressions, force to 
inline them so that the
+              // query can still run. This can happen if the join condition 
contains `With` and
+              // the common expression references columns from both join sides.
+              // TODO: things can go wrong if the common expression is 
nondeterministic. We
+              //       don't fix it for now to match the old buggy behavior 
when certain
+              //       `RuntimeReplaceable` did not use the `With` expression.
+              // TODO: we should calculate the ref count and also inline the 
common expression
+              //       if it's ref count is 1.
+              refToExpr(id) = child
+            } else {
+              val alias = Alias(child, s"_common_expr_$index")()
+              childProjections(childProjectionIndex) += alias
+              refToExpr(id) = alias.toAttribute
             }
+          }
+        }
+
+        for (i <- inputPlans.indices) {
+          val projectList = childProjections(i)
+          if (projectList.nonEmpty) {
+            inputPlans(i) = Project(inputPlans(i).output ++ projectList, 
inputPlans(i))
+          }
+        }
 
-            newChildren = newChildren.zip(childProjections).map { case (child, 
projections) =>
-              if (projections.nonEmpty) {
-                Project(child.output ++ projections, child)
-              } else {
-                child
-              }
+        child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
+          case ref: CommonExpressionRef =>
+            if (!refToExpr.contains(ref.id)) {
+              throw SparkException.internalError("Undefined common expression 
id " + ref.id)
             }
+            refToExpr(ref.id)
+        }
 
+      case c: ConditionalExpression =>
+        val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
+          rewriteWithExprAndInputPlans(_, inputPlans))
+        val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
+        // Use transformUp to handle nested With.
+        newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
+          case With(child, defs) =>
+            // For With in the conditional branches, they may not be evaluated 
at all and we can't
+            // pull the common expressions into a project which will always be 
evaluated. Inline it.
+            val refToExpr = defs.map(d => d.id -> d.child).toMap
             child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
               case ref: CommonExpressionRef => refToExpr(ref.id)
             }
         }
 
-        newPlan = newPlan.withNewChildren(newChildren)
-        if (p.output == newPlan.output) {
-          newPlan
-        } else {
-          Project(p.output, newPlan)
-        }
+      case other => other.mapChildren(rewriteWithExprAndInputPlans(_, 
inputPlans))
     }
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
index c625379eb5ff..a386e9bf4efe 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
CommonExpressionDef, CommonExpressionRef, With}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Coalesce, CommonExpressionDef, CommonExpressionRef, With}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -57,7 +58,7 @@ class RewriteWithExpressionSuite extends PlanTest {
     )
   }
 
-  test("nested WITH expression") {
+  test("nested WITH expression in the definition expression") {
     val a = testRelation.output.head
     val commonExprDef = CommonExpressionDef(a + a)
     val ref = new CommonExpressionRef(commonExprDef)
@@ -85,6 +86,57 @@ class RewriteWithExpressionSuite extends PlanTest {
     )
   }
 
+  test("nested WITH expression in the main expression") {
+    val a = testRelation.output.head
+    val commonExprDef = CommonExpressionDef(a + a)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val innerExpr = With(ref + ref, Seq(commonExprDef))
+    val innerCommonExprName = "_common_expr_0"
+
+    val b = testRelation.output.last
+    val outerCommonExprDef = CommonExpressionDef(b + b)
+    val outerRef = new CommonExpressionRef(outerCommonExprDef)
+    val outerExpr = With(outerRef * outerRef + innerExpr, 
Seq(outerCommonExprDef))
+    val outerCommonExprName = "_common_expr_0"
+
+    val plan = testRelation.select(outerExpr.as("col"))
+    val rewrittenInnerExpr = (a + a).as(innerCommonExprName)
+    val rewrittenOuterExpr = (b + b).as(outerCommonExprName)
+    val finalExpr = rewrittenOuterExpr.toAttribute * 
rewrittenOuterExpr.toAttribute +
+      (rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute)
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select((testRelation.output :+ rewrittenInnerExpr): _*)
+        .select((testRelation.output :+ rewrittenInnerExpr.toAttribute :+ 
rewrittenOuterExpr): _*)
+        .select(finalExpr.as("col"))
+        .analyze
+    )
+  }
+
+  test("correlated nested WITH expression is not supported") {
+    val b = testRelation.output.last
+    val outerCommonExprDef = CommonExpressionDef(b + b)
+    val outerRef = new CommonExpressionRef(outerCommonExprDef)
+
+    val a = testRelation.output.head
+    // The inner expression definition references the outer expression
+    val commonExprDef1 = CommonExpressionDef(a + a + outerRef)
+    val ref1 = new CommonExpressionRef(commonExprDef1)
+    val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1))
+
+    val outerExpr1 = With(outerRef + innerExpr1, Seq(outerCommonExprDef))
+    
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr1.as("col"))))
+
+    val commonExprDef2 = CommonExpressionDef(a + a)
+    val ref2 = new CommonExpressionRef(commonExprDef2)
+    // The inner main expression references the outer expression
+    val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef1))
+
+    val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
+    
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr2.as("col"))))
+  }
+
   test("WITH expression in filter") {
     val a = testRelation.output.head
     val commonExprDef = CommonExpressionDef(a + a)
@@ -154,4 +206,27 @@ class RewriteWithExpressionSuite extends PlanTest {
         )
     )
   }
+
+  test("WITH expression inside conditional expression") {
+    val a = testRelation.output.head
+    val commonExprDef = CommonExpressionDef(a + a)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val expr = Coalesce(Seq(a, With(ref * ref, Seq(commonExprDef))))
+    val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a)))
+    val plan = testRelation.select(expr.as("col"))
+    // With in the conditional branches is always inlined.
+    comparePlans(Optimizer.execute(plan), 
testRelation.select(inlinedExpr.as("col")))
+
+    val expr2 = Coalesce(Seq(With(ref * ref, Seq(commonExprDef)), a))
+    val plan2 = testRelation.select(expr2.as("col"))
+    val commonExprName = "_common_expr_0"
+    // With in the always-evaluated branches can still be optimized.
+    comparePlans(
+      Optimizer.execute(plan2),
+      testRelation
+        .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
+        .select(Coalesce(Seq(($"$commonExprName" * $"$commonExprName"), 
a)).as("col"))
+        .analyze
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to