Github user maropu commented on the issue:

    https://github.com/apache/spark/pull/17164
  
    ```
    import org.apache.spark.sql.execution.debug._
    spark.conf.set("spark.sql.aggregate.preferSortAggregate", "true")
    val df = spark.range(10).selectExpr("id % 2 AS key", "rand() AS value")
    df.groupBy().count.debugCodegen
    
    Found 2 WholeStageCodegen subtrees.
    == Subtree 1 / 2 ==
    *SortAggregate(key=[], functions=[partial_count(1)], output=[count#51L])
    +- *Project
       +- *Range (0, 10, step=1, splits=Some(4))
    
    Generated code:
    /* 001 */ public Object generate(Object[] references) {
    /* 002 */   return new GeneratedIterator(references);
    /* 003 */ }
    /* 004 */
    /* 005 */ final class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
    /* 006 */   private Object[] references;
    /* 007 */   private scala.collection.Iterator[] inputs;
    /* 008 */   private boolean sagg_initAgg;
    /* 009 */   private boolean sagg_bufIsNull;
    /* 010 */   private long sagg_bufValue;
    /* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric 
range_numOutputRows;
    /* 012 */   private org.apache.spark.sql.execution.metric.SQLMetric 
range_numGeneratedRows;
    /* 013 */   private boolean range_initRange;
    /* 014 */   private long range_number;
    /* 015 */   private TaskContext range_taskContext;
    /* 016 */   private InputMetrics range_inputMetrics;
    /* 017 */   private long range_batchEnd;
    /* 018 */   private long range_numElementsTodo;
    /* 019 */   private scala.collection.Iterator range_input;
    /* 020 */   private UnsafeRow range_result;
    /* 021 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
    /* 022 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
range_rowWriter;
    /* 023 */   private org.apache.spark.sql.execution.metric.SQLMetric 
sagg_numOutputRows;
    /* 024 */   private org.apache.spark.sql.execution.metric.SQLMetric 
sagg_aggTime;
    /* 025 */   private UnsafeRow sagg_result;
    /* 026 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
    /* 027 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
sagg_rowWriter;
    /* 028 */
    /* 029 */   public GeneratedIterator(Object[] references) {
    /* 030 */     this.references = references;
    /* 031 */   }
    /* 032 */
    /* 033 */   public void init(int index, scala.collection.Iterator[] inputs) 
{
    /* 034 */     partitionIndex = index;
    /* 035 */     this.inputs = inputs;
    /* 036 */     sagg_initAgg = false;
    /* 037 */
    /* 038 */     this.range_numOutputRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[0];
    /* 039 */     this.range_numGeneratedRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[1];
    /* 040 */     range_initRange = false;
    /* 041 */     range_number = 0L;
    /* 042 */     range_taskContext = TaskContext.get();
    /* 043 */     range_inputMetrics = 
range_taskContext.taskMetrics().inputMetrics();
    /* 044 */     range_batchEnd = 0;
    /* 045 */     range_numElementsTodo = 0L;
    /* 046 */     range_input = inputs[0];
    /* 047 */     range_result = new UnsafeRow(1);
    /* 048 */     this.range_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
    /* 049 */     this.range_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 
1);
    /* 050 */     this.sagg_numOutputRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[2];
    /* 051 */     this.sagg_aggTime = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[3];
    /* 052 */     sagg_result = new UnsafeRow(1);
    /* 053 */     this.sagg_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 0);
    /* 054 */     this.sagg_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 
1);
    /* 055 */
    /* 056 */   }
    /* 057 */
    /* 058 */   private void sagg_doAggregateWithoutKey() throws 
java.io.IOException {
    /* 059 */     // initialize aggregation buffer
    /* 060 */     sagg_bufIsNull = false;
    /* 061 */     sagg_bufValue = 0L;
    /* 062 */
    /* 063 */     // initialize Range
    /* 064 */     if (!range_initRange) {
    /* 065 */       range_initRange = true;
    /* 066 */       initRange(partitionIndex);
    /* 067 */     }
    /* 068 */
    /* 069 */     while (true) {
    /* 070 */       while (range_number != range_batchEnd) {
    /* 071 */         long range_value = range_number;
    /* 072 */         range_number += 1L;
    /* 073 */
    /* 074 */         // do aggregate
    /* 075 */         // common sub-expressions
    /* 076 */
    /* 077 */         // evaluate aggregate function
    /* 078 */         boolean sagg_isNull1 = false;
    /* 079 */
    /* 080 */         long sagg_value1 = -1L;
    /* 081 */         sagg_value1 = sagg_bufValue + 1L;
    /* 082 */         // update aggregation buffer
    /* 083 */         sagg_bufIsNull = false;
    /* 084 */         sagg_bufValue = sagg_value1;
    /* 085 */
    /* 086 */         if (shouldStop()) return;
    /* 087 */       }
    /* 088 */
    /* 089 */       if (range_taskContext.isInterrupted()) {
    /* 090 */         throw new TaskKilledException();
    /* 091 */       }
    /* 092 */
    /* 093 */       long range_nextBatchTodo;
    /* 094 */       if (range_numElementsTodo > 1000L) {
    /* 095 */         range_nextBatchTodo = 1000L;
    /* 096 */         range_numElementsTodo -= 1000L;
    /* 097 */       } else {
    /* 098 */         range_nextBatchTodo = range_numElementsTodo;
    /* 099 */         range_numElementsTodo = 0;
    /* 100 */         if (range_nextBatchTodo == 0) break;
    /* 101 */       }
    /* 102 */       range_numOutputRows.add(range_nextBatchTodo);
    /* 103 */       range_inputMetrics.incRecordsRead(range_nextBatchTodo);
    /* 104 */
    /* 105 */       range_batchEnd += range_nextBatchTodo * 1L;
    /* 106 */     }
    /* 107 */
    /* 108 */   }
    /* 109 */
    /* 110 */   private void initRange(int idx) {
    /* 111 */     java.math.BigInteger index = 
java.math.BigInteger.valueOf(idx);
    /* 112 */     java.math.BigInteger numSlice = 
java.math.BigInteger.valueOf(4L);
    /* 113 */     java.math.BigInteger numElement = 
java.math.BigInteger.valueOf(10L);
    /* 114 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
    /* 115 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
    /* 116 */     long partitionEnd;
    /* 117 */
    /* 118 */     java.math.BigInteger st = 
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
    /* 119 */     if 
(st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
    /* 120 */       range_number = Long.MAX_VALUE;
    /* 121 */     } else if 
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
    /* 122 */       range_number = Long.MIN_VALUE;
    /* 123 */     } else {
    /* 124 */       range_number = st.longValue();
    /* 125 */     }
    /* 126 */     range_batchEnd = range_number;
    /* 127 */
    /* 128 */     java.math.BigInteger end = 
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
    /* 129 */     .multiply(step).add(start);
    /* 130 */     if 
(end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
    /* 131 */       partitionEnd = Long.MAX_VALUE;
    /* 132 */     } else if 
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
    /* 133 */       partitionEnd = Long.MIN_VALUE;
    /* 134 */     } else {
    /* 135 */       partitionEnd = end.longValue();
    /* 136 */     }
    /* 137 */
    /* 138 */     java.math.BigInteger startToEnd = 
java.math.BigInteger.valueOf(partitionEnd).subtract(
    /* 139 */       java.math.BigInteger.valueOf(range_number));
    /* 140 */     range_numElementsTodo  = startToEnd.divide(step).longValue();
    /* 141 */     if (range_numElementsTodo < 0) {
    /* 142 */       range_numElementsTodo = 0;
    /* 143 */     } else if 
(startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
    /* 144 */       range_numElementsTodo++;
    /* 145 */     }
    /* 146 */   }
    /* 147 */
    /* 148 */   protected void processNext() throws java.io.IOException {
    /* 149 */     while (!sagg_initAgg) {
    /* 150 */       sagg_initAgg = true;
    /* 151 */       long sagg_beforeAgg = System.nanoTime();
    /* 152 */       sagg_doAggregateWithoutKey();
    /* 153 */       sagg_aggTime.add((System.nanoTime() - sagg_beforeAgg) / 
1000000);
    /* 154 */
    /* 155 */       // output the result
    /* 156 */
    /* 157 */       sagg_numOutputRows.add(1);
    /* 158 */       sagg_rowWriter.zeroOutNullBytes();
    /* 159 */
    /* 160 */       if (sagg_bufIsNull) {
    /* 161 */         sagg_rowWriter.setNullAt(0);
    /* 162 */       } else {
    /* 163 */         sagg_rowWriter.write(0, sagg_bufValue);
    /* 164 */       }
    /* 165 */       append(sagg_result);
    /* 166 */     }
    /* 167 */   }
    /* 168 */ }
    
    == Subtree 2 / 2 ==
    *SortAggregate(key=[], functions=[count(1)], output=[count#47L])
    +- Exchange SinglePartition
       +- *SortAggregate(key=[], functions=[partial_count(1)], 
output=[count#51L])
          +- *Project
             +- *Range (0, 10, step=1, splits=Some(4))
    
    Generated code:
    /* 001 */ public Object generate(Object[] references) {
    /* 002 */   return new GeneratedIterator(references);
    /* 003 */ }
    /* 004 */
    /* 005 */ final class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
    /* 006 */   private Object[] references;
    /* 007 */   private scala.collection.Iterator[] inputs;
    /* 008 */   private boolean sagg_initAgg;
    /* 009 */   private boolean sagg_bufIsNull;
    /* 010 */   private long sagg_bufValue;
    /* 011 */   private scala.collection.Iterator inputadapter_input;
    /* 012 */   private org.apache.spark.sql.execution.metric.SQLMetric 
sagg_numOutputRows;
    /* 013 */   private org.apache.spark.sql.execution.metric.SQLMetric 
sagg_aggTime;
    /* 014 */   private UnsafeRow sagg_result;
    /* 015 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
    /* 016 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
sagg_rowWriter;
    /* 017 */
    /* 018 */   public GeneratedIterator(Object[] references) {
    /* 019 */     this.references = references;
    /* 020 */   }
    /* 021 */
    /* 022 */   public void init(int index, scala.collection.Iterator[] inputs) 
{
    /* 023 */     partitionIndex = index;
    /* 024 */     this.inputs = inputs;
    /* 025 */     sagg_initAgg = false;
    /* 026 */
    /* 027 */     inputadapter_input = inputs[0];
    /* 028 */     this.sagg_numOutputRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[0];
    /* 029 */     this.sagg_aggTime = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[1];
    /* 030 */     sagg_result = new UnsafeRow(1);
    /* 031 */     this.sagg_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 0);
    /* 032 */     this.sagg_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 
1);
    /* 033 */
    /* 034 */   }
    /* 035 */
    /* 036 */   private void sagg_doAggregateWithoutKey() throws 
java.io.IOException {
    /* 037 */     // initialize aggregation buffer
    /* 038 */     sagg_bufIsNull = false;
    /* 039 */     sagg_bufValue = 0L;
    /* 040 */
    /* 041 */     while (inputadapter_input.hasNext() && !stopEarly()) {
    /* 042 */       InternalRow inputadapter_row = (InternalRow) 
inputadapter_input.next();
    /* 043 */       // do aggregate
    /* 044 */       // common sub-expressions
    /* 045 */
    /* 046 */       // evaluate aggregate function
    /* 047 */       boolean sagg_isNull3 = false;
    /* 048 */
    /* 049 */       long inputadapter_value = inputadapter_row.getLong(0);
    /* 050 */       long sagg_value3 = -1L;
    /* 051 */       sagg_value3 = sagg_bufValue + inputadapter_value;
    /* 052 */       // update aggregation buffer
    /* 053 */       sagg_bufIsNull = false;
    /* 054 */       sagg_bufValue = sagg_value3;
    /* 055 */       if (shouldStop()) return;
    /* 056 */     }
    /* 057 */
    /* 058 */   }
    /* 059 */
    /* 060 */   protected void processNext() throws java.io.IOException {
    /* 061 */     while (!sagg_initAgg) {
    /* 062 */       sagg_initAgg = true;
    /* 063 */       long sagg_beforeAgg = System.nanoTime();
    /* 064 */       sagg_doAggregateWithoutKey();
    /* 065 */       sagg_aggTime.add((System.nanoTime() - sagg_beforeAgg) / 
1000000);
    /* 066 */
    /* 067 */       // output the result
    /* 068 */
    /* 069 */       sagg_numOutputRows.add(1);
    /* 070 */       sagg_rowWriter.zeroOutNullBytes();
    /* 071 */
    /* 072 */       if (sagg_bufIsNull) {
    /* 073 */         sagg_rowWriter.setNullAt(0);
    /* 074 */       } else {
    /* 075 */         sagg_rowWriter.write(0, sagg_bufValue);
    /* 076 */       }
    /* 077 */       append(sagg_result);
    /* 078 */     }
    /* 079 */   }
    /* 080 */ }
    ```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to