peter-toth commented on code in PR #41677:
URL: https://github.com/apache/spark/pull/41677#discussion_r1385031698
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala:
##########
@@ -30,211 +30,382 @@ import org.apache.spark.util.Utils
* This class is used to compute equality of (sub)expression trees.
Expressions can be added
* to this class and they subsequently query for expression equality.
Expression trees are
* considered equal if for the same input(s), the same result is produced.
+ *
+ * Please note that `EquivalentExpressions` is mainly used in subexpression
elimination where common
+ * non-leaf expression subtrees are calculated, but there there is one special
use case in
+ * `PhysicalAggregation` where `EquivalentExpressions` is used as a mutable
set of non-deterministic
+ * expressions. For that special use case we have the `allowLeafExpressions`
config.
*/
class EquivalentExpressions(
- skipForShortcutEnable: Boolean =
SQLConf.get.subexpressionEliminationSkipForShotcutExpr) {
+ skipForShortcutEnable: Boolean =
SQLConf.get.subexpressionEliminationSkipForShotcutExpr,
+ minConditionalCount: Option[Double] =
+
Some(SQLConf.get.subexpressionEliminationMinExpectedConditionalEvaluationCount)
+ .filter(_ >= 0d),
+ allowLeafExpressions: Boolean = false) {
+
+ // The subexpressions are stored by height to speed up certain calculations.
+ private val maps = mutable.ArrayBuffer[mutable.Map[ExpressionEquals,
ExpressionStats]]()
- // For each expression, the set of equivalent expressions.
- private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals,
ExpressionStats]
+ // `EquivalentExpressions` has 2 states internally, it can be either
inflated or not.
+ // The inflated state means that all added expressions have been traversed
recursively and their
+ // subexpressions are also added to `maps`. The idea behind these 2 states
is that when an
+ // expression tree is added we don't need to traverse/record its
subexpressions immediately.
+ // The typical use case of this data structure is that multiple expression
trees are added and
+ // then we want to see the common subexpressions. It might be the case that
the same expression
+ // trees or partly overlapping expressions trees are added multiple times.
With this approach we
+ // just need to record how many times an expression tree is explicitly added
when later when
+ // `getExprState()` or `getCommonSubexpressions()` is called we inflate the
data structure (do the
+ // recursive traversal and record the subexpressions in `inflate()`) if
needed.
+ private var inflated: Boolean = true
/**
- * Adds each expression to this data structure, grouping them with existing
equivalent
- * expressions. Non-recursive.
- * Returns true if there was already a matching expression.
+ * Adds each expression to this data structure and returns true if there was
already a matching
+ * expression.
*/
def addExpr(expr: Expression): Boolean = {
- if (supportedExpression(expr)) {
- updateExprInMap(expr, equivalenceMap)
+ if (supportedExpression(expr) && expr.deterministic) {
+ updateWithExpr(expr, 1, 0d)
} else {
false
}
}
/**
- * Adds or removes an expression to/from the map and updates `useCount`.
- * Returns true
- * - if there was a matching expression in the map before add or
- * - if there remained a matching expression in the map after remove
(`useCount` remained > 0)
- * to indicate there is no need to recurse in `updateExprTree`.
+ * Adds the expression to this data structure, including its children
recursively.
*/
- private def updateExprInMap(
- expr: Expression,
- map: mutable.HashMap[ExpressionEquals, ExpressionStats],
- useCount: Int = 1): Boolean = {
- if (expr.deterministic) {
- val wrapper = ExpressionEquals(expr)
- map.get(wrapper) match {
- case Some(stats) =>
- stats.useCount += useCount
- if (stats.useCount > 0) {
- true
- } else if (stats.useCount == 0) {
- map -= wrapper
- false
- } else {
- // Should not happen
- throw new IllegalStateException(
- s"Cannot update expression: $expr in map: $map with use count:
$useCount")
- }
- case _ =>
- if (useCount > 0) {
- map.put(wrapper, ExpressionStats(expr)(useCount))
- } else {
- // Should not happen
- throw new IllegalStateException(
- s"Cannot update expression: $expr in map: $map with use count:
$useCount")
- }
- false
- }
- } else {
- false
+ def addExprTree(expr: Expression): Unit = {
+ if (supportedExpression(expr)) {
+ updateWithExpr(expr, 1, 0d)
}
}
- /**
- * Adds or removes only expressions which are common in each of given
expressions, in a recursive
- * way.
- * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e +
(c + 1)))`, the common
- * expression `(c + 1)` will be added into `equivalenceMap`.
- *
- * Note that as we don't know in advance if any child node of an expression
will be common across
- * all given expressions, we compute local equivalence maps for all given
expressions and filter
- * only the common nodes.
- * Those common nodes are then removed from the local map and added to the
final map of
- * expressions.
- */
- private def updateCommonExprs(
- exprs: Seq[Expression],
- map: mutable.HashMap[ExpressionEquals, ExpressionStats],
- useCount: Int): Unit = {
- assert(exprs.length > 1)
- var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals,
ExpressionStats]
- updateExprTree(exprs.head, localEquivalenceMap)
-
- exprs.tail.foreach { expr =>
- val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals,
ExpressionStats]
- updateExprTree(expr, otherLocalEquivalenceMap)
- localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
- otherLocalEquivalenceMap.contains(key)
- }
- }
+ private def supportedExpression(expr: Expression) = {
+ !expr.exists {
+ // `LambdaVariable` is usually used as a loop variable, which can't be
evaluated ahead of the
+ // loop. So we can't evaluate sub-expressions containing
`LambdaVariable` at the beginning.
+ case _: LambdaVariable => true
- // Start with the highest expression, remove it from `localEquivalenceMap`
and add it to `map`.
- // The remaining highest expression in `localEquivalenceMap` is also
common expression so loop
- // until `localEquivalenceMap` is not empty.
- var statsOption =
Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
- while (statsOption.nonEmpty) {
- val stats = statsOption.get
- updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)
- updateExprTree(stats.expr, map, useCount)
+ // `PlanExpression` wraps query plan. To compare query plans of
`PlanExpression` on executor,
+ // can cause error like NPE.
+ case _: PlanExpression[_] => Utils.isInRunningSparkTask
- statsOption =
Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
+ case _ => false
}
}
- private def skipForShortcut(expr: Expression): Expression = {
- if (skipForShortcutEnable) {
- // The subexpression may not need to eval even if it appears more than
once.
- // e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if
`a` is true.
- expr match {
- case and: And => and.left
- case or: Or => or.left
- case other => other
- }
- } else {
- expr
+ private def updateWithExpr(
+ expr: Expression,
+ evalCount: Int,
+ condEvalCount: Double): Boolean = {
+ require(evalCount >= 0 && condEvalCount >= 0d)
+
+ inflated = false
+ val wrapper = ExpressionEquals(expr)
+ val map = getMapByHeight(wrapper.height)
+ map.get(wrapper) match {
+ case Some(es) =>
+ es.evalCount += evalCount
+ es.condEvalCount += condEvalCount
+ true
+ case _ =>
+ map(wrapper) = ExpressionStats(expr)(evalCount, condEvalCount, 0, 0d)
+ false
}
}
- // There are some special expressions that we should not recurse into all of
its children.
- // 1. CodegenFallback: it's children will not be used to generate code
(call eval() instead)
- // 2. ConditionalExpression: use its children that will always be
evaluated.
- private def childrenToRecurse(expr: Expression): Seq[Expression] = expr
match {
- case _: CodegenFallback => Nil
- case c: ConditionalExpression =>
c.alwaysEvaluatedInputs.map(skipForShortcut)
- case other => skipForShortcut(other).children
+ private def getMapByHeight(height: Int) = {
+ val index = height - 1
+ while (maps.size <= index) {
+ maps += mutable.Map.empty
+ }
+ maps(index)
}
- // For some special expressions we cannot just recurse into all of its
children, but we can
- // recursively add the common expressions shared between all of its children.
- private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]]
= expr match {
- case _: CodegenFallback => Nil
- case c: ConditionalExpression => c.branchGroups
- case _ => Nil
+ // Iterate expressions from parents to children and fill inflated
`realEvalCount`s and
+ // `realCondEvalCount`s from explicitly added `evalCount`s and
`condEvalCount`s.
+ private def inflate() = {
+ if (!inflated) {
+ maps.reverse.foreach { map =>
+ map.foreach {
+ case (_, es) => inflateExprState(es)
+ case _ =>
+ }
+ }
+ inflated = true
+ }
}
- private def supportedExpression(e: Expression) = {
- !e.exists {
- // `LambdaVariable` is usually used as a loop variable, which can't be
evaluated ahead of the
- // loop. So we can't evaluate sub-expressions containing
`LambdaVariable` at the beginning.
- case _: LambdaVariable => true
+ private def inflateExprState(exprStats: ExpressionStats): Unit = {
+ val expr = exprStats.expr
+ if (!expr.isInstanceOf[LeafExpression] || allowLeafExpressions) {
+ val evalCount = exprStats.evalCount
+ val condEvalCount = exprStats.condEvalCount
- // `PlanExpression` wraps query plan. To compare query plans of
`PlanExpression` on executor,
- // can cause error like NPE.
- case _: PlanExpression[_] => Utils.isInRunningSparkTask
+ exprStats.evalCount = 0
+ exprStats.condEvalCount = 0d
+ exprStats.realEvalCount += evalCount
+ exprStats.realCondEvalCount += condEvalCount
- case _ => false
+ expr match {
+ // CodegenFallback's children will not be used to generate code (call
eval() instead)
+ case _: CodegenFallback =>
+
+ case c: CaseWhen =>
+ // Let's consider `CaseWhen(Seq((w1, t1), (w2, t2), (w3, t3), ...
(wn, tn)), Some(e))`
+ // example and use `Wn`, `Tn` and `E` notations for the local
equivalence maps built from
+ // `wn`, `tn` and `e` expressions respectively.
+ //
+ // Let's try to build a local equivalence map of the above
`CaseWhen` example and then add
+ // that local map to `map`.
+ //
+ // We know that `w1` is surely evaluated so `W1` should be part of
the local map.
+ // We also know that based on the result of `w1` either `t1` or `w2`
is evaluated so the
+ // "intersection" between `T1` and `W2` should be also part of the
local map.
+ // Please note that "intersection" might not describe well the
operation that we need
+ // between `T1` and `W2`. It is an intersection in terms of surely
evaluated
+ // subexpressions between `T1` and `W2` but it is also kind of an
union between
+ // conditionally evaluated subexpressions. See the details in
`intersectWith()`.
+ // So the local map can be calculated as `W1 | (T1 & W2)` so far,
where `|` and `&` mean
+ // the "union" and "intersection" of equivalence maps.
+ // But we can continue the previous logic further because if `w2` is
evaluated, then based
+ // on the result of `w2` either `t2` or `w3` is also evaluated.
+ // So eventually the local equivalence map can be calculated as
+ // `W1 | (T1 & (W2 | (T2 & (W3 | (T3 & ... & (Wn | (Tn & E)))))))`.
+
+ // As `w1` is always evaluated so we can add it immediately to `map`
(instead of adding it
+ // to `localMap`).
+ updateWithExpr(c.branches.head._1, evalCount, condEvalCount)
+
+ val localMap = new EquivalentExpressions
+ if (c.elseValue.isDefined) {
+ localMap.updateWithExpr(c.branches.last._2, evalCount,
condEvalCount)
+ localMap.intersectWithExpr(c.elseValue.get, evalCount,
condEvalCount)
+ } else {
+ localMap.updateWithExpr(c.branches.last._2, 0, (evalCount +
condEvalCount) / 2)
+ }
+ if (c.branches.length > 1) {
+ c.branches.reverse.sliding(2).foreach { case Seq((w, _), (_,
prevt)) =>
+ localMap.updateWithExpr(w, evalCount, condEvalCount)
+ localMap.intersectWithExpr(prevt, evalCount, condEvalCount)
+ }
+ }
+
+ unionWith(localMap)
+
+ case i: If =>
+ updateWithExpr(i.predicate, evalCount, condEvalCount)
+
+ val localMap = new EquivalentExpressions
+ localMap.updateWithExpr(i.trueValue, evalCount, condEvalCount)
+ localMap.intersectWithExpr(i.falseValue, evalCount, condEvalCount)
+
+ unionWith(localMap)
+
+ case a: And if skipForShortcutEnable =>
+ updateWithExpr(a.left, evalCount, condEvalCount)
+ updateWithExpr(a.right, 0, (evalCount + condEvalCount) / 2)
+
+ case o: Or if skipForShortcutEnable =>
+ updateWithExpr(o.left, evalCount, condEvalCount)
+ updateWithExpr(o.right, 0, (evalCount + condEvalCount) / 2)
+
+ case n: NaNvl =>
+ updateWithExpr(n.left, evalCount, condEvalCount)
+ updateWithExpr(n.right, 0, (evalCount + condEvalCount) / 2)
+
+ case c: Coalesce =>
+ updateWithExpr(c.children.head, evalCount, condEvalCount)
+ var cec = evalCount + condEvalCount
+ c.children.tail.foreach {
+ cec /= 2
+ updateWithExpr(_, 0, cec)
+ }
+
+ case e => e.children.foreach(updateWithExpr(_, evalCount,
condEvalCount))
+ }
}
}
+ private def intersectWithExpr(
+ expr: Expression,
+ evalCount: Int,
+ condEvalCount: Double) = {
+ val localMap = new EquivalentExpressions
+ localMap.updateWithExpr(expr, evalCount, condEvalCount)
+ intersectWith(localMap)
+ }
+
/**
- * Adds the expression to this data structure recursively. Stops if a
matching expression
- * is found. That is, if `expr` has already been added, its children are not
added.
+ * This method can be used to compute the equivalence map if there is a
branching in expression
+ * evaluation.
+ * E.g. if we have `If(_, a, b)` expression and `A` and `B` are the
equivalence maps built from
+ * `a` and `b` this method computes the equivalence map `C` in which the
keys are the superset of
+ * expressions from both `A` and `B`. The `evalCount` statistics of
expressions in `C` depends on
+ * whether the expression was present in both `A` and `B` or not.
+ * If an expression was present in both then the result `evalCount` of the
expression is the
+ * minimum of `evalCount`s from `A` and `B` (intersection of equivalence
maps).
+ * For the sake of simplicity branching is modelled with 0.5 / 0.5
probabilities so the
+ * `condEvalCount` statistics of expressions in `C` are calculated by
adjusting both
+ * `condEvalCount` from `A` and `B` by `0.5` and summing them. Also,
difference between
+ * `evalCount` of an expression from `A` and `B` becomes part of
`condEvalCount` and so adjusted
+ * by `0.5`.
+ *
+ * Please note that this method modifies `map` and `otherMap` is no longer
safe to use after this
+ * method.
*/
- def addExprTree(
- expr: Expression,
- map: mutable.HashMap[ExpressionEquals, ExpressionStats] =
equivalenceMap): Unit = {
- if (supportedExpression(expr)) {
- updateExprTree(expr, map)
+ private def intersectWith(other: EquivalentExpressions) = {
+ inflate()
+ other.inflate()
+
+ val zippedMaps = maps.zip(other.maps)
+ zippedMaps.foreach { case (map, otherMap) =>
+ map.foreach { case (key, value) =>
+ otherMap.remove(key) match {
+ case Some(otherValue) =>
+ val (min, max) = if (value.realEvalCount <
otherValue.realEvalCount) {
+ (value.realEvalCount, otherValue.realEvalCount)
+ } else {
+ (otherValue.realEvalCount, value.realEvalCount)
+ }
+ value.realCondEvalCount += otherValue.realCondEvalCount + max - min
Review Comment:
`value.realCondEvalCount = (value.realCondEvalCount +
otherValue.realCondEvalCount + max - min) / 2` is the full calculation, but the
`value.realCondEvalCount /= 2` extracted a bit below.
The `max - min` / 2 is also need. E.g. if we have `If(_, a + b, (a + b) + (a
+ b))` then during the intersect of the `then` and `else` branches we have `a +
b -> 1 + 0` in `then` and `a + b -> 2 + 0` in `else`. The result should be (`a
+ b -> 1 + 0.5)`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]