jchen5 commented on code in PR #41301:
URL: https://github.com/apache/spark/pull/41301#discussion_r1214962726


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
             (d.copy(child = newChild), joinCond, outerReferenceMap)
 
           case j @ Join(left, right, joinType, condition, _) =>
-            val outerReferences = collectOuterReferences(j.expressions)
-            // Join condition containing outer references is not supported.
-            assert(outerReferences.isEmpty, s"Correlated column is not allowed 
in join: $j")
-            val newOuterReferences = parentOuterReferences ++ outerReferences
-            val shouldPushToLeft = joinType match {
+            def splitCorrelatedPredicate(condition: Option[Expression],
+                                         isInnerJoin: Boolean,
+                                         shouldDecorrelatePredicates: Boolean):
+            (Seq[Expression], Seq[Expression], Seq[Expression],
+              Seq[Expression], AttributeMap[Attribute]) = {
+              // Similar to Filters above, we split the join condition (if 
present) into correlated
+              // and uncorrelated predicates, and separately handle joins 
under set and aggregation
+              // operations.
+              if (shouldDecorrelatePredicates) {
+                val conditions =
+                  if (condition.isDefined) 
splitConjunctivePredicates(condition.get)
+                  else Seq.empty[Expression]
+                val (correlated, uncorrelated) = 
conditions.partition(containsOuter)
+                val equivalences =
+                  if (underSetOp) AttributeMap.empty[Attribute]
+                  else collectEquivalentOuterReferences(correlated)
+                var (equalityCond, predicates) =
+                  if (underSetOp) (Seq.empty[Expression], correlated)
+                  else correlated.partition(canPullUpOverAgg)

Review Comment:
   I think we can do
   ```
   else if (aggregated) collectEquivalentOuterReferences(correlated)
   else correlated
   ```
   so that if it's not aggregated, we can directly use all the correlated 
predicates without adding a DomainJoin, equivalent to the Filter logic around
   ```
                 // Results of this sub-tree is not aggregated, so all 
correlated predicates
                 // can be directly used as outer query join conditions.
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
             (d.copy(child = newChild), joinCond, outerReferenceMap)
 
           case j @ Join(left, right, joinType, condition, _) =>
-            val outerReferences = collectOuterReferences(j.expressions)
-            // Join condition containing outer references is not supported.
-            assert(outerReferences.isEmpty, s"Correlated column is not allowed 
in join: $j")
-            val newOuterReferences = parentOuterReferences ++ outerReferences
-            val shouldPushToLeft = joinType match {
+            def splitCorrelatedPredicate(condition: Option[Expression],
+                                         isInnerJoin: Boolean,
+                                         shouldDecorrelatePredicates: Boolean):
+            (Seq[Expression], Seq[Expression], Seq[Expression],
+              Seq[Expression], AttributeMap[Attribute]) = {
+              // Similar to Filters above, we split the join condition (if 
present) into correlated
+              // and uncorrelated predicates, and separately handle joins 
under set and aggregation
+              // operations.
+              if (shouldDecorrelatePredicates) {
+                val conditions =
+                  if (condition.isDefined) 
splitConjunctivePredicates(condition.get)
+                  else Seq.empty[Expression]
+                val (correlated, uncorrelated) = 
conditions.partition(containsOuter)
+                val equivalences =
+                  if (underSetOp) AttributeMap.empty[Attribute]
+                  else collectEquivalentOuterReferences(correlated)
+                var (equalityCond, predicates) =
+                  if (underSetOp) (Seq.empty[Expression], correlated)
+                  else correlated.partition(canPullUpOverAgg)
+                // Fully preserve the join predicate for non-inner joins.
+                if (!isInnerJoin) {
+                  predicates = predicates ++ equalityCond

Review Comment:
   Maybe slightly better to set `predicates = correlated`, to preserve the 
ordering of the original predicates?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
             (d.copy(child = newChild), joinCond, outerReferenceMap)
 
           case j @ Join(left, right, joinType, condition, _) =>
-            val outerReferences = collectOuterReferences(j.expressions)
-            // Join condition containing outer references is not supported.
-            assert(outerReferences.isEmpty, s"Correlated column is not allowed 
in join: $j")
-            val newOuterReferences = parentOuterReferences ++ outerReferences
-            val shouldPushToLeft = joinType match {
+            def splitCorrelatedPredicate(condition: Option[Expression],
+                                         isInnerJoin: Boolean,
+                                         shouldDecorrelatePredicates: Boolean):
+            (Seq[Expression], Seq[Expression], Seq[Expression],
+              Seq[Expression], AttributeMap[Attribute]) = {

Review Comment:
   Can you add a comment that the results are (correlated, uncorrelated, 
equalityCond, predicates, equivalences) and maybe a brief description of them



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
             (d.copy(child = newChild), joinCond, outerReferenceMap)
 
           case j @ Join(left, right, joinType, condition, _) =>
-            val outerReferences = collectOuterReferences(j.expressions)
-            // Join condition containing outer references is not supported.
-            assert(outerReferences.isEmpty, s"Correlated column is not allowed 
in join: $j")
-            val newOuterReferences = parentOuterReferences ++ outerReferences
-            val shouldPushToLeft = joinType match {
+            def splitCorrelatedPredicate(condition: Option[Expression],
+                                         isInnerJoin: Boolean,
+                                         shouldDecorrelatePredicates: Boolean):
+            (Seq[Expression], Seq[Expression], Seq[Expression],
+              Seq[Expression], AttributeMap[Attribute]) = {
+              // Similar to Filters above, we split the join condition (if 
present) into correlated
+              // and uncorrelated predicates, and separately handle joins 
under set and aggregation
+              // operations.
+              if (shouldDecorrelatePredicates) {
+                val conditions =
+                  if (condition.isDefined) 
splitConjunctivePredicates(condition.get)
+                  else Seq.empty[Expression]
+                val (correlated, uncorrelated) = 
conditions.partition(containsOuter)
+                val equivalences =
+                  if (underSetOp) AttributeMap.empty[Attribute]
+                  else collectEquivalentOuterReferences(correlated)
+                var (equalityCond, predicates) =
+                  if (underSetOp) (Seq.empty[Expression], correlated)
+                  else correlated.partition(canPullUpOverAgg)
+                // Fully preserve the join predicate for non-inner joins.
+                if (!isInnerJoin) {
+                  predicates = predicates ++ equalityCond
+                }
+                (correlated, uncorrelated, equalityCond, predicates, 
equivalences)
+              } else {
+                (Seq.empty[Expression],
+                  if (condition.isEmpty) Seq.empty[Expression] else 
Seq(condition.get),
+                  Seq.empty[Expression],
+                  Seq.empty[Expression],
+                  AttributeMap.empty[Attribute])
+              }
+            }
+
+            val shouldDecorrelatePredicates =
+              SQLConf.get.getConf(SQLConf.DECORRELATE_JOIN_PREDICATE_ENABLED)
+            if (!shouldDecorrelatePredicates) {
+              val outerReferences = collectOuterReferences(j.expressions)
+              // Join condition containing outer references is not supported.
+              assert(outerReferences.isEmpty, s"Correlated column is not 
allowed in join: $j")
+            }
+            val (correlated, uncorrelated, equalityCond, predicates, 
equivalences) =
+              splitCorrelatedPredicate(condition, joinType == Inner, 
shouldDecorrelatePredicates)
+            val outerReferences = collectOuterReferences(j.expressions) ++
+              collectOuterReferences(predicates)
+            val newOuterReferences =
+              parentOuterReferences ++ outerReferences -- equivalences.keySet
+            var shouldPushToLeft = joinType match {
               case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
               case _ => hasOuterReferences(left)
             }
             val shouldPushToRight = joinType match {
               case RightOuter | FullOuter => true
               case _ => hasOuterReferences(right)
             }
+            if (shouldDecorrelatePredicates && !shouldPushToLeft && 
!shouldPushToRight
+              && !correlated.isEmpty) {

Review Comment:
   Can we change this check from correlated.isEmpty to predicates.isEmpty? I.e. 
if all the correlated predicates are in equalityCond and we can directly use 
them as join conditions, then I think we shouldn't need to add a DomainJoin.



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala:
##########
@@ -454,4 +460,85 @@ class DecorrelateInnerQuerySuite extends PlanTest {
             DomainJoin(Seq(x), testRelation))))
     check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x))
   }
+
+  test("SPARK-43780: aggregation in subquery with correlated equi-join") {
+    // Join in the subquery is on equi-predicates, so all the correlated 
references can be
+    // substituted by equivalent ones from the outer query, and domain join is 
not needed.
+    val outerPlan = testRelation
+    val innerPlan =
+      Aggregate(
+        Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, Inner,
+            Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))
+
+    val correctAnswer =
+      Aggregate(
+        Seq(y), Seq(Alias(count(Literal(1)), "a")(), y),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, Inner, Some(x === a3), 
JoinHint.NONE)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(y === a))
+  }
+
+  test("SPARK-43780: aggregation in subquery with correlated non-equi-join") {
+    // Join in the subquery is on non-equi-predicate, so we introduce a 
DomainJoin.
+    val outerPlan = testRelation
+    val innerPlan =
+      Aggregate(
+        Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, Inner,
+            Some(And(x === a3, y > OuterReference(a))), JoinHint.NONE)))
+    val correctAnswer =
+      Aggregate(
+        Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
+        Project(Seq(x, y, a3, b3, a),
+          Join(
+            DomainJoin(Seq(a), testRelation2),
+            testRelation3, Inner, Some(And(x === a3, y > a)), JoinHint.NONE)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
+  }
+
+  test("SPARK-43780: aggregation in subquery with correlated left join") {
+    // Join in the subquery is on equi-predicates, so all the correlated 
references can be
+    // substituted by equivalent ones from the outer query, and domain join is 
not needed.
+    val outerPlan = testRelation
+    val innerPlan =
+      Aggregate(
+        Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, LeftOuter,
+            Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))

Review Comment:
   Can we also add a test with left join where the correlated predicate 
involves the right (outer) side? Then we would need to have a DomainJoin, right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to