This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ef0440ef221 [SPARK-48146][SQL] Fix aggregate function in With 
expression child assertion
7ef0440ef221 is described below

commit 7ef0440ef22161a6160f7b9000c70b26c84eecf7
Author: Kelvin Jiang <kelvin.ji...@databricks.com>
AuthorDate: Fri May 10 22:39:15 2024 +0800

    [SPARK-48146][SQL] Fix aggregate function in With expression child assertion
    
    ### What changes were proposed in this pull request?
    
    In https://github.com/apache/spark/pull/46034, there was a complicated edge 
case where common expression references in aggregate functions in the child of 
a `With` expression could become dangling. An assertion was added to avoid that 
case from happening, but the assertion wasn't fully accurate as a query like:
    ```
    select
      id between max(if(id between 1 and 2, 2, 1)) over () and id
    from range(10)
    ```
    would fail the assertion.
    
    This PR fixes the assertion to be more accurate.
    
    ### Why are the changes needed?
    
    This addresses a regression in https://github.com/apache/spark/pull/46034.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46443 from kelvinjian-db/SPARK-48146-agg.
    
    Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/expressions/With.scala      | 26 +++++++++++++++++----
 .../optimizer/RewriteWithExpressionSuite.scala     | 27 +++++++++++++++++++++-
 2 files changed, 48 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
index 14deedd9c70f..29794b33641c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, 
COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
TreePattern, WITH_EXPRESSION}
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -27,9 +28,11 @@ import org.apache.spark.sql.types.DataType
  */
 case class With(child: Expression, defs: Seq[CommonExpressionDef])
   extends Expression with Unevaluable {
-  // We do not allow With to be created with an AggregateExpression in the 
child, as this would
-  // create a dangling CommonExpressionRef after rewriting it in 
RewriteWithExpression.
-  assert(!child.containsPattern(AGGREGATE_EXPRESSION))
+  // We do not allow creating a With expression with an AggregateExpression 
that contains a
+  // reference to a common expression defined in that scope (note that it can 
contain another With
+  // expression with a common expression ref of the inner With). This is to 
prevent the creation of
+  // a dangling CommonExpressionRef after rewriting it in 
RewriteWithExpression.
+  assert(!With.childContainsUnsupportedAggExpr(this))
 
   override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION)
   override def dataType: DataType = child.dataType
@@ -92,6 +95,21 @@ object With {
     val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_))
     With(replaced(commonExprRefs), commonExprDefs)
   }
+
+  private[sql] def childContainsUnsupportedAggExpr(withExpr: With): Boolean = {
+    lazy val commonExprIds = withExpr.defs.map(_.id).toSet
+    withExpr.child.exists {
+      case agg: AggregateExpression =>
+        // Check that the aggregate expression does not contain a reference to 
a common expression
+        // in the outer With expression (it is ok if it contains a reference 
to a common expression
+        // for a nested With expression).
+        agg.exists {
+          case r: CommonExpressionRef => commonExprIds.contains(r.id)
+          case _ => false
+        }
+      case _ => false
+    }
+  }
 }
 
 case class CommonExpressionId(id: Long = CommonExpressionId.newId, 
canonicalized: Boolean = false) {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
index d482b18d9331..8f023fa4156b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -353,7 +353,7 @@ class RewriteWithExpressionSuite extends PlanTest {
     )
   }
 
-  test("aggregate functions in child of WITH expression is not supported") {
+  test("aggregate functions in child of WITH expression with ref is not 
supported") {
     val a = testRelation.output.head
     intercept[java.lang.AssertionError] {
       val expr = With(a - 1) { case Seq(ref) =>
@@ -366,4 +366,29 @@ class RewriteWithExpressionSuite extends PlanTest {
       Optimizer.execute(plan)
     }
   }
+
+  test("WITH expression nested in aggregate function") {
+    val a = testRelation.output.head
+    val expr = With(a + 1) { case Seq(ref) =>
+      ref * ref
+    }
+    val nestedExpr = With(a - 1) { case Seq(ref) =>
+      ref * max(expr) + ref
+    }
+    val plan = testRelation.groupBy(a)(nestedExpr.as("col")).analyze
+    val commonExpr1Id = expr.defs.head.id.id
+    val commonExpr1Name = s"_common_expr_$commonExpr1Id"
+    val commonExpr2Id = nestedExpr.defs.head.id.id
+    val commonExpr2Name = s"_common_expr_$commonExpr2Id"
+    val aggExprName = "_aggregateexpression"
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .groupBy(a)(a, max($"$commonExpr1Name" * 
$"$commonExpr1Name").as(aggExprName))
+        .select($"a", $"$aggExprName", (a - 1).as(commonExpr2Name))
+        .select(($"$commonExpr2Name" * $"$aggExprName" + 
$"$commonExpr2Name").as("col"))
+        .analyze
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to