cloud-fan commented on code in PR #41677:
URL: https://github.com/apache/spark/pull/41677#discussion_r1384971935
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala:
##########
@@ -30,211 +30,382 @@ import org.apache.spark.util.Utils
* This class is used to compute equality of (sub)expression trees.
Expressions can be added
* to this class and they subsequently query for expression equality.
Expression trees are
* considered equal if for the same input(s), the same result is produced.
+ *
+ * Please note that `EquivalentExpressions` is mainly used in subexpression
elimination where common
+ * non-leaf expression subtrees are calculated, but there there is one special
use case in
+ * `PhysicalAggregation` where `EquivalentExpressions` is used as a mutable
set of non-deterministic
+ * expressions. For that special use case we have the `allowLeafExpressions`
config.
*/
class EquivalentExpressions(
- skipForShortcutEnable: Boolean =
SQLConf.get.subexpressionEliminationSkipForShotcutExpr) {
+ skipForShortcutEnable: Boolean =
SQLConf.get.subexpressionEliminationSkipForShotcutExpr,
+ minConditionalCount: Option[Double] =
+
Some(SQLConf.get.subexpressionEliminationMinExpectedConditionalEvaluationCount)
+ .filter(_ >= 0d),
+ allowLeafExpressions: Boolean = false) {
+
+ // The subexpressions are stored by height to speed up certain calculations.
+ private val maps = mutable.ArrayBuffer[mutable.Map[ExpressionEquals,
ExpressionStats]]()
- // For each expression, the set of equivalent expressions.
- private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals,
ExpressionStats]
+ // `EquivalentExpressions` has 2 states internally, it can be either
inflated or not.
+ // The inflated state means that all added expressions have been traversed
recursively and their
+ // subexpressions are also added to `maps`. The idea behind these 2 states
is that when an
+ // expression tree is added we don't need to traverse/record its
subexpressions immediately.
+ // The typical use case of this data structure is that multiple expression
trees are added and
+ // then we want to see the common subexpressions. It might be the case that
the same expression
+ // trees or partly overlapping expressions trees are added multiple times.
With this approach we
+ // just need to record how many times an expression tree is explicitly added
when later when
+ // `getExprState()` or `getCommonSubexpressions()` is called we inflate the
data structure (do the
+ // recursive traversal and record the subexpressions in `inflate()`) if
needed.
+ private var inflated: Boolean = true
/**
- * Adds each expression to this data structure, grouping them with existing
equivalent
- * expressions. Non-recursive.
- * Returns true if there was already a matching expression.
+ * Adds each expression to this data structure and returns true if there was
already a matching
+ * expression.
*/
def addExpr(expr: Expression): Boolean = {
- if (supportedExpression(expr)) {
- updateExprInMap(expr, equivalenceMap)
+ if (supportedExpression(expr) && expr.deterministic) {
+ updateWithExpr(expr, 1, 0d)
} else {
false
}
}
/**
- * Adds or removes an expression to/from the map and updates `useCount`.
- * Returns true
- * - if there was a matching expression in the map before add or
- * - if there remained a matching expression in the map after remove
(`useCount` remained > 0)
- * to indicate there is no need to recurse in `updateExprTree`.
+ * Adds the expression to this data structure, including its children
recursively.
*/
- private def updateExprInMap(
- expr: Expression,
- map: mutable.HashMap[ExpressionEquals, ExpressionStats],
- useCount: Int = 1): Boolean = {
- if (expr.deterministic) {
- val wrapper = ExpressionEquals(expr)
- map.get(wrapper) match {
- case Some(stats) =>
- stats.useCount += useCount
- if (stats.useCount > 0) {
- true
- } else if (stats.useCount == 0) {
- map -= wrapper
- false
- } else {
- // Should not happen
- throw new IllegalStateException(
- s"Cannot update expression: $expr in map: $map with use count:
$useCount")
- }
- case _ =>
- if (useCount > 0) {
- map.put(wrapper, ExpressionStats(expr)(useCount))
- } else {
- // Should not happen
- throw new IllegalStateException(
- s"Cannot update expression: $expr in map: $map with use count:
$useCount")
- }
- false
- }
- } else {
- false
+ def addExprTree(expr: Expression): Unit = {
+ if (supportedExpression(expr)) {
+ updateWithExpr(expr, 1, 0d)
}
}
- /**
- * Adds or removes only expressions which are common in each of given
expressions, in a recursive
- * way.
- * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e +
(c + 1)))`, the common
- * expression `(c + 1)` will be added into `equivalenceMap`.
- *
- * Note that as we don't know in advance if any child node of an expression
will be common across
- * all given expressions, we compute local equivalence maps for all given
expressions and filter
- * only the common nodes.
- * Those common nodes are then removed from the local map and added to the
final map of
- * expressions.
- */
- private def updateCommonExprs(
- exprs: Seq[Expression],
- map: mutable.HashMap[ExpressionEquals, ExpressionStats],
- useCount: Int): Unit = {
- assert(exprs.length > 1)
- var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals,
ExpressionStats]
- updateExprTree(exprs.head, localEquivalenceMap)
-
- exprs.tail.foreach { expr =>
- val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals,
ExpressionStats]
- updateExprTree(expr, otherLocalEquivalenceMap)
- localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
- otherLocalEquivalenceMap.contains(key)
- }
- }
+ private def supportedExpression(expr: Expression) = {
+ !expr.exists {
+ // `LambdaVariable` is usually used as a loop variable, which can't be
evaluated ahead of the
+ // loop. So we can't evaluate sub-expressions containing
`LambdaVariable` at the beginning.
+ case _: LambdaVariable => true
- // Start with the highest expression, remove it from `localEquivalenceMap`
and add it to `map`.
- // The remaining highest expression in `localEquivalenceMap` is also
common expression so loop
- // until `localEquivalenceMap` is not empty.
- var statsOption =
Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
- while (statsOption.nonEmpty) {
- val stats = statsOption.get
- updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)
- updateExprTree(stats.expr, map, useCount)
+ // `PlanExpression` wraps query plan. To compare query plans of
`PlanExpression` on executor,
+ // can cause error like NPE.
+ case _: PlanExpression[_] => Utils.isInRunningSparkTask
- statsOption =
Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
+ case _ => false
}
}
- private def skipForShortcut(expr: Expression): Expression = {
- if (skipForShortcutEnable) {
- // The subexpression may not need to eval even if it appears more than
once.
- // e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if
`a` is true.
- expr match {
- case and: And => and.left
- case or: Or => or.left
- case other => other
- }
- } else {
- expr
+ private def updateWithExpr(
+ expr: Expression,
+ evalCount: Int,
+ condEvalCount: Double): Boolean = {
+ require(evalCount >= 0 && condEvalCount >= 0d)
+
+ inflated = false
+ val wrapper = ExpressionEquals(expr)
+ val map = getMapByHeight(wrapper.height)
+ map.get(wrapper) match {
+ case Some(es) =>
+ es.evalCount += evalCount
+ es.condEvalCount += condEvalCount
+ true
+ case _ =>
+ map(wrapper) = ExpressionStats(expr)(evalCount, condEvalCount, 0, 0d)
+ false
}
}
- // There are some special expressions that we should not recurse into all of
its children.
- // 1. CodegenFallback: it's children will not be used to generate code
(call eval() instead)
- // 2. ConditionalExpression: use its children that will always be
evaluated.
- private def childrenToRecurse(expr: Expression): Seq[Expression] = expr
match {
- case _: CodegenFallback => Nil
- case c: ConditionalExpression =>
c.alwaysEvaluatedInputs.map(skipForShortcut)
- case other => skipForShortcut(other).children
+ private def getMapByHeight(height: Int) = {
+ val index = height - 1
+ while (maps.size <= index) {
+ maps += mutable.Map.empty
+ }
+ maps(index)
}
- // For some special expressions we cannot just recurse into all of its
children, but we can
- // recursively add the common expressions shared between all of its children.
- private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]]
= expr match {
- case _: CodegenFallback => Nil
- case c: ConditionalExpression => c.branchGroups
- case _ => Nil
+ // Iterate expressions from parents to children and fill inflated
`realEvalCount`s and
+ // `realCondEvalCount`s from explicitly added `evalCount`s and
`condEvalCount`s.
+ private def inflate() = {
+ if (!inflated) {
+ maps.reverse.foreach { map =>
+ map.foreach {
+ case (_, es) => inflateExprState(es)
+ case _ =>
+ }
+ }
+ inflated = true
+ }
}
- private def supportedExpression(e: Expression) = {
- !e.exists {
- // `LambdaVariable` is usually used as a loop variable, which can't be
evaluated ahead of the
- // loop. So we can't evaluate sub-expressions containing
`LambdaVariable` at the beginning.
- case _: LambdaVariable => true
+ private def inflateExprState(exprStats: ExpressionStats): Unit = {
+ val expr = exprStats.expr
+ if (!expr.isInstanceOf[LeafExpression] || allowLeafExpressions) {
Review Comment:
what will go wrong if we always allow leaf expressions?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]