peter-toth commented on a change in pull request #33281:
URL: https://github.com/apache/spark/pull/33281#discussion_r667364410



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
##########
@@ -70,34 +96,39 @@ class EquivalentExpressions {
    * For example, if `((a + b) + c)` and `(a + b)` are common expressions, we 
only add
    * `((a + b) + c)`.
    */
-  private def addCommonExprs(
+  private def updateCommonExprs(
       exprs: Seq[Expression],
-      map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = {
+      map: mutable.HashMap[ExpressionEquals, ExpressionStats],
+      useCount: Int): Unit = {
     assert(exprs.length > 1)
     var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, 
ExpressionStats]
-    addExprTree(exprs.head, localEquivalenceMap)
+    updateExprTree(exprs.head, localEquivalenceMap)
 
     exprs.tail.foreach { expr =>
       val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, 
ExpressionStats]
-      addExprTree(expr, otherLocalEquivalenceMap)
+      updateExprTree(expr, otherLocalEquivalenceMap)
       localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
         otherLocalEquivalenceMap.contains(key)
       }
     }
 
-    localEquivalenceMap.foreach { case (commonExpr, state) =>
-      val possibleParents = localEquivalenceMap.filter { case (_, v) => 
v.height > state.height }
-      val notChild = possibleParents.forall { case (k, _) =>
-        k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty
-      }
-      if (notChild) {
-        // If the `commonExpr` already appears in the equivalence map, calling 
`addExprTree` will
-        // increase the `useCount` and mark it as a common subexpression. 
Otherwise, `addExprTree`
-        // will recursively add `commonExpr` and its descendant to the 
equivalence map, in case
-        // they also appear in other places. For example, `If(a + b > 1, a + b 
+ c, a + b + c)`,
-        // `a + b` also appears in the condition and should be treated as 
common subexpression.
-        addExprTree(commonExpr.e, map)
-      }
+    // Start with the highest common expression, update `map` with the 
expression and remove it (and
+    // its children recursively if required) from `localEquivalenceMap`. The 
remaining highest
+    // expression in `localEquivalenceMap` is also common expression.
+    var statsOption = 
Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
+    while (statsOption.nonEmpty) {
+      val stats = statsOption.get
+      updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)

Review comment:
       Honestly I'm not sure how significant the difference is. I think if we 
have deep expressions in `localEquivalenceMap` then filtering by height (before 
this PR) might not help a lot.
   The new code in this PR might look a bit complex at first, but actually it 
is very simple, we just remove expressions from the `localEquivalenceMap` with 
the reverse of `addExprTree()`.
   
   This new approach also fixes a bug tested here: 
https://github.com/apache/spark/pull/33281/files#r667343014




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to