peter-toth commented on a change in pull request #24553: [SPARK-27604][SQL] 
Enhance constant propagation
URL: https://github.com/apache/spark/pull/24553#discussion_r282064927
 
 

 ##########
 File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 ##########
 @@ -55,100 +55,95 @@ object ConstantFolding extends Rule[LogicalPlan] {
 }
 
 /**
- * Substitutes [[Attribute Attributes]] which can be statically evaluated with 
their corresponding
+ * Substitutes [[Expression Expressions]] which can be statically evaluated 
with their corresponding
  * value in conjunctive [[Expression Expressions]]
  * eg.
  * {{{
- *   SELECT * FROM table WHERE i = 5 AND j = i + 3
- *   ==>  SELECT * FROM table WHERE i = 5 AND j = 8
+ *   SELECT * FROM table WHERE i = 5 AND j = i + 3             =>  ... WHERE i 
= 5 AND j = 8
+ *   SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3  =>  ... WHERE 
abs(i) = 5 AND j <= 8
  * }}}
  *
  * Approach used:
- * - Populate a mapping of attribute => constant value by looking at all the 
equals predicates
- * - Using this mapping, replace occurrence of the attributes with the 
corresponding constant values
- *   in the AND node.
+ * - Populate a mapping of expression => constant value by looking at all the 
deterministic equals
+ *   predicates
+ * - Using this mapping, replace occurrence of the expressions with the 
corresponding constant
+ *   values in the AND node.
  */
 object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case f: Filter =>
-      val (newCondition, _) = traverse(f.condition, replaceChildren = true)
-      if (newCondition.isDefined) {
-        f.copy(condition = newCondition.get)
-      } else {
+      val (newCondition, _) = traverse(f.condition)
+      if (newCondition fastEquals f.condition) {
         f
+      } else {
+        f.copy(condition = newCondition)
       }
   }
 
-  type EqualityPredicates = Seq[((AttributeReference, Literal), 
BinaryComparison)]
-
   /**
-   * Traverse a condition as a tree and replace attributes with constant 
values.
+   * Traverse a condition as a tree and replace expressions with constant 
values.
    * - On matching [[And]], recursively traverse each children and get 
propagated mappings.
    *   If the current node is not child of another [[And]], replace all 
occurrences of the
-   *   attributes with the corresponding constant values.
+   *   expressions with the corresponding constant values.
    * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate 
the mapping
-   *   of attribute => constant.
+   *   of expression => constant.
    * - On matching [[Or]] or [[Not]], recursively traverse each children, 
propagate empty mapping.
    * - Otherwise, stop traversal and propagate empty mapping.
-   * @param condition condition to be traversed
-   * @param replaceChildren whether to replace attributes with constant values 
in children
+   * @param expression expression to be traversed
    * @return A tuple including:
    *         1. Option[Expression]: optional changed condition after traversal
-   *         2. EqualityPredicates: propagated mapping of attribute => constant
+   *         2. Map[Expression, Literal]: propagated mapping of expression => 
constant
    */
-  private def traverse(condition: Expression, replaceChildren: Boolean)
-    : (Option[Expression], EqualityPredicates) =
-    condition match {
-      case e @ EqualTo(left: AttributeReference, right: Literal) => (None, 
Seq(((left, right), e)))
-      case e @ EqualTo(left: Literal, right: AttributeReference) => (None, 
Seq(((right, left), e)))
-      case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
-        (None, Seq(((left, right), e)))
-      case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
-        (None, Seq(((right, left), e)))
-      case a: And =>
-        val (newLeft, equalityPredicatesLeft) = traverse(a.left, 
replaceChildren = false)
-        val (newRight, equalityPredicatesRight) = traverse(a.right, 
replaceChildren = false)
-        val equalityPredicates = equalityPredicatesLeft ++ 
equalityPredicatesRight
-        val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
-          Some(And(replaceConstants(newLeft.getOrElse(a.left), 
equalityPredicates),
-            replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
+  private def traverse(expression: Expression): (Expression, Map[Expression, 
Literal]) =
+    expression match {
+      case e @ EqualTo(left, right: Literal) if !left.foldable && 
left.deterministic =>
+        (e, Map(left.canonicalized -> right))
+      case e @ EqualTo(left: Literal, right) if !right.foldable && 
right.deterministic =>
+        (e, Map(right.canonicalized -> left))
+      case e @ EqualNullSafe(left, right: Literal) if !left.foldable && 
left.deterministic =>
+        (e, Map(left.canonicalized -> right))
+      case e @ EqualNullSafe(left: Literal, right) if !right.foldable && 
right.deterministic =>
+        (e, Map(right.canonicalized -> left))
+      case a @ And(left, right) =>
+        val (newLeft, equalityPredicatesLeft) = traverse(left)
+        val replacedRight = replaceConstants(right, equalityPredicatesLeft)
 
 Review comment:
   Thanks, fixed.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to