cloud-fan commented on code in PR #48627:
URL: https://github.com/apache/spark/pull/48627#discussion_r1901465121
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala:
##########
@@ -246,6 +267,197 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
}
}
+ // Handle the case where the left-hand side of an IN-subquery contains an
aggregate.
+ //
+ // This handler pulls up any expression containing such an IN-subquery
into a new Project
+ // node, replacing aggregate expressions with attributes, and then
re-enters
+ // RewritePredicateSubquery#apply, where the new Project node will be
handled
+ // by the Unary node handler.
+ //
+ // The Unary node handler uses the left-hand side of the IN-subquery in a
+ // join condition. Thus, without this pre-transformation, the join
condition
+ // contains an aggregate, which is illegal. With this pre-transformation,
the
+ // join condition contains an attribute from the left-hand side of the
+ // IN-subquery contained in the Project node.
+ //
+ // For example:
+ //
+ // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x
+ // FROM v2 GROUP BY col1;
+ //
+ // The above query has this plan on entry to
RewritePredicateSubquery#apply:
+ //
+ // Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25]
+ // : +- LocalRelation [c2#35L]
+ // +- LocalRelation [col1#28, col2#29]
+ //
+ // Note that the Aggregate node contains the IN-subquery and the left-hand
+ // side of the IN-subquery is an aggregate expression (sum(col2#29)).
+ //
+ // This handler transforms the above plan into the following:
+ //
+ // Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25]
+ // : +- LocalRelation [c2#35L]
+ // +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L]
+ // +- LocalRelation [col1#28, col2#29]
+ //
+ // The transformation pulled the IN-subquery up into a Project. The
left-hand side of the
+ // IN-subquery is now an attribute (sum(col2)#36L) that refers to the
actual aggregation
+ // which is still performed in the Aggregate node (sum(col2#29) AS
sum(col2)#36L). The Unary
+ // node handler will use that attribute in the join condition (rather than
the aggregate
+ // expression).
+ //
+ // If the IN-subquery is nested in a larger expression, that entire larger
+ // expression is pulled up into the Project. For example:
+ //
+ // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
+ // FROM v2;
+ //
+ // The input to RewritePredicateSubquery#apply is the following plan:
+ //
+ // Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS
x#29]
+ // : +- LocalRelation [c3#44L]
+ // +- LocalRelation [col2#34, col3#35]
+ //
+ // This handler transforms the plan into:
+ //
+ // Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS
x#29]
+ // : +- LocalRelation [c3#44L]
+ // +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS
sum(col3)#46L]
+ // +- LocalRelation [col2#34, col3#35]
+ //
+ // Note that the entire AND expression was pulled up into the Project, but
the Aggregate
+ // node continues to perform the aggregations (but without the IN-subquery
expression).
+ case a: Aggregate if
exprsContainsAggregateInSubquery(a.aggregateExpressions) =>
+ // Find any interesting expressions from Aggregate.aggregateExpressions.
+ //
+ // An interesting expression is one that contains an IN-subquery whose
left-hand
+ // operand contains aggregates. For example:
+ //
+ // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
+ // FROM v2 GROUP BY col1;
+ //
+ // withInSubquery will be a List containing a single Alias expression:
+ //
+ // List(sum(col2#12) IN (list#8 []) AS (...)#19)
+ val withInSubquery =
a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_))
+
+ // Extract the aggregate expressions from each interesting expression.
This will include
+ // any aggregate expressions that were not part of the IN-subquery but
were part
+ // of the larger containing expression.
+ val inSubqueryMapping = withInSubquery.map { e =>
+ (e, extractAggregateExpressions(e))
+ }
+
+ // Map each interesting expression to its contained aggregate
expressions.
+ //
+ // Example #1:
+ //
+ // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
+ // FROM v2 GROUP BY col1;
+ //
+ // inSubqueryMap will have a single entry mapping an Alias expression to
a Vector
+ // with a single aggregate expression:
+ //
+ // Map(
+ // sum(col2#100) IN (list []) AS (...)#107 -> Vector(sum(col2#100))
+ // )
+ //
+ // Example #2:
+ //
+ // SELECT (SUM(col1), SUM(col2)) IN (SELECT c1, c2 FROM v1)
+ // FROM v2;
+ //
+ // inSubqueryMap will have a single entry mapping an Alias expression to
a Vector
+ // with two aggregate expressions:
+ //
+ // Map(
+ // named_struct(_0, sum(col1#169), _1, sum(col2#170)) IN (list#166
[]) AS (...)#179
+ // -> Vector(sum(col1#169), sum(col2#170))
+ // )
+ //
+ // Example #3:
+ //
+ // select SUM(col1) IN (SELECT c1 FROM v1), SUM(col2) IN (SELECT c2 FROM
v1)
+ // FROM v2;
+ //
+ // inSubqueryMap will have two entries, each mapping an Alias expression
to a Vector
+ // with a single aggregate expression:
+ //
+ // Map(
+ // sum(col1#193) IN (list#189 []) AS (...)#207 ->
Vector(sum(col1#193)),
+ // sum(col2#194) IN (list#190 []) AS (...)#208 ->
Vector(sum(col2#194))
+ // )
+ //
+ // Example #5:
+ //
+ // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
+ // FROM v2;
+ //
+ // inSubqueryMap will contain a single AND expression that maps to two
aggregate
+ // expressions, even though only one of those aggregate expressions is
used as
+ // the left-hand operand of the IN-subquery expression.
+ //
+ // Map(
+ // (sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29
+ // -> Vector(sum(col2#34), sum(col3#35))
+ // )
+ //
+ // The keys of inSubqueryMap will be used to determine which expressions
in
+ // the old Aggregate node are interesting. The values of inSubqueryMap,
after
+ // being wrapped in Alias expressions, will replace their associated
interesting
+ // expressions in a new Aggregate node.
+ val inSubqueryMap = inSubqueryMapping.toMap
+
+ // Get all aggregate expressions associated with interesting expressions.
+ val aggregateExprs = inSubqueryMapping.flatMap(_._2)
+ // Create aliases for each above aggregate expression. We can't use the
aggregate
+ // expressions directly in the new Aggregate node because
Aggregate.aggregateExpressions
+ // has the type Seq[NamedExpression].
+ val aggregateExprAliases = aggregateExprs.map(a => Alias(a,
toPrettySQL(a))())
+ // Create a mapping from each aggregate expression to its alias.
+ val aggregateExprAliasMap =
aggregateExprs.zip(aggregateExprAliases).toMap
+ // Create attributes from those aliases of aggregate expressions. These
attributes
+ // will be used in the new Project node to refer to the aliased
aggregate expressions
+ // in the new Aggregate node.
+ val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute)
+ // Create a mapping from aggregate expressions to attributes. This will
be
+ // used when patching the interesting expressions after they are pulled
up
+ // into the new Project node: aggregate expressions will be replaced by
attributes.
+ val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap
+
+ // Create an Aggregate node without the interesting expressions, just
+ // the associated aggregate expressions plus any other group-by or
aggregate expressions
+ // that were not involved in the interesting expressions.
+ val newAggregateExpressions = a.aggregateExpressions.flatMap {
+ // If this expression contains IN-subqueries with aggregates in the
left-hand
+ // operand, replace with just the aggregates.
+ case ae: Expression if inSubqueryMap.contains(ae) =>
+ // Replace the expression with an aliased aggregate expression.
+ inSubqueryMap(ae).map(aggregateExprAliasMap(_))
+ case ae => Seq(ae)
+ }
+ val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions)
+
+ // Create a projection with the IN-subquery expressions that contain
aggregates, replacing
+ // the aggregate expressions with attribute references to the output of
the new Aggregate
+ // operator. Also include the other output of the Aggregate operator.
+ val projList = a.aggregateExpressions.map {
+ // If this expression contains an IN-subquery that uses an aggregate,
we
+ // need to do something special
+ case ae: Expression if inSubqueryMap.contains(ae) =>
+ ae.transform {
+ // Patch any aggregate expression with its corresponding attribute.
+ case a: AggregateExpression => aggregateExprAttrMap(a)
+ }.asInstanceOf[NamedExpression]
+ case ae => ae.toAttribute
+ }
+ val newProj = Project(projList, newAggregate)
+
+ // Reapply this rule, but now with all interesting expressions
+ // from Aggregate.aggregateExpressions pulled up into a Project node.
+ apply(newProj)
Review Comment:
This reminds me of the rule `RewriteWithExpression`, which also needs to
rewrite `Aggregate` first. We should not call `apply` here in the middle of
plan traveral, as `apply` transforms the plan again, and leads to O(n^2)
complexity. Instead of, we should also add a util function that rewrites
`UnaryNode` (not transforms the full tree) and call it here and the original
case match for `UnaryNode`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]