andylam-db commented on code in PR #42725:
URL: https://github.com/apache/spark/pull/42725#discussion_r1309371876


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala:
##########
@@ -158,6 +159,50 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
           val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
           Project(p.output, Filter(newCond.get, inputPlan))
       }
+    // This case takes care of predicate subqueries in join conditions that 
are not pushed down
+    // to the children nodes by [[PushDownPredicates]].
+    case j: Join if j.condition.exists(cond =>
+      SubqueryExpression.hasInOrCorrelatedExistsSubquery(cond)) &&
+      conf.getConf(DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION) =>
+      var (newLeft, newRight) = (j.left, j.right)
+      val (withSubquery, withoutSubquery) = 
splitConjunctivePredicates(j.condition.get)
+        .partition(SubqueryExpression.hasInOrCorrelatedExistsSubquery)
+      // Using `transformUp` (instead of `transformDown`) is important here 
because we throw an
+      // exception if `expr` references both join inputs.
+      // For example, (x1 = y1) OR x2 in (...) references both sides.
+      // `transformUp` will tranform the subquery first, while `transformDown` 
will try to transform
+      // the OR expression first and throw exception.
+      val newSubqueryPredicates = withSubquery.map(_.transformUp {
+        case expr =>
+          val referenceLeft = 
expr.references.intersect(j.left.outputSet).nonEmpty
+          val referenceRight = 
expr.references.intersect(j.right.outputSet).nonEmpty
+          if (referenceLeft && referenceRight &&
+            SubqueryExpression.hasInOrCorrelatedExistsSubquery(expr)) {
+            throw new IllegalStateException(
+              s"Unable to optimize predicate subquery in join condition 
references both " +
+                s"join children.")
+          } else if (referenceLeft) {
+            val (newCond, newInputPlan) = rewriteExistentialExpr(Seq(expr), 
newLeft)
+            newLeft = newInputPlan
+            newCond.get
+          } else if (referenceRight) {
+            val (newCond, newInputPlan) = rewriteExistentialExpr(Seq(expr), 
newRight)
+            newRight = newInputPlan
+            newCond.get
+          } else {
+            expr
+          }
+      })
+      val withoutSubqueryPredicate = withoutSubquery.reduceOption(And)
+      val withSubqueryPredicate = newSubqueryPredicates.reduceOption(And)
+      val newCondition = (withoutSubqueryPredicate, withSubqueryPredicate) 
match {
+        case (Some(a), Some(b)) => Some(And(a, b))
+        case (Some(a), None) => Some(a)
+        case (None, Some(b)) => Some(b)
+        case (None, None) => None

Review Comment:
   Nice! changed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to