cloud-fan commented on code in PR #41677:
URL: https://github.com/apache/spark/pull/41677#discussion_r1384971935


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala:
##########
@@ -30,211 +30,382 @@ import org.apache.spark.util.Utils
  * This class is used to compute equality of (sub)expression trees. 
Expressions can be added
  * to this class and they subsequently query for expression equality. 
Expression trees are
  * considered equal if for the same input(s), the same result is produced.
+ *
+ * Please note that `EquivalentExpressions` is mainly used in subexpression 
elimination where common
+ * non-leaf expression subtrees are calculated, but there there is one special 
use case in
+ * `PhysicalAggregation` where `EquivalentExpressions` is used as a mutable 
set of non-deterministic
+ * expressions. For that special use case we have the `allowLeafExpressions` 
config.
  */
 class EquivalentExpressions(
-    skipForShortcutEnable: Boolean = 
SQLConf.get.subexpressionEliminationSkipForShotcutExpr) {
+    skipForShortcutEnable: Boolean = 
SQLConf.get.subexpressionEliminationSkipForShotcutExpr,
+    minConditionalCount: Option[Double] =
+      
Some(SQLConf.get.subexpressionEliminationMinExpectedConditionalEvaluationCount)
+        .filter(_ >= 0d),
+    allowLeafExpressions: Boolean = false) {
+
+  // The subexpressions are stored by height to speed up certain calculations.
+  private val maps = mutable.ArrayBuffer[mutable.Map[ExpressionEquals, 
ExpressionStats]]()
 
-  // For each expression, the set of equivalent expressions.
-  private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, 
ExpressionStats]
+  // `EquivalentExpressions` has 2 states internally, it can be either 
inflated or not.
+  // The inflated state means that all added expressions have been traversed 
recursively and their
+  // subexpressions are also added to `maps`. The idea behind these 2 states 
is that when an
+  // expression tree is added we don't need to traverse/record its 
subexpressions immediately.
+  // The typical use case of this data structure is that multiple expression 
trees are added and
+  // then we want to see the common subexpressions. It might be the case that 
the same expression
+  // trees or partly overlapping expressions trees are added multiple times. 
With this approach we
+  // just need to record how many times an expression tree is explicitly added 
when later when
+  // `getExprState()` or `getCommonSubexpressions()` is called we inflate the 
data structure (do the
+  // recursive traversal and record the subexpressions in `inflate()`) if 
needed.
+  private var inflated: Boolean = true
 
   /**
-   * Adds each expression to this data structure, grouping them with existing 
equivalent
-   * expressions. Non-recursive.
-   * Returns true if there was already a matching expression.
+   * Adds each expression to this data structure and returns true if there was 
already a matching
+   * expression.
    */
   def addExpr(expr: Expression): Boolean = {
-    if (supportedExpression(expr)) {
-      updateExprInMap(expr, equivalenceMap)
+    if (supportedExpression(expr) && expr.deterministic) {
+      updateWithExpr(expr, 1, 0d)
     } else {
       false
     }
   }
 
   /**
-   * Adds or removes an expression to/from the map and updates `useCount`.
-   * Returns true
-   * - if there was a matching expression in the map before add or
-   * - if there remained a matching expression in the map after remove 
(`useCount` remained > 0)
-   * to indicate there is no need to recurse in `updateExprTree`.
+   * Adds the expression to this data structure, including its children 
recursively.
    */
-  private def updateExprInMap(
-      expr: Expression,
-      map: mutable.HashMap[ExpressionEquals, ExpressionStats],
-      useCount: Int = 1): Boolean = {
-    if (expr.deterministic) {
-      val wrapper = ExpressionEquals(expr)
-      map.get(wrapper) match {
-        case Some(stats) =>
-          stats.useCount += useCount
-          if (stats.useCount > 0) {
-            true
-          } else if (stats.useCount == 0) {
-            map -= wrapper
-            false
-          } else {
-            // Should not happen
-            throw new IllegalStateException(
-              s"Cannot update expression: $expr in map: $map with use count: 
$useCount")
-          }
-        case _ =>
-          if (useCount > 0) {
-            map.put(wrapper, ExpressionStats(expr)(useCount))
-          } else {
-            // Should not happen
-            throw new IllegalStateException(
-              s"Cannot update expression: $expr in map: $map with use count: 
$useCount")
-          }
-          false
-      }
-    } else {
-      false
+  def addExprTree(expr: Expression): Unit = {
+    if (supportedExpression(expr)) {
+      updateWithExpr(expr, 1, 0d)
     }
   }
 
-  /**
-   * Adds or removes only expressions which are common in each of given 
expressions, in a recursive
-   * way.
-   * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + 
(c + 1)))`, the common
-   * expression `(c + 1)` will be added into `equivalenceMap`.
-   *
-   * Note that as we don't know in advance if any child node of an expression 
will be common across
-   * all given expressions, we compute local equivalence maps for all given 
expressions and filter
-   * only the common nodes.
-   * Those common nodes are then removed from the local map and added to the 
final map of
-   * expressions.
-   */
-  private def updateCommonExprs(
-      exprs: Seq[Expression],
-      map: mutable.HashMap[ExpressionEquals, ExpressionStats],
-      useCount: Int): Unit = {
-    assert(exprs.length > 1)
-    var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, 
ExpressionStats]
-    updateExprTree(exprs.head, localEquivalenceMap)
-
-    exprs.tail.foreach { expr =>
-      val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, 
ExpressionStats]
-      updateExprTree(expr, otherLocalEquivalenceMap)
-      localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
-        otherLocalEquivalenceMap.contains(key)
-      }
-    }
+  private def supportedExpression(expr: Expression) = {
+    !expr.exists {
+      // `LambdaVariable` is usually used as a loop variable, which can't be 
evaluated ahead of the
+      // loop. So we can't evaluate sub-expressions containing 
`LambdaVariable` at the beginning.
+      case _: LambdaVariable => true
 
-    // Start with the highest expression, remove it from `localEquivalenceMap` 
and add it to `map`.
-    // The remaining highest expression in `localEquivalenceMap` is also 
common expression so loop
-    // until `localEquivalenceMap` is not empty.
-    var statsOption = 
Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
-    while (statsOption.nonEmpty) {
-      val stats = statsOption.get
-      updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)
-      updateExprTree(stats.expr, map, useCount)
+      // `PlanExpression` wraps query plan. To compare query plans of 
`PlanExpression` on executor,
+      // can cause error like NPE.
+      case _: PlanExpression[_] => Utils.isInRunningSparkTask
 
-      statsOption = 
Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
+      case _ => false
     }
   }
 
-  private def skipForShortcut(expr: Expression): Expression = {
-    if (skipForShortcutEnable) {
-      // The subexpression may not need to eval even if it appears more than 
once.
-      // e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if 
`a` is true.
-      expr match {
-        case and: And => and.left
-        case or: Or => or.left
-        case other => other
-      }
-    } else {
-      expr
+  private def updateWithExpr(
+      expr: Expression,
+      evalCount: Int,
+      condEvalCount: Double): Boolean = {
+    require(evalCount >= 0 && condEvalCount >= 0d)
+
+    inflated = false
+    val wrapper = ExpressionEquals(expr)
+    val map = getMapByHeight(wrapper.height)
+    map.get(wrapper) match {
+      case Some(es) =>
+        es.evalCount += evalCount
+        es.condEvalCount += condEvalCount
+        true
+      case _ =>
+        map(wrapper) = ExpressionStats(expr)(evalCount, condEvalCount, 0, 0d)
+        false
     }
   }
 
-  // There are some special expressions that we should not recurse into all of 
its children.
-  //   1. CodegenFallback: it's children will not be used to generate code 
(call eval() instead)
-  //   2. ConditionalExpression: use its children that will always be 
evaluated.
-  private def childrenToRecurse(expr: Expression): Seq[Expression] = expr 
match {
-    case _: CodegenFallback => Nil
-    case c: ConditionalExpression => 
c.alwaysEvaluatedInputs.map(skipForShortcut)
-    case other => skipForShortcut(other).children
+  private def getMapByHeight(height: Int) = {
+    val index = height - 1
+    while (maps.size <= index) {
+      maps += mutable.Map.empty
+    }
+    maps(index)
   }
 
-  // For some special expressions we cannot just recurse into all of its 
children, but we can
-  // recursively add the common expressions shared between all of its children.
-  private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] 
= expr match {
-    case _: CodegenFallback => Nil
-    case c: ConditionalExpression => c.branchGroups
-    case _ => Nil
+  // Iterate expressions from parents to children and fill inflated 
`realEvalCount`s and
+  // `realCondEvalCount`s from explicitly added `evalCount`s and 
`condEvalCount`s.
+  private def inflate() = {
+    if (!inflated) {
+      maps.reverse.foreach { map =>
+        map.foreach {
+          case (_, es) => inflateExprState(es)
+          case _ =>
+        }
+      }
+      inflated = true
+    }
   }
 
-  private def supportedExpression(e: Expression) = {
-    !e.exists {
-      // `LambdaVariable` is usually used as a loop variable, which can't be 
evaluated ahead of the
-      // loop. So we can't evaluate sub-expressions containing 
`LambdaVariable` at the beginning.
-      case _: LambdaVariable => true
+  private def inflateExprState(exprStats: ExpressionStats): Unit = {
+    val expr = exprStats.expr
+    if (!expr.isInstanceOf[LeafExpression] || allowLeafExpressions) {

Review Comment:
   what will go wrong if we always allow leaf expressions?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to