peter-toth commented on a change in pull request #24553: [SPARK-27604][SQL]
Enhance constant propagation
URL: https://github.com/apache/spark/pull/24553#discussion_r282534623
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
##########
@@ -55,100 +55,81 @@ object ConstantFolding extends Rule[LogicalPlan] {
}
/**
- * Substitutes [[Attribute Attributes]] which can be statically evaluated with
their corresponding
+ * Substitutes [[Expression Expressions]] which can be statically evaluated
with their corresponding
* value in conjunctive [[Expression Expressions]]
* eg.
* {{{
- * SELECT * FROM table WHERE i = 5 AND j = i + 3
- * ==> SELECT * FROM table WHERE i = 5 AND j = 8
+ * SELECT * FROM table WHERE i = 5 AND j = i + 3 => ... WHERE i
= 5 AND j = 8
+ * SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3 => ... WHERE
abs(i) = 5 AND j <= 8
* }}}
*
* Approach used:
- * - Populate a mapping of attribute => constant value by looking at all the
equals predicates
- * - Using this mapping, replace occurrence of the attributes with the
corresponding constant values
- * in the AND node.
+ * - Populate a mapping of expression => constant value by looking at all the
deterministic equals
+ * predicates
+ * - Using this mapping, replace occurrence of the expressions with the
corresponding constant
+ * values in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter =>
- val (newCondition, _) = traverse(f.condition, replaceChildren = true)
- if (newCondition.isDefined) {
- f.copy(condition = newCondition.get)
- } else {
+ val (newCondition, _) = traverse(f.condition)
+ if (newCondition fastEquals f.condition) {
f
+ } else {
+ f.copy(condition = newCondition)
}
}
- type EqualityPredicates = Seq[((AttributeReference, Literal),
BinaryComparison)]
-
/**
- * Traverse a condition as a tree and replace attributes with constant
values.
+ * Traverse a condition as a tree and replace expressions with constant
values.
* - On matching [[And]], recursively traverse each children and get
propagated mappings.
* If the current node is not child of another [[And]], replace all
occurrences of the
- * attributes with the corresponding constant values.
+ * expressions with the corresponding constant values.
* - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate
the mapping
- * of attribute => constant.
+ * of expression => constant.
* - On matching [[Or]] or [[Not]], recursively traverse each children,
propagate empty mapping.
* - Otherwise, stop traversal and propagate empty mapping.
- * @param condition condition to be traversed
- * @param replaceChildren whether to replace attributes with constant values
in children
+ * @param expression expression to be traversed
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
- * 2. EqualityPredicates: propagated mapping of attribute => constant
+ * 2. Map[Expression, Literal]: propagated mapping of expression =>
constant
*/
- private def traverse(condition: Expression, replaceChildren: Boolean)
- : (Option[Expression], EqualityPredicates) =
- condition match {
- case e @ EqualTo(left: AttributeReference, right: Literal) => (None,
Seq(((left, right), e)))
- case e @ EqualTo(left: Literal, right: AttributeReference) => (None,
Seq(((right, left), e)))
- case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
- (None, Seq(((left, right), e)))
- case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
- (None, Seq(((right, left), e)))
- case a: And =>
- val (newLeft, equalityPredicatesLeft) = traverse(a.left,
replaceChildren = false)
- val (newRight, equalityPredicatesRight) = traverse(a.right,
replaceChildren = false)
- val equalityPredicates = equalityPredicatesLeft ++
equalityPredicatesRight
- val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
- Some(And(replaceConstants(newLeft.getOrElse(a.left),
equalityPredicates),
- replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
- } else {
- if (newLeft.isDefined || newRight.isDefined) {
- Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
- } else {
- None
- }
- }
- (newSelf, equalityPredicates)
- case o: Or =>
- // Ignore the EqualityPredicates from children since they are only
propagated through And.
- val (newLeft, _) = traverse(o.left, replaceChildren = true)
- val (newRight, _) = traverse(o.right, replaceChildren = true)
- val newSelf = if (newLeft.isDefined || newRight.isDefined) {
- Some(Or(left = newLeft.getOrElse(o.left), right =
newRight.getOrElse((o.right))))
+ private def traverse(expression: Expression): (Expression, Map[Expression,
Literal]) =
+ expression match {
+ case e @ EqualTo(left, right: Literal) if !left.foldable &&
left.deterministic =>
+ (e, Map(left.canonicalized -> right))
+ case e @ EqualTo(left: Literal, right) if !right.foldable &&
right.deterministic =>
+ (e, Map(right.canonicalized -> left))
+ case e @ EqualNullSafe(left, right: Literal) if !left.foldable &&
left.deterministic =>
+ (e, Map(left.canonicalized -> right))
+ case e @ EqualNullSafe(left: Literal, right) if !right.foldable &&
right.deterministic =>
+ (e, Map(right.canonicalized -> left))
+ case a @ And(left, right) =>
+ val (newLeft, equalityPredicatesLeft) = traverse(left)
+ val replacedRight = replaceConstants(right, equalityPredicatesLeft)
+ val (replacedNewRight, equalityPredicatesRight) =
traverse(replacedRight)
+ val replacedNewLeft = replaceConstants(newLeft,
equalityPredicatesRight)
+ val newAnd = if ((replacedNewLeft fastEquals left) &&
(replacedNewRight fastEquals right)) {
+ a
} else {
- None
+ And(replacedNewLeft, replacedNewRight)
}
- (newSelf, Seq.empty)
- case n: Not =>
- // Ignore the EqualityPredicates from children since they are only
propagated through And.
- val (newChild, _) = traverse(n.child, replaceChildren = true)
- (newChild.map(Not), Seq.empty)
- case _ => (None, Seq.empty)
+ (newAnd, equalityPredicatesLeft ++= equalityPredicatesRight)
+ case o @ (_: Or | _: Not) =>
+ (o.mapChildren(traverse(_)._1), Map.empty)
+ case _ => (expression, Map.empty)
Review comment:
Actually, I don't see why do we stop traversing at all. For example `b` in
`... WHERE If(a = b AND b = 1, ..., ...)` doesn't get "resolved" now.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]