maropu commented on a change in pull request #32980:
URL: https://github.com/apache/spark/pull/32980#discussion_r659400977



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
##########
@@ -1049,17 +1095,25 @@ class CodegenContext extends Logging {
     // elimination.
     val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
 
-    val nonSplitExprCode = {
+    val nonSplitCode = {
+      val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
       commonExprs.map { exprs =>
-        val eval = 
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
+        
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
           val eval = exprs.head.genCode(this)
-          // Generate the code for this expression tree.
-          val state = SubExprEliminationState(eval.isNull, eval.value)
+          // Collects other subexpressions from the children.
+          val childrenSubExprs = 
mutable.ArrayBuffer.empty[SubExprEliminationState]
+          exprs.head.foreach {
+            case e if subExprEliminationExprs.contains(e) =>
+              childrenSubExprs += subExprEliminationExprs(e)
+            case _ =>
+          }
+          val state = SubExprEliminationState(eval, 
childrenSubExprs.toSeq.reverse)

Review comment:
       btw, how about moving `.reverse` into the SubExprEliminationState side 
if we always need to sort it;
   ```
   object SubExprEliminationState {
     def apply(eval: ExprCode, children: Seq[SubExprEliminationState]): 
SubExprEliminationState = {
       new SubExprEliminationState(eval, children.reverse)
     }
   }
   ```

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
##########
@@ -1030,11 +1032,55 @@ class CodegenContext extends Logging {
   }
 
   /**
-   * Checks and sets up the state and codegen for subexpression elimination. 
This finds the
-   * common subexpressions, generates the code snippets that evaluate those 
expressions and
-   * populates the mapping of common subexpressions to the generated code 
snippets. The generated
-   * code snippets will be returned and should be inserted into generated 
codes before these
-   * common subexpressions actually are used first time.
+   * Evaluates a sequence of `SubExprEliminationState` which represent 
subexpressions. After
+   * evaluating a subexpression, this method will clean up the code block to 
avoid duplicate
+   * evaluation.
+   */
+  def evaluateSubExprEliminationState(subExprStates: 
Iterable[SubExprEliminationState]): String = {
+    val code = new StringBuilder()
+
+    subExprStates.foreach { state =>
+      val currentCode = evaluateSubExprEliminationState(state.children) + "\n" 
+ state.eval.code
+      code.append(currentCode + "\n")
+      state.eval.code = EmptyBlock
+    }
+
+    code.toString()
+  }
+
+  /**
+   * Checks and sets up the state and codegen for subexpression elimination in 
whole-stage codegen.
+   *
+   * This finds the common subexpressions, generates the code snippets that 
evaluate those
+   * expressions and populates the mapping of common subexpressions to the 
generated code snippets.
+   *
+   * The generated code snippet for subexpression is wrapped in 
`SubExprEliminationState`, which
+   * contains a `ExprCode` and the children `SubExprEliminationState` if any. 
The `ExprCode`

Review comment:
       nit: `a ExprCode` -> `an ExprCode`

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
##########
@@ -1049,17 +1095,25 @@ class CodegenContext extends Logging {
     // elimination.
     val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
 
-    val nonSplitExprCode = {
+    val nonSplitCode = {
+      val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
       commonExprs.map { exprs =>
-        val eval = 
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
+        
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
           val eval = exprs.head.genCode(this)
-          // Generate the code for this expression tree.
-          val state = SubExprEliminationState(eval.isNull, eval.value)
+          // Collects other subexpressions from the children.
+          val childrenSubExprs = 
mutable.ArrayBuffer.empty[SubExprEliminationState]
+          exprs.head.foreach {
+            case e if subExprEliminationExprs.contains(e) =>
+              childrenSubExprs += subExprEliminationExprs(e)
+            case _ =>
+          }
+          val state = SubExprEliminationState(eval, 
childrenSubExprs.toSeq.reverse)

Review comment:
       `childrenSubExprs.toSeq.reverse` -> `childrenSubExprs.reverse`?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
##########
@@ -1049,17 +1095,25 @@ class CodegenContext extends Logging {
     // elimination.
     val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
 
-    val nonSplitExprCode = {
+    val nonSplitCode = {
+      val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
       commonExprs.map { exprs =>
-        val eval = 
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
+        
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
           val eval = exprs.head.genCode(this)
-          // Generate the code for this expression tree.
-          val state = SubExprEliminationState(eval.isNull, eval.value)
+          // Collects other subexpressions from the children.
+          val childrenSubExprs = 
mutable.ArrayBuffer.empty[SubExprEliminationState]
+          exprs.head.foreach {
+            case e if subExprEliminationExprs.contains(e) =>
+              childrenSubExprs += subExprEliminationExprs(e)

Review comment:
       Q: Is it difficult to add some tests for this new behaviour?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
##########
@@ -1030,11 +1032,55 @@ class CodegenContext extends Logging {
   }
 
   /**
-   * Checks and sets up the state and codegen for subexpression elimination. 
This finds the
-   * common subexpressions, generates the code snippets that evaluate those 
expressions and
-   * populates the mapping of common subexpressions to the generated code 
snippets. The generated
-   * code snippets will be returned and should be inserted into generated 
codes before these
-   * common subexpressions actually are used first time.
+   * Evaluates a sequence of `SubExprEliminationState` which represent 
subexpressions. After
+   * evaluating a subexpression, this method will clean up the code block to 
avoid duplicate
+   * evaluation.
+   */
+  def evaluateSubExprEliminationState(subExprStates: 
Iterable[SubExprEliminationState]): String = {

Review comment:
       nit: `Iterable` -> `Seq`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to