Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19912#discussion_r155543551
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 ---
    @@ -64,50 +64,94 @@ object ConstantFolding extends Rule[LogicalPlan] {
      * }}}
      *
      * Approach used:
    - * - Start from AND operator as the root
    - * - Get all the children conjunctive predicates which are EqualTo / 
EqualNullSafe such that they
    - *   don't have a `NOT` or `OR` operator in them
      * - Populate a mapping of attribute => constant value by looking at all 
the equals predicates
      * - Using this mapping, replace occurrence of the attributes with the 
corresponding constant values
      *   in the AND node.
      */
     object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
    -  private def containsNonConjunctionPredicates(expression: Expression): 
Boolean = expression.find {
    -    case _: Not | _: Or => true
    -    case _ => false
    -  }.isDefined
    -
       def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    -    case f: Filter => f transformExpressionsUp {
    -      case and: And =>
    -        val conjunctivePredicates =
    -          splitConjunctivePredicates(and)
    -            .filter(expr => expr.isInstanceOf[EqualTo] || 
expr.isInstanceOf[EqualNullSafe])
    -            .filterNot(expr => containsNonConjunctionPredicates(expr))
    -
    -        val equalityPredicates = conjunctivePredicates.collect {
    -          case e @ EqualTo(left: AttributeReference, right: Literal) => 
((left, right), e)
    -          case e @ EqualTo(left: Literal, right: AttributeReference) => 
((right, left), e)
    -          case e @ EqualNullSafe(left: AttributeReference, right: Literal) 
=> ((left, right), e)
    -          case e @ EqualNullSafe(left: Literal, right: AttributeReference) 
=> ((right, left), e)
    -        }
    +    case f: Filter =>
    +      val (newCondition, _) = traverse(f.condition, replaceChildren = true)
    +      if (newCondition.isDefined) {
    +        f.copy(condition = newCondition.get)
    +      } else {
    +        f
    +      }
    +  }
     
    -        val constantsMap = AttributeMap(equalityPredicates.map(_._1))
    -        val predicates = equalityPredicates.map(_._2).toSet
    +  type EqualityPredicates = Seq[((AttributeReference, Literal), 
BinaryComparison)]
     
    -        def replaceConstants(expression: Expression) = expression 
transform {
    -          case a: AttributeReference =>
    -            constantsMap.get(a) match {
    -              case Some(literal) => literal
    -              case None => a
    -            }
    +  /**
    +   * Traverse a condition as a tree and replace attributes with constant 
values.
    +   * - On matching [[And]], recursively traverse each children and get 
propagated mappings.
    +   *   If the current node is not child of another [[And]], replace all 
occurrences of the
    +   *   attributes with the corresponding constant values.
    +   * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], 
propagate the mapping
    +   *   of attribute => constant.
    +   * - On matching [[Or]] or [[Not]], recursively traverse each children, 
propagate empty mapping.
    +   * - Otherwise, stop traversal and propagate empty mapping.
    +   * @param condition condition to be traversed
    +   * @param replaceChildren whether to replace attributes with constant 
values in children
    +   * @return A tuple including:
    +   *         1. Option[Expression]: optional changed condition after 
traversal
    +   *         2. EqualityPredicates: propagated mapping of attribute => 
constant
    +   */
    +  private def traverse(condition: Expression, replaceChildren: Boolean)
    +    : (Option[Expression], EqualityPredicates) =
    +    condition match {
    +      case e @ EqualTo(left: AttributeReference, right: Literal) => (None, 
Seq(((left, right), e)))
    +      case e @ EqualTo(left: Literal, right: AttributeReference) => (None, 
Seq(((right, left), e)))
    +      case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
    +        (None, Seq(((left, right), e)))
    +      case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
    +        (None, Seq(((right, left), e)))
    +      case a: And =>
    +        val (newLeft, equalityPredicatesLeft) = traverse(a.left, 
replaceChildren = false)
    +        val (newRight, equalityPredicatesRight) = traverse(a.right, 
replaceChildren = false)
    +        val equalityPredicates = equalityPredicatesLeft ++ 
equalityPredicatesRight
    +        val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
    +          Some(And(replaceConstants(newLeft.getOrElse(a.left), 
equalityPredicates),
    +            replaceConstants(newRight.getOrElse(a.right), 
equalityPredicates)))
    +        } else {
    +          if (newLeft.isDefined || newRight.isDefined) {
    +            Some(And(newLeft.getOrElse(a.left), 
newRight.getOrElse(a.right)))
    +          } else {
    +            None
    +          }
             }
    +        (newSelf, equalityPredicates)
    +      case o: Or =>
    +        // Ignore the EqualityPredicates from children since they are only 
propagated through And.
    +        val (newLeft, _) = traverse(o.left, replaceChildren = true)
    +        val (newRight, _) = traverse(o.right, replaceChildren = true)
    +        val newSelf = if (newLeft.isDefined || newRight.isDefined) {
    +          Some(Or(left = newLeft.getOrElse(o.left), right = 
newRight.getOrElse((o.right))))
    +        } else {
    +          None
    +        }
    +        (newSelf, Seq.empty)
    +      case n: Not =>
    +        // Ignore the EqualityPredicates from children since they are only 
propagated through And.
    +        val (newChild, _) = traverse(n.child, replaceChildren = true)
    +        (newChild.map(Not), Seq.empty)
    +      case _ => (None, Seq.empty)
    +    }
     
    -        and transform {
    -          case e @ EqualTo(_, _) if !predicates.contains(e) => 
replaceConstants(e)
    -          case e @ EqualNullSafe(_, _) if !predicates.contains(e) => 
replaceConstants(e)
    +  private def replaceConstants(condition: Expression, equalityPredicates: 
EqualityPredicates)
    +    : Expression = {
    +    val constantsMap = AttributeMap(equalityPredicates.map(_._1))
    +    val predicates = equalityPredicates.map(_._2).toSet
    +    def _replaceConstants(expression: Expression) = expression transform {
    +      case a: AttributeReference =>
    +        constantsMap.get(a) match {
    --- End diff --
    
    `constentsMap.getOrElse(a, a)`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to