sarutak opened a new pull request #30859:
URL: https://github.com/apache/spark/pull/30859
### What changes were proposed in this pull request?
<!--
Please clarify what changes you are proposing. The purpose of this section
is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR. See the examples below.
1. If you refactor some codes with changing classes, showing the class
hierarchy will help reviewers.
2. If you fix some SQL features, you can provide some references of other
DBMSes.
3. If there is design documentation, please add the link.
4. If there is a discussion in the mailing list, please add the link.
-->
This PR fixes an issue that `EXPLAIN CODEGEN` doesn't show the corresponding
code for subqueries.
```
spark.conf.set("spark.sql.adaptive.enabled", "false")
val df = spark.range(1, 100)
df.createTempView("df")
spark.sql("SELECT (SELECT min(id) AS v FROM df)").explain("CODEGEN")
scala> spark.sql("SELECT (SELECT min(id) AS v FROM df)").explain("CODEGEN")
Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 (maxMethodCodeSize:55; maxConstantPoolSize:97(0.15% used);
numInnerClasses:0) ==
*(1) Project [Subquery scalar-subquery#3, [id=#24] AS scalarsubquery()#5L]
: +- Subquery scalar-subquery#3, [id=#24]
: +- *(2) HashAggregate(keys=[], functions=[min(id#0L)], output=[v#2L])
: +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#20]
: +- *(1) HashAggregate(keys=[], functions=[partial_min(id#0L)],
output=[min#8L])
: +- *(1) Range (1, 100, step=1, splits=12)
+- *(1) Scan OneRowRelation[]
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private scala.collection.Iterator rdd_input_0;
/* 010 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[]
project_mutableStateArray_0 = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
/* 011 */
/* 012 */ public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 013 */ this.references = references;
/* 014 */ }
/* 015 */
/* 016 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 017 */ partitionIndex = index;
/* 018 */ this.inputs = inputs;
/* 019 */ rdd_input_0 = inputs[0];
/* 020 */ project_mutableStateArray_0[0] = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 021 */
/* 022 */ }
/* 023 */
/* 024 */ private void project_doConsume_0() throws java.io.IOException {
/* 025 */ // common sub-expressions
/* 026 */
/* 027 */ project_mutableStateArray_0[0].reset();
/* 028 */
/* 029 */ if (false) {
/* 030 */ project_mutableStateArray_0[0].setNullAt(0);
/* 031 */ } else {
/* 032 */ project_mutableStateArray_0[0].write(0, 1L);
/* 033 */ }
/* 034 */ append((project_mutableStateArray_0[0].getRow()));
/* 035 */
/* 036 */ }
/* 037 */
/* 038 */ protected void processNext() throws java.io.IOException {
/* 039 */ while ( rdd_input_0.hasNext()) {
/* 040 */ InternalRow rdd_row_0 = (InternalRow) rdd_input_0.next();
/* 041 */ ((org.apache.spark.sql.execution.metric.SQLMetric)
references[0] /* numOutputRows */).add(1);
/* 042 */ project_doConsume_0();
/* 043 */ if (shouldStop()) return;
/* 044 */ }
/* 045 */ }
/* 046 */
/* 047 */ }
```
After this change, the corresponding code for subqueries are shown.
```
Found 3 WholeStageCodegen subtrees.
== Subtree 1 / 3 (maxMethodCodeSize:282; maxConstantPoolSize:206(0.31%
used); numInnerClasses:0) ==
*(1) HashAggregate(keys=[], functions=[partial_min(id#0L)], output=[min#8L])
+- *(1) Range (1, 100, step=1, splits=12)
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean agg_initAgg_0;
/* 010 */ private boolean agg_bufIsNull_0;
/* 011 */ private long agg_bufValue_0;
/* 012 */ private boolean range_initRange_0;
/* 013 */ private long range_nextIndex_0;
/* 014 */ private TaskContext range_taskContext_0;
/* 015 */ private InputMetrics range_inputMetrics_0;
/* 016 */ private long range_batchEnd_0;
/* 017 */ private long range_numElementsTodo_0;
/* 018 */ private boolean agg_agg_isNull_2_0;
/* 019 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[]
range_mutableStateArray_0 = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 020 */
/* 021 */ public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 022 */ this.references = references;
/* 023 */ }
/* 024 */
/* 025 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 026 */ partitionIndex = index;
/* 027 */ this.inputs = inputs;
/* 028 */
/* 029 */ range_taskContext_0 = TaskContext.get();
/* 030 */ range_inputMetrics_0 =
range_taskContext_0.taskMetrics().inputMetrics();
/* 031 */ range_mutableStateArray_0[0] = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 032 */ range_mutableStateArray_0[1] = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 033 */ range_mutableStateArray_0[2] = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 034 */
/* 035 */ }
/* 036 */
/* 037 */ private void agg_doAggregateWithoutKey_0() throws
java.io.IOException {
/* 038 */ // initialize aggregation buffer
/* 039 */ agg_bufIsNull_0 = true;
/* 040 */ agg_bufValue_0 = -1L;
/* 041 */
/* 042 */ // initialize Range
/* 043 */ if (!range_initRange_0) {
/* 044 */ range_initRange_0 = true;
/* 045 */ initRange(partitionIndex);
/* 046 */ }
/* 047 */
/* 048 */ while (true) {
/* 049 */ if (range_nextIndex_0 == range_batchEnd_0) {
/* 050 */ long range_nextBatchTodo_0;
/* 051 */ if (range_numElementsTodo_0 > 1000L) {
/* 052 */ range_nextBatchTodo_0 = 1000L;
/* 053 */ range_numElementsTodo_0 -= 1000L;
/* 054 */ } else {
/* 055 */ range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 056 */ range_numElementsTodo_0 = 0;
/* 057 */ if (range_nextBatchTodo_0 == 0) break;
/* 058 */ }
/* 059 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 060 */ }
/* 061 */
/* 062 */ int range_localEnd_0 = (int)((range_batchEnd_0 -
range_nextIndex_0) / 1L);
/* 063 */ for (int range_localIdx_0 = 0; range_localIdx_0 <
range_localEnd_0; range_localIdx_0++) {
/* 064 */ long range_value_0 = ((long)range_localIdx_0 * 1L) +
range_nextIndex_0;
/* 065 */
/* 066 */ agg_doConsume_0(range_value_0);
/* 067 */
/* 068 */ // shouldStop check is eliminated
/* 069 */ }
/* 070 */ range_nextIndex_0 = range_batchEnd_0;
/* 071 */ ((org.apache.spark.sql.execution.metric.SQLMetric)
references[0] /* numOutputRows */).add(range_localEnd_0);
/* 072 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 073 */ range_taskContext_0.killTaskIfInterrupted();
/* 074 */ }
/* 075 */
/* 076 */ }
/* 077 */
/* 078 */ private void initRange(int idx) {
/* 079 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 080 */ java.math.BigInteger numSlice =
java.math.BigInteger.valueOf(12L);
/* 081 */ java.math.BigInteger numElement =
java.math.BigInteger.valueOf(99L);
/* 082 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 083 */ java.math.BigInteger start = java.math.BigInteger.valueOf(1L);
/* 084 */ long partitionEnd;
/* 085 */
/* 086 */ java.math.BigInteger st =
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 087 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE))
> 0) {
/* 088 */ range_nextIndex_0 = Long.MAX_VALUE;
/* 089 */ } else if
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 090 */ range_nextIndex_0 = Long.MIN_VALUE;
/* 091 */ } else {
/* 092 */ range_nextIndex_0 = st.longValue();
/* 093 */ }
/* 094 */ range_batchEnd_0 = range_nextIndex_0;
/* 095 */
/* 096 */ java.math.BigInteger end =
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 097 */ .multiply(step).add(start);
/* 098 */ if
(end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 099 */ partitionEnd = Long.MAX_VALUE;
/* 100 */ } else if
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 101 */ partitionEnd = Long.MIN_VALUE;
/* 102 */ } else {
/* 103 */ partitionEnd = end.longValue();
/* 104 */ }
/* 105 */
/* 106 */ java.math.BigInteger startToEnd =
java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 107 */ java.math.BigInteger.valueOf(range_nextIndex_0));
/* 108 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue();
/* 109 */ if (range_numElementsTodo_0 < 0) {
/* 110 */ range_numElementsTodo_0 = 0;
/* 111 */ } else if
(startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 112 */ range_numElementsTodo_0++;
/* 113 */ }
/* 114 */ }
/* 115 */
/* 116 */ private void agg_doConsume_0(long agg_expr_0_0) throws
java.io.IOException {
/* 117 */ // do aggregate
/* 118 */ // common sub-expressions
/* 119 */
/* 120 */ // evaluate aggregate functions and update aggregation buffers
/* 121 */
/* 122 */ agg_agg_isNull_2_0 = true;
/* 123 */ long agg_value_2 = -1L;
/* 124 */
/* 125 */ if (!agg_bufIsNull_0 && (agg_agg_isNull_2_0 ||
/* 126 */ agg_value_2 > agg_bufValue_0)) {
/* 127 */ agg_agg_isNull_2_0 = false;
/* 128 */ agg_value_2 = agg_bufValue_0;
/* 129 */ }
/* 130 */
/* 131 */ if (!false && (agg_agg_isNull_2_0 ||
/* 132 */ agg_value_2 > agg_expr_0_0)) {
/* 133 */ agg_agg_isNull_2_0 = false;
/* 134 */ agg_value_2 = agg_expr_0_0;
/* 135 */ }
/* 136 */
/* 137 */ agg_bufIsNull_0 = agg_agg_isNull_2_0;
/* 138 */ agg_bufValue_0 = agg_value_2;
/* 139 */
/* 140 */ }
/* 141 */
/* 142 */ protected void processNext() throws java.io.IOException {
/* 143 */ while (!agg_initAgg_0) {
/* 144 */ agg_initAgg_0 = true;
/* 145 */ long agg_beforeAgg_0 = System.nanoTime();
/* 146 */ agg_doAggregateWithoutKey_0();
/* 147 */ ((org.apache.spark.sql.execution.metric.SQLMetric)
references[2] /* aggTime */).add((System.nanoTime() - agg_beforeAgg_0) /
1000000);
/* 148 */
/* 149 */ // output the result
/* 150 */
/* 151 */ ((org.apache.spark.sql.execution.metric.SQLMetric)
references[1] /* numOutputRows */).add(1);
/* 152 */ range_mutableStateArray_0[2].reset();
/* 153 */
/* 154 */ range_mutableStateArray_0[2].zeroOutNullBytes();
/* 155 */
/* 156 */ if (agg_bufIsNull_0) {
/* 157 */ range_mutableStateArray_0[2].setNullAt(0);
/* 158 */ } else {
/* 159 */ range_mutableStateArray_0[2].write(0, agg_bufValue_0);
/* 160 */ }
/* 161 */ append((range_mutableStateArray_0[2].getRow()));
/* 162 */ }
/* 163 */ }
/* 164 */
/* 165 */ }
```
### Why are the changes needed?
<!--
Please clarify why the changes are needed. For instance,
1. If you propose a new API, clarify the use case for a new API.
2. If you fix a bug, you can clarify why it is a bug.
-->
For better debuggability.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such as
the documentation fix.
If yes, please clarify the previous behavior and the change this PR proposes
- provide the console output, description and/or an example to show the
behavior difference if possible.
If possible, please also clarify if this is a user-facing change compared to
the released Spark versions or within the unreleased branches such as master.
If no, write 'No'.
-->
Yes. After this change, users can see subquery code by `EXPLAIN CODEGEN`.
### How was this patch tested?
<!--
If tests were added, say they were added here. Please make sure to add some
test cases that check the changes thoroughly including negative and positive
cases if possible.
If it was tested in a way different from regular unit tests, please clarify
how you tested step by step, ideally copy and paste-able, so that other
reviewers can test and check, and descendants can verify in the future.
If tests were not added, please describe why they were not added and/or why
it was difficult to add.
-->
New test.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]