viirya commented on a change in pull request #33281:
URL: https://github.com/apache/spark/pull/33281#discussion_r667389005
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
##########
@@ -70,34 +96,39 @@ class EquivalentExpressions {
* For example, if `((a + b) + c)` and `(a + b)` are common expressions, we
only add
* `((a + b) + c)`.
*/
- private def addCommonExprs(
+ private def updateCommonExprs(
exprs: Seq[Expression],
- map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = {
+ map: mutable.HashMap[ExpressionEquals, ExpressionStats],
+ useCount: Int): Unit = {
assert(exprs.length > 1)
var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals,
ExpressionStats]
- addExprTree(exprs.head, localEquivalenceMap)
+ updateExprTree(exprs.head, localEquivalenceMap)
exprs.tail.foreach { expr =>
val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals,
ExpressionStats]
- addExprTree(expr, otherLocalEquivalenceMap)
+ updateExprTree(expr, otherLocalEquivalenceMap)
localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
otherLocalEquivalenceMap.contains(key)
}
}
- localEquivalenceMap.foreach { case (commonExpr, state) =>
- val possibleParents = localEquivalenceMap.filter { case (_, v) =>
v.height > state.height }
- val notChild = possibleParents.forall { case (k, _) =>
- k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty
- }
- if (notChild) {
- // If the `commonExpr` already appears in the equivalence map, calling
`addExprTree` will
- // increase the `useCount` and mark it as a common subexpression.
Otherwise, `addExprTree`
- // will recursively add `commonExpr` and its descendant to the
equivalence map, in case
- // they also appear in other places. For example, `If(a + b > 1, a + b
+ c, a + b + c)`,
- // `a + b` also appears in the condition and should be treated as
common subexpression.
- addExprTree(commonExpr.e, map)
- }
+ // Start with the highest common expression, update `map` with the
expression and remove it (and
+ // its children recursively if required) from `localEquivalenceMap`. The
remaining highest
+ // expression in `localEquivalenceMap` is also common expression.
+ var statsOption =
Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
+ while (statsOption.nonEmpty) {
+ val stats = statsOption.get
+ updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)
Review comment:
I see. I'd rather consider it as an improvement as it doesn't cause
query failure or codegen failure, though it fails to identify a common
subexpression.
That's said, we don't need to hurry on this for 3.2.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]