cfmcgrady commented on pull request #30565: URL: https://github.com/apache/spark/pull/30565#issuecomment-882415014
hi, @viirya It's possible to eliminate subexpression in stage scope? currently, subexpression elimination will separately eliminate the expression in Project and Filter. For example: (copy from jira ticket [SPARK-35882](https://issues.apache.org/jira/browse/SPARK-35882)) ```scala @tailrec def fib(index: Int, prev: Int, current: Int): Int = { if (index == 0) { current } else { fib(index - 1, prev = prev + current, current = prev) } } val getFibonacci: Int => Seq[Int] = (index: Int) => { Seq(fib(index, prev = 1, current = 0), fib(index + 1, prev = 1, current = 0), fib(index + 2, prev = 1, current = 0)) } import org.apache.spark.sql.functions._ val udfFib = udf(getFibonacci) val df = spark.range(1, 5) .withColumn("fib", udfFib(col("id"))) .withColumn("fib2", explode(col("fib"))) df.explain(true) ``` ``` == Parsed Logical Plan == 'Project [id#2L, fib#4, explode('fib) AS fib2#13] +- Project [id#2L, UDF(cast(id#2L as int)) AS fib#4] +- Range (1, 5, step=1, splits=Some(2)) == Analyzed Logical Plan == id: bigint, fib: array<int>, fib2: int Project [id#2L, fib#4, fib2#14] +- Generate explode(fib#4), false, [fib2#14] +- Project [id#2L, UDF(cast(id#2L as int)) AS fib#4] +- Range (1, 5, step=1, splits=Some(2)) == Optimized Logical Plan == Generate explode(fib#4), false, [fib2#14] +- Project [id#2L, UDF(cast(id#2L as int)) AS fib#4] +- Filter ((size(UDF(cast(id#2L as int)), true) > 0) AND isnotnull(UDF(cast(id#2L as int)))) +- Range (1, 5, step=1, splits=Some(2)) == Physical Plan == *(1) Generate explode(fib#4), [id#2L, fib#4], false, [fib2#14] +- *(1) Project [id#2L, UDF(cast(id#2L as int)) AS fib#4] +- *(1) Filter ((size(UDF(cast(id#2L as int)), true) > 0) AND isnotnull(UDF(cast(id#2L as int)))) +- *(1) Range (1, 5, step=1, splits=2) ``` ```java /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 035 */ /* 036 */ private void project_doConsume_0(long project_expr_0_0) throws java.io.IOException { /* 037 */ // common sub-expressions /* 038 */ /* 039 */ boolean project_isNull_3 = false; /* 040 */ int project_value_3 = -1; /* 041 */ if (!false) { /* 042 */ project_value_3 = (int) project_expr_0_0; /* 043 */ } /* 044 */ /* 045 */ Integer project_conv_0 = project_value_3; /* 046 */ Object project_arg_0 = project_isNull_3 ? null : project_conv_0; /* 047 */ /* 048 */ ArrayData project_result_0 = null; /* 049 */ try { /* 050 */ project_result_0 = (ArrayData)((scala.Function1[]) references[4] /* converters */)[1].apply(((scala.Function1) references[5] /* udf */).apply(project_arg_0)); /* 051 */ } catch (Throwable e) { /* 052 */ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( /* 053 */ "SimpleSuite$$Lambda$839/2056566350", "int", "array<int>", e); /* 054 */ } /* 055 */ /* 056 */ boolean project_isNull_2 = project_result_0 == null; /* 057 */ ArrayData project_value_2 = null; /* 058 */ if (!project_isNull_2) { /* 059 */ project_value_2 = project_result_0; /* 060 */ } /* 061 */ /* 062 */ generate_doConsume_0(project_expr_0_0, project_value_2, project_isNull_2); /* 063 */ /* 064 */ } /* 065 */ /* 110 */ protected void processNext() throws java.io.IOException { /* 111 */ // initialize Range /* 112 */ if (!range_initRange_0) { /* 113 */ range_initRange_0 = true; /* 114 */ initRange(partitionIndex); /* 115 */ } /* 116 */ /* 117 */ while (true) { /* 118 */ if (range_nextIndex_0 == range_batchEnd_0) { /* 119 */ long range_nextBatchTodo_0; /* 120 */ if (range_numElementsTodo_0 > 1000L) { /* 121 */ range_nextBatchTodo_0 = 1000L; /* 122 */ range_numElementsTodo_0 -= 1000L; /* 123 */ } else { /* 124 */ range_nextBatchTodo_0 = range_numElementsTodo_0; /* 125 */ range_numElementsTodo_0 = 0; /* 126 */ if (range_nextBatchTodo_0 == 0) break; /* 127 */ } /* 128 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L; /* 129 */ } /* 130 */ /* 131 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L); /* 132 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) { /* 133 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0; /* 134 */ /* 135 */ do { /* 136 */ // common subexpressions /* 137 */ /* 138 */ boolean filter_isNull_1 = false; /* 139 */ int filter_value_1 = -1; /* 140 */ if (!false) { /* 141 */ filter_value_1 = (int) range_value_0; /* 142 */ } /* 143 */ /* 144 */ Integer filter_conv_0 = filter_value_1; /* 145 */ Object filter_arg_0 = filter_isNull_1 ? null : filter_conv_0; /* 146 */ /* 147 */ ArrayData filter_result_0 = null; /* 148 */ try { /* 149 */ filter_result_0 = (ArrayData)((scala.Function1[]) references[2] /* converters */)[1].apply(((scala.Function1) references[3] /* udf */).apply(filter_arg_0)); /* 150 */ } catch (Throwable e) { /* 151 */ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( /* 152 */ "SimpleSuite$$Lambda$839/2056566350", "int", "array<int>", e); /* 153 */ } /* 154 */ /* 155 */ boolean filter_isNull_0 = filter_result_0 == null; /* 156 */ ArrayData filter_value_0 = null; /* 157 */ if (!filter_isNull_0) { /* 158 */ filter_value_0 = filter_result_0; /* 159 */ } /* 160 */ /* 161 */ // end of common subexpressions /* 162 */ /* 163 */ // RequiredVariables /* 164 */ /* 165 */ // end of RequiredVariables /* 166 */ // ev code /* 167 */ boolean filter_isNull_4 = false; /* 168 */ /* 169 */ int filter_value_4 = filter_isNull_0 ? -1 : /* 170 */ (filter_value_0).numElements(); /* 171 */ /* 172 */ boolean filter_value_3 = false; /* 173 */ filter_value_3 = filter_value_4 > 0; /* 174 */ // end of ev code /* 175 */ if (!filter_value_3) continue; /* 176 */ // common subexpressions /* 177 */ /* 178 */ // end of common subexpressions /* 179 */ /* 180 */ // RequiredVariables /* 181 */ /* 182 */ // end of RequiredVariables /* 183 */ // ev code /* 184 */ boolean filter_value_7 = !filter_isNull_0; /* 185 */ // end of ev code /* 186 */ if (!filter_value_7) continue; /* 187 */ /* 188 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* numOutputRows */).add(1); /* 189 */ /* 190 */ project_doConsume_0(range_value_0); /* 191 */ /* 192 */ } while(false); /* 193 */ /* 194 */ if (shouldStop()) { /* 195 */ range_nextIndex_0 = range_value_0 + 1L; /* 196 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1); /* 197 */ range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1); /* 198 */ return; /* 199 */ } /* 200 */ /* 201 */ } /* 202 */ range_nextIndex_0 = range_batchEnd_0; /* 203 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0); /* 204 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0); /* 205 */ range_taskContext_0.killTaskIfInterrupted(); /* 206 */ } /* 207 */ } /* 208 */ /* 284 */ } ``` From generate code, we found that the `fib` executes twice(L50 and L149). if `fib` only executes once in the `WholeStageCodegen`, it can improve performance when `fib` is an expensive UDF. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
