Github user davies commented on a diff in the pull request:
https://github.com/apache/spark/pull/12306#discussion_r60009438
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
---
@@ -1447,3 +1450,133 @@ object EmbedSerializerInFilter extends
Rule[LogicalPlan] {
}
}
}
+
+/**
+ * This rule rewrites predicate sub-queries into left semi/anti joins. The
following predicates
+ * are supported:
+ * a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved
conditions in Filter
+ * will be pulled out as the join conditions.
+ * b. IN/NOT IN will be rewritten as semi/anti join, unresolved conditions
in the Filter will
+ * be pulled out as join conditions, value = selected column will also
be used as join
+ * condition.
+ */
+object RewritePredicateSubquery extends Rule[LogicalPlan] with
PredicateHelper {
+ /**
+ * Pull out all correlated predicates from a given sub-query. This
method removes the correlated
+ * predicates from sub-query [[Filter]]s and adds the references of
these predicates to
+ * all intermediate [[Project]] clauses (if they are missing) in order
to be able to evaluate the
+ * predicates in the join condition.
+ *
+ * This method returns the rewritten sub-query and the combined (AND)
extracted predicate.
+ */
+ private def pullOutCorrelatedPredicates(
+ subquery: LogicalPlan,
+ query: LogicalPlan): (LogicalPlan, Option[Expression]) = {
+ val references: Set[Expression] = query.output.toSet
+ val predicateMap = mutable.Map.empty[LogicalPlan, Seq[Expression]]
+ val transformed = subquery transformUp {
+ case f @ Filter(cond, child) =>
+ // Find all correlated predicates.
+ val (correlated, local) =
splitConjunctivePredicates(cond).partition { e =>
+ e.find(references.contains).isDefined
+ }
+ // Rewrite the filter without the correlated predicates if any.
+ correlated match {
+ case Nil => f
+ case xs if local.nonEmpty =>
+ val newFilter = Filter(local.reduce(And), child)
+ predicateMap += newFilter -> correlated
+ newFilter
+ case xs =>
+ predicateMap += child -> correlated
+ child
+ }
+ case p @ Project(expressions, child) =>
+ // Find all pulled out predicates defined in the Project's subtree.
+ val localPredicates = p.collect(predicateMap).flatten
+
+ // Determine which correlated predicate references are missing
from this project.
+ val localPredicateReferences = localPredicates
+ .map(_.references)
+ .reduceOption(_ ++ _)
+ .getOrElse(AttributeSet.empty)
+ val missingReferences = localPredicateReferences -- p.references
-- query.outputSet
+
+ // Create a new project if we need to add missing references.
+ if (missingReferences.nonEmpty) {
+ Project(expressions ++ missingReferences, child)
+ } else {
+ p
+ }
+ }
+ (transformed, predicateMap.values.flatten.reduceOption(And))
+ }
+
+ /**
+ * Prepare an [[InSubQuery]] by rewriting it (in case of correlated
predicates) and by
+ * constructing the required join condition. Both the rewritten subquery
and the constructed
+ * join condition are returned.
+ */
+ private def rewriteInSubquery(
+ subquery: InSubQuery,
+ query: LogicalPlan): (LogicalPlan, Expression) = {
+ val expressions = subquery.expressions
+ val (resolved, joinCondition) =
pullOutCorrelatedPredicates(subquery.query, query)
+ val conditions = joinCondition.toSeq ++
expressions.zip(resolved.output).map(EqualTo.tupled)
+ (resolved, conditions.reduceLeft(And))
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case f @ Filter(condition, child) =>
+ val (withSubquery, withoutSubquery) =
+
splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery)
+
+ // Construct the pruned filter condition.
+ val newFilter: LogicalPlan = withoutSubquery match {
+ case Nil => child
+ case conditions => Filter(conditions.reduce(And), child)
+ }
+
+ // Filter the plan by applying left semi and left anti joins.
+ withSubquery.foldLeft(newFilter) {
+ case (p, Exists(sub)) =>
+ val (resolved, joinCondition) = pullOutCorrelatedPredicates(sub,
p)
+ Join(p, resolved, LeftSemi, joinCondition)
+ case (p, Not(Exists(sub))) =>
+ val (resolved, joinCondition) = pullOutCorrelatedPredicates(sub,
p)
+ Join(p, resolved, LeftAnti, joinCondition)
+ case (p, in: InSubQuery) =>
+ val (resolved, cond) = rewriteInSubquery(in, p)
+ Join(p, resolved, LeftSemi, Option(cond))
+ case (p, Not(in: InSubQuery)) =>
+ // This is a NULL-aware (left) anti join (NAAJ).
+ // We currently only allow subqueries with non-nullable fields.
In this way we can plan a
+ // regular ANTI join, instead of a much more complex NAAJ (which
is not yet available in
+ // Spark SQL). In order to get the NAAJ semantically right, we
need to add a filter to
+ // left hand side of the query that checks that either all
columns are non-null or that
+ // the right hand side is empty.
+ val (resolved, cond) = rewriteInSubquery(in, p)
+
+ // Make absolutely sure that the rewritten query contains no
nullable fields. We re-check
+ // this here because the rewritten query can contain pulled-up
nullable columns.
+ if (resolved.output.exists(_.nullable)) {
+ throw new AnalysisException("NOT IN with nullable subquery is
not supported. " +
+ "Please use a non-nullable sub-query or rewrite this using
NOT EXISTS.")
+ }
+
+ // Construct filter for the left hand side
--- End diff --
Should the anti join not output the row if there is any null in the left
key?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]