Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/18692#discussion_r153060595
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
---
@@ -152,3 +152,99 @@ object EliminateOuterJoin extends Rule[LogicalPlan]
with PredicateHelper {
if (j.joinType == newJoinType) f else Filter(condition,
j.copy(joinType = newJoinType))
}
}
+
+/**
+ * A rule that eliminates CROSS joins by inferring join conditions from
propagated constraints.
+ *
+ * The optimization is applicable only to CROSS joins. For other join
types, adding inferred join
+ * conditions would potentially shuffle children as child node's
partitioning won't satisfy the JOIN
+ * node's requirements which otherwise could have.
+ *
+ * For instance, if there is a CROSS join, where the left relation has 'a
= 1' and the right
+ * relation has 'b = 1', the rule infers 'a = b' as a join predicate.
+ */
+object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper {
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (SQLConf.get.constraintPropagationEnabled) {
+ eliminateCrossJoin(plan)
+ } else {
+ plan
+ }
+ }
+
+ private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan
transform {
+ case join@Join(leftPlan, rightPlan, Cross, None) =>
+ val leftConstraints =
join.constraints.filter(_.references.subsetOf(leftPlan.outputSet))
+ val rightConstraints =
join.constraints.filter(_.references.subsetOf(rightPlan.outputSet))
+ val inferredJoinPredicates = inferJoinPredicates(leftConstraints,
rightConstraints)
+ val joinConditionOpt = inferredJoinPredicates.reduceOption(And)
+ if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner,
joinConditionOpt) else join
+ }
+
+ private def inferJoinPredicates(
+ leftConstraints: Set[Expression],
+ rightConstraints: Set[Expression]): Set[EqualTo] = {
+
+ // iterate through the left constraints and build a hash map that
points semantically
+ // equivalent expressions into attributes
+ val emptyEquivalenceMap = Map.empty[SemanticExpression, Set[Attribute]]
+ val equivalenceMap = leftConstraints.foldLeft(emptyEquivalenceMap) {
case (map, constraint) =>
+ constraint match {
+ case EqualTo(attr: Attribute, expr: Expression) =>
+ updateEquivalenceMap(map, attr, expr)
+ case EqualTo(expr: Expression, attr: Attribute) =>
+ updateEquivalenceMap(map, attr, expr)
+ case _ => map
+ }
+ }
+
+ // iterate through the right constraints and infer join conditions
using the equivalence map
+ rightConstraints.foldLeft(Set.empty[EqualTo]) { case (joinConditions,
constraint) =>
+ constraint match {
+ case EqualTo(attr: Attribute, expr: Expression) =>
+ appendJoinConditions(attr, expr, equivalenceMap, joinConditions)
+ case EqualTo(expr: Expression, attr: Attribute) =>
+ appendJoinConditions(attr, expr, equivalenceMap, joinConditions)
+ case _ => joinConditions
+ }
+ }
+ }
+
+ private def updateEquivalenceMap(
+ equivalenceMap: Map[SemanticExpression, Set[Attribute]],
+ attr: Attribute,
+ expr: Expression): Map[SemanticExpression, Set[Attribute]] = {
+
+ val equivalentAttrs = equivalenceMap.getOrElse(expr,
Set.empty[Attribute])
+ if (equivalentAttrs.contains(attr)) {
+ equivalenceMap
+ } else {
+ equivalenceMap.updated(expr, equivalentAttrs + attr)
+ }
+ }
+
+ private def appendJoinConditions(
+ attr: Attribute,
+ expr: Expression,
+ equivalenceMap: Map[SemanticExpression, Set[Attribute]],
+ joinConditions: Set[EqualTo]): Set[EqualTo] = {
+
+ equivalenceMap.get(expr) match {
+ case Some(equivalentAttrs) => joinConditions ++
equivalentAttrs.map(EqualTo(attr, _))
+ case None => joinConditions
+ }
+ }
+
+ // the purpose of this class is to treat 'a === 1 and 1 === 'a as the
same expressions
+ implicit class SemanticExpression(private val expr: Expression) {
--- End diff --
Can we reuse `EquivalentExpressions`? You can search the code base and see
how the others use it.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]