Github user aokolnychyi commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18692#discussion_r153066992
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala 
---
    @@ -152,3 +152,99 @@ object EliminateOuterJoin extends Rule[LogicalPlan] 
with PredicateHelper {
           if (j.joinType == newJoinType) f else Filter(condition, 
j.copy(joinType = newJoinType))
       }
     }
    +
    +/**
    + * A rule that eliminates CROSS joins by inferring join conditions from 
propagated constraints.
    + *
    + * The optimization is applicable only to CROSS joins. For other join 
types, adding inferred join
    + * conditions would potentially shuffle children as child node's 
partitioning won't satisfy the JOIN
    + * node's requirements which otherwise could have.
    + *
    + * For instance, if there is a CROSS join, where the left relation has 'a 
= 1' and the right
    + * relation has 'b = 1', the rule infers 'a = b' as a join predicate.
    + */
    +object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper {
    +
    +  def apply(plan: LogicalPlan): LogicalPlan = {
    +    if (SQLConf.get.constraintPropagationEnabled) {
    +      eliminateCrossJoin(plan)
    +    } else {
    +      plan
    +    }
    +  }
    +
    +  private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan 
transform {
    +    case join@Join(leftPlan, rightPlan, Cross, None) =>
    +      val leftConstraints = 
join.constraints.filter(_.references.subsetOf(leftPlan.outputSet))
    +      val rightConstraints = 
join.constraints.filter(_.references.subsetOf(rightPlan.outputSet))
    +      val inferredJoinPredicates = inferJoinPredicates(leftConstraints, 
rightConstraints)
    +      val joinConditionOpt = inferredJoinPredicates.reduceOption(And)
    +      if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, 
joinConditionOpt) else join
    +  }
    +
    +  private def inferJoinPredicates(
    +      leftConstraints: Set[Expression],
    +      rightConstraints: Set[Expression]): Set[EqualTo] = {
    +
    +    // iterate through the left constraints and build a hash map that 
points semantically
    +    // equivalent expressions into attributes
    +    val emptyEquivalenceMap = Map.empty[SemanticExpression, Set[Attribute]]
    +    val equivalenceMap = leftConstraints.foldLeft(emptyEquivalenceMap) { 
case (map, constraint) =>
    +      constraint match {
    +        case EqualTo(attr: Attribute, expr: Expression) =>
    +          updateEquivalenceMap(map, attr, expr)
    +        case EqualTo(expr: Expression, attr: Attribute) =>
    +          updateEquivalenceMap(map, attr, expr)
    +        case _ => map
    +      }
    +    }
    +
    +    // iterate through the right constraints and infer join conditions 
using the equivalence map
    +    rightConstraints.foldLeft(Set.empty[EqualTo]) { case (joinConditions, 
constraint) =>
    +      constraint match {
    +        case EqualTo(attr: Attribute, expr: Expression) =>
    +          appendJoinConditions(attr, expr, equivalenceMap, joinConditions)
    +        case EqualTo(expr: Expression, attr: Attribute) =>
    +          appendJoinConditions(attr, expr, equivalenceMap, joinConditions)
    +        case _ => joinConditions
    +      }
    +    }
    +  }
    +
    +  private def updateEquivalenceMap(
    +      equivalenceMap: Map[SemanticExpression, Set[Attribute]],
    +      attr: Attribute,
    +      expr: Expression): Map[SemanticExpression, Set[Attribute]] = {
    +
    +    val equivalentAttrs = equivalenceMap.getOrElse(expr, 
Set.empty[Attribute])
    +    if (equivalentAttrs.contains(attr)) {
    +      equivalenceMap
    +    } else {
    +      equivalenceMap.updated(expr, equivalentAttrs + attr)
    +    }
    +  }
    +
    +  private def appendJoinConditions(
    +      attr: Attribute,
    +      expr: Expression,
    +      equivalenceMap: Map[SemanticExpression, Set[Attribute]],
    +      joinConditions: Set[EqualTo]): Set[EqualTo] = {
    +
    +    equivalenceMap.get(expr) match {
    +      case Some(equivalentAttrs) => joinConditions ++ 
equivalentAttrs.map(EqualTo(attr, _))
    +      case None => joinConditions
    +    }
    +  }
    +
    +  // the purpose of this class is to treat 'a === 1 and 1 === 'a as the 
same expressions
    +  implicit class SemanticExpression(private val expr: Expression) {
    --- End diff --
    
    @gatorsmile 
    
    I think we just need the case class inside ``EquivalentExpressions`` since 
we have to map all semantically equivalent expressions into a set of attributes 
(as opposed to mapping an expression into a set of equivalent expressions). 
    
    I see two ways to go:
    
    1. Expose the case class inside ``EquivalentExpressions`` with minimum 
changes in the code base (e.g., using a companion object):
    
    ````
    object EquivalentExpressions {
    
      /**
       * Wrapper around an Expression that provides semantic equality.
       */
      implicit class SemanticExpr(private val e: Expression) {
        override def equals(o: Any): Boolean = o match {
          case other: SemanticExpr => e.semanticEquals(other.e)
          case _ => false
        }
    
        override def hashCode: Int = e.semanticHash()
      }
    }
    ````
    
    2. Keep ``EquivalentExpressions`` as it is and maintain a separate map from 
expressions to attributes in the proposed rule.
    
    Personally, I lean toward the first idea since it might be useful to have 
``SemanticExpr`` alone. However, there can be other drawbacks that did not come 
into my mind.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to