This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 3e01d410b4f9 [SPARK-50091][SQL][3.5] Handle case of aggregates in 
left-hand operand of IN-subquery
3e01d410b4f9 is described below

commit 3e01d410b4f9b7bb98ad6f209b8311d4fd154164
Author: Bruce Robbins <[email protected]>
AuthorDate: Fri Jan 24 21:20:45 2025 -0800

    [SPARK-50091][SQL][3.5] Handle case of aggregates in left-hand operand of 
IN-subquery
    
    ### What changes were proposed in this pull request?
    
    This is a back-port of #48627.
    
    This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle 
the case where an `Aggregate` node contains an aggregate expression in the 
left-hand operand of an IN-subquery expression. The explicit handler moves the 
IN-subquery expressions out of the `Aggregate` and into a parent `Project` 
node. The `Aggregate` will continue to perform the aggregations that were used 
as an operand to the IN-subquery expression, but will not include the 
IN-subquery expression itself. After [...]
    ```
    Project [col1#32, exists#42 AS (sum(col2) IN 
(listquery()))https://github.com/apache/spark/pull/40]
    +- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L)
       :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L]
       :  +- LocalRelation [col1#32, col2#33]
       +- LocalRelation [c2#39L]
    ```
    `sum(col2)#41L` in the above join condition, despite how it looks, is the 
name of the attribute, not an aggregate expression.
    
    ### Why are the changes needed?
    
    The following query fails:
    ```
    create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), 
(3, 7), (3, 1);
    create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 
2), (3, 7), (3, 1);
    
    select col1, sum(col2) in (select c2 from v1)
    from v2 group by col1;
    ```
    It fails with this error:
    ```
    [INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, 
false]) SQLSTATE: XX000
    ```
    With SPARK_TESTING=1, it fails with this error:
    ```
    [PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule 
org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch 
RewriteSubquery generated an invalid plan: Special expressions are placed in 
the wrong plan:
    Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN 
(listquery()))https://github.com/apache/spark/pull/19]
    +- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L)
       :- LocalRelation [col1#11, col2#12]
       +- LocalRelation [c2#18L]
    ```
    The issue is that `RewritePredicateSubquery` builds a `Join` operator where 
the join condition contains an aggregate expression.
    
    The bug is in the handler for `UnaryNode` in 
`RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and 
assumes that the left-hand operand of IN-subquery can be used in the join 
condition. This works fine for most cases, but not when the left-hand operand 
is an aggregate expression.
    
    This PR moves the offending IN-subqueries to a `Project` node, with the 
aggregates replaced by attributes referring to the aggregate expressions. The 
resulting join condition now uses those attributes rather than the actual 
aggregate expressions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, other than allowing this type of query to succeed.
    
    ### How was this patch tested?
    
    New unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49663 from bersprockets/aggregate_in_set_issue_br35.
    
    Authored-by: Bruce Robbins <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../spark/sql/catalyst/optimizer/subquery.scala    | 96 ++++++++++++++++++++--
 .../catalyst/optimizer/RewriteSubquerySuite.scala  | 19 ++++-
 .../scala/org/apache/spark/sql/SubquerySuite.scala | 30 +++++++
 3 files changed, 136 insertions(+), 9 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index ee2005315781..0652ee221c35 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -26,6 +26,7 @@ import 
org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import 
org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
+import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
@@ -100,6 +101,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
     }
   }
 
+  def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = {
+    exprs.exists { expr =>
+      exprContainsAggregateInSubquery(expr)
+    }
+  }
+
+  def exprContainsAggregateInSubquery(expr: Expression): Boolean = {
+    expr.exists {
+      case InSubquery(values, _) =>
+        values.exists { v =>
+          v.exists {
+            case _: AggregateExpression => true
+            case _ => false
+          }
+        }
+      case _ => false;
+    }
+  }
+
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
     _.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) {
     case Filter(condition, child)
@@ -162,15 +182,75 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
           Project(p.output, Filter(newCond.get, inputPlan))
       }
 
+    // Handle the case where the left-hand side of an IN-subquery contains an 
aggregate.
+    //
+    // If an Aggregate node contains such an IN-subquery, this handler will 
pull up all
+    // expressions from the Aggregate node into a new Project node. The new 
Project node
+    // will then be handled by the Unary node handler.
+    //
+    // The Unary node handler uses the left-hand side of the IN-subquery in a
+    // join condition. Thus, without this pre-transformation, the join 
condition
+    // contains an aggregate, which is illegal. With this pre-transformation, 
the
+    // join condition contains an attribute from the left-hand side of the
+    // IN-subquery contained in the Project node.
+    //
+    // For example:
+    //
+    //   SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
+    //   FROM v2;
+    //
+    // The above query has this plan on entry to 
RewritePredicateSubquery#apply:
+    //
+    //   Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS 
x#13]
+    //   :  +- LocalRelation [c3#28L]
+    //   +- LocalRelation [col2#18, col3#19]
+    //
+    // Note that the Aggregate node contains the IN-subquery and the left-hand
+    // side of the IN-subquery is an aggregate expression sum(col2#18)).
+    //
+    // This handler transforms the above plan into the following:
+    // scalastyle:off line.size.limit
+    //
+    //   Project [(_aggregateexpression#20L IN (list#12 []) AND 
(_aggregateexpression#21L > -1)) AS x#13]
+    //   :  +- LocalRelation [c3#28L]
+    //   +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) 
AS _aggregateexpression#21L]
+    //      +- LocalRelation [col2#18, col3#19]
+    //
+    // scalastyle:on
+    // Note that both the IN-subquery and the greater-than expressions have 
been
+    // pulled up into the Project node. These expressions use attributes
+    // (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the 
aggregations
+    // which are still performed in the Aggregate node (sum(col2#18) and 
sum(col3#19)).
+    case p @ PhysicalAggregation(
+        groupingExpressions, aggregateExpressions, resultExpressions, child)
+        if exprsContainsAggregateInSubquery(p.expressions) =>
+      val aggExprs = aggregateExpressions.map(
+        ae => Alias(ae, "_aggregateexpression")(ae.resultId))
+      val aggExprIds = aggExprs.map(_.exprId).toSet
+      val resExprs = resultExpressions.map(_.transform {
+        case a: AttributeReference if aggExprIds.contains(a.exprId) =>
+          a.withName("_aggregateexpression")
+      }.asInstanceOf[NamedExpression])
+      // Rewrite the projection and the aggregate separately and then piece 
them together.
+      val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ 
aggExprs, child)
+      val newProj = Project(resExprs, newAgg)
+      handleUnaryNode(newProj)
+
     case u: UnaryNode if u.expressions.exists(
-        SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
-      var newChild = u.child
-      u.mapExpressions(expr => {
-        val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
-        newChild = p
-        // The newExpr can not be None
-        newExpr.get
-      }).withNewChildren(Seq(newChild))
+        SubqueryExpression.hasInOrCorrelatedExistsSubquery) => 
handleUnaryNode(u)
+  }
+
+  /**
+   * Handle the unary node case
+   */
+  private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
+    var newChild = u.child
+    u.mapExpressions(expr => {
+      val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
+      newChild = p
+      // The newExpr can not be None
+      newExpr.get
+    }).withNewChildren(Seq(newChild))
   }
 
   /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
index 17547bbcb940..c45a761353c8 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.catalyst.QueryPlanningTracker
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
+import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
 import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.LongType
 
 
 class RewriteSubquerySuite extends PlanTest {
@@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
     Optimize.executeAndTrack(query.analyze, tracker)
     
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations 
== 0)
   }
+
+  test("SPARK-50091: Don't put aggregate expression in join condition") {
+    val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
+    val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
+    val plan = 
relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
+    val optimized = Optimize.execute(plan.analyze)
+    val aggregate = relation2
+      .select($"col2")
+      .groupBy()(sum($"col2").as("_aggregateexpression"))
+    val correctAnswer = aggregate
+      .join(relation1.select(Cast($"c3", LongType).as("c3")),
+        ExistenceJoin($"exists".boolean.withNullability(false)),
+        Some($"_aggregateexpression" === $"c3"))
+      .select($"exists".as("(sum(col2) IN (listquery()))")).analyze
+    comparePlans(optimized, correctAnswer)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 260c992f1aed..04702201f82f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest
       checkAnswer(df3, Row(7))
     }
   }
+
+  test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
+    withView("v1", "v2") {
+      Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8))
+        .toDF("c1", "c2", "c3")
+        .createOrReplaceTempView("v1")
+      Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1))
+        .toDF("col1", "col2", "col3")
+        .createOrReplaceTempView("v2")
+
+      val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 
GROUP BY col1")
+      checkAnswer(df1,
+        Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil)
+
+      val df2 = sql("""SELECT
+                      |  col1,
+                      |  SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN 
(SELECT c2 FROM v1) AS x
+                      |FROM v2 GROUP BY col1
+                      |ORDER BY col1""".stripMargin)
+      checkAnswer(df2,
+        Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
+
+      val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 
FROM v1) AS x
+                      |FROM v2
+                      |GROUP BY col1
+                      |ORDER BY col1""".stripMargin)
+      checkAnswer(df3,
+        Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to