cloud-fan commented on code in PR #48627:
URL: https://github.com/apache/spark/pull/48627#discussion_r1903631161


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala:
##########
@@ -246,46 +267,242 @@ object RewritePredicateSubquery extends 
Rule[LogicalPlan] with PredicateHelper {
         }
       }
 
-    case u: UnaryNode if u.expressions.exists(
-        SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
-      var newChild = u.child
-      var introducedAttrs = Seq.empty[Attribute]
-      val updatedNode = u.mapExpressions(expr => {
-        val (newExpr, p, newAttrs) = 
rewriteExistentialExprWithAttrs(Seq(expr), newChild)
-        newChild = p
-        introducedAttrs ++= newAttrs
-        // The newExpr can not be None
-        newExpr.get
-      }).withNewChildren(Seq(newChild))
-      updatedNode match {
-        case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
-          // If we have introduced new `exists`-attributes that are referenced 
by
-          // aggregateExpressions within a non-aggregateFunction expression, 
we wrap them in
-          // first() aggregate function. first() is Spark's executable version 
of any_value()
-          // aggregate function.
-          // We do this to keep the aggregation valid, i.e avoid references 
outside of aggregate
-          // functions that are not in grouping expressions.
-          // Note that the same `exists` attr will never appear in 
groupingExpressions due to
-          // PullOutGroupingExpressions rule.
-          // Also note: the value of `exists` is functionally determined by 
grouping expressions,
-          // so applying any aggregate function is semantically safe.
-          val aggFunctionReferences = a.aggregateExpressions.
-            flatMap(extractAggregateExpressions).
-            flatMap(_.references).toSet
-          val nonAggFuncReferences =
-            
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
-          val toBeWrappedExistsAttrs = 
introducedAttrs.filter(nonAggFuncReferences.contains)
-
-          // Replace all eligible `exists` by `First(exists)` among 
aggregateExpressions.
-          val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
-            aggExpr.transformUp {
-              case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
-                new First(attr).toAggregateExpression()
-            }.asInstanceOf[NamedExpression]
-          }
-          a.copy(aggregateExpressions = newAggregateExpressions)
-        case _ => updatedNode
+    // Handle the case where the left-hand side of an IN-subquery contains an 
aggregate.
+    //
+    // This handler pulls up any expression containing such an IN-subquery 
into a new Project
+    // node, replacing aggregate expressions with attributes. The new Project 
node will be
+    // handled by the Unary node handler.
+    //
+    // The Unary node handler uses the left-hand side of the IN-subquery in a
+    // join condition. Thus, without this pre-transformation, the join 
condition
+    // contains an aggregate, which is illegal. With this pre-transformation, 
the
+    // join condition contains an attribute from the left-hand side of the
+    // IN-subquery contained in the Project node.
+    //
+    // For example:
+    //
+    //   SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x
+    //   FROM v2 GROUP BY col1;
+    //
+    // The above query has this plan on entry to 
RewritePredicateSubquery#apply:
+    //
+    //   Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25]
+    //   :  +- LocalRelation [c2#35L]
+    //   +- LocalRelation [col1#28, col2#29]
+    //
+    // Note that the Aggregate node contains the IN-subquery and the left-hand
+    // side of the IN-subquery is an aggregate expression (sum(col2#29)).
+    //
+    // This handler transforms the above plan into the following:
+    //
+    //   Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25]
+    //   :  +- LocalRelation [c2#35L]
+    //   +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L]
+    //      +- LocalRelation [col1#28, col2#29]
+    //
+    // The transformation pulled the IN-subquery up into a Project. The 
left-hand side of the
+    // IN-subquery is now an attribute (sum(col2)#36L) that refers to the 
actual aggregation
+    // which is still performed in the Aggregate node (sum(col2#29) AS 
sum(col2)#36L). The Unary
+    // node handler will use that attribute in the join condition (rather than 
the aggregate
+    // expression).
+    //
+    // If the IN-subquery is nested in a larger expression, that entire larger
+    // expression is pulled up into the Project. For example:
+    //
+    //   SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
+    //   FROM v2;
+    //
+    // The input to RewritePredicateSubquery#apply is the following plan:
+    //
+    //   Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS 
x#29]
+    //   :  +- LocalRelation [c3#44L]
+    //   +- LocalRelation [col2#34, col3#35]
+    //
+    // This handler transforms the plan into:
+    //
+    //   Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS 
x#29]
+    //   :  +- LocalRelation [c3#44L]
+    //   +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS 
sum(col3)#46L]
+    //      +- LocalRelation [col2#34, col3#35]
+    //
+    // Note that the entire AND expression was pulled up into the Project, but 
the Aggregate
+    // node continues to perform the aggregations (but without the IN-subquery 
expression).
+    case a: Aggregate if 
exprsContainsAggregateInSubquery(a.aggregateExpressions) =>
+      // Find any interesting expressions from Aggregate.aggregateExpressions.
+      //
+      // An interesting expression is one that contains an IN-subquery whose 
left-hand
+      // operand contains aggregates. For example:
+      //
+      //   SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
+      //   FROM v2 GROUP BY col1;
+      //
+      // withInSubquery will be a List containing a single Alias expression:
+      //
+      //   List(sum(col2#12) IN (list#8 []) AS (...)#19)
+      val withInSubquery = 
a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_))

Review Comment:
   BTW I think it's better to always build the query plan tree with this 
normalized form (`Aggregate` should only do grouping and aggregating, 
projection should always happen in `Project`), but this is a much bigger topic.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to