Repository: spark Updated Branches: refs/heads/master 501b71119 -> fcb68e0f5
[SPARK-19786][SQL] Facilitate loop optimizations in a JIT compiler regarding range() ## What changes were proposed in this pull request? This PR improves performance of operations with `range()` by changing Java code generated by Catalyst. This PR is inspired by the [blog article](https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html). This PR changes generated code in the following two points. 1. Replace a while-loop with long instance variables a for-loop with int local varibles 2. Suppress generation of `shouldStop()` method if this method is unnecessary (e.g. `append()` is not generated). These points facilitates compiler optimizations in a JIT compiler by feeding the simplified Java code into the JIT compiler. The performance is improved by 7.6x. Benchmark program: ```java val N = 1 << 29 val iters = 2 val benchmark = new Benchmark("range.count", N * iters) benchmark.addCase(s"with this PR") { i => var n = 0 var len = 0 while (n < iters) { len += sparkSession.range(N).selectExpr("count(id)").collect.length n += 1 } } benchmark.run ``` Performance result without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz range.count: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o this PR 1349 / 1356 796.2 1.3 1.0X ``` Performance result with this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz range.count: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ with this PR 177 / 271 6065.3 0.2 1.0X ``` Here is a comparison between generated code w/o and with this PR. Only the method ```agg_doAggregateWithoutKey``` is changed. Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows; /* 013 */ private boolean range_initRange; /* 014 */ private long range_number; /* 015 */ private TaskContext range_taskContext; /* 016 */ private InputMetrics range_inputMetrics; /* 017 */ private long range_batchEnd; /* 018 */ private long range_numElementsTodo; /* 019 */ private scala.collection.Iterator range_input; /* 020 */ private UnsafeRow range_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter; /* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 025 */ private UnsafeRow agg_result; /* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 028 */ /* 029 */ public GeneratedIterator(Object[] references) { /* 030 */ this.references = references; /* 031 */ } /* 032 */ /* 033 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 034 */ partitionIndex = index; /* 035 */ this.inputs = inputs; /* 036 */ agg_initAgg = false; /* 037 */ /* 038 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 039 */ this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 040 */ range_initRange = false; /* 041 */ range_number = 0L; /* 042 */ range_taskContext = TaskContext.get(); /* 043 */ range_inputMetrics = range_taskContext.taskMetrics().inputMetrics(); /* 044 */ range_batchEnd = 0; /* 045 */ range_numElementsTodo = 0L; /* 046 */ range_input = inputs[0]; /* 047 */ range_result = new UnsafeRow(1); /* 048 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0); /* 049 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1); /* 050 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 051 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 052 */ agg_result = new UnsafeRow(1); /* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 055 */ /* 056 */ } /* 057 */ /* 058 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 059 */ // initialize aggregation buffer /* 060 */ agg_bufIsNull = false; /* 061 */ agg_bufValue = 0L; /* 062 */ /* 063 */ // initialize Range /* 064 */ if (!range_initRange) { /* 065 */ range_initRange = true; /* 066 */ initRange(partitionIndex); /* 067 */ } /* 068 */ /* 069 */ while (true) { /* 070 */ while (range_number != range_batchEnd) { /* 071 */ long range_value = range_number; /* 072 */ range_number += 1L; /* 073 */ /* 074 */ // do aggregate /* 075 */ // common sub-expressions /* 076 */ /* 077 */ // evaluate aggregate function /* 078 */ boolean agg_isNull1 = false; /* 079 */ /* 080 */ long agg_value1 = -1L; /* 081 */ agg_value1 = agg_bufValue + 1L; /* 082 */ // update aggregation buffer /* 083 */ agg_bufIsNull = false; /* 084 */ agg_bufValue = agg_value1; /* 085 */ /* 086 */ if (shouldStop()) return; /* 087 */ } /* 088 */ /* 089 */ if (range_taskContext.isInterrupted()) { /* 090 */ throw new TaskKilledException(); /* 091 */ } /* 092 */ /* 093 */ long range_nextBatchTodo; /* 094 */ if (range_numElementsTodo > 1000L) { /* 095 */ range_nextBatchTodo = 1000L; /* 096 */ range_numElementsTodo -= 1000L; /* 097 */ } else { /* 098 */ range_nextBatchTodo = range_numElementsTodo; /* 099 */ range_numElementsTodo = 0; /* 100 */ if (range_nextBatchTodo == 0) break; /* 101 */ } /* 102 */ range_numOutputRows.add(range_nextBatchTodo); /* 103 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo); /* 104 */ /* 105 */ range_batchEnd += range_nextBatchTodo * 1L; /* 106 */ } /* 107 */ /* 108 */ } /* 109 */ /* 110 */ private void initRange(int idx) { /* 111 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 112 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 113 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L); /* 114 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 115 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 117 */ /* 118 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 119 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 120 */ range_number = Long.MAX_VALUE; /* 121 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 122 */ range_number = Long.MIN_VALUE; /* 123 */ } else { /* 124 */ range_number = st.longValue(); /* 125 */ } /* 126 */ range_batchEnd = range_number; /* 127 */ /* 128 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 129 */ .multiply(step).add(start); /* 130 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 131 */ partitionEnd = Long.MAX_VALUE; /* 132 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 133 */ partitionEnd = Long.MIN_VALUE; /* 134 */ } else { /* 135 */ partitionEnd = end.longValue(); /* 136 */ } /* 137 */ /* 138 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 139 */ java.math.BigInteger.valueOf(range_number)); /* 140 */ range_numElementsTodo = startToEnd.divide(step).longValue(); /* 141 */ if (range_numElementsTodo < 0) { /* 142 */ range_numElementsTodo = 0; /* 143 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 144 */ range_numElementsTodo++; /* 145 */ } /* 146 */ } /* 147 */ /* 148 */ protected void processNext() throws java.io.IOException { /* 149 */ while (!agg_initAgg) { /* 150 */ agg_initAgg = true; /* 151 */ long agg_beforeAgg = System.nanoTime(); /* 152 */ agg_doAggregateWithoutKey(); /* 153 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 154 */ /* 155 */ // output the result /* 156 */ /* 157 */ agg_numOutputRows.add(1); /* 158 */ agg_rowWriter.zeroOutNullBytes(); /* 159 */ /* 160 */ if (agg_bufIsNull) { /* 161 */ agg_rowWriter.setNullAt(0); /* 162 */ } else { /* 163 */ agg_rowWriter.write(0, agg_bufValue); /* 164 */ } /* 165 */ append(agg_result); /* 166 */ } /* 167 */ } /* 168 */ } ``` Generated code with this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows; /* 013 */ private boolean range_initRange; /* 014 */ private long range_number; /* 015 */ private TaskContext range_taskContext; /* 016 */ private InputMetrics range_inputMetrics; /* 017 */ private long range_batchEnd; /* 018 */ private long range_numElementsTodo; /* 019 */ private scala.collection.Iterator range_input; /* 020 */ private UnsafeRow range_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter; /* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 025 */ private UnsafeRow agg_result; /* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 028 */ /* 029 */ public GeneratedIterator(Object[] references) { /* 030 */ this.references = references; /* 031 */ } /* 032 */ /* 033 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 034 */ partitionIndex = index; /* 035 */ this.inputs = inputs; /* 036 */ agg_initAgg = false; /* 037 */ /* 038 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 039 */ this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 040 */ range_initRange = false; /* 041 */ range_number = 0L; /* 042 */ range_taskContext = TaskContext.get(); /* 043 */ range_inputMetrics = range_taskContext.taskMetrics().inputMetrics(); /* 044 */ range_batchEnd = 0; /* 045 */ range_numElementsTodo = 0L; /* 046 */ range_input = inputs[0]; /* 047 */ range_result = new UnsafeRow(1); /* 048 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0); /* 049 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1); /* 050 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 051 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 052 */ agg_result = new UnsafeRow(1); /* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 055 */ /* 056 */ } /* 057 */ /* 058 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 059 */ // initialize aggregation buffer /* 060 */ agg_bufIsNull = false; /* 061 */ agg_bufValue = 0L; /* 062 */ /* 063 */ // initialize Range /* 064 */ if (!range_initRange) { /* 065 */ range_initRange = true; /* 066 */ initRange(partitionIndex); /* 067 */ } /* 068 */ /* 069 */ while (true) { /* 070 */ long range_range = range_batchEnd - range_number; /* 071 */ if (range_range != 0L) { /* 072 */ int range_localEnd = (int)(range_range / 1L); /* 073 */ for (int range_localIdx = 0; range_localIdx < range_localEnd; range_localIdx++) { /* 074 */ long range_value = ((long)range_localIdx * 1L) + range_number; /* 075 */ /* 076 */ // do aggregate /* 077 */ // common sub-expressions /* 078 */ /* 079 */ // evaluate aggregate function /* 080 */ boolean agg_isNull1 = false; /* 081 */ /* 082 */ long agg_value1 = -1L; /* 083 */ agg_value1 = agg_bufValue + 1L; /* 084 */ // update aggregation buffer /* 085 */ agg_bufIsNull = false; /* 086 */ agg_bufValue = agg_value1; /* 087 */ /* 088 */ // shouldStop check is eliminated /* 089 */ } /* 090 */ range_number = range_batchEnd; /* 091 */ } /* 092 */ /* 093 */ if (range_taskContext.isInterrupted()) { /* 094 */ throw new TaskKilledException(); /* 095 */ } /* 096 */ /* 097 */ long range_nextBatchTodo; /* 098 */ if (range_numElementsTodo > 1000L) { /* 099 */ range_nextBatchTodo = 1000L; /* 100 */ range_numElementsTodo -= 1000L; /* 101 */ } else { /* 102 */ range_nextBatchTodo = range_numElementsTodo; /* 103 */ range_numElementsTodo = 0; /* 104 */ if (range_nextBatchTodo == 0) break; /* 105 */ } /* 106 */ range_numOutputRows.add(range_nextBatchTodo); /* 107 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo); /* 108 */ /* 109 */ range_batchEnd += range_nextBatchTodo * 1L; /* 110 */ } /* 111 */ /* 112 */ } /* 113 */ /* 114 */ private void initRange(int idx) { /* 115 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 116 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 117 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L); /* 118 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 119 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 120 */ long partitionEnd; /* 121 */ /* 122 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 123 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 124 */ range_number = Long.MAX_VALUE; /* 125 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 126 */ range_number = Long.MIN_VALUE; /* 127 */ } else { /* 128 */ range_number = st.longValue(); /* 129 */ } /* 130 */ range_batchEnd = range_number; /* 131 */ /* 132 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 133 */ .multiply(step).add(start); /* 134 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 135 */ partitionEnd = Long.MAX_VALUE; /* 136 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 137 */ partitionEnd = Long.MIN_VALUE; /* 138 */ } else { /* 139 */ partitionEnd = end.longValue(); /* 140 */ } /* 141 */ /* 142 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 143 */ java.math.BigInteger.valueOf(range_number)); /* 144 */ range_numElementsTodo = startToEnd.divide(step).longValue(); /* 145 */ if (range_numElementsTodo < 0) { /* 146 */ range_numElementsTodo = 0; /* 147 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 148 */ range_numElementsTodo++; /* 149 */ } /* 150 */ } /* 151 */ /* 152 */ protected void processNext() throws java.io.IOException { /* 153 */ while (!agg_initAgg) { /* 154 */ agg_initAgg = true; /* 155 */ long agg_beforeAgg = System.nanoTime(); /* 156 */ agg_doAggregateWithoutKey(); /* 157 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 158 */ /* 159 */ // output the result /* 160 */ /* 161 */ agg_numOutputRows.add(1); /* 162 */ agg_rowWriter.zeroOutNullBytes(); /* 163 */ /* 164 */ if (agg_bufIsNull) { /* 165 */ agg_rowWriter.setNullAt(0); /* 166 */ } else { /* 167 */ agg_rowWriter.write(0, agg_bufValue); /* 168 */ } /* 169 */ append(agg_result); /* 170 */ } /* 171 */ } /* 172 */ } ``` A part of suppressing `shouldStop()` was originally developed by inouehrs ## How was this patch tested? Add new tests into `DataFrameRangeSuite` Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Closes #17122 from kiszk/SPARK-19786. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fcb68e0f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fcb68e0f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fcb68e0f Branch: refs/heads/master Commit: fcb68e0f5d49234ac4527109887ff08cd4e1c29f Parents: 501b711 Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Authored: Fri Mar 10 18:04:37 2017 +0100 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Fri Mar 10 18:04:37 2017 +0100 ---------------------------------------------------------------------- .../apache/spark/sql/execution/SortExec.scala | 2 ++ .../sql/execution/WholeStageCodegenExec.scala | 15 +++++++++++ .../execution/aggregate/HashAggregateExec.scala | 2 ++ .../sql/execution/basicPhysicalOperators.scala | 27 +++++++++++++++----- .../apache/spark/sql/DataFrameRangeSuite.scala | 16 ++++++++++++ 5 files changed, 55 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fcb68e0f/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index cc576bb..f98ae82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -177,6 +177,8 @@ case class SortExec( """.stripMargin.trim } + protected override val shouldStopRequired = false + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { s""" |${row.code} http://git-wip-us.apache.org/repos/asf/spark/blob/fcb68e0f/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c58474e..c31fd92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -206,6 +206,21 @@ trait CodegenSupport extends SparkPlan { def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException } + + /** + * For optimization to suppress shouldStop() in a loop of WholeStageCodegen. + * Returning true means we need to insert shouldStop() into the loop producing rows, if any. + */ + def isShouldStopRequired: Boolean = { + return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) + } + + /** + * Set to false if this plan consumes all rows produced by children but doesn't output row + * to buffer by calling append(), so the children don't require shouldStop() + * in the loop of producing rows. + */ + protected def shouldStopRequired: Boolean = true } http://git-wip-us.apache.org/repos/asf/spark/blob/fcb68e0f/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed0..68c8e6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -238,6 +238,8 @@ case class HashAggregateExec( """.stripMargin } + protected override val shouldStopRequired = false + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) http://git-wip-us.apache.org/repos/asf/spark/blob/fcb68e0f/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 87e90ed..d876688 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -387,8 +387,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") - // The default size of a batch. - val batchSize = 1000L + // The default size of a batch, which must be positive integer + val batchSize = 1000 ctx.addNewFunction("initRange", s""" @@ -434,6 +434,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val input = ctx.freshName("input") // Right now, Range is only used when there is one upstream. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val range = ctx.freshName("range") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $number = $value + ${step}L; return; }" + } else { + "// shouldStop check is eliminated" + } s""" | // initialize Range | if (!$initTerm) { @@ -442,11 +451,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | while (true) { - | while ($number != $batchEnd) { - | long $value = $number; - | $number += ${step}L; - | ${consume(ctx, Seq(ev))} - | if (shouldStop()) return; + | long $range = $batchEnd - $number; + | if ($range != 0L) { + | int $localEnd = (int)($range / ${step}L); + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | long $value = ((long)$localIdx * ${step}L) + $number; + | ${consume(ctx, Seq(ev))} + | $shouldStop + | } + | $number = $batchEnd; | } | | if ($taskContext.isInterrupted()) { http://git-wip-us.apache.org/repos/asf/spark/blob/fcb68e0f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index acf393a..5e323c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -89,6 +89,22 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000 val res13 = spark.range(-n, n, n / 9).select("id") assert(res13.count == 18) + + // range with non aggregation operation + val res14 = spark.range(0, 100, 2).toDF.filter("50 <= id") + val len14 = res14.collect.length + assert(len14 == 25) + + val res15 = spark.range(100, -100, -2).toDF.filter("id <= 0") + val len15 = res15.collect.length + assert(len15 == 50) + + val res16 = spark.range(-1500, 1500, 3).toDF.filter("0 <= id") + val len16 = res16.collect.length + assert(len16 == 500) + + val res17 = spark.range(10, 0, -1, 1).toDF.sortWithinPartitions("id") + assert(res17.collect === (1 to 10).map(i => Row(i)).toArray) } test("Range with randomized parameters") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org