peter-toth commented on code in PR #41677:
URL: https://github.com/apache/spark/pull/41677#discussion_r1238116148


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala:
##########
@@ -32,209 +32,325 @@ import org.apache.spark.util.Utils
  * considered equal if for the same input(s), the same result is produced.
  */
 class EquivalentExpressions(
-    skipForShortcutEnable: Boolean = 
SQLConf.get.subexpressionEliminationSkipForShotcutExpr) {
+    skipForShortcutEnable: Boolean = 
SQLConf.get.subexpressionEliminationSkipForShotcutExpr,
+    minConditionalCount: Option[Double] =
+      
Some(SQLConf.get.subexpressionEliminationMinExpectedConditionalEvaluationCount)
+        .filter(_ >= 0)) {
 
   // For each expression, the set of equivalent expressions.
   private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, 
ExpressionStats]
 
   /**
-   * Adds each expression to this data structure, grouping them with existing 
equivalent
-   * expressions. Non-recursive.
-   * Returns true if there was already a matching expression.
+   * Adds each expression to this data structure and returns true if there was 
already a matching
+   * expression.
    */
   def addExpr(expr: Expression): Boolean = {
     if (supportedExpression(expr)) {
-      updateExprInMap(expr, equivalenceMap)
+      updateMapWithExpr(equivalenceMap, expr, 1, 0)
     } else {
       false
     }
   }
 
   /**
-   * Adds or removes an expression to/from the map and updates `useCount`.
-   * Returns true
-   * - if there was a matching expression in the map before add or
-   * - if there remained a matching expression in the map after remove 
(`useCount` remained > 0)
-   * to indicate there is no need to recurse in `updateExprTree`.
+   * Adds the expression to this data structure, including its children 
recursively.
    */
-  private def updateExprInMap(
+  def addExprTree(expr: Expression): Unit = {
+    if (supportedExpression(expr)) {
+      updateMapWithExprTree(equivalenceMap, expr, 1, 0)
+    }
+  }
+
+  private def supportedExpression(expr: Expression) = {
+    !expr.exists {
+      // `LambdaVariable` is usually used as a loop variable, which can't be 
evaluated ahead of the
+      // loop. So we can't evaluate sub-expressions containing 
`LambdaVariable` at the beginning.
+      case _: LambdaVariable => true
+
+      // `PlanExpression` wraps query plan. To compare query plans of 
`PlanExpression` on executor,
+      // can cause error like NPE.
+      case _: PlanExpression[_] => Utils.isInRunningSparkTask
+
+      case _ => false
+    }
+  }
+
+  private def updateMapWithExprTree(
+      map: mutable.HashMap[ExpressionEquals, ExpressionStats],
       expr: Expression,
+      evalCount: Int,
+      conditionalEvalCount: Double): Unit = {
+    if (!expr.isInstanceOf[LeafExpression]) {
+      updateMapWithExpr(map, expr, evalCount, conditionalEvalCount)
+
+      expr match {
+        // CodegenFallback's children will not be used to generate code (call 
eval() instead)
+        case _: CodegenFallback =>
+
+        case c: CaseWhen =>
+          // Let's consider `CaseWhen(Seq((w1, t1), (w2, t2), (w3, t3), ... 
(wn, tn)), Some(e))`
+          // example and use `Wn`, `Tn` and `E` notations for the local 
equivalence maps built from
+          // `wn`, `tn` and `e` expressions respectively.
+          //
+          // Let's try to build a local equivalence map of the above 
`CaseWhen` example and then add
+          // that local map to `map`.
+          //
+          // We know that `w1` is surely evaluated so `W1` should be part of 
the local map.
+          // We also know that based on the result of `w1` either `t1` or `w2` 
is evaluated so the
+          // "intersection" between `T1` and `W2` should be also part of the 
local map.
+          // Please note that "intersection" might not describe well the 
operation that we need
+          // between `T1` and `W2`. It is an intersection in terms of surely 
evaluated
+          // subexpressions (`evalCount`) between `T1` and `W2` but it is also 
kind of an union
+          // between conditionally evaluated subexpressions 
(`conditionalEvalCount`). See the
+          // details in `intersectMapWithMap()`.
+          // So the local map can be calculated as `W1 | (T1 & W2)` so far, 
where `|` and `&` mean
+          // the "union" and "intersection" of equivalence maps.
+          // But we can continue the previous logic further because if `w2` is 
evaluated, then based
+          // on the result of `w2` either `t2` or `w3` is also evaluated.
+          // So eventually the local equivalence map can be calculated as
+          // `W1 | (T1 & W2 | T2 & (W3 | T3 & ... & (Wn | Tn & E)))`.

Review Comment:
   Good point, I fixed the comment in 
https://github.com/apache/spark/pull/41677/commits/bebfa21f5c530be39a6eab4bba8e4b78b5322ee0
 to avoid confusion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to