andylam-db commented on code in PR #42725:
URL: https://github.com/apache/spark/pull/42725#discussion_r1309371876
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala:
##########
@@ -158,6 +159,50 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
}
+ // This case takes care of predicate subqueries in join conditions that
are not pushed down
+ // to the children nodes by [[PushDownPredicates]].
+ case j: Join if j.condition.exists(cond =>
+ SubqueryExpression.hasInOrCorrelatedExistsSubquery(cond)) &&
+ conf.getConf(DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION) =>
+ var (newLeft, newRight) = (j.left, j.right)
+ val (withSubquery, withoutSubquery) =
splitConjunctivePredicates(j.condition.get)
+ .partition(SubqueryExpression.hasInOrCorrelatedExistsSubquery)
+ // Using `transformUp` (instead of `transformDown`) is important here
because we throw an
+ // exception if `expr` references both join inputs.
+ // For example, (x1 = y1) OR x2 in (...) references both sides.
+ // `transformUp` will tranform the subquery first, while `transformDown`
will try to transform
+ // the OR expression first and throw exception.
+ val newSubqueryPredicates = withSubquery.map(_.transformUp {
+ case expr =>
+ val referenceLeft =
expr.references.intersect(j.left.outputSet).nonEmpty
+ val referenceRight =
expr.references.intersect(j.right.outputSet).nonEmpty
+ if (referenceLeft && referenceRight &&
+ SubqueryExpression.hasInOrCorrelatedExistsSubquery(expr)) {
+ throw new IllegalStateException(
+ s"Unable to optimize predicate subquery in join condition
references both " +
+ s"join children.")
+ } else if (referenceLeft) {
+ val (newCond, newInputPlan) = rewriteExistentialExpr(Seq(expr),
newLeft)
+ newLeft = newInputPlan
+ newCond.get
+ } else if (referenceRight) {
+ val (newCond, newInputPlan) = rewriteExistentialExpr(Seq(expr),
newRight)
+ newRight = newInputPlan
+ newCond.get
+ } else {
+ expr
+ }
+ })
+ val withoutSubqueryPredicate = withoutSubquery.reduceOption(And)
+ val withSubqueryPredicate = newSubqueryPredicates.reduceOption(And)
+ val newCondition = (withoutSubqueryPredicate, withSubqueryPredicate)
match {
+ case (Some(a), Some(b)) => Some(And(a, b))
+ case (Some(a), None) => Some(a)
+ case (None, Some(b)) => Some(b)
+ case (None, None) => None
Review Comment:
Nice! changed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]