mgaido91 commented on a change in pull request #24553: [SPARK-27604][SQL] 
Enhance constant propagation
URL: https://github.com/apache/spark/pull/24553#discussion_r282058717
 
 

 ##########
 File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 ##########
 @@ -55,100 +55,95 @@ object ConstantFolding extends Rule[LogicalPlan] {
 }
 
 /**
- * Substitutes [[Attribute Attributes]] which can be statically evaluated with 
their corresponding
+ * Substitutes [[Expression Expressions]] which can be statically evaluated 
with their corresponding
  * value in conjunctive [[Expression Expressions]]
  * eg.
  * {{{
- *   SELECT * FROM table WHERE i = 5 AND j = i + 3
- *   ==>  SELECT * FROM table WHERE i = 5 AND j = 8
+ *   SELECT * FROM table WHERE i = 5 AND j = i + 3             =>  ... WHERE i 
= 5 AND j = 8
+ *   SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3  =>  ... WHERE 
abs(i) = 5 AND j <= 8
  * }}}
  *
  * Approach used:
- * - Populate a mapping of attribute => constant value by looking at all the 
equals predicates
- * - Using this mapping, replace occurrence of the attributes with the 
corresponding constant values
- *   in the AND node.
+ * - Populate a mapping of expression => constant value by looking at all the 
deterministic equals
+ *   predicates
+ * - Using this mapping, replace occurrence of the expressions with the 
corresponding constant
+ *   values in the AND node.
  */
 object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case f: Filter =>
-      val (newCondition, _) = traverse(f.condition, replaceChildren = true)
-      if (newCondition.isDefined) {
-        f.copy(condition = newCondition.get)
-      } else {
+      val (newCondition, _) = traverse(f.condition)
+      if (newCondition fastEquals f.condition) {
         f
+      } else {
+        f.copy(condition = newCondition)
       }
   }
 
-  type EqualityPredicates = Seq[((AttributeReference, Literal), 
BinaryComparison)]
-
   /**
-   * Traverse a condition as a tree and replace attributes with constant 
values.
+   * Traverse a condition as a tree and replace expressions with constant 
values.
    * - On matching [[And]], recursively traverse each children and get 
propagated mappings.
    *   If the current node is not child of another [[And]], replace all 
occurrences of the
-   *   attributes with the corresponding constant values.
+   *   expressions with the corresponding constant values.
    * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate 
the mapping
-   *   of attribute => constant.
+   *   of expression => constant.
    * - On matching [[Or]] or [[Not]], recursively traverse each children, 
propagate empty mapping.
    * - Otherwise, stop traversal and propagate empty mapping.
-   * @param condition condition to be traversed
-   * @param replaceChildren whether to replace attributes with constant values 
in children
+   * @param expression expression to be traversed
    * @return A tuple including:
    *         1. Option[Expression]: optional changed condition after traversal
-   *         2. EqualityPredicates: propagated mapping of attribute => constant
+   *         2. Map[Expression, Literal]: propagated mapping of expression => 
constant
    */
-  private def traverse(condition: Expression, replaceChildren: Boolean)
-    : (Option[Expression], EqualityPredicates) =
-    condition match {
-      case e @ EqualTo(left: AttributeReference, right: Literal) => (None, 
Seq(((left, right), e)))
-      case e @ EqualTo(left: Literal, right: AttributeReference) => (None, 
Seq(((right, left), e)))
-      case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
-        (None, Seq(((left, right), e)))
-      case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
-        (None, Seq(((right, left), e)))
-      case a: And =>
-        val (newLeft, equalityPredicatesLeft) = traverse(a.left, 
replaceChildren = false)
-        val (newRight, equalityPredicatesRight) = traverse(a.right, 
replaceChildren = false)
-        val equalityPredicates = equalityPredicatesLeft ++ 
equalityPredicatesRight
-        val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
-          Some(And(replaceConstants(newLeft.getOrElse(a.left), 
equalityPredicates),
-            replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
+  private def traverse(expression: Expression): (Expression, Map[Expression, 
Literal]) =
+    expression match {
+      case e @ EqualTo(left, right: Literal) if !left.foldable && 
left.deterministic =>
+        (e, Map(left.canonicalized -> right))
+      case e @ EqualTo(left: Literal, right) if !right.foldable && 
right.deterministic =>
+        (e, Map(right.canonicalized -> left))
+      case e @ EqualNullSafe(left, right: Literal) if !left.foldable && 
left.deterministic =>
+        (e, Map(left.canonicalized -> right))
+      case e @ EqualNullSafe(left: Literal, right) if !right.foldable && 
right.deterministic =>
+        (e, Map(right.canonicalized -> left))
+      case a @ And(left, right) =>
+        val (newLeft, equalityPredicatesLeft) = traverse(left)
+        val replacedRight = replaceConstants(right, equalityPredicatesLeft)
+        val (replacedNewRight, equalityPredicatesRight) = 
traverse(replacedRight)
+        val replacedNewLeft = replaceConstants(newLeft, 
equalityPredicatesRight)
+        val newAnd = if ((replacedNewLeft fastEquals left) && 
(replacedNewRight fastEquals right)) {
+          a
         } else {
-          if (newLeft.isDefined || newRight.isDefined) {
-            Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
-          } else {
-            None
-          }
+          And(replacedNewLeft, replacedNewRight)
         }
-        (newSelf, equalityPredicates)
-      case o: Or =>
+        (newAnd, equalityPredicatesLeft ++= equalityPredicatesRight)
 
 Review comment:
   yes, this is a `mutable.Map`. What you'd like is that this map is populated 
bottom-up. But the map here is modified, so any other reference to it gets the 
modifications (additions) performed here and I am not sure this pattern is 
safe. Actually I have not been able to figure out an example which may cause 
issues, but I am not confident it is safe. Have you thought about this and can 
you ensure/explain whether this is safe? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to