Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/19912#discussion_r155332134
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
---
@@ -64,50 +64,91 @@ object ConstantFolding extends Rule[LogicalPlan] {
* }}}
*
* Approach used:
- * - Start from AND operator as the root
- * - Get all the children conjunctive predicates which are EqualTo /
EqualNullSafe such that they
- * don't have a `NOT` or `OR` operator in them
* - Populate a mapping of attribute => constant value by looking at all
the equals predicates
* - Using this mapping, replace occurrence of the attributes with the
corresponding constant values
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
- private def containsNonConjunctionPredicates(expression: Expression):
Boolean = expression.find {
- case _: Not | _: Or => true
- case _ => false
- }.isDefined
-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case f: Filter => f transformExpressionsUp {
- case and: And =>
- val conjunctivePredicates =
- splitConjunctivePredicates(and)
- .filter(expr => expr.isInstanceOf[EqualTo] ||
expr.isInstanceOf[EqualNullSafe])
- .filterNot(expr => containsNonConjunctionPredicates(expr))
-
- val equalityPredicates = conjunctivePredicates.collect {
- case e @ EqualTo(left: AttributeReference, right: Literal) =>
((left, right), e)
- case e @ EqualTo(left: Literal, right: AttributeReference) =>
((right, left), e)
- case e @ EqualNullSafe(left: AttributeReference, right: Literal)
=> ((left, right), e)
- case e @ EqualNullSafe(left: Literal, right: AttributeReference)
=> ((right, left), e)
- }
-
- val constantsMap = AttributeMap(equalityPredicates.map(_._1))
- val predicates = equalityPredicates.map(_._2).toSet
+ case f: Filter =>
+ val (newCondition, _) = traverse(f.condition, replaceChildren = true)
+ if (newCondition.isDefined) {
+ f.copy(condition = newCondition.get)
+ } else {
+ f
+ }
+ }
- def replaceConstants(expression: Expression) = expression
transform {
- case a: AttributeReference =>
- constantsMap.get(a) match {
- case Some(literal) => literal
- case None => a
- }
+ /**
+ * Traverse a condition as a tree and replace attributes with constant
values.
+ * - If the child of [[And]] is [[EqualTo]] or [[EqualNullSafe]],
propagate the mapping
+ * of attribute => constant.
+ * - If the current [[And]] node is not child of another [[And]],
replace occurrence of the
+ * attributes with the corresponding constant values in both children
with propagated mapping.
+ * @param condition condition to be traversed
+ * @param replaceChildren whether to replace attributes with the
corresponding constant values
+ */
+ private def traverse(condition: Expression, replaceChildren: Boolean)
+ : (Option[Expression], Seq[((AttributeReference, Literal),
BinaryComparison)]) =
+ condition match {
+ case e @ EqualTo(left: AttributeReference, right: Literal) => (None,
Seq(((left, right), e)))
+ case e @ EqualTo(left: Literal, right: AttributeReference) => (None,
Seq(((right, left), e)))
+ case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
+ (None, Seq(((left, right), e)))
+ case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
+ (None, Seq(((right, left), e)))
+ case a: And =>
+ val (newLeft, equalityPredicatesLeft) = traverse(a.left, false)
+ val (newRight, equalityPredicatesRight) = traverse(a.right, false)
+ val equalityPredicates = equalityPredicatesLeft ++
equalityPredicatesRight
+ val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
+ Some(And(replaceConstants(newLeft.getOrElse(a.left),
equalityPredicates),
+ replaceConstants(newRight.getOrElse(a.right),
equalityPredicates)))
+ } else {
+ if (newLeft.isDefined || newRight.isDefined) {
+ Some(And(newLeft.getOrElse(a.left),
newRight.getOrElse(a.right)))
+ } else {
+ None
+ }
+ }
+ (newSelf, equalityPredicates)
+ case o: Or =>
+ val (newLeft, _) = traverse(o.left, true)
+ val (newRight, _) = traverse(o.right, true)
+ val newSelf = if (newLeft.isDefined || newRight.isDefined) {
+ Some(Or(left = newLeft.getOrElse(o.left), right =
newRight.getOrElse((o.right))))
+ } else {
+ None
}
+ (newSelf, Seq.empty)
+ case n: Not =>
+ val (newChild, _) = traverse(n.child, true)
+ val newSelf = if (newChild.isDefined) {
+ Some(Not(newChild.get))
+ } else {
+ None
+ }
+ (newSelf, Seq.empty)
--- End diff --
```Scala
val (newChild, _) = traverse(n.child, replaceChildren = true)
(newChild.map(Not), Seq.empty)
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]