cfmcgrady commented on pull request #30565:
URL: https://github.com/apache/spark/pull/30565#issuecomment-882415014


   hi, @viirya 
   
   It's possible to eliminate subexpression in stage scope? currently, 
subexpression elimination will separately eliminate the expression in Project 
and Filter.
   
   For example: (copy from jira ticket 
[SPARK-35882](https://issues.apache.org/jira/browse/SPARK-35882))
   ```scala
       @tailrec
       def fib(index: Int, prev: Int, current: Int): Int = {
         if (index == 0) {
           current
         } else {
           fib(index - 1, prev = prev + current, current = prev)
         }
       }
       val getFibonacci: Int => Seq[Int] = (index: Int) => {
         Seq(fib(index, prev = 1, current = 0),
           fib(index + 1, prev = 1, current = 0),
           fib(index + 2, prev = 1, current = 0))
       }
   
       import org.apache.spark.sql.functions._
       val udfFib = udf(getFibonacci)
   
       val df = spark.range(1, 5)
         .withColumn("fib", udfFib(col("id")))
         .withColumn("fib2", explode(col("fib")))
       df.explain(true)
   ```
   
   ```
   == Parsed Logical Plan ==
   'Project [id#2L, fib#4, explode('fib) AS fib2#13]
   +- Project [id#2L, UDF(cast(id#2L as int)) AS fib#4]
      +- Range (1, 5, step=1, splits=Some(2))
   
   == Analyzed Logical Plan ==
   id: bigint, fib: array<int>, fib2: int
   Project [id#2L, fib#4, fib2#14]
   +- Generate explode(fib#4), false, [fib2#14]
      +- Project [id#2L, UDF(cast(id#2L as int)) AS fib#4]
         +- Range (1, 5, step=1, splits=Some(2))
   
   == Optimized Logical Plan ==
   Generate explode(fib#4), false, [fib2#14]
   +- Project [id#2L, UDF(cast(id#2L as int)) AS fib#4]
      +- Filter ((size(UDF(cast(id#2L as int)), true) > 0) AND 
isnotnull(UDF(cast(id#2L as int))))
         +- Range (1, 5, step=1, splits=Some(2))
   
   == Physical Plan ==
   *(1) Generate explode(fib#4), [id#2L, fib#4], false, [fib2#14]
   +- *(1) Project [id#2L, UDF(cast(id#2L as int)) AS fib#4]
      +- *(1) Filter ((size(UDF(cast(id#2L as int)), true) > 0) AND 
isnotnull(UDF(cast(id#2L as int))))
         +- *(1) Range (1, 5, step=1, splits=2)
   ```
   
   ```java
   /* 005 */ // codegenStageId=1
   /* 006 */ final class GeneratedIteratorForCodegenStage1 extends 
org.apache.spark.sql.execution.BufferedRowIterator {
   /* 035 */
   /* 036 */   private void project_doConsume_0(long project_expr_0_0) throws 
java.io.IOException {
   /* 037 */     // common sub-expressions
   /* 038 */
   /* 039 */     boolean project_isNull_3 = false;
   /* 040 */     int project_value_3 = -1;
   /* 041 */     if (!false) {
   /* 042 */       project_value_3 = (int) project_expr_0_0;
   /* 043 */     }
   /* 044 */
   /* 045 */     Integer project_conv_0 = project_value_3;
   /* 046 */     Object project_arg_0 = project_isNull_3 ? null : 
project_conv_0;
   /* 047 */
   /* 048 */     ArrayData project_result_0 = null;
   /* 049 */     try {
   /* 050 */       project_result_0 = (ArrayData)((scala.Function1[]) 
references[4] /* converters */)[1].apply(((scala.Function1) references[5] /* 
udf */).apply(project_arg_0));
   /* 051 */     } catch (Throwable e) {
   /* 052 */       throw 
QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
   /* 053 */         "SimpleSuite$$Lambda$839/2056566350", "int", "array<int>", 
e);
   /* 054 */     }
   /* 055 */
   /* 056 */     boolean project_isNull_2 = project_result_0 == null;
   /* 057 */     ArrayData project_value_2 = null;
   /* 058 */     if (!project_isNull_2) {
   /* 059 */       project_value_2 = project_result_0;
   /* 060 */     }
   /* 061 */
   /* 062 */     generate_doConsume_0(project_expr_0_0, project_value_2, 
project_isNull_2);
   /* 063 */
   /* 064 */   }
   /* 065 */
   /* 110 */   protected void processNext() throws java.io.IOException {
   /* 111 */     // initialize Range
   /* 112 */     if (!range_initRange_0) {
   /* 113 */       range_initRange_0 = true;
   /* 114 */       initRange(partitionIndex);
   /* 115 */     }
   /* 116 */
   /* 117 */     while (true) {
   /* 118 */       if (range_nextIndex_0 == range_batchEnd_0) {
   /* 119 */         long range_nextBatchTodo_0;
   /* 120 */         if (range_numElementsTodo_0 > 1000L) {
   /* 121 */           range_nextBatchTodo_0 = 1000L;
   /* 122 */           range_numElementsTodo_0 -= 1000L;
   /* 123 */         } else {
   /* 124 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
   /* 125 */           range_numElementsTodo_0 = 0;
   /* 126 */           if (range_nextBatchTodo_0 == 0) break;
   /* 127 */         }
   /* 128 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
   /* 129 */       }
   /* 130 */
   /* 131 */       int range_localEnd_0 = (int)((range_batchEnd_0 - 
range_nextIndex_0) / 1L);
   /* 132 */       for (int range_localIdx_0 = 0; range_localIdx_0 < 
range_localEnd_0; range_localIdx_0++) {
   /* 133 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + 
range_nextIndex_0;
   /* 134 */
   /* 135 */         do {
   /* 136 */           // common subexpressions
   /* 137 */
   /* 138 */           boolean filter_isNull_1 = false;
   /* 139 */           int filter_value_1 = -1;
   /* 140 */           if (!false) {
   /* 141 */             filter_value_1 = (int) range_value_0;
   /* 142 */           }
   /* 143 */
   /* 144 */           Integer filter_conv_0 = filter_value_1;
   /* 145 */           Object filter_arg_0 = filter_isNull_1 ? null : 
filter_conv_0;
   /* 146 */
   /* 147 */           ArrayData filter_result_0 = null;
   /* 148 */           try {
   /* 149 */             filter_result_0 = (ArrayData)((scala.Function1[]) 
references[2] /* converters */)[1].apply(((scala.Function1) references[3] /* 
udf */).apply(filter_arg_0));
   /* 150 */           } catch (Throwable e) {
   /* 151 */             throw 
QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
   /* 152 */               "SimpleSuite$$Lambda$839/2056566350", "int", 
"array<int>", e);
   /* 153 */           }
   /* 154 */
   /* 155 */           boolean filter_isNull_0 = filter_result_0 == null;
   /* 156 */           ArrayData filter_value_0 = null;
   /* 157 */           if (!filter_isNull_0) {
   /* 158 */             filter_value_0 = filter_result_0;
   /* 159 */           }
   /* 160 */
   /* 161 */           // end of common subexpressions
   /* 162 */
   /* 163 */           // RequiredVariables
   /* 164 */
   /* 165 */           // end of RequiredVariables
   /* 166 */           // ev code
   /* 167 */           boolean filter_isNull_4 = false;
   /* 168 */
   /* 169 */         int filter_value_4 = filter_isNull_0 ? -1 :
   /* 170 */           (filter_value_0).numElements();
   /* 171 */
   /* 172 */           boolean filter_value_3 = false;
   /* 173 */           filter_value_3 = filter_value_4 > 0;
   /* 174 */           // end of ev code
   /* 175 */           if (!filter_value_3) continue;
   /* 176 */           // common subexpressions
   /* 177 */
   /* 178 */           // end of common subexpressions
   /* 179 */
   /* 180 */           // RequiredVariables
   /* 181 */
   /* 182 */           // end of RequiredVariables
   /* 183 */           // ev code
   /* 184 */           boolean filter_value_7 = !filter_isNull_0;
   /* 185 */           // end of ev code
   /* 186 */           if (!filter_value_7) continue;
   /* 187 */
   /* 188 */           ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[1] /* numOutputRows */).add(1);
   /* 189 */
   /* 190 */           project_doConsume_0(range_value_0);
   /* 191 */
   /* 192 */         } while(false);
   /* 193 */
   /* 194 */         if (shouldStop()) {
   /* 195 */           range_nextIndex_0 = range_value_0 + 1L;
   /* 196 */           ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
   /* 197 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 
1);
   /* 198 */           return;
   /* 199 */         }
   /* 200 */
   /* 201 */       }
   /* 202 */       range_nextIndex_0 = range_batchEnd_0;
   /* 203 */       ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(range_localEnd_0);
   /* 204 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
   /* 205 */       range_taskContext_0.killTaskIfInterrupted();
   /* 206 */     }
   /* 207 */   }
   /* 208 */
   /* 284 */ }
   ```
   
   From generate code, we found that the `fib` executes twice(L50 and L149). if 
`fib` only executes once in the `WholeStageCodegen`, it can improve performance 
when `fib` is an expensive UDF.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to