maropu commented on a change in pull request #33142:
URL: https://github.com/apache/spark/pull/33142#discussion_r661911895



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
##########
@@ -170,65 +171,75 @@ class EquivalentExpressions {
       // can cause error like NPE.
       (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)
 
-    if (!skip && !addFunc(expr)) {
-      childrenToRecurse(expr).foreach(addExprTree(_, addFunc))
-      
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, 
addFunc))
+    if (!skip && !addExprToMap(expr, map)) {
+      val height = childrenToRecurse(expr).map(addExprTree0(_, map))
+        .reduceOption(_ max _).map(_ + 1).getOrElse(0)
+      map(ExpressionEquals(expr)).height = height
+      // `commonChildrenToRecurse` are some additional children to find common 
subexpression, and
+      // we should only use `childrenToRecurse` to calculate the height.
+      
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map))
+      height
+    } else {
+      0
     }
   }
 
   /**
-   * Returns all of the expression trees that are equivalent to `e`. Returns
-   * an empty collection if there are none.
+   * Returns the state of the given expression in the `equivalenceMap`. 
Returns None if there is no
+   * equivalent expressions.
    */
-  def getEquivalentExprs(e: Expression): Seq[Expression] = {
-    equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq
+  def getExprState(e: Expression): Option[ExpressionStats] = {
+    equivalenceMap.get(ExpressionEquals(e))
+  }
+
+  // Exposed for testing.
+  private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = {
+    equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height)
   }
 
   /**
-   * Returns all the equivalent sets of expressions which appear more than 
given `repeatTimes`
-   * times.
+   * Returns a sequence of expressions that more than one equivalent 
expressions.
    */
-  def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = {
-    equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq
-      .sortBy(_.head)(new ExpressionContainmentOrdering)
+  def getCommonSubexpressions: Seq[Expression] = {
+    getAllExprStates(1).map(_.expr)
   }
 
   /**
    * Returns the state of the data structure as a string. If `all` is false, 
skips sets of
    * equivalent expressions with cardinality 1.
    */
   def debugString(all: Boolean = false): String = {
-    val sb: mutable.StringBuilder = new StringBuilder()
+    val sb = new java.lang.StringBuilder()
     sb.append("Equivalent expressions:\n")
-    equivalenceMap.foreach { case (k, v) =>
-      if (all || v.length > 1) {
-        sb.append("  " + v.mkString(", ")).append("\n")
-      }
+    equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { 
stats =>
+      sb.append("  ").append(s"${stats.expr}: useCount = 
${stats.useCount}").append('\n')
     }
     sb.toString()
   }
 }
 
 /**
- * Orders `Expression` by parent/child relations. The child expression is 
smaller
- * than parent expression. If there is child-parent relationships among the 
subexpressions,
- * we want the child expressions come first than parent expressions, so we can 
replace
- * child expressions in parent expressions with subexpression evaluation. Note 
that
- * this is not for general expression ordering. For example, two irrelevant or 
semantically-equal
- * expressions will be considered as equal by this ordering. But for the usage 
here, the order of
- * irrelevant expressions does not matter.
+ * Wrapper around an Expression that provides semantic equality.
  */
-class ExpressionContainmentOrdering extends Ordering[Expression] {
-  override def compare(x: Expression, y: Expression): Int = {
-    if (x.find(_.semanticEquals(y)).isDefined) {
-      // `y` is child expression of `x`.
-      1
-    } else if (y.find(_.semanticEquals(x)).isDefined) {
-      // `x` is child expression of `y`.
-      -1
-    } else {
-      // Irrelevant or semantically-equal expressions
-      0
-    }
+case class ExpressionEquals(e: Expression) {
+  override def equals(o: Any): Boolean = o match {
+    case other: ExpressionEquals => e.semanticEquals(other.e)
+    case _ => false
   }
+
+  override def hashCode: Int = e.semanticHash()
 }
+
+/**
+ * A wrapper in place of using Seq[Expression] to record a group of equivalent 
expressions.
+ *
+ * This saves a lot of memory when there are a lot of expressions in a same 
equivalence group.
+ * Instead of appending to a mutable list/buffer of Expressions, just update 
the "flattened"
+ * useCount in this wrapper in-place.
+ *
+ * This also tracks the "height" of the expression, so that we can return 
expressions with smaller

Review comment:
       `the "height" of the expression` -> `track the "height" of common 
subexpressions`?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
##########
@@ -135,33 +125,47 @@ class EquivalentExpressions {
   // For some special expressions we cannot just recurse into all of its 
children, but we can
   // recursively add the common expressions shared between all of its children.
   private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] 
= expr match {
+    case _: CodegenFallback => Nil

Review comment:
       Is it better to backport this part into branch-3.1/3.0?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
##########
@@ -170,65 +171,75 @@ class EquivalentExpressions {
       // can cause error like NPE.
       (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)
 
-    if (!skip && !addFunc(expr)) {
-      childrenToRecurse(expr).foreach(addExprTree(_, addFunc))
-      
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, 
addFunc))
+    if (!skip && !addExprToMap(expr, map)) {
+      val height = childrenToRecurse(expr).map(addExprTree0(_, map))
+        .reduceOption(_ max _).map(_ + 1).getOrElse(0)
+      map(ExpressionEquals(expr)).height = height

Review comment:
       (My comment is the same with [the @viirya 
one](https://github.com/apache/spark/pull/33142#issuecomment-871634042)) we can 
always judge if an expr is a parent of another expr or not from this `height`? 
It seems this height depends on a `map` state, so a true height value can 
change after the assignment? For this purpose, we cannot simply use the height 
of an expression instead?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to