jchen5 commented on code in PR #46839:
URL: https://github.com/apache/spark/pull/46839#discussion_r1623678864


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala:
##########
@@ -249,6 +250,73 @@ object SubExprUtils extends PredicateHelper {
       }
     }
   }
+
+  /**
+   * Returns the inner query attributes that are guaranteed to have a single 
value for each
+   * outer row. Therefore, a scalar subquery is allowed to group-by on these 
attributes.
+   * We can derive these from correlated equality predicates, though we need 
to take care about
+   * propagating this through operators like OUTER JOIN or UNION.
+   *
+   * Positive examples: x = outer(a) AND y = outer(b)
+   * Negative examples:
+   * - x <= outer(a)
+   * - x + y = outer(a)
+   * - x = outer(a) OR y = outer(b)
+   * - y = outer(b) + 1 (this and similar expressions could be supported, but 
very carefully)
+   * - An equality under the right side of a LEFT OUTER JOIN, e.g.
+   *   select *, (select count(*) from y left join
+   *     (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x;
+   * - An equality under UNION e.g.
+   *   select *, (select count(*) from
+   *     (select * from y where y1 = x1 union all select * from y) group by 
y1) from x;
+   */
+  def getCorrelatedEquivalentInnerColumns(plan: LogicalPlan): AttributeSet = {
+    plan match {
+      case Filter(cond, child) =>
+        val correlated = AttributeSet(splitConjunctivePredicates(cond)
+          .filter(containsOuter) // TODO: can remove this line to allow e.g. 
where x = 1 group by x
+          .filter(DecorrelateInnerQuery.canPullUpOverAgg)
+          .flatMap(_.references))
+        correlated ++ getCorrelatedEquivalentInnerColumns(child)
+
+      case Join(left, right, joinType, _, _) =>
+         joinType match {
+          case _: InnerLike =>
+            AttributeSet(plan.children.flatMap(child => 
getCorrelatedEquivalentInnerColumns(child)))
+          case LeftOuter => getCorrelatedEquivalentInnerColumns(left)
+          case RightOuter => getCorrelatedEquivalentInnerColumns(right)
+          case FullOuter => AttributeSet.empty
+          case LeftSemi => getCorrelatedEquivalentInnerColumns(left)
+          case LeftAnti => getCorrelatedEquivalentInnerColumns(left)
+          case _ => AttributeSet.empty
+        }
+
+      case _: Union => AttributeSet.empty
+      case Except(left, right, _) => getCorrelatedEquivalentInnerColumns(left)
+
+      case
+        _: Aggregate |
+        _: Distinct |
+        _: Intersect |
+        _: GlobalLimit |
+        _: LocalLimit |
+        _: Offset |
+        _: Project |
+        _: Repartition |
+        _: RepartitionByExpression |
+        _: RebalancePartitions |
+        _: Sample |
+        _: Sort |
+        _: Window |
+        _: Tail |
+        _: WithCTE |
+        _: Range |
+        _: SubqueryAlias =>
+        AttributeSet(plan.children.flatMap(child => 
getCorrelatedEquivalentInnerColumns(child)))
+
+      case _ => AttributeSet.empty

Review Comment:
   The list of operators handled here is by no means comprehensive and ensuring 
it covers enough is tricky. I used the list in LogicalPlanVisitor as a starting 
point, but in my testing I discovered that e.g. SubqueryAlias also needs to be 
handled to cover cases with FROM subqueries inside the scalar subquery.
   
   Suggestions on other important operators to handle or other potential 
approaches welcome. (In the long run I think we need to replace this entire 
check with a runtime check as described in 
https://issues.apache.org/jira/browse/SPARK-48501, but that's highly nontrivial)



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala:
##########
@@ -249,6 +250,73 @@ object SubExprUtils extends PredicateHelper {
       }
     }
   }
+
+  /**
+   * Returns the inner query attributes that are guaranteed to have a single 
value for each
+   * outer row. Therefore, a scalar subquery is allowed to group-by on these 
attributes.
+   * We can derive these from correlated equality predicates, though we need 
to take care about
+   * propagating this through operators like OUTER JOIN or UNION.
+   *
+   * Positive examples: x = outer(a) AND y = outer(b)
+   * Negative examples:
+   * - x <= outer(a)
+   * - x + y = outer(a)
+   * - x = outer(a) OR y = outer(b)
+   * - y = outer(b) + 1 (this and similar expressions could be supported, but 
very carefully)
+   * - An equality under the right side of a LEFT OUTER JOIN, e.g.
+   *   select *, (select count(*) from y left join
+   *     (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x;
+   * - An equality under UNION e.g.
+   *   select *, (select count(*) from
+   *     (select * from y where y1 = x1 union all select * from y) group by 
y1) from x;
+   */
+  def getCorrelatedEquivalentInnerColumns(plan: LogicalPlan): AttributeSet = {
+    plan match {
+      case Filter(cond, child) =>
+        val correlated = AttributeSet(splitConjunctivePredicates(cond)
+          .filter(containsOuter) // TODO: can remove this line to allow e.g. 
where x = 1 group by x

Review Comment:
   I intend to enable that in a separate PR, to reduce risk here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to