Github user sameeragarwal commented on the pull request:
https://github.com/apache/spark/pull/12345#issuecomment-209190135
Generate code for a query of the form `sqlContext.range(N).selectExpr("(id
& 65535) as k").groupBy("k").sum().collect()` looks like:
```java
/* 009 */ final class GeneratedIterator extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 010 */ private Object[] references;
/* 011 */ private boolean agg_initAgg;
/* 012 */ private agg_GeneratedAggregateHashMap agg_aggregateHashMap;
/* 013 */ private
java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>
agg_genMapIter;
/* 014 */ private
org.apache.spark.sql.execution.aggregate.TungstenAggregate agg_plan;
/* 015 */ private
org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 016 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter
agg_sorter;
/* 017 */ private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 018 */ private org.apache.spark.sql.execution.metric.LongSQLMetric
range_numOutputRows;
/* 019 */ private
org.apache.spark.sql.execution.metric.LongSQLMetricValue range_metricValue;
/* 020 */ private boolean range_initRange;
/* 021 */ private long range_partitionEnd;
/* 022 */ private long range_number;
/* 023 */ private boolean range_overflow;
/* 024 */ private scala.collection.Iterator range_input;
/* 025 */ private UnsafeRow range_result;
/* 026 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 027 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
range_rowWriter;
/* 028 */ private UnsafeRow project_result;
/* 029 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder;
/* 030 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
project_rowWriter;
/* 031 */ private UnsafeRow agg_result;
/* 032 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 033 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 034 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowJoiner
agg_unsafeRowJoiner;
/* 035 */ private org.apache.spark.sql.execution.metric.LongSQLMetric
wholestagecodegen_numOutputRows;
/* 036 */ private
org.apache.spark.sql.execution.metric.LongSQLMetricValue
wholestagecodegen_metricValue;
/* 037 */ private UnsafeRow wholestagecodegen_result;
/* 038 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder
wholestagecodegen_holder;
/* 039 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
wholestagecodegen_rowWriter;
/* 040 */
/* 041 */ public GeneratedIterator(Object[] references) {
/* 042 */ this.references = references;
/* 043 */ }
/* 044 */
/* 045 */ public void init(int index, scala.collection.Iterator inputs[])
{
/* 046 */ partitionIndex = index;
/* 047 */ agg_initAgg = false;
/* 048 */ agg_aggregateHashMap = new agg_GeneratedAggregateHashMap();
/* 049 */
/* 050 */ this.agg_plan =
(org.apache.spark.sql.execution.aggregate.TungstenAggregate) references[0];
/* 051 */
/* 052 */ this.range_numOutputRows =
(org.apache.spark.sql.execution.metric.LongSQLMetric) references[1];
/* 053 */ range_metricValue =
(org.apache.spark.sql.execution.metric.LongSQLMetricValue)
range_numOutputRows.localValue();
/* 054 */ range_initRange = false;
/* 055 */ range_partitionEnd = 0L;
/* 056 */ range_number = 0L;
/* 057 */ range_overflow = false;
/* 058 */ range_input = inputs[0];
/* 059 */ range_result = new UnsafeRow(1);
/* 060 */ this.range_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 061 */ this.range_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder,
1);
/* 062 */ project_result = new UnsafeRow(1);
/* 063 */ this.project_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result,
0);
/* 064 */ this.project_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder,
1);
/* 065 */ agg_result = new UnsafeRow(1);
/* 066 */ this.agg_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 067 */ this.agg_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder,
1);
/* 068 */ agg_unsafeRowJoiner = agg_plan.createUnsafeJoiner();
/* 069 */ this.wholestagecodegen_numOutputRows =
(org.apache.spark.sql.execution.metric.LongSQLMetric) references[2];
/* 070 */ wholestagecodegen_metricValue =
(org.apache.spark.sql.execution.metric.LongSQLMetricValue)
wholestagecodegen_numOutputRows.localValue();
/* 071 */ wholestagecodegen_result = new UnsafeRow(2);
/* 072 */ this.wholestagecodegen_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(wholestagecodegen_result,
0);
/* 073 */ this.wholestagecodegen_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(wholestagecodegen_holder,
2);
/* 074 */ }
/* 075 */
/* 076 */ public class agg_GeneratedAggregateHashMap {
/* 077 */ public
org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
/* 078 */ public
org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch;
/* 079 */ private int[] buckets;
/* 080 */ private int numBuckets;
/* 081 */ private int maxSteps;
/* 082 */ private int numRows = 0;
/* 083 */ private org.apache.spark.sql.types.StructType schema =
/* 084 */ new org.apache.spark.sql.types.StructType()
/* 085 */ .add("k", org.apache.spark.sql.types.DataTypes.LongType)
/* 086 */ .add("sum", org.apache.spark.sql.types.DataTypes.LongType);
/* 087 */
/* 088 */ private org.apache.spark.sql.types.StructType
aggregateBufferSchema =
/* 089 */
/* 090 */ new org.apache.spark.sql.types.StructType()
/* 091 */ .add("sum", org.apache.spark.sql.types.DataTypes.LongType);
/* 092 */
/* 093 */ public agg_GeneratedAggregateHashMap() {
/* 094 */ // TODO: These should be generated based on the schema
/* 095 */ int DEFAULT_CAPACITY = 1 << 16;
/* 096 */ double DEFAULT_LOAD_FACTOR = 0.25;
/* 097 */ int DEFAULT_MAX_STEPS = 2;
/* 098 */ assert (DEFAULT_CAPACITY > 0 && ((DEFAULT_CAPACITY &
(DEFAULT_CAPACITY - 1)) == 0));
/* 099 */ this.maxSteps = DEFAULT_MAX_STEPS;
/* 100 */ numBuckets = (int) (DEFAULT_CAPACITY / DEFAULT_LOAD_FACTOR);
/* 101 */ batch =
org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema,
/* 102 */ org.apache.spark.memory.MemoryMode.ON_HEAP,
DEFAULT_CAPACITY);
/* 103 */ aggregateBufferBatch =
org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(
/* 104 */ aggregateBufferSchema,
org.apache.spark.memory.MemoryMode.ON_HEAP, DEFAULT_CAPACITY);
/* 105 */ for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) {
/* 106 */ aggregateBufferBatch.setColumn(i, batch.column(i+1));
/* 107 */ }
/* 108 */ buckets = new int[numBuckets];
/* 109 */ java.util.Arrays.fill(buckets, -1);
/* 110 */ }
/* 111 */
/* 112 */ public
org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(long
agg_key) {
/* 113 */ long h = hash(agg_key);
/* 114 */ int step = 0;
/* 115 */ int idx = (int) h & (numBuckets - 1);
/* 116 */ while (step < maxSteps) {
/* 117 */ // Return bucket index if it's either an empty slot or
already contains the key
/* 118 */ if (buckets[idx] == -1) {
/* 119 */ batch.column(0).putLong(numRows, agg_key);
/* 120 */ batch.column(1).putLong(numRows, 0);
/* 121 */ buckets[idx] = numRows++;
/* 122 */ batch.setNumRows(numRows);
/* 123 */ return aggregateBufferBatch.getRow(buckets[idx]);
/* 124 */ } else if (equals(idx, agg_key)) {
/* 125 */ return aggregateBufferBatch.getRow(buckets[idx]);
/* 126 */ }
/* 127 */ idx = (idx + 1) & (numBuckets - 1);
/* 128 */ step++;
/* 129 */ }
/* 130 */ // Didn't find it
/* 131 */ return null;
/* 132 */ }
/* 133 */
/* 134 */ private boolean equals(int idx, long agg_key) {
/* 135 */ return batch.column(0).getLong(buckets[idx]) == agg_key;
/* 136 */ }
/* 137 */
/* 138 */ // TODO: Improve this hash function
/* 139 */ private long hash(long agg_key) {
/* 140 */ return agg_key;
/* 141 */ }
/* 142 */
/* 143 */ }
/* 144 */
/* 145 */ private void agg_doAggregateWithKeys() throws
java.io.IOException {
/* 146 */ agg_hashMap = agg_plan.createHashMap();
/* 147 */
/* 148 */ /*** PRODUCE: Project [(id#224L & 65535) AS k#227L] */
/* 149 */
/* 150 */ /*** PRODUCE: Range 0, 1, 1, 20971520, [id#224L] */
/* 151 */
/* 152 */ // initialize Range
/* 153 */ if (!range_initRange) {
/* 154 */ range_initRange = true;
/* 155 */ initRange(partitionIndex);
/* 156 */ }
/* 157 */
/* 158 */ while (!range_overflow && range_number < range_partitionEnd) {
/* 159 */ long range_value = range_number;
/* 160 */ range_number += 1L;
/* 161 */ if (range_number < range_value ^ 1L < 0) {
/* 162 */ range_overflow = true;
/* 163 */ }
/* 164 */
/* 165 */ /*** CONSUME: Project [(id#224L & 65535) AS k#227L] */
/* 166 */
/* 167 */ /*** CONSUME: TungstenAggregate(key=[k#227L],
functions=[(sum(k#227L),mode=Partial,isDistinct=false)],
output=[k#227L,sum#235L]) */
/* 168 */ /* (input[0, bigint] & 65535) */
/* 169 */ long project_value = -1L;
/* 170 */ project_value = range_value & 65535L;
/* 171 */
/* 172 */ UnsafeRow agg_aggBuffer = null;
/* 173 */ org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row
agg_aggregateRow = null;
/* 174 */
/* 175 */ agg_aggregateRow =
/* 176 */ agg_aggregateHashMap.findOrInsert(project_value);
/* 177 */
/* 178 */ if (agg_aggregateRow == null) {
/* 179 */ // generate grouping key
/* 180 */ agg_rowWriter.write(0, project_value);
/* 181 */ /* hash(input[0, bigint], 42) */
/* 182 */ int agg_value3 = 42;
/* 183 */
/* 184 */ agg_value3 =
org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(project_value, agg_value3);
/* 185 */ if (true) {
/* 186 */ // try to get the buffer from hash map
/* 187 */ agg_aggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value3);
/* 188 */ }
/* 189 */ if (agg_aggBuffer == null) {
/* 190 */ if (agg_sorter == null) {
/* 191 */ agg_sorter =
agg_hashMap.destructAndCreateExternalSorter();
/* 192 */ } else {
/* 193 */
agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 194 */ }
/* 195 */
/* 196 */ // the hash map had be spilled, it should have enough
memory now,
/* 197 */ // try to allocate buffer again.
/* 198 */ agg_aggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value3);
/* 199 */ if (agg_aggBuffer == null) {
/* 200 */ // failed to allocate the first page
/* 201 */ throw new OutOfMemoryError("No enough memory for
aggregation");
/* 202 */ }
/* 203 */ }
/* 204 */ }
/* 205 */
/* 206 */ if (agg_aggregateRow != null) {
/* 207 */ // evaluate aggregate function
/* 208 */ /* (coalesce(input[0, bigint], cast(0 as bigint)) +
cast(input[1, bigint] as bigint)) */
/* 209 */ /* coalesce(input[0, bigint], cast(0 as bigint)) */
/* 210 */ /* input[0, bigint] */
/* 211 */ boolean agg_isNull6 = agg_aggregateRow.isNullAt(0);
/* 212 */ long agg_value7 = agg_isNull6 ? -1L :
(agg_aggregateRow.getLong(0));
/* 213 */ boolean agg_isNull5 = agg_isNull6;
/* 214 */ long agg_value6 = agg_value7;
/* 215 */
/* 216 */ if (agg_isNull5) {
/* 217 */ /* cast(0 as bigint) */
/* 218 */ boolean agg_isNull7 = false;
/* 219 */ long agg_value8 = -1L;
/* 220 */ if (!false) {
/* 221 */ agg_value8 = (long) 0;
/* 222 */ }
/* 223 */ if (!agg_isNull7) {
/* 224 */ agg_isNull5 = false;
/* 225 */ agg_value6 = agg_value8;
/* 226 */ }
/* 227 */ }
/* 228 */ /* cast(input[1, bigint] as bigint) */
/* 229 */ boolean agg_isNull9 = false;
/* 230 */ long agg_value10 = -1L;
/* 231 */ if (!false) {
/* 232 */ agg_value10 = project_value;
/* 233 */ }
/* 234 */ long agg_value5 = -1L;
/* 235 */ agg_value5 = agg_value6 + agg_value10;
/* 236 */ // update aggregate row
/* 237 */ agg_aggregateRow.setLong(0, agg_value5);
/* 238 */ } else {
/* 239 */ // evaluate aggregate function
/* 240 */ /* (coalesce(input[0, bigint], cast(0 as bigint)) +
cast(input[1, bigint] as bigint)) */
/* 241 */ /* coalesce(input[0, bigint], cast(0 as bigint)) */
/* 242 */ /* input[0, bigint] */
/* 243 */ boolean agg_isNull13 = agg_aggBuffer.isNullAt(0);
/* 244 */ long agg_value14 = agg_isNull13 ? -1L :
(agg_aggBuffer.getLong(0));
/* 245 */ boolean agg_isNull12 = agg_isNull13;
/* 246 */ long agg_value13 = agg_value14;
/* 247 */
/* 248 */ if (agg_isNull12) {
/* 249 */ /* cast(0 as bigint) */
/* 250 */ boolean agg_isNull14 = false;
/* 251 */ long agg_value15 = -1L;
/* 252 */ if (!false) {
/* 253 */ agg_value15 = (long) 0;
/* 254 */ }
/* 255 */ if (!agg_isNull14) {
/* 256 */ agg_isNull12 = false;
/* 257 */ agg_value13 = agg_value15;
/* 258 */ }
/* 259 */ }
/* 260 */ /* cast(input[1, bigint] as bigint) */
/* 261 */ boolean agg_isNull16 = false;
/* 262 */ long agg_value17 = -1L;
/* 263 */ if (!false) {
/* 264 */ agg_value17 = project_value;
/* 265 */ }
/* 266 */ long agg_value12 = -1L;
/* 267 */ agg_value12 = agg_value13 + agg_value17;
/* 268 */ // update aggregate buffer
/* 269 */ agg_aggBuffer.setLong(0, agg_value12);
/* 270 */ }
/* 271 */
/* 272 */ if (shouldStop()) return;
/* 273 */ }
/* 274 */
/* 275 */ agg_genMapIter = agg_aggregateHashMap.batch.rowIterator();
/* 276 */
/* 277 */ agg_mapIter = agg_plan.finishAggregate(agg_hashMap,
agg_sorter);
/* 278 */ }
/* 279 */
/* 280 */ private void initRange(int idx) {
/* 281 */ java.math.BigInteger index =
java.math.BigInteger.valueOf(idx);
/* 282 */ java.math.BigInteger numSlice =
java.math.BigInteger.valueOf(1L);
/* 283 */ java.math.BigInteger numElement =
java.math.BigInteger.valueOf(20971520L);
/* 284 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 285 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 286 */
/* 287 */ java.math.BigInteger st =
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 288 */ if
(st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 289 */ range_number = Long.MAX_VALUE;
/* 290 */ } else if
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 291 */ range_number = Long.MIN_VALUE;
/* 292 */ } else {
/* 293 */ range_number = st.longValue();
/* 294 */ }
/* 295 */
/* 296 */ java.math.BigInteger end =
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 297 */ .multiply(step).add(start);
/* 298 */ if
(end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 299 */ range_partitionEnd = Long.MAX_VALUE;
/* 300 */ } else if
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 301 */ range_partitionEnd = Long.MIN_VALUE;
/* 302 */ } else {
/* 303 */ range_partitionEnd = end.longValue();
/* 304 */ }
/* 305 */
/* 306 */ range_metricValue.add((range_partitionEnd - range_number) /
1L);
/* 307 */ }
/* 308 */
/* 309 */ protected void processNext() throws java.io.IOException {
/* 310 */ /*** PRODUCE: TungstenAggregate(key=[k#227L],
functions=[(sum(k#227L),mode=Partial,isDistinct=false)],
output=[k#227L,sum#235L]) */
/* 311 */
/* 312 */ if (!agg_initAgg) {
/* 313 */ agg_initAgg = true;
/* 314 */ agg_doAggregateWithKeys();
/* 315 */ }
/* 316 */
/* 317 */ // output the result
/* 318 */
/* 319 */ while (agg_genMapIter.hasNext()) {
/* 320 */ wholestagecodegen_metricValue.add(1);
/* 321 */ org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row
wholestagecodegen_aggregateHashMapRow =
/* 322 */
(org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
/* 323 */ agg_genMapIter.next();
/* 324 */
/* 325 */ wholestagecodegen_rowWriter.zeroOutNullBytes();
/* 326 */
/* 327 */ /* input[0, bigint] */
/* 328 */ long wholestagecodegen_value =
wholestagecodegen_aggregateHashMapRow.getLong(0);
/* 329 */ wholestagecodegen_rowWriter.write(0,
wholestagecodegen_value);
/* 330 */
/* 331 */ /* input[1, bigint] */
/* 332 */ boolean wholestagecodegen_isNull1 =
wholestagecodegen_aggregateHashMapRow.isNullAt(1);
/* 333 */ long wholestagecodegen_value1 = wholestagecodegen_isNull1 ?
-1L : (wholestagecodegen_aggregateHashMapRow.getLong(1));
/* 334 */ if (wholestagecodegen_isNull1) {
/* 335 */ wholestagecodegen_rowWriter.setNullAt(1);
/* 336 */ } else {
/* 337 */ wholestagecodegen_rowWriter.write(1,
wholestagecodegen_value1);
/* 338 */ }
/* 339 */
/* 340 */ /*** CONSUME: WholeStageCodegen */
/* 341 */
/* 342 */ append(wholestagecodegen_result);
/* 343 */
/* 344 */ if (shouldStop()) return;
/* 345 */ }
/* 346 */
/* 347 */ agg_aggregateHashMap.batch.close();
/* 348 */
/* 349 */ while (agg_mapIter.next()) {
/* 350 */ wholestagecodegen_metricValue.add(1);
/* 351 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 352 */ UnsafeRow agg_aggBuffer1 = (UnsafeRow)
agg_mapIter.getValue();
/* 353 */
/* 354 */ UnsafeRow agg_resultRow =
agg_unsafeRowJoiner.join(agg_aggKey, agg_aggBuffer1);
/* 355 */
/* 356 */ /*** CONSUME: WholeStageCodegen */
/* 357 */
/* 358 */ append(agg_resultRow);
/* 359 */
/* 360 */ if (shouldStop()) return;
/* 361 */ }
/* 362 */
/* 363 */ agg_mapIter.close();
/* 364 */ if (agg_sorter == null) {
/* 365 */ agg_hashMap.free();
/* 366 */ }
/* 367 */ }
/* 368 */ }
```
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]