Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/18692#discussion_r153060595 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala --- @@ -152,3 +152,99 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } + +/** + * A rule that eliminates CROSS joins by inferring join conditions from propagated constraints. + * + * The optimization is applicable only to CROSS joins. For other join types, adding inferred join + * conditions would potentially shuffle children as child node's partitioning won't satisfy the JOIN + * node's requirements which otherwise could have. + * + * For instance, if there is a CROSS join, where the left relation has 'a = 1' and the right + * relation has 'b = 1', the rule infers 'a = b' as a join predicate. + */ +object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + eliminateCrossJoin(plan) + } else { + plan + } + } + + private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan transform { + case join@Join(leftPlan, rightPlan, Cross, None) => + val leftConstraints = join.constraints.filter(_.references.subsetOf(leftPlan.outputSet)) + val rightConstraints = join.constraints.filter(_.references.subsetOf(rightPlan.outputSet)) + val inferredJoinPredicates = inferJoinPredicates(leftConstraints, rightConstraints) + val joinConditionOpt = inferredJoinPredicates.reduceOption(And) + if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, joinConditionOpt) else join + } + + private def inferJoinPredicates( + leftConstraints: Set[Expression], + rightConstraints: Set[Expression]): Set[EqualTo] = { + + // iterate through the left constraints and build a hash map that points semantically + // equivalent expressions into attributes + val emptyEquivalenceMap = Map.empty[SemanticExpression, Set[Attribute]] + val equivalenceMap = leftConstraints.foldLeft(emptyEquivalenceMap) { case (map, constraint) => + constraint match { + case EqualTo(attr: Attribute, expr: Expression) => + updateEquivalenceMap(map, attr, expr) + case EqualTo(expr: Expression, attr: Attribute) => + updateEquivalenceMap(map, attr, expr) + case _ => map + } + } + + // iterate through the right constraints and infer join conditions using the equivalence map + rightConstraints.foldLeft(Set.empty[EqualTo]) { case (joinConditions, constraint) => + constraint match { + case EqualTo(attr: Attribute, expr: Expression) => + appendJoinConditions(attr, expr, equivalenceMap, joinConditions) + case EqualTo(expr: Expression, attr: Attribute) => + appendJoinConditions(attr, expr, equivalenceMap, joinConditions) + case _ => joinConditions + } + } + } + + private def updateEquivalenceMap( + equivalenceMap: Map[SemanticExpression, Set[Attribute]], + attr: Attribute, + expr: Expression): Map[SemanticExpression, Set[Attribute]] = { + + val equivalentAttrs = equivalenceMap.getOrElse(expr, Set.empty[Attribute]) + if (equivalentAttrs.contains(attr)) { + equivalenceMap + } else { + equivalenceMap.updated(expr, equivalentAttrs + attr) + } + } + + private def appendJoinConditions( + attr: Attribute, + expr: Expression, + equivalenceMap: Map[SemanticExpression, Set[Attribute]], + joinConditions: Set[EqualTo]): Set[EqualTo] = { + + equivalenceMap.get(expr) match { + case Some(equivalentAttrs) => joinConditions ++ equivalentAttrs.map(EqualTo(attr, _)) + case None => joinConditions + } + } + + // the purpose of this class is to treat 'a === 1 and 1 === 'a as the same expressions + implicit class SemanticExpression(private val expr: Expression) { --- End diff -- Can we reuse `EquivalentExpressions`? You can search the code base and see how the others use it.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org