jchen5 commented on code in PR #41301:
URL: https://github.com/apache/spark/pull/41301#discussion_r1273821132


##########
sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out:
##########
@@ -572,6 +572,36 @@ struct<c1:int,c2:int>
 0      1
 
 
+-- !query
+SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND 
t2.c1 = t1.c1)
+-- !query schema
+struct<c1:int,c2:int,c1:int,c2:int,c1:int,c2:int>
+-- !query output
+0      1       0       2       0       1
+0      1       0       2       0       2
+0      1       0       3       0       1
+0      1       0       3       0       2
+
+
+-- !query
+SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = 
t4.c1 AND t2.c1 = t1.c1)
+-- !query schema
+struct<c1:int,c2:int,c1:int,c2:int,c1:int,c2:int>
+-- !query output
+0      1       0       1       0       2
+0      1       0       1       0       3
+0      1       0       2       0       2
+0      1       0       2       0       3
+0      1       1       1       NULL    NULL
+0      1       1       3       NULL    NULL
+1      2       0       1       0       2

Review Comment:
   Looks like these results are incorrect - for example, this tuple has t1.c1 = 
1, t2.c1 = 0, t4.c1 = 0. So the result of t4 left join t2 should be a null row 
because t2.c1 = t1.c1 is false. 
   
   I checked out the PR and tried running it locally and actually got a 
different set of results which looks correct, so maybe you just need to update 
the golden file?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -826,8 +882,13 @@ object DecorrelateInnerQuery extends PredicateHelper {
             } else {
               (right, Nil, AttributeMap.empty[Attribute])
             }
-            val newOuterReferenceMap = leftOuterReferenceMap ++ 
rightOuterReferenceMap
-            val newJoinCond = leftJoinCond ++ rightJoinCond
+            val newOuterReferenceMap = leftOuterReferenceMap ++ 
rightOuterReferenceMap ++
+              equivalences
+            val newCorrelated =
+              if (shouldDecorrelatePredicates) {
+                replaceOuterReferences(predicates, newOuterReferenceMap)

Review Comment:
   Here we are pulling up the equalityCond predicates (predicates is everything 
that isn't in equalityCond) - they get moved from the join ON cond to the 
top-level join.
   
   I was thinking that it might not be safe to pull them up, we might need to 
leave them in the ON cond. It seems like there could be a problem if other 
columns of the outer table are referenced after the join, because whether they 
are null or not would depend on the outer join cond and we can't just pull it 
all the way to the top-level join.
   
   Example: what about inner query `t1 left join t2 on t1.x = t2.x and t2.y = 
outer(a) where t2.z is null` - looks like in the current code we'd pull the 
predicate up and it would become `t1 left join t2 on t1.x = t2.x where t2.z is 
null`, the set of rows that match the ON cond would change and so the `t2.z is 
null` filter would change.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
             (d.copy(child = newChild), joinCond, outerReferenceMap)
 
           case j @ Join(left, right, joinType, condition, _) =>
-            val outerReferences = collectOuterReferences(j.expressions)
-            // Join condition containing outer references is not supported.
-            assert(outerReferences.isEmpty, s"Correlated column is not allowed 
in join: $j")
-            val newOuterReferences = parentOuterReferences ++ outerReferences
-            val shouldPushToLeft = joinType match {
+            def splitCorrelatedPredicate(condition: Option[Expression],
+                                         isInnerJoin: Boolean,
+                                         shouldDecorrelatePredicates: Boolean):
+            (Seq[Expression], Seq[Expression], Seq[Expression],
+              Seq[Expression], AttributeMap[Attribute]) = {
+              // Similar to Filters above, we split the join condition (if 
present) into correlated
+              // and uncorrelated predicates, and separately handle joins 
under set and aggregation
+              // operations.
+              if (shouldDecorrelatePredicates) {
+                val conditions =
+                  if (condition.isDefined) 
splitConjunctivePredicates(condition.get)
+                  else Seq.empty[Expression]
+                val (correlated, uncorrelated) = 
conditions.partition(containsOuter)
+                val equivalences =
+                  if (underSetOp) AttributeMap.empty[Attribute]
+                  else collectEquivalentOuterReferences(correlated)
+                var (equalityCond, predicates) =
+                  if (underSetOp) (Seq.empty[Expression], correlated)
+                  else correlated.partition(canPullUpOverAgg)

Review Comment:
   Sorry, I mixed up some code in what I wrote, what I meant was like this:
   ```
   val (predicatesPulledUp, predicatesNotPulledUp) =
   if (underSetOp || !isInnerJoin) (Seq.empty[Expression], correlated)
   else if (aggregated) correlated.partition(canPullUpOverAgg)
   else (correlated, Seq.empty[Expression])
   ```
   HOWEVER on second thought I'm not sure if we actually can safely pull up 
those predicates for non-inner joins - see my other comment.
   
   But in the case of Filter or inner join conds, I think this would work. If 
we are under an aggregate, then only equality conds can be pulled up. But if 
we're not under an aggregate or set op, then we can pull all conds up.
   
   If we walk through the code in the Filter case for `if (aggregated || 
underSetOp)` compared to the `else` case, I think this works out to be 
equivalent. In the aggregate case, we add the pull-up-able predicates 
(equalityCond) to `newJoinCond` and the non-pull-upable predicates stay in 
`newFilter`. In `else` case (the non-aggregate non-setop case) we add all of 
`correlated` to `newJoinCond` and remove it all from `newFilter`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to