Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/18692#discussion_r152415250
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
---
@@ -152,3 +152,79 @@ object EliminateOuterJoin extends Rule[LogicalPlan]
with PredicateHelper {
if (j.joinType == newJoinType) f else Filter(condition,
j.copy(joinType = newJoinType))
}
}
+
+/**
+ * A rule that uses propagated constraints to infer join conditions. The
optimization is applicable
+ * only to CROSS joins. For other join types, adding inferred join
conditions would potentially
+ * shuffle children as child node's partitioning won't satisfy the JOIN
node's requirements
+ * which otherwise could have.
+ *
+ * For instance, if there is a CROSS join, where the left relation has 'a
= 1' and the right
+ * relation has 'b = 1', then the rule infers 'a = b' as a join predicate.
+ */
+object InferJoinConditionsFromConstraints extends Rule[LogicalPlan] with
PredicateHelper {
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (SQLConf.get.constraintPropagationEnabled) {
+ inferJoinConditions(plan)
+ } else {
+ plan
+ }
+ }
+
+ private def inferJoinConditions(plan: LogicalPlan): LogicalPlan = plan
transform {
+ case join @ Join(left, right, Cross, conditionOpt) =>
+
+ val rightEqualToPredicates = join.constraints.collect {
+ case equalTo @ EqualTo(attr: Attribute, _) if
isAttributeContainedInPlan(attr, right) =>
+ equalTo
+ case equalTo @ EqualTo(_, attr: Attribute) if
isAttributeContainedInPlan(attr, right) =>
+ equalTo
+ }
+
+ val inferredJoinPredicates = join.constraints.flatMap {
+ case EqualTo(attr: Attribute, equivalentExpr) if
isAttributeContainedInPlan(attr, left) =>
+ collectJoinPredicates(attr, equivalentExpr, right,
rightEqualToPredicates)
+ case EqualTo(equivalentExpr, attr: Attribute) if
isAttributeContainedInPlan(attr, left) =>
+ collectJoinPredicates(attr, equivalentExpr, right,
rightEqualToPredicates)
+ case _ => Nil
+ }
+
+ val newConditionOpt = conditionOpt match {
+ case Some(condition) =>
+ val existingPredicates = splitConjunctivePredicates(condition)
+ val newPredicates = findNewPredicates(inferredJoinPredicates,
existingPredicates)
+ if (newPredicates.nonEmpty) Some(And(newPredicates.reduce(And),
condition)) else None
+ case None =>
+ inferredJoinPredicates.reduceOption(And)
+ }
+ if (newConditionOpt.isDefined) Join(left, right, Inner,
newConditionOpt) else join
+ }
+
+ private def collectJoinPredicates(
+ leftAttr: Attribute,
+ equivalentExpr: Expression,
+ rightPlan: LogicalPlan,
+ rightPlanEqualToPredicates: Set[EqualTo]): Set[EqualTo] = {
+
+ rightPlanEqualToPredicates.collect {
+ case EqualTo(attr: Attribute, expr)
+ if expr.semanticEquals(equivalentExpr) &&
isAttributeContainedInPlan(attr, rightPlan) =>
+ EqualTo(leftAttr, attr)
+ case EqualTo(expr, attr: Attribute)
+ if expr.semanticEquals(equivalentExpr) &&
isAttributeContainedInPlan(attr, rightPlan) =>
+ EqualTo(leftAttr, attr)
+ }
+ }
+
+ private def isAttributeContainedInPlan(attr: Attribute, logicalPlan:
LogicalPlan): Boolean = {
+ attr.references.subsetOf(logicalPlan.outputSet)
+ }
+
+ private def findNewPredicates(
+ inferredPredicates: Set[EqualTo],
+ existingPredicates: Seq[Expression]) : Set[EqualTo] = {
--- End diff --
`existingPredicates: Seq[Expression]) : Set[EqualTo] = {`
-> `existingPredicates: Seq[Expression]): Set[EqualTo] = {`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]