cloud-fan commented on code in PR #48627:
URL: https://github.com/apache/spark/pull/48627#discussion_r1901465121


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala:
##########
@@ -246,6 +267,197 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
         }
       }
 
+    // Handle the case where the left-hand side of an IN-subquery contains an 
aggregate.
+    //
+    // This handler pulls up any expression containing such an IN-subquery 
into a new Project
+    // node, replacing aggregate expressions with attributes, and then 
re-enters
+    // RewritePredicateSubquery#apply, where the new Project node will be 
handled
+    // by the Unary node handler.
+    //
+    // The Unary node handler uses the left-hand side of the IN-subquery in a
+    // join condition. Thus, without this pre-transformation, the join 
condition
+    // contains an aggregate, which is illegal. With this pre-transformation, 
the
+    // join condition contains an attribute from the left-hand side of the
+    // IN-subquery contained in the Project node.
+    //
+    // For example:
+    //
+    //   SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x
+    //   FROM v2 GROUP BY col1;
+    //
+    // The above query has this plan on entry to 
RewritePredicateSubquery#apply:
+    //
+    //   Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25]
+    //   :  +- LocalRelation [c2#35L]
+    //   +- LocalRelation [col1#28, col2#29]
+    //
+    // Note that the Aggregate node contains the IN-subquery and the left-hand
+    // side of the IN-subquery is an aggregate expression (sum(col2#29)).
+    //
+    // This handler transforms the above plan into the following:
+    //
+    //   Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25]
+    //   :  +- LocalRelation [c2#35L]
+    //   +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L]
+    //      +- LocalRelation [col1#28, col2#29]
+    //
+    // The transformation pulled the IN-subquery up into a Project. The 
left-hand side of the
+    // IN-subquery is now an attribute (sum(col2)#36L) that refers to the 
actual aggregation
+    // which is still performed in the Aggregate node (sum(col2#29) AS 
sum(col2)#36L). The Unary
+    // node handler will use that attribute in the join condition (rather than 
the aggregate
+    // expression).
+    //
+    // If the IN-subquery is nested in a larger expression, that entire larger
+    // expression is pulled up into the Project. For example:
+    //
+    //   SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
+    //   FROM v2;
+    //
+    // The input to RewritePredicateSubquery#apply is the following plan:
+    //
+    //   Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS 
x#29]
+    //   :  +- LocalRelation [c3#44L]
+    //   +- LocalRelation [col2#34, col3#35]
+    //
+    // This handler transforms the plan into:
+    //
+    //   Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS 
x#29]
+    //   :  +- LocalRelation [c3#44L]
+    //   +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS 
sum(col3)#46L]
+    //      +- LocalRelation [col2#34, col3#35]
+    //
+    // Note that the entire AND expression was pulled up into the Project, but 
the Aggregate
+    // node continues to perform the aggregations (but without the IN-subquery 
expression).
+    case a: Aggregate if 
exprsContainsAggregateInSubquery(a.aggregateExpressions) =>
+      // Find any interesting expressions from Aggregate.aggregateExpressions.
+      //
+      // An interesting expression is one that contains an IN-subquery whose 
left-hand
+      // operand contains aggregates. For example:
+      //
+      //   SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
+      //   FROM v2 GROUP BY col1;
+      //
+      // withInSubquery will be a List containing a single Alias expression:
+      //
+      //   List(sum(col2#12) IN (list#8 []) AS (...)#19)
+      val withInSubquery = 
a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_))
+
+      // Extract the aggregate expressions from each interesting expression. 
This will include
+      // any aggregate expressions that were not part of the IN-subquery but 
were part
+      // of the larger containing expression.
+      val inSubqueryMapping = withInSubquery.map { e =>
+        (e, extractAggregateExpressions(e))
+      }
+
+      // Map each interesting expression to its contained aggregate 
expressions.
+      //
+      // Example #1:
+      //
+      //  SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
+      //  FROM v2 GROUP BY col1;
+      //
+      // inSubqueryMap will have a single entry mapping an Alias expression to 
a Vector
+      // with a single aggregate expression:
+      //
+      //   Map(
+      //     sum(col2#100) IN (list []) AS (...)#107 -> Vector(sum(col2#100))
+      //   )
+      //
+      // Example #2:
+      //
+      //   SELECT (SUM(col1), SUM(col2)) IN (SELECT c1, c2 FROM v1)
+      //   FROM v2;
+      //
+      // inSubqueryMap will have a single entry mapping an Alias expression to 
a Vector
+      // with two aggregate expressions:
+      //
+      //   Map(
+      //     named_struct(_0, sum(col1#169), _1, sum(col2#170)) IN (list#166 
[]) AS (...)#179
+      //       -> Vector(sum(col1#169), sum(col2#170))
+      //   )
+      //
+      // Example #3:
+      //
+      // select SUM(col1) IN (SELECT c1 FROM v1), SUM(col2) IN (SELECT c2 FROM 
v1)
+      // FROM v2;
+      //
+      // inSubqueryMap will have two entries, each mapping an Alias expression 
to a Vector
+      // with a single aggregate expression:
+      //
+      //   Map(
+      //     sum(col1#193) IN (list#189 []) AS (...)#207 -> 
Vector(sum(col1#193)),
+      //     sum(col2#194) IN (list#190 []) AS (...)#208 -> 
Vector(sum(col2#194))
+      //   )
+      //
+      // Example #5:
+      //
+      //   SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
+      //   FROM v2;
+      //
+      // inSubqueryMap will contain a single AND expression that maps to two 
aggregate
+      // expressions, even though only one of those aggregate expressions is 
used as
+      // the left-hand operand of the IN-subquery expression.
+      //
+      //   Map(
+      //     (sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29
+      //       -> Vector(sum(col2#34), sum(col3#35))
+      //   )
+      //
+      // The keys of inSubqueryMap will be used to determine which expressions 
in
+      // the old Aggregate node are interesting. The values of inSubqueryMap, 
after
+      // being wrapped in Alias expressions, will replace their associated 
interesting
+      // expressions in a new Aggregate node.
+      val inSubqueryMap = inSubqueryMapping.toMap
+
+      // Get all aggregate expressions associated with interesting expressions.
+      val aggregateExprs = inSubqueryMapping.flatMap(_._2)
+      // Create aliases for each above aggregate expression. We can't use the 
aggregate
+      // expressions directly in the new Aggregate node because 
Aggregate.aggregateExpressions
+      // has the type Seq[NamedExpression].
+      val aggregateExprAliases = aggregateExprs.map(a => Alias(a, 
toPrettySQL(a))())
+      // Create a mapping from each aggregate expression to its alias.
+      val aggregateExprAliasMap = 
aggregateExprs.zip(aggregateExprAliases).toMap
+      // Create attributes from those aliases of aggregate expressions. These 
attributes
+      // will be used in the new Project node to refer to the aliased 
aggregate expressions
+      // in the new Aggregate node.
+      val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute)
+      // Create a mapping from aggregate expressions to attributes. This will 
be
+      // used when patching the interesting expressions after they are pulled 
up
+      // into the new Project node: aggregate expressions will be replaced by 
attributes.
+      val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap
+
+      // Create an Aggregate node without the interesting expressions, just
+      // the associated aggregate expressions plus any other group-by or 
aggregate expressions
+      // that were not involved in the interesting expressions.
+      val newAggregateExpressions = a.aggregateExpressions.flatMap {
+        // If this expression contains IN-subqueries with aggregates in the 
left-hand
+        // operand, replace with just the aggregates.
+        case ae: Expression if inSubqueryMap.contains(ae) =>
+          // Replace the expression with an aliased aggregate expression.
+          inSubqueryMap(ae).map(aggregateExprAliasMap(_))
+        case ae => Seq(ae)
+      }
+      val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions)
+
+      // Create a projection with the IN-subquery expressions that contain 
aggregates, replacing
+      // the aggregate expressions with attribute references to the output of 
the new Aggregate
+      // operator. Also include the other output of the Aggregate operator.
+      val projList = a.aggregateExpressions.map {
+        // If this expression contains an IN-subquery that uses an aggregate, 
we
+        // need to do something special
+        case ae: Expression if inSubqueryMap.contains(ae) =>
+          ae.transform {
+            // Patch any aggregate expression with its corresponding attribute.
+            case a: AggregateExpression => aggregateExprAttrMap(a)
+          }.asInstanceOf[NamedExpression]
+        case ae => ae.toAttribute
+      }
+      val newProj = Project(projList, newAggregate)
+
+      // Reapply this rule, but now with all interesting expressions
+      // from Aggregate.aggregateExpressions pulled up into a Project node.
+      apply(newProj)

Review Comment:
   This reminds me of the rule `RewriteWithExpression`, which also needs to 
rewrite `Aggregate` first. We should not call `apply` here in the middle of 
plan traveral, as `apply` transforms the plan again, and leads to O(n^2) 
complexity. Instead of, we should also add a util function that rewrites 
`UnaryNode` (not transforms the full tree) and call it here and the original 
case match for `UnaryNode`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to