wangyum commented on code in PR #40268:
URL: https://github.com/apache/spark/pull/40268#discussion_r1151346831
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala:
##########
@@ -113,15 +114,13 @@ object ConstantPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(LITERAL, FILTER), ruleId) {
case f: Filter =>
- val (newCondition, _) = traverse(f.condition, replaceChildren = true,
nullIsFalse = true)
- if (newCondition.isDefined) {
- f.copy(condition = newCondition.get)
- } else {
- f
- }
+ f.mapExpressions(traverse(_, nullIsFalse = true, None))
Review Comment:
This is our internal change:
```scala
object ConstantPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(LITERAL, FILTER), ruleId) {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true,
nullIsFalse = true)
if (newCondition.isDefined) {
f.copy(condition = newCondition.get)
} else {
f
}
}
/**
* Traverse a condition as a tree and replace attributes with constant
values.
* - On matching [[And]], recursively traverse each children and get
propagated mappings.
* If the current node is not child of another [[And]], replace all
occurrences of the
* attributes with the corresponding constant values.
* - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate
the mapping
* of attribute => constant.
* - On matching [[Or]] or [[Not]], recursively traverse each children,
propagate empty mapping.
* - Otherwise, stop traversal and propagate empty mapping.
* @param condition condition to be traversed
* @param replaceChildren whether to replace attributes with constant
values in children
* @param nullIsFalse whether a boolean expression result can be
considered to false e.g. in the
* case of `WHERE e`, null result of expression `e`
means the same as if it
* resulted false
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after
traversal
* 2. EqualityPredicates: propagated mapping of attribute =>
constant
*/
private def traverse(condition: Expression, replaceChildren: Boolean,
nullIsFalse: Boolean)
: (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, AttributeMap(Map(left -> (right, e))))
case e @ EqualTo(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, AttributeMap(Map(right -> (left, e))))
case e @ EqualNullSafe(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, AttributeMap(Map(left -> (right, e))))
case e @ EqualNullSafe(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, AttributeMap(Map(right -> (left, e))))
case a: And =>
val (newLeft, equalityPredicatesLeft) =
traverse(a.left, replaceChildren = false, nullIsFalse)
val (newRight, equalityPredicatesRight) =
traverse(a.right, replaceChildren = false, nullIsFalse)
val equalityPredicates = equalityPredicatesLeft ++
equalityPredicatesRight
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
Some(And(replaceConstants(newLeft.getOrElse(a.left),
equalityPredicates),
replaceConstants(newRight.getOrElse(a.right),
equalityPredicates)))
} else {
if (newLeft.isDefined || newRight.isDefined) {
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
} else {
None
}
}
(newSelf, equalityPredicates)
case o: Or =>
// Ignore the EqualityPredicates from children since they are only
propagated through And.
val (newLeft, _) = traverse(o.left, replaceChildren = true,
nullIsFalse)
val (newRight, _) = traverse(o.right, replaceChildren = true,
nullIsFalse)
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
Some(Or(left = newLeft.getOrElse(o.left), right =
newRight.getOrElse((o.right))))
} else {
None
}
(newSelf, AttributeMap.empty)
case n: Not =>
// Ignore the EqualityPredicates from children since they are only
propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true,
nullIsFalse = false)
(newChild.map(Not), AttributeMap.empty)
case _ => (None, AttributeMap.empty)
}
// We need to take into account if an attribute is nullable and the
context of the conjunctive
// expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where
attribute `c` can be
// substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable
then the enclosing
// NOT prevents us to do the substitution as NOT flips the context
(`nullIsFalse`) of what a
// null result of the enclosed expression means.
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
!ar.nullable || nullIsFalse
private def replaceConstants(
condition: Expression,
equalityPredicates: AttributeMap[(Literal, BinaryComparison)]):
Expression = {
val constantsMap = AttributeMap(equalityPredicates.map { case (attr,
(lit, _)) => attr -> lit })
val predicates = equalityPredicates.values.map(_._2).toSet
condition transform {
case b: BinaryComparison if !predicates.contains(b) => b transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
}
}
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]