GitHub user kiszk opened a pull request:
https://github.com/apache/spark/pull/17122
[SPARK-19786][SQL] Facilitate loop optimizations in a JIT compiler
regarding range()
## What changes were proposed in this pull request?
This PR improves performance of operations with `range()` by changing Java
code generated by Catalyst. This PR is inspired by the [blog
article](https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html).
This PR changes generated code in the following two points.
1. Replace a while-loop with long instance variables a for-loop with int
local varibles
2. Suppress generation of `shouldStop()` method if this method is
unnecessary (e.g. `append()` is not generated).
These points facilitates compiler optimizations in a JIT compiler by
feeding the simplified Java code into the JIT compiler. The performance is
improved by 7.6x.
Benchmark program:
```java
val N = 1 << 29
val iters = 2
val benchmark = new Benchmark("range.count", N * iters)
benchmark.addCase(s"with this PR") { i =>
var n = 0
var len = 0
while (n < iters) {
len += sparkSession.range(N).selectExpr("count(id)").collect.length
n += 1
}
}
benchmark.run
```
Performance result without this PR
```
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux
4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
range.count: Best/Avg Time(ms) Rate(M/s)
Per Row(ns) Relative
------------------------------------------------------------------------------------------------
w/o this PR 1349 / 1356 796.2
1.3 1.0X
```
Performance result with this PR
```
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux
4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
range.count: Best/Avg Time(ms) Rate(M/s)
Per Row(ns) Relative
------------------------------------------------------------------------------------------------
with this PR 177 / 271 6065.3
0.2 1.0X
```
Here is a comparison between generated code w/o and with this PR. Only the
method ```agg_doAggregateWithoutKey``` is changed.
Generated code without this PR
```java
/* 005 */ final class GeneratedIterator extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private boolean agg_initAgg;
/* 009 */ private boolean agg_bufIsNull;
/* 010 */ private long agg_bufValue;
/* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric
range_numOutputRows;
/* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric
range_numGeneratedRows;
/* 013 */ private boolean range_initRange;
/* 014 */ private long range_number;
/* 015 */ private TaskContext range_taskContext;
/* 016 */ private InputMetrics range_inputMetrics;
/* 017 */ private long range_batchEnd;
/* 018 */ private long range_numElementsTodo;
/* 019 */ private scala.collection.Iterator range_input;
/* 020 */ private UnsafeRow range_result;
/* 021 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 022 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
range_rowWriter;
/* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric
agg_numOutputRows;
/* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric
agg_aggTime;
/* 025 */ private UnsafeRow agg_result;
/* 026 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 027 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 028 */
/* 029 */ public GeneratedIterator(Object[] references) {
/* 030 */ this.references = references;
/* 031 */ }
/* 032 */
/* 033 */ public void init(int index, scala.collection.Iterator[] inputs)
{
/* 034 */ partitionIndex = index;
/* 035 */ this.inputs = inputs;
/* 036 */ agg_initAgg = false;
/* 037 */
/* 038 */ this.range_numOutputRows =
(org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 039 */ this.range_numGeneratedRows =
(org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 040 */ range_initRange = false;
/* 041 */ range_number = 0L;
/* 042 */ range_taskContext = TaskContext.get();
/* 043 */ range_inputMetrics =
range_taskContext.taskMetrics().inputMetrics();
/* 044 */ range_batchEnd = 0;
/* 045 */ range_numElementsTodo = 0L;
/* 046 */ range_input = inputs[0];
/* 047 */ range_result = new UnsafeRow(1);
/* 048 */ this.range_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 049 */ this.range_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder,
1);
/* 050 */ this.agg_numOutputRows =
(org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 051 */ this.agg_aggTime =
(org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 052 */ agg_result = new UnsafeRow(1);
/* 053 */ this.agg_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 054 */ this.agg_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder,
1);
/* 055 */
/* 056 */ }
/* 057 */
/* 058 */ private void agg_doAggregateWithoutKey() throws
java.io.IOException {
/* 059 */ // initialize aggregation buffer
/* 060 */ agg_bufIsNull = false;
/* 061 */ agg_bufValue = 0L;
/* 062 */
/* 063 */ // initialize Range
/* 064 */ if (!range_initRange) {
/* 065 */ range_initRange = true;
/* 066 */ initRange(partitionIndex);
/* 067 */ }
/* 068 */
/* 069 */ while (true) {
/* 070 */ while (range_number != range_batchEnd) {
/* 071 */ long range_value = range_number;
/* 072 */ range_number += 1L;
/* 073 */
/* 074 */ // do aggregate
/* 075 */ // common sub-expressions
/* 076 */
/* 077 */ // evaluate aggregate function
/* 078 */ boolean agg_isNull1 = false;
/* 079 */
/* 080 */ long agg_value1 = -1L;
/* 081 */ agg_value1 = agg_bufValue + 1L;
/* 082 */ // update aggregation buffer
/* 083 */ agg_bufIsNull = false;
/* 084 */ agg_bufValue = agg_value1;
/* 085 */
/* 086 */ if (shouldStop()) return;
/* 087 */ }
/* 088 */
/* 089 */ if (range_taskContext.isInterrupted()) {
/* 090 */ throw new TaskKilledException();
/* 091 */ }
/* 092 */
/* 093 */ long range_nextBatchTodo;
/* 094 */ if (range_numElementsTodo > 1000L) {
/* 095 */ range_nextBatchTodo = 1000L;
/* 096 */ range_numElementsTodo -= 1000L;
/* 097 */ } else {
/* 098 */ range_nextBatchTodo = range_numElementsTodo;
/* 099 */ range_numElementsTodo = 0;
/* 100 */ if (range_nextBatchTodo == 0) break;
/* 101 */ }
/* 102 */ range_numOutputRows.add(range_nextBatchTodo);
/* 103 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 104 */
/* 105 */ range_batchEnd += range_nextBatchTodo * 1L;
/* 106 */ }
/* 107 */
/* 108 */ }
/* 109 */
/* 110 */ private void initRange(int idx) {
/* 111 */ java.math.BigInteger index =
java.math.BigInteger.valueOf(idx);
/* 112 */ java.math.BigInteger numSlice =
java.math.BigInteger.valueOf(2L);
/* 113 */ java.math.BigInteger numElement =
java.math.BigInteger.valueOf(10000L);
/* 114 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 115 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 117 */
/* 118 */ java.math.BigInteger st =
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 119 */ if
(st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 120 */ range_number = Long.MAX_VALUE;
/* 121 */ } else if
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 122 */ range_number = Long.MIN_VALUE;
/* 123 */ } else {
/* 124 */ range_number = st.longValue();
/* 125 */ }
/* 126 */ range_batchEnd = range_number;
/* 127 */
/* 128 */ java.math.BigInteger end =
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 129 */ .multiply(step).add(start);
/* 130 */ if
(end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 131 */ partitionEnd = Long.MAX_VALUE;
/* 132 */ } else if
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 133 */ partitionEnd = Long.MIN_VALUE;
/* 134 */ } else {
/* 135 */ partitionEnd = end.longValue();
/* 136 */ }
/* 137 */
/* 138 */ java.math.BigInteger startToEnd =
java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 139 */ java.math.BigInteger.valueOf(range_number));
/* 140 */ range_numElementsTodo = startToEnd.divide(step).longValue();
/* 141 */ if (range_numElementsTodo < 0) {
/* 142 */ range_numElementsTodo = 0;
/* 143 */ } else if
(startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 144 */ range_numElementsTodo++;
/* 145 */ }
/* 146 */ }
/* 147 */
/* 148 */ protected void processNext() throws java.io.IOException {
/* 149 */ while (!agg_initAgg) {
/* 150 */ agg_initAgg = true;
/* 151 */ long agg_beforeAgg = System.nanoTime();
/* 152 */ agg_doAggregateWithoutKey();
/* 153 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) /
1000000);
/* 154 */
/* 155 */ // output the result
/* 156 */
/* 157 */ agg_numOutputRows.add(1);
/* 158 */ agg_rowWriter.zeroOutNullBytes();
/* 159 */
/* 160 */ if (agg_bufIsNull) {
/* 161 */ agg_rowWriter.setNullAt(0);
/* 162 */ } else {
/* 163 */ agg_rowWriter.write(0, agg_bufValue);
/* 164 */ }
/* 165 */ append(agg_result);
/* 166 */ }
/* 167 */ }
/* 168 */ }
```
Generated code with this PR
```java
/* 005 */ final class GeneratedIterator extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private boolean agg_initAgg;
/* 009 */ private boolean agg_bufIsNull;
/* 010 */ private long agg_bufValue;
/* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric
range_numOutputRows;
/* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric
range_numGeneratedRows;
/* 013 */ private boolean range_initRange;
/* 014 */ private long range_number;
/* 015 */ private TaskContext range_taskContext;
/* 016 */ private InputMetrics range_inputMetrics;
/* 017 */ private long range_batchEnd;
/* 018 */ private long range_numElementsTodo;
/* 019 */ private scala.collection.Iterator range_input;
/* 020 */ private UnsafeRow range_result;
/* 021 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 022 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
range_rowWriter;
/* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric
agg_numOutputRows;
/* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric
agg_aggTime;
/* 025 */ private UnsafeRow agg_result;
/* 026 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 027 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 028 */
/* 029 */ public GeneratedIterator(Object[] references) {
/* 030 */ this.references = references;
/* 031 */ }
/* 032 */
/* 033 */ public void init(int index, scala.collection.Iterator[] inputs)
{
/* 034 */ partitionIndex = index;
/* 035 */ this.inputs = inputs;
/* 036 */ agg_initAgg = false;
/* 037 */
/* 038 */ this.range_numOutputRows =
(org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 039 */ this.range_numGeneratedRows =
(org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 040 */ range_initRange = false;
/* 041 */ range_number = 0L;
/* 042 */ range_taskContext = TaskContext.get();
/* 043 */ range_inputMetrics =
range_taskContext.taskMetrics().inputMetrics();
/* 044 */ range_batchEnd = 0;
/* 045 */ range_numElementsTodo = 0L;
/* 046 */ range_input = inputs[0];
/* 047 */ range_result = new UnsafeRow(1);
/* 048 */ this.range_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 049 */ this.range_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder,
1);
/* 050 */ this.agg_numOutputRows =
(org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 051 */ this.agg_aggTime =
(org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 052 */ agg_result = new UnsafeRow(1);
/* 053 */ this.agg_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 054 */ this.agg_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder,
1);
/* 055 */
/* 056 */ }
/* 057 */
/* 058 */ private void agg_doAggregateWithoutKey() throws
java.io.IOException {
/* 059 */ // initialize aggregation buffer
/* 060 */ agg_bufIsNull = false;
/* 061 */ agg_bufValue = 0L;
/* 062 */
/* 063 */ // initialize Range
/* 064 */ if (!range_initRange) {
/* 065 */ range_initRange = true;
/* 066 */ initRange(partitionIndex);
/* 067 */ }
/* 068 */
/* 069 */ while (true) {
/* 070 */ long range_range = range_batchEnd - range_number;
/* 071 */ if (range_range != 0L) {
/* 072 */ int range_localEnd = (int)(range_range / 1L);
/* 073 */ for (int range_localIdx = 0; range_localIdx <
range_localEnd; range_localIdx++) {
/* 074 */ long range_value = ((long)range_localIdx * 1L) +
range_number;
/* 075 */
/* 076 */ // do aggregate
/* 077 */ // common sub-expressions
/* 078 */
/* 079 */ // evaluate aggregate function
/* 080 */ boolean agg_isNull1 = false;
/* 081 */
/* 082 */ long agg_value1 = -1L;
/* 083 */ agg_value1 = agg_bufValue + 1L;
/* 084 */ // update aggregation buffer
/* 085 */ agg_bufIsNull = false;
/* 086 */ agg_bufValue = agg_value1;
/* 087 */
/* 088 */ // shouldStop check is eliminated
/* 089 */ }
/* 090 */ range_number = range_batchEnd;
/* 091 */ }
/* 092 */
/* 093 */ if (range_taskContext.isInterrupted()) {
/* 094 */ throw new TaskKilledException();
/* 095 */ }
/* 096 */
/* 097 */ long range_nextBatchTodo;
/* 098 */ if (range_numElementsTodo > 1000L) {
/* 099 */ range_nextBatchTodo = 1000L;
/* 100 */ range_numElementsTodo -= 1000L;
/* 101 */ } else {
/* 102 */ range_nextBatchTodo = range_numElementsTodo;
/* 103 */ range_numElementsTodo = 0;
/* 104 */ if (range_nextBatchTodo == 0) break;
/* 105 */ }
/* 106 */ range_numOutputRows.add(range_nextBatchTodo);
/* 107 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 108 */
/* 109 */ range_batchEnd += range_nextBatchTodo * 1L;
/* 110 */ }
/* 111 */
/* 112 */ }
/* 113 */
/* 114 */ private void initRange(int idx) {
/* 115 */ java.math.BigInteger index =
java.math.BigInteger.valueOf(idx);
/* 116 */ java.math.BigInteger numSlice =
java.math.BigInteger.valueOf(2L);
/* 117 */ java.math.BigInteger numElement =
java.math.BigInteger.valueOf(10000L);
/* 118 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 119 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 120 */ long partitionEnd;
/* 121 */
/* 122 */ java.math.BigInteger st =
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 123 */ if
(st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 124 */ range_number = Long.MAX_VALUE;
/* 125 */ } else if
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 126 */ range_number = Long.MIN_VALUE;
/* 127 */ } else {
/* 128 */ range_number = st.longValue();
/* 129 */ }
/* 130 */ range_batchEnd = range_number;
/* 131 */
/* 132 */ java.math.BigInteger end =
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 133 */ .multiply(step).add(start);
/* 134 */ if
(end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 135 */ partitionEnd = Long.MAX_VALUE;
/* 136 */ } else if
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 137 */ partitionEnd = Long.MIN_VALUE;
/* 138 */ } else {
/* 139 */ partitionEnd = end.longValue();
/* 140 */ }
/* 141 */
/* 142 */ java.math.BigInteger startToEnd =
java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 143 */ java.math.BigInteger.valueOf(range_number));
/* 144 */ range_numElementsTodo = startToEnd.divide(step).longValue();
/* 145 */ if (range_numElementsTodo < 0) {
/* 146 */ range_numElementsTodo = 0;
/* 147 */ } else if
(startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 148 */ range_numElementsTodo++;
/* 149 */ }
/* 150 */ }
/* 151 */
/* 152 */ protected void processNext() throws java.io.IOException {
/* 153 */ while (!agg_initAgg) {
/* 154 */ agg_initAgg = true;
/* 155 */ long agg_beforeAgg = System.nanoTime();
/* 156 */ agg_doAggregateWithoutKey();
/* 157 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) /
1000000);
/* 158 */
/* 159 */ // output the result
/* 160 */
/* 161 */ agg_numOutputRows.add(1);
/* 162 */ agg_rowWriter.zeroOutNullBytes();
/* 163 */
/* 164 */ if (agg_bufIsNull) {
/* 165 */ agg_rowWriter.setNullAt(0);
/* 166 */ } else {
/* 167 */ agg_rowWriter.write(0, agg_bufValue);
/* 168 */ }
/* 169 */ append(agg_result);
/* 170 */ }
/* 171 */ }
/* 172 */ }
```
A part of suppressing `shouldStop()` was originally developed by @inouehrs
## How was this patch tested?
Add new tests into `DataFrameRangeSuite`
You can merge this pull request into a Git repository by running:
$ git pull https://github.com/kiszk/spark SPARK-19786
Alternatively you can review and apply these changes as the patch at:
https://github.com/apache/spark/pull/17122.patch
To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:
This closes #17122
----
commit 47f405c32ffac9b0356050c0d6bbb8c0ea5e0f51
Author: Kazuaki Ishizaki <[email protected]>
Date: 2017-03-01T14:01:43Z
initial commit
----
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]