GitHub user kiszk opened a pull request:

    https://github.com/apache/spark/pull/17122

    [SPARK-19786][SQL] Facilitate loop optimizations in a JIT compiler 
regarding range()

    ## What changes were proposed in this pull request?
    
    This PR improves performance of operations with `range()` by changing Java 
code generated by Catalyst. This PR is inspired by the [blog 
article](https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html).
    
    This PR changes generated code in the following two points.
    1. Replace a while-loop with long instance variables a for-loop with int 
local varibles
    2. Suppress generation of `shouldStop()` method if this method is 
unnecessary (e.g. `append()` is not generated).
    
    These points facilitates compiler optimizations in a JIT compiler by 
feeding the simplified Java code into the JIT compiler. The performance is 
improved by 7.6x.
    
    Benchmark program:
    ```java
    val N = 1 << 29
    val iters = 2
    val benchmark = new Benchmark("range.count", N * iters)
    benchmark.addCase(s"with this PR") { i =>
      var n = 0
      var len = 0
      while (n < iters) {
        len += sparkSession.range(N).selectExpr("count(id)").collect.length
        n += 1
      }
    }
    benchmark.run
    ```
    
    Performance result without this PR
    ```
    OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 
4.4.0-47-generic
    Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
    range.count:                             Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------
    w/o this PR                                   1349 / 1356        796.2      
     1.3       1.0X
    ```
    
    Performance result with this PR
    ```
    OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 
4.4.0-47-generic
    Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
    range.count:                             Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------
    with this PR                                   177 /  271       6065.3      
     0.2       1.0X
    ```
    
    Here is a comparison between generated code w/o and with this PR. Only the 
method ```agg_doAggregateWithoutKey``` is changed.
    
    Generated code without this PR
    ```java
    
    /* 005 */ final class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
    /* 006 */   private Object[] references;
    /* 007 */   private scala.collection.Iterator[] inputs;
    /* 008 */   private boolean agg_initAgg;
    /* 009 */   private boolean agg_bufIsNull;
    /* 010 */   private long agg_bufValue;
    /* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric 
range_numOutputRows;
    /* 012 */   private org.apache.spark.sql.execution.metric.SQLMetric 
range_numGeneratedRows;
    /* 013 */   private boolean range_initRange;
    /* 014 */   private long range_number;
    /* 015 */   private TaskContext range_taskContext;
    /* 016 */   private InputMetrics range_inputMetrics;
    /* 017 */   private long range_batchEnd;
    /* 018 */   private long range_numElementsTodo;
    /* 019 */   private scala.collection.Iterator range_input;
    /* 020 */   private UnsafeRow range_result;
    /* 021 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
    /* 022 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
range_rowWriter;
    /* 023 */   private org.apache.spark.sql.execution.metric.SQLMetric 
agg_numOutputRows;
    /* 024 */   private org.apache.spark.sql.execution.metric.SQLMetric 
agg_aggTime;
    /* 025 */   private UnsafeRow agg_result;
    /* 026 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
    /* 027 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
    /* 028 */
    /* 029 */   public GeneratedIterator(Object[] references) {
    /* 030 */     this.references = references;
    /* 031 */   }
    /* 032 */
    /* 033 */   public void init(int index, scala.collection.Iterator[] inputs) 
{
    /* 034 */     partitionIndex = index;
    /* 035 */     this.inputs = inputs;
    /* 036 */     agg_initAgg = false;
    /* 037 */
    /* 038 */     this.range_numOutputRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[0];
    /* 039 */     this.range_numGeneratedRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[1];
    /* 040 */     range_initRange = false;
    /* 041 */     range_number = 0L;
    /* 042 */     range_taskContext = TaskContext.get();
    /* 043 */     range_inputMetrics = 
range_taskContext.taskMetrics().inputMetrics();
    /* 044 */     range_batchEnd = 0;
    /* 045 */     range_numElementsTodo = 0L;
    /* 046 */     range_input = inputs[0];
    /* 047 */     range_result = new UnsafeRow(1);
    /* 048 */     this.range_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
    /* 049 */     this.range_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 
1);
    /* 050 */     this.agg_numOutputRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[2];
    /* 051 */     this.agg_aggTime = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[3];
    /* 052 */     agg_result = new UnsafeRow(1);
    /* 053 */     this.agg_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
    /* 054 */     this.agg_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 
1);
    /* 055 */
    /* 056 */   }
    /* 057 */
    /* 058 */   private void agg_doAggregateWithoutKey() throws 
java.io.IOException {
    /* 059 */     // initialize aggregation buffer
    /* 060 */     agg_bufIsNull = false;
    /* 061 */     agg_bufValue = 0L;
    /* 062 */
    /* 063 */     // initialize Range
    /* 064 */     if (!range_initRange) {
    /* 065 */       range_initRange = true;
    /* 066 */       initRange(partitionIndex);
    /* 067 */     }
    /* 068 */
    /* 069 */     while (true) {
    /* 070 */       while (range_number != range_batchEnd) {
    /* 071 */         long range_value = range_number;
    /* 072 */         range_number += 1L;
    /* 073 */
    /* 074 */         // do aggregate
    /* 075 */         // common sub-expressions
    /* 076 */
    /* 077 */         // evaluate aggregate function
    /* 078 */         boolean agg_isNull1 = false;
    /* 079 */
    /* 080 */         long agg_value1 = -1L;
    /* 081 */         agg_value1 = agg_bufValue + 1L;
    /* 082 */         // update aggregation buffer
    /* 083 */         agg_bufIsNull = false;
    /* 084 */         agg_bufValue = agg_value1;
    /* 085 */
    /* 086 */         if (shouldStop()) return;
    /* 087 */       }
    /* 088 */
    /* 089 */       if (range_taskContext.isInterrupted()) {
    /* 090 */         throw new TaskKilledException();
    /* 091 */       }
    /* 092 */
    /* 093 */       long range_nextBatchTodo;
    /* 094 */       if (range_numElementsTodo > 1000L) {
    /* 095 */         range_nextBatchTodo = 1000L;
    /* 096 */         range_numElementsTodo -= 1000L;
    /* 097 */       } else {
    /* 098 */         range_nextBatchTodo = range_numElementsTodo;
    /* 099 */         range_numElementsTodo = 0;
    /* 100 */         if (range_nextBatchTodo == 0) break;
    /* 101 */       }
    /* 102 */       range_numOutputRows.add(range_nextBatchTodo);
    /* 103 */       range_inputMetrics.incRecordsRead(range_nextBatchTodo);
    /* 104 */
    /* 105 */       range_batchEnd += range_nextBatchTodo * 1L;
    /* 106 */     }
    /* 107 */
    /* 108 */   }
    /* 109 */
    /* 110 */   private void initRange(int idx) {
    /* 111 */     java.math.BigInteger index = 
java.math.BigInteger.valueOf(idx);
    /* 112 */     java.math.BigInteger numSlice = 
java.math.BigInteger.valueOf(2L);
    /* 113 */     java.math.BigInteger numElement = 
java.math.BigInteger.valueOf(10000L);
    /* 114 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
    /* 115 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
    /* 117 */
    /* 118 */     java.math.BigInteger st = 
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
    /* 119 */     if 
(st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
    /* 120 */       range_number = Long.MAX_VALUE;
    /* 121 */     } else if 
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
    /* 122 */       range_number = Long.MIN_VALUE;
    /* 123 */     } else {
    /* 124 */       range_number = st.longValue();
    /* 125 */     }
    /* 126 */     range_batchEnd = range_number;
    /* 127 */
    /* 128 */     java.math.BigInteger end = 
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
    /* 129 */     .multiply(step).add(start);
    /* 130 */     if 
(end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
    /* 131 */       partitionEnd = Long.MAX_VALUE;
    /* 132 */     } else if 
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
    /* 133 */       partitionEnd = Long.MIN_VALUE;
    /* 134 */     } else {
    /* 135 */       partitionEnd = end.longValue();
    /* 136 */     }
    /* 137 */
    /* 138 */     java.math.BigInteger startToEnd = 
java.math.BigInteger.valueOf(partitionEnd).subtract(
    /* 139 */       java.math.BigInteger.valueOf(range_number));
    /* 140 */     range_numElementsTodo  = startToEnd.divide(step).longValue();
    /* 141 */     if (range_numElementsTodo < 0) {
    /* 142 */       range_numElementsTodo = 0;
    /* 143 */     } else if 
(startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
    /* 144 */       range_numElementsTodo++;
    /* 145 */     }
    /* 146 */   }
    /* 147 */
    /* 148 */   protected void processNext() throws java.io.IOException {
    /* 149 */     while (!agg_initAgg) {
    /* 150 */       agg_initAgg = true;
    /* 151 */       long agg_beforeAgg = System.nanoTime();
    /* 152 */       agg_doAggregateWithoutKey();
    /* 153 */       agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 
1000000);
    /* 154 */
    /* 155 */       // output the result
    /* 156 */
    /* 157 */       agg_numOutputRows.add(1);
    /* 158 */       agg_rowWriter.zeroOutNullBytes();
    /* 159 */
    /* 160 */       if (agg_bufIsNull) {
    /* 161 */         agg_rowWriter.setNullAt(0);
    /* 162 */       } else {
    /* 163 */         agg_rowWriter.write(0, agg_bufValue);
    /* 164 */       }
    /* 165 */       append(agg_result);
    /* 166 */     }
    /* 167 */   }
    /* 168 */ }
    ```
    
    Generated code with this PR
    ```java
    /* 005 */ final class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
    /* 006 */   private Object[] references;
    /* 007 */   private scala.collection.Iterator[] inputs;
    /* 008 */   private boolean agg_initAgg;
    /* 009 */   private boolean agg_bufIsNull;
    /* 010 */   private long agg_bufValue;
    /* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric 
range_numOutputRows;
    /* 012 */   private org.apache.spark.sql.execution.metric.SQLMetric 
range_numGeneratedRows;
    /* 013 */   private boolean range_initRange;
    /* 014 */   private long range_number;
    /* 015 */   private TaskContext range_taskContext;
    /* 016 */   private InputMetrics range_inputMetrics;
    /* 017 */   private long range_batchEnd;
    /* 018 */   private long range_numElementsTodo;
    /* 019 */   private scala.collection.Iterator range_input;
    /* 020 */   private UnsafeRow range_result;
    /* 021 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
    /* 022 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
range_rowWriter;
    /* 023 */   private org.apache.spark.sql.execution.metric.SQLMetric 
agg_numOutputRows;
    /* 024 */   private org.apache.spark.sql.execution.metric.SQLMetric 
agg_aggTime;
    /* 025 */   private UnsafeRow agg_result;
    /* 026 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
    /* 027 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
    /* 028 */
    /* 029 */   public GeneratedIterator(Object[] references) {
    /* 030 */     this.references = references;
    /* 031 */   }
    /* 032 */
    /* 033 */   public void init(int index, scala.collection.Iterator[] inputs) 
{
    /* 034 */     partitionIndex = index;
    /* 035 */     this.inputs = inputs;
    /* 036 */     agg_initAgg = false;
    /* 037 */
    /* 038 */     this.range_numOutputRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[0];
    /* 039 */     this.range_numGeneratedRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[1];
    /* 040 */     range_initRange = false;
    /* 041 */     range_number = 0L;
    /* 042 */     range_taskContext = TaskContext.get();
    /* 043 */     range_inputMetrics = 
range_taskContext.taskMetrics().inputMetrics();
    /* 044 */     range_batchEnd = 0;
    /* 045 */     range_numElementsTodo = 0L;
    /* 046 */     range_input = inputs[0];
    /* 047 */     range_result = new UnsafeRow(1);
    /* 048 */     this.range_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
    /* 049 */     this.range_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 
1);
    /* 050 */     this.agg_numOutputRows = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[2];
    /* 051 */     this.agg_aggTime = 
(org.apache.spark.sql.execution.metric.SQLMetric) references[3];
    /* 052 */     agg_result = new UnsafeRow(1);
    /* 053 */     this.agg_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
    /* 054 */     this.agg_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 
1);
    /* 055 */
    /* 056 */   }
    /* 057 */
    /* 058 */   private void agg_doAggregateWithoutKey() throws 
java.io.IOException {
    /* 059 */     // initialize aggregation buffer
    /* 060 */     agg_bufIsNull = false;
    /* 061 */     agg_bufValue = 0L;
    /* 062 */
    /* 063 */     // initialize Range
    /* 064 */     if (!range_initRange) {
    /* 065 */       range_initRange = true;
    /* 066 */       initRange(partitionIndex);
    /* 067 */     }
    /* 068 */
    /* 069 */     while (true) {
    /* 070 */       long range_range = range_batchEnd - range_number;
    /* 071 */       if (range_range != 0L) {
    /* 072 */         int range_localEnd = (int)(range_range / 1L);
    /* 073 */         for (int range_localIdx = 0; range_localIdx < 
range_localEnd; range_localIdx++) {
    /* 074 */           long range_value = ((long)range_localIdx * 1L) + 
range_number;
    /* 075 */
    /* 076 */           // do aggregate
    /* 077 */           // common sub-expressions
    /* 078 */
    /* 079 */           // evaluate aggregate function
    /* 080 */           boolean agg_isNull1 = false;
    /* 081 */
    /* 082 */           long agg_value1 = -1L;
    /* 083 */           agg_value1 = agg_bufValue + 1L;
    /* 084 */           // update aggregation buffer
    /* 085 */           agg_bufIsNull = false;
    /* 086 */           agg_bufValue = agg_value1;
    /* 087 */
    /* 088 */           // shouldStop check is eliminated
    /* 089 */         }
    /* 090 */         range_number = range_batchEnd;
    /* 091 */       }
    /* 092 */
    /* 093 */       if (range_taskContext.isInterrupted()) {
    /* 094 */         throw new TaskKilledException();
    /* 095 */       }
    /* 096 */
    /* 097 */       long range_nextBatchTodo;
    /* 098 */       if (range_numElementsTodo > 1000L) {
    /* 099 */         range_nextBatchTodo = 1000L;
    /* 100 */         range_numElementsTodo -= 1000L;
    /* 101 */       } else {
    /* 102 */         range_nextBatchTodo = range_numElementsTodo;
    /* 103 */         range_numElementsTodo = 0;
    /* 104 */         if (range_nextBatchTodo == 0) break;
    /* 105 */       }
    /* 106 */       range_numOutputRows.add(range_nextBatchTodo);
    /* 107 */       range_inputMetrics.incRecordsRead(range_nextBatchTodo);
    /* 108 */
    /* 109 */       range_batchEnd += range_nextBatchTodo * 1L;
    /* 110 */     }
    /* 111 */
    /* 112 */   }
    /* 113 */
    /* 114 */   private void initRange(int idx) {
    /* 115 */     java.math.BigInteger index = 
java.math.BigInteger.valueOf(idx);
    /* 116 */     java.math.BigInteger numSlice = 
java.math.BigInteger.valueOf(2L);
    /* 117 */     java.math.BigInteger numElement = 
java.math.BigInteger.valueOf(10000L);
    /* 118 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
    /* 119 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
    /* 120 */     long partitionEnd;
    /* 121 */
    /* 122 */     java.math.BigInteger st = 
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
    /* 123 */     if 
(st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
    /* 124 */       range_number = Long.MAX_VALUE;
    /* 125 */     } else if 
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
    /* 126 */       range_number = Long.MIN_VALUE;
    /* 127 */     } else {
    /* 128 */       range_number = st.longValue();
    /* 129 */     }
    /* 130 */     range_batchEnd = range_number;
    /* 131 */
    /* 132 */     java.math.BigInteger end = 
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
    /* 133 */     .multiply(step).add(start);
    /* 134 */     if 
(end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
    /* 135 */       partitionEnd = Long.MAX_VALUE;
    /* 136 */     } else if 
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
    /* 137 */       partitionEnd = Long.MIN_VALUE;
    /* 138 */     } else {
    /* 139 */       partitionEnd = end.longValue();
    /* 140 */     }
    /* 141 */
    /* 142 */     java.math.BigInteger startToEnd = 
java.math.BigInteger.valueOf(partitionEnd).subtract(
    /* 143 */       java.math.BigInteger.valueOf(range_number));
    /* 144 */     range_numElementsTodo  = startToEnd.divide(step).longValue();
    /* 145 */     if (range_numElementsTodo < 0) {
    /* 146 */       range_numElementsTodo = 0;
    /* 147 */     } else if 
(startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
    /* 148 */       range_numElementsTodo++;
    /* 149 */     }
    /* 150 */   }
    /* 151 */
    /* 152 */   protected void processNext() throws java.io.IOException {
    /* 153 */     while (!agg_initAgg) {
    /* 154 */       agg_initAgg = true;
    /* 155 */       long agg_beforeAgg = System.nanoTime();
    /* 156 */       agg_doAggregateWithoutKey();
    /* 157 */       agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 
1000000);
    /* 158 */
    /* 159 */       // output the result
    /* 160 */
    /* 161 */       agg_numOutputRows.add(1);
    /* 162 */       agg_rowWriter.zeroOutNullBytes();
    /* 163 */
    /* 164 */       if (agg_bufIsNull) {
    /* 165 */         agg_rowWriter.setNullAt(0);
    /* 166 */       } else {
    /* 167 */         agg_rowWriter.write(0, agg_bufValue);
    /* 168 */       }
    /* 169 */       append(agg_result);
    /* 170 */     }
    /* 171 */   }
    /* 172 */ }
    ```
    
    A part of suppressing `shouldStop()` was originally developed by @inouehrs
    
    ## How was this patch tested?
    
    Add new tests into `DataFrameRangeSuite`

You can merge this pull request into a Git repository by running:

    $ git pull https://github.com/kiszk/spark SPARK-19786

Alternatively you can review and apply these changes as the patch at:

    https://github.com/apache/spark/pull/17122.patch

To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:

    This closes #17122
    
----
commit 47f405c32ffac9b0356050c0d6bbb8c0ea5e0f51
Author: Kazuaki Ishizaki <[email protected]>
Date:   2017-03-01T14:01:43Z

    initial commit

----


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to