Github user gatorsmile commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18692#discussion_r153342778
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala 
---
    @@ -152,3 +152,99 @@ object EliminateOuterJoin extends Rule[LogicalPlan] 
with PredicateHelper {
           if (j.joinType == newJoinType) f else Filter(condition, 
j.copy(joinType = newJoinType))
       }
     }
    +
    +/**
    + * A rule that eliminates CROSS joins by inferring join conditions from 
propagated constraints.
    + *
    + * The optimization is applicable only to CROSS joins. For other join 
types, adding inferred join
    + * conditions would potentially shuffle children as child node's 
partitioning won't satisfy the JOIN
    + * node's requirements which otherwise could have.
    + *
    + * For instance, if there is a CROSS join, where the left relation has 'a 
= 1' and the right
    + * relation has 'b = 1', the rule infers 'a = b' as a join predicate.
    + */
    +object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper {
    +
    +  def apply(plan: LogicalPlan): LogicalPlan = {
    +    if (SQLConf.get.constraintPropagationEnabled) {
    +      eliminateCrossJoin(plan)
    +    } else {
    +      plan
    +    }
    +  }
    +
    +  private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan 
transform {
    +    case join@Join(leftPlan, rightPlan, Cross, None) =>
    +      val leftConstraints = 
join.constraints.filter(_.references.subsetOf(leftPlan.outputSet))
    +      val rightConstraints = 
join.constraints.filter(_.references.subsetOf(rightPlan.outputSet))
    +      val inferredJoinPredicates = inferJoinPredicates(leftConstraints, 
rightConstraints)
    +      val joinConditionOpt = inferredJoinPredicates.reduceOption(And)
    +      if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, 
joinConditionOpt) else join
    +  }
    +
    +  private def inferJoinPredicates(
    +      leftConstraints: Set[Expression],
    +      rightConstraints: Set[Expression]): Set[EqualTo] = {
    +
    +    // iterate through the left constraints and build a hash map that 
points semantically
    +    // equivalent expressions into attributes
    +    val emptyEquivalenceMap = Map.empty[SemanticExpression, Set[Attribute]]
    +    val equivalenceMap = leftConstraints.foldLeft(emptyEquivalenceMap) { 
case (map, constraint) =>
    +      constraint match {
    +        case EqualTo(attr: Attribute, expr: Expression) =>
    +          updateEquivalenceMap(map, attr, expr)
    +        case EqualTo(expr: Expression, attr: Attribute) =>
    +          updateEquivalenceMap(map, attr, expr)
    +        case _ => map
    +      }
    +    }
    +
    +    // iterate through the right constraints and infer join conditions 
using the equivalence map
    +    rightConstraints.foldLeft(Set.empty[EqualTo]) { case (joinConditions, 
constraint) =>
    +      constraint match {
    +        case EqualTo(attr: Attribute, expr: Expression) =>
    +          appendJoinConditions(attr, expr, equivalenceMap, joinConditions)
    +        case EqualTo(expr: Expression, attr: Attribute) =>
    +          appendJoinConditions(attr, expr, equivalenceMap, joinConditions)
    +        case _ => joinConditions
    +      }
    +    }
    +  }
    +
    +  private def updateEquivalenceMap(
    +      equivalenceMap: Map[SemanticExpression, Set[Attribute]],
    +      attr: Attribute,
    +      expr: Expression): Map[SemanticExpression, Set[Attribute]] = {
    +
    +    val equivalentAttrs = equivalenceMap.getOrElse(expr, 
Set.empty[Attribute])
    +    if (equivalentAttrs.contains(attr)) {
    +      equivalenceMap
    +    } else {
    +      equivalenceMap.updated(expr, equivalentAttrs + attr)
    +    }
    +  }
    +
    +  private def appendJoinConditions(
    +      attr: Attribute,
    +      expr: Expression,
    +      equivalenceMap: Map[SemanticExpression, Set[Attribute]],
    +      joinConditions: Set[EqualTo]): Set[EqualTo] = {
    +
    +    equivalenceMap.get(expr) match {
    +      case Some(equivalentAttrs) => joinConditions ++ 
equivalentAttrs.map(EqualTo(attr, _))
    +      case None => joinConditions
    +    }
    +  }
    +
    +  // the purpose of this class is to treat 'a === 1 and 1 === 'a as the 
same expressions
    +  implicit class SemanticExpression(private val expr: Expression) {
    --- End diff --
    
    I did not check it carefully, but how about `ExpressionSet`? 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to