allisonwang-db commented on a change in pull request #32072:
URL: https://github.com/apache/spark/pull/32072#discussion_r612696150
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
##########
@@ -656,3 +705,301 @@ object RewriteCorrelatedScalarSubquery extends
Rule[LogicalPlan] with AliasHelpe
}
}
}
+
+/**
+ * Decorrelate the inner query by eliminating outer references and create
domain joins.
+ * The implementation is based on the paper: Unnesting Arbitrary Queries by
Thomas Neumann
+ * and Alfons Kemper. https://dl.gi.de/handle/20.500.12116/2418.
+ * (1) Recursively collects outer references from the inner query until it
reaches a node
+ * that does not contain correlated value.
+ * (2) Inserts an optional [[DomainJoin]] node to indicate whether a domain
(inner) join is
+ * needed between the outer query and the specific subtree of the inner
query.
+ * (3) Returns a list of join conditions with the outer query and a mapping
between outer
+ * references with references inside the inner query. The parent nodes
need to preserve
+ * the references inside the join conditions and substitute all outer
references using
+ * the mapping.
+ *
+ * E.g. decorrelate an inner query with equality predicates:
+ *
+ * Aggregate [] [min(c2)] Aggregate [c1] [min(c2), c1]
+ * +- Filter [outer(c3) = c1] => +- Relation [t]
+ * +- Relation [t]
+ *
+ * Join conditions: [c3 = c1]
+ *
+ * E.g. decorrelate an inner query with non-equality predicates:
+ *
+ * Aggregate [] [min(c2)] Aggregate [c3'] [min(c2), c3']
+ * +- Filter [outer(c3) > c1] => +- Filter [c3' > c1]
+ * +- Relation [t] +- DomainJoin [c3']
+ * +- Relation [t]
+ *
+ * Join conditions: [c3 <=> c3']
+ */
+object DecorrelateInnerQuery extends PredicateHelper {
+
+ /**
+ * Check if the given expression is an equality condition.
+ */
+ private def isEquality(expression: Expression): Boolean = expression match {
+ case Equality(_, _) => true
+ case _ => false
+ }
+
+ /**
+ * Collect outer references in an expressions that are in the output
attributes of the outer plan.
+ */
+ private def collectOuterReferences(expression: Expression): AttributeSet = {
+ AttributeSet(expression.collect { case o: OuterReference => o.toAttribute
})
+ }
+
+ /**
+ * Collect outer references in a sequence of expressions that are in the
output attributes
+ * of the outer plan.
+ */
+ private def collectOuterReferences(expressions: Seq[Expression]):
AttributeSet = {
+ AttributeSet.fromAttributeSets(expressions.map(collectOuterReferences))
+ }
+
+ /**
+ * Build a mapping between outer references with equivalent inner query
attributes.
+ * E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y}
+ */
+ private def collectEquivalentOuterReferences(
+ expressions: Seq[Expression]): Map[Attribute, Attribute] = {
+ expressions.collect {
+ case Equality(o: OuterReference, a: Attribute) => (o.toAttribute,
a.toAttribute)
+ case Equality(a: Attribute, o: OuterReference) => (o.toAttribute,
a.toAttribute)
+ }.toMap
+ }
+
+ /**
+ * Replace all outer references using the expressions in the given outer
reference map.
+ */
+ private def replaceOuterReference[E <: Expression](
+ expression: E,
+ outerReferenceMap: Map[Attribute, Attribute]): E = {
+ expression.transform {
+ case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o)
+ }.asInstanceOf[E]
+ }
+
+ /**
+ * Replace all outer references in the given expressions using the
expressions in the
+ * outer reference map.
+ */
+ private def replaceOuterReferences[E <: Expression](
+ expressions: Seq[E],
+ outerReferenceMap: Map[Attribute, Attribute]): Seq[E] = {
+ expressions.map(replaceOuterReference(_, outerReferenceMap))
+ }
+
+ /**
+ * Return all missing references of the attribute set from the required
attributes
+ * in the join condition.
+ */
+ private def missingReferences(
+ expressions: Seq[Expression],
+ joinCond: Seq[Expression]): AttributeSet = {
+ val outputSet = AttributeSet(expressions)
+ AttributeSet(joinCond.flatMap(_.references)) -- outputSet
+ }
+
+ /**
+ * Deduplicate the inner and the outer query attributes and return an aliased
+ * subquery plan and join conditions if duplicates are found. Duplicated
attributes
+ * can break the structural integrity when joining the inner and outer plan
together.
+ */
+ def deduplicate(
+ innerPlan: LogicalPlan,
+ conditions: Seq[Expression],
+ outerOutputSet: AttributeSet): (LogicalPlan, Seq[Expression]) = {
+ val duplicates = innerPlan.outputSet.intersect(outerOutputSet)
+ if (duplicates.nonEmpty) {
+ val aliasMap = AttributeMap(duplicates.map { dup =>
+ dup -> Alias(dup, dup.toString)()
+ }.toSeq)
+ val aliasedExpressions = innerPlan.output.map { ref =>
+ aliasMap.getOrElse(ref, ref)
+ }
+ val aliasedProjection = Project(aliasedExpressions, innerPlan)
+ val aliasedConditions = conditions.map(_.transform {
+ case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
+ })
+ (aliasedProjection, aliasedConditions)
+ } else {
+ (innerPlan, conditions)
+ }
+ }
+
+ def apply(
+ innerPlan: LogicalPlan,
+ outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+ apply(innerPlan, Seq(outerPlan))
+ }
+
+ def apply(
+ innerPlan: LogicalPlan,
+ outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
+ val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet))
+
+ // The return type of the recursion.
+ // The first parameter is a new logical plan with correlation eliminated.
+ // The second parameter is a list of join conditions with the outer query.
+ // The third parameter is a mapping between the outer references and
equivalent
+ // expressions from the inner query that is used to replace outer
references.
+ type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute])
+
+ // Recursively decorrelate the input plan with a set of parent outer
references and
+ // a boolean flag indicating whether the result of the plan will be
aggregated.
+ def decorrelate(
+ plan: LogicalPlan,
+ parentOuterReferences: AttributeSet,
+ aggregated: Boolean = false): ReturnType = {
+ val isCorrelated = hasOuterReferences(plan)
+ if (!isCorrelated) {
+ // We have reached a plan without correlation to the outer plan.
+ if (parentOuterReferences.isEmpty) {
+ // If there is no outer references from the parent nodes, it means
all outer
+ // attributes can be substituted by attributes from the inner plan.
So no
+ // domain join is needed.
+ (plan, Nil, Map.empty[Attribute, Attribute])
+ } else {
+ // Build the domain join with the parent outer references.
+ val attributes = parentOuterReferences.toSeq
+ val domains = attributes.map(_.newInstance())
+ // A placeholder to be rewritten into domain join.
+ val domainJoin = DomainJoin(domains, plan)
+ val outerReferenceMap = attributes.zip(domains).toMap
+ // Build join conditions between domain attributes and outer
references.
+ // EqualNullSafe is used to make sure null key can be joined
together. Note
+ // outer referenced attributes can be changed during the outer query
optimization.
+ // The equality conditions will also serve as an attribute mapping
between new
+ // outer references and domain attributes when rewriting the domain
joins.
+ // E.g. if the attribute a is changed to a1, the join condition a'
<=> outer(a)
+ // will become a' <=> a1, and we can construct the aliases based on
the condition:
+ // DomainJoin [a'] Join Inner
+ // +- InnerQuery => :- InnerQuery
+ // +- Aggregate [a1] [a1 AS a']
+ // +- OuterQuery
+ val conditions = outerReferenceMap.map {
+ case (o, a) => EqualNullSafe(a, OuterReference(o))
+ }
+ (domainJoin, conditions.toSeq, outerReferenceMap)
+ }
+ } else {
+ // Collect outer references from the current node.
+ val outerReferences = collectOuterReferences(plan.expressions)
+ plan match {
+ case Filter(condition, child) =>
+ val (correlated, uncorrelated) =
+ splitConjunctivePredicates(condition)
+ .partition(containsOuter)
+ val (equality, nonEquality) = correlated.partition(isEquality)
+ // Find equivalent outer reference relations and remove equivalent
attributes from
+ // parentOuterReferences since they can be replaced directly by
expressions
+ // inside the inner plan.
+ val equivalences = collectEquivalentOuterReferences(equality)
+ // When the results are aggregated, outer references inside the
non-equality
+ // predicates cannot be used directly as join conditions with the
outer query.
+ val outerReferences = if (aggregated) {
+ collectOuterReferences(nonEquality)
+ } else {
+ AttributeSet.empty
+ }
+ val newOuterReferences = parentOuterReferences ++ outerReferences
-- equivalences.keySet
+ val (newChild, joinCond, outerReferenceMap) =
+ decorrelate(child, newOuterReferences, aggregated)
+ // Add the mapping from the current node.
+ val newOuterReferenceMap = outerReferenceMap ++ equivalences
+ // Replace all outer references in non-equality filter conditions
using the domain
+ // attributes produced for inner query with aggregates. This step
is necessary
+ // for pushing down the non-equality filters into the domain join
as join conditions.
+ val (newFilterCond, newJoinCond) = if (aggregated) {
+ val nonEqualityCond = replaceOuterReferences(nonEquality,
newOuterReferenceMap)
+ (nonEqualityCond ++ uncorrelated, equality)
+ } else {
+ (uncorrelated, correlated)
+ }
+ val newFilter = newFilterCond match {
+ case Nil => newChild
+ case xs => Filter(xs.reduce(And), newChild)
+ }
+ (newFilter, joinCond ++ newJoinCond, newOuterReferenceMap)
+
+ case Project(projectList, child) =>
+ val newOuterReferences = parentOuterReferences ++ outerReferences
+ val (newChild, joinCond, outerReferenceMap) =
+ decorrelate(child, newOuterReferences, aggregated)
+ // Replace all outer references in the original project list.
+ val newProjectList = replaceOuterReferences(projectList,
outerReferenceMap)
+ // Preserve required domain attributes in the join condition by
adding the missing
+ // references to the new project list.
+ val referencesToAdd =
missingReferences(newProjectList.map(_.toAttribute), joinCond)
+ val newProject = Project(newProjectList ++ referencesToAdd,
newChild)
+ (newProject, joinCond, outerReferenceMap)
+
+ case a @ Aggregate(groupingExpressions, aggregateExpressions, child)
=>
+ val newOuterReferences = parentOuterReferences ++ outerReferences
+ val (newChild, joinCond, outerReferenceMap) =
+ decorrelate(child, newOuterReferences, aggregated = true)
+ // Replace all outer references in grouping and aggregate
expressions.
+ val newGroupingExpr = replaceOuterReferences(groupingExpressions,
outerReferenceMap)
+ val newAggExpr = replaceOuterReferences(aggregateExpressions,
outerReferenceMap)
+ // Add all required domain attributes to both grouping and
aggregate expressions.
+ val groupingExprToAdd = missingReferences(newGroupingExpr,
joinCond)
+ val aggExprToAdd =
missingReferences(newAggExpr.map(_.toAttribute), joinCond)
Review comment:
Good catch!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]