This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 3e01d410b4f9 [SPARK-50091][SQL][3.5] Handle case of aggregates in
left-hand operand of IN-subquery
3e01d410b4f9 is described below
commit 3e01d410b4f9b7bb98ad6f209b8311d4fd154164
Author: Bruce Robbins <[email protected]>
AuthorDate: Fri Jan 24 21:20:45 2025 -0800
[SPARK-50091][SQL][3.5] Handle case of aggregates in left-hand operand of
IN-subquery
### What changes were proposed in this pull request?
This is a back-port of #48627.
This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle
the case where an `Aggregate` node contains an aggregate expression in the
left-hand operand of an IN-subquery expression. The explicit handler moves the
IN-subquery expressions out of the `Aggregate` and into a parent `Project`
node. The `Aggregate` will continue to perform the aggregations that were used
as an operand to the IN-subquery expression, but will not include the
IN-subquery expression itself. After [...]
```
Project [col1#32, exists#42 AS (sum(col2) IN
(listquery()))https://github.com/apache/spark/pull/40]
+- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L)
:- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L]
: +- LocalRelation [col1#32, col2#33]
+- LocalRelation [c2#39L]
```
`sum(col2)#41L` in the above join condition, despite how it looks, is the
name of the attribute, not an aggregate expression.
### Why are the changes needed?
The following query fails:
```
create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2),
(3, 7), (3, 1);
create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2,
2), (3, 7), (3, 1);
select col1, sum(col2) in (select c2 from v1)
from v2 group by col1;
```
It fails with this error:
```
[INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int,
false]) SQLSTATE: XX000
```
With SPARK_TESTING=1, it fails with this error:
```
[PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule
org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch
RewriteSubquery generated an invalid plan: Special expressions are placed in
the wrong plan:
Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN
(listquery()))https://github.com/apache/spark/pull/19]
+- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L)
:- LocalRelation [col1#11, col2#12]
+- LocalRelation [c2#18L]
```
The issue is that `RewritePredicateSubquery` builds a `Join` operator where
the join condition contains an aggregate expression.
The bug is in the handler for `UnaryNode` in
`RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and
assumes that the left-hand operand of IN-subquery can be used in the join
condition. This works fine for most cases, but not when the left-hand operand
is an aggregate expression.
This PR moves the offending IN-subqueries to a `Project` node, with the
aggregates replaced by attributes referring to the aggregate expressions. The
resulting join condition now uses those attributes rather than the actual
aggregate expressions.
### Does this PR introduce _any_ user-facing change?
No, other than allowing this type of query to succeed.
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49663 from bersprockets/aggregate_in_set_issue_br35.
Authored-by: Bruce Robbins <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/sql/catalyst/optimizer/subquery.scala | 96 ++++++++++++++++++++--
.../catalyst/optimizer/RewriteSubquerySuite.scala | 19 ++++-
.../scala/org/apache/spark/sql/SubquerySuite.scala | 30 +++++++
3 files changed, 136 insertions(+), 9 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index ee2005315781..0652ee221c35 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -26,6 +26,7 @@ import
org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import
org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
+import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -100,6 +101,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
}
}
+ def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = {
+ exprs.exists { expr =>
+ exprContainsAggregateInSubquery(expr)
+ }
+ }
+
+ def exprContainsAggregateInSubquery(expr: Expression): Boolean = {
+ expr.exists {
+ case InSubquery(values, _) =>
+ values.exists { v =>
+ v.exists {
+ case _: AggregateExpression => true
+ case _ => false
+ }
+ }
+ case _ => false;
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case Filter(condition, child)
@@ -162,15 +182,75 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
Project(p.output, Filter(newCond.get, inputPlan))
}
+ // Handle the case where the left-hand side of an IN-subquery contains an
aggregate.
+ //
+ // If an Aggregate node contains such an IN-subquery, this handler will
pull up all
+ // expressions from the Aggregate node into a new Project node. The new
Project node
+ // will then be handled by the Unary node handler.
+ //
+ // The Unary node handler uses the left-hand side of the IN-subquery in a
+ // join condition. Thus, without this pre-transformation, the join
condition
+ // contains an aggregate, which is illegal. With this pre-transformation,
the
+ // join condition contains an attribute from the left-hand side of the
+ // IN-subquery contained in the Project node.
+ //
+ // For example:
+ //
+ // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
+ // FROM v2;
+ //
+ // The above query has this plan on entry to
RewritePredicateSubquery#apply:
+ //
+ // Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS
x#13]
+ // : +- LocalRelation [c3#28L]
+ // +- LocalRelation [col2#18, col3#19]
+ //
+ // Note that the Aggregate node contains the IN-subquery and the left-hand
+ // side of the IN-subquery is an aggregate expression sum(col2#18)).
+ //
+ // This handler transforms the above plan into the following:
+ // scalastyle:off line.size.limit
+ //
+ // Project [(_aggregateexpression#20L IN (list#12 []) AND
(_aggregateexpression#21L > -1)) AS x#13]
+ // : +- LocalRelation [c3#28L]
+ // +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19)
AS _aggregateexpression#21L]
+ // +- LocalRelation [col2#18, col3#19]
+ //
+ // scalastyle:on
+ // Note that both the IN-subquery and the greater-than expressions have
been
+ // pulled up into the Project node. These expressions use attributes
+ // (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the
aggregations
+ // which are still performed in the Aggregate node (sum(col2#18) and
sum(col3#19)).
+ case p @ PhysicalAggregation(
+ groupingExpressions, aggregateExpressions, resultExpressions, child)
+ if exprsContainsAggregateInSubquery(p.expressions) =>
+ val aggExprs = aggregateExpressions.map(
+ ae => Alias(ae, "_aggregateexpression")(ae.resultId))
+ val aggExprIds = aggExprs.map(_.exprId).toSet
+ val resExprs = resultExpressions.map(_.transform {
+ case a: AttributeReference if aggExprIds.contains(a.exprId) =>
+ a.withName("_aggregateexpression")
+ }.asInstanceOf[NamedExpression])
+ // Rewrite the projection and the aggregate separately and then piece
them together.
+ val newAgg = Aggregate(groupingExpressions, groupingExpressions ++
aggExprs, child)
+ val newProj = Project(resExprs, newAgg)
+ handleUnaryNode(newProj)
+
case u: UnaryNode if u.expressions.exists(
- SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
- var newChild = u.child
- u.mapExpressions(expr => {
- val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
- newChild = p
- // The newExpr can not be None
- newExpr.get
- }).withNewChildren(Seq(newChild))
+ SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
handleUnaryNode(u)
+ }
+
+ /**
+ * Handle the unary node case
+ */
+ private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
+ var newChild = u.child
+ u.mapExpressions(expr => {
+ val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
+ newChild = p
+ // The newExpr can not be None
+ newExpr.get
+ }).withNewChildren(Seq(newChild))
}
/**
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
index 17547bbcb940..c45a761353c8 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
+import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.LongType
class RewriteSubquerySuite extends PlanTest {
@@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
Optimize.executeAndTrack(query.analyze, tracker)
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations
== 0)
}
+
+ test("SPARK-50091: Don't put aggregate expression in join condition") {
+ val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
+ val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
+ val plan =
relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
+ val optimized = Optimize.execute(plan.analyze)
+ val aggregate = relation2
+ .select($"col2")
+ .groupBy()(sum($"col2").as("_aggregateexpression"))
+ val correctAnswer = aggregate
+ .join(relation1.select(Cast($"c3", LongType).as("c3")),
+ ExistenceJoin($"exists".boolean.withNullability(false)),
+ Some($"_aggregateexpression" === $"c3"))
+ .select($"exists".as("(sum(col2) IN (listquery()))")).analyze
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 260c992f1aed..04702201f82f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest
checkAnswer(df3, Row(7))
}
}
+
+ test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
+ withView("v1", "v2") {
+ Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8))
+ .toDF("c1", "c2", "c3")
+ .createOrReplaceTempView("v1")
+ Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1))
+ .toDF("col1", "col2", "col3")
+ .createOrReplaceTempView("v2")
+
+ val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2
GROUP BY col1")
+ checkAnswer(df1,
+ Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil)
+
+ val df2 = sql("""SELECT
+ | col1,
+ | SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN
(SELECT c2 FROM v1) AS x
+ |FROM v2 GROUP BY col1
+ |ORDER BY col1""".stripMargin)
+ checkAnswer(df2,
+ Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
+
+ val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2
FROM v1) AS x
+ |FROM v2
+ |GROUP BY col1
+ |ORDER BY col1""".stripMargin)
+ checkAnswer(df3,
+ Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]