peter-toth commented on a change in pull request #24553: [SPARK-27604][SQL] 
Enhance constant propagation
URL: https://github.com/apache/spark/pull/24553#discussion_r282068349
 
 

 ##########
 File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 ##########
 @@ -55,100 +55,95 @@ object ConstantFolding extends Rule[LogicalPlan] {
 }
 
 /**
- * Substitutes [[Attribute Attributes]] which can be statically evaluated with 
their corresponding
+ * Substitutes [[Expression Expressions]] which can be statically evaluated 
with their corresponding
  * value in conjunctive [[Expression Expressions]]
  * eg.
  * {{{
- *   SELECT * FROM table WHERE i = 5 AND j = i + 3
- *   ==>  SELECT * FROM table WHERE i = 5 AND j = 8
+ *   SELECT * FROM table WHERE i = 5 AND j = i + 3             =>  ... WHERE i 
= 5 AND j = 8
+ *   SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3  =>  ... WHERE 
abs(i) = 5 AND j <= 8
  * }}}
  *
  * Approach used:
- * - Populate a mapping of attribute => constant value by looking at all the 
equals predicates
- * - Using this mapping, replace occurrence of the attributes with the 
corresponding constant values
- *   in the AND node.
+ * - Populate a mapping of expression => constant value by looking at all the 
deterministic equals
+ *   predicates
+ * - Using this mapping, replace occurrence of the expressions with the 
corresponding constant
+ *   values in the AND node.
  */
 object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case f: Filter =>
-      val (newCondition, _) = traverse(f.condition, replaceChildren = true)
-      if (newCondition.isDefined) {
-        f.copy(condition = newCondition.get)
-      } else {
+      val (newCondition, _) = traverse(f.condition)
+      if (newCondition fastEquals f.condition) {
         f
+      } else {
+        f.copy(condition = newCondition)
       }
   }
 
-  type EqualityPredicates = Seq[((AttributeReference, Literal), 
BinaryComparison)]
-
   /**
-   * Traverse a condition as a tree and replace attributes with constant 
values.
+   * Traverse a condition as a tree and replace expressions with constant 
values.
    * - On matching [[And]], recursively traverse each children and get 
propagated mappings.
    *   If the current node is not child of another [[And]], replace all 
occurrences of the
-   *   attributes with the corresponding constant values.
+   *   expressions with the corresponding constant values.
    * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate 
the mapping
-   *   of attribute => constant.
+   *   of expression => constant.
    * - On matching [[Or]] or [[Not]], recursively traverse each children, 
propagate empty mapping.
    * - Otherwise, stop traversal and propagate empty mapping.
-   * @param condition condition to be traversed
-   * @param replaceChildren whether to replace attributes with constant values 
in children
+   * @param expression expression to be traversed
    * @return A tuple including:
    *         1. Option[Expression]: optional changed condition after traversal
-   *         2. EqualityPredicates: propagated mapping of attribute => constant
+   *         2. Map[Expression, Literal]: propagated mapping of expression => 
constant
    */
-  private def traverse(condition: Expression, replaceChildren: Boolean)
-    : (Option[Expression], EqualityPredicates) =
-    condition match {
-      case e @ EqualTo(left: AttributeReference, right: Literal) => (None, 
Seq(((left, right), e)))
-      case e @ EqualTo(left: Literal, right: AttributeReference) => (None, 
Seq(((right, left), e)))
-      case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
-        (None, Seq(((left, right), e)))
-      case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
-        (None, Seq(((right, left), e)))
-      case a: And =>
-        val (newLeft, equalityPredicatesLeft) = traverse(a.left, 
replaceChildren = false)
-        val (newRight, equalityPredicatesRight) = traverse(a.right, 
replaceChildren = false)
-        val equalityPredicates = equalityPredicatesLeft ++ 
equalityPredicatesRight
-        val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
-          Some(And(replaceConstants(newLeft.getOrElse(a.left), 
equalityPredicates),
-            replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
+  private def traverse(expression: Expression): (Expression, Map[Expression, 
Literal]) =
+    expression match {
+      case e @ EqualTo(left, right: Literal) if !left.foldable && 
left.deterministic =>
+        (e, Map(left.canonicalized -> right))
+      case e @ EqualTo(left: Literal, right) if !right.foldable && 
right.deterministic =>
+        (e, Map(right.canonicalized -> left))
+      case e @ EqualNullSafe(left, right: Literal) if !left.foldable && 
left.deterministic =>
+        (e, Map(left.canonicalized -> right))
+      case e @ EqualNullSafe(left: Literal, right) if !right.foldable && 
right.deterministic =>
+        (e, Map(right.canonicalized -> left))
+      case a @ And(left, right) =>
+        val (newLeft, equalityPredicatesLeft) = traverse(left)
+        val replacedRight = replaceConstants(right, equalityPredicatesLeft)
+        val (replacedNewRight, equalityPredicatesRight) = 
traverse(replacedRight)
+        val replacedNewLeft = replaceConstants(newLeft, 
equalityPredicatesRight)
+        val newAnd = if ((replacedNewLeft fastEquals left) && 
(replacedNewRight fastEquals right)) {
+          a
         } else {
-          if (newLeft.isDefined || newRight.isDefined) {
-            Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
-          } else {
-            None
-          }
+          And(replacedNewLeft, replacedNewRight)
         }
-        (newSelf, equalityPredicates)
-      case o: Or =>
+        (newAnd, equalityPredicatesLeft ++= equalityPredicatesRight)
 
 Review comment:
   I see your point, but isn't that fact that I use `equalityPredicatesLeft` 
for replacing before I update it's content with `equalityPredicatesRight` and 
propagate it up? I mean if `transformUp` were executed on a different thread 
then it would be an issue, but that's not the case IMHO.
   Do you still have this concern?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to