Repository: spark
Updated Branches:
  refs/heads/master 46fe40838 -> e3133f4ab


[SPARK-25497][SQL] Limit operation within whole stage codegen should not 
consume all the inputs

## What changes were proposed in this pull request?

This PR is inspired by https://github.com/apache/spark/pull/22524, but proposes 
a safer fix.

The current limit whole stage codegen has 2 problems:
1. It's only applied to `InputAdapter`, many leaf nodes can't stop earlier 
w.r.t. limit.
2. It needs to override a method, which will break if we have more than one 
limit in the whole-stage.

The first problem is easy to fix, just figure out which nodes can stop earlier 
w.r.t. limit, and update them. This PR updates `RangeExec`, 
`ColumnarBatchScan`, `SortExec`, `HashAggregateExec`.

The second problem is hard to fix. This PR proposes to propagate the limit 
counter variable name upstream, so that the upstream leaf/blocking nodes can 
check the limit counter and quit the loop earlier.

For better performance, the implementation here follows 
`CodegenSupport.needStopCheck`, so that we only codegen the check only if there 
is limit in the query. For columnar node like range, we check the limit counter 
per-batch instead of per-row, to make the inner loop tight and fast.

Why this is safer?
1. the leaf/blocking nodes don't have to check the limit counter and stop 
earlier. It's only for performance. (this is same as before)
2. The blocking operators can stop propagating the limit counter name, because 
the counter of limit after blocking operators will never increase, before 
blocking operators consume all the data from upstream operators. So the 
upstream operators don't care about limit after blocking operators. This is 
also for performance only, it's OK if we forget to do it for some new blocking 
operators.

## How was this patch tested?

a new test

Closes #22630 from cloud-fan/limit.

Authored-by: Wenchen Fan <wenc...@databricks.com>
Signed-off-by: Kazuaki Ishizaki <ishiz...@jp.ibm.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e3133f4a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e3133f4a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e3133f4a

Branch: refs/heads/master
Commit: e3133f4abf1cd5667abe5f0d05fa0af0df3033ae
Parents: 46fe408
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Oct 9 16:46:23 2018 +0900
Committer: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Committed: Tue Oct 9 16:46:23 2018 +0900

----------------------------------------------------------------------
 .../sql/execution/BufferedRowIterator.java      |  10 --
 .../spark/sql/execution/ColumnarBatchScan.scala |   4 +-
 .../apache/spark/sql/execution/SortExec.scala   |  12 +-
 .../sql/execution/WholeStageCodegenExec.scala   |  59 +++++++++-
 .../execution/aggregate/HashAggregateExec.scala |  22 +---
 .../sql/execution/basicPhysicalOperators.scala  |  91 +++++++++------
 .../org/apache/spark/sql/execution/limit.scala  |  31 ++++--
 .../sql/execution/metric/SQLMetricsSuite.scala  | 111 ++++++++++++-------
 8 files changed, 215 insertions(+), 125 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e3133f4a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
index 74c9c05..3d0511b 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -74,16 +74,6 @@ public abstract class BufferedRowIterator {
   }
 
   /**
-   * Returns whether this iterator should stop fetching next row from 
[[CodegenSupport#inputRDDs]].
-   *
-   * If it returns true, the caller should exit the loop that [[InputAdapter]] 
generates.
-   * This interface is mainly used to limit the number of input rows.
-   */
-  public boolean stopEarly() {
-    return false;
-  }
-
-  /**
    * Returns whether `processNext()` should stop processing next row from 
`input` or not.
    *
    * If it returns true, the caller should exit the loop (return from 
processNext()).

http://git-wip-us.apache.org/repos/asf/spark/blob/e3133f4a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 48abad9..9f6b593 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -136,7 +136,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport 
{
        |if ($batch == null) {
        |  $nextBatchFuncName();
        |}
-       |while ($batch != null) {
+       |while ($limitNotReachedCond $batch != null) {
        |  int $numRows = $batch.numRows();
        |  int $localEnd = $numRows - $idx;
        |  for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
@@ -166,7 +166,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport 
{
     }
     val inputRow = if (needsUnsafeRowConversion) null else row
     s"""
-       |while ($input.hasNext()) {
+       |while ($limitNotReachedCond $input.hasNext()) {
        |  InternalRow $row = (InternalRow) $input.next();
        |  $numOutputRows.add(1);
        |  ${consume(ctx, outputVars, inputRow).trim}

http://git-wip-us.apache.org/repos/asf/spark/blob/e3133f4a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 0dc16ba..f1470e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -39,7 +39,7 @@ case class SortExec(
     global: Boolean,
     child: SparkPlan,
     testSpillFrequency: Int = 0)
-  extends UnaryExecNode with CodegenSupport {
+  extends UnaryExecNode with BlockingOperatorWithCodegen {
 
   override def output: Seq[Attribute] = child.output
 
@@ -124,14 +124,6 @@ case class SortExec(
   // Name of sorter variable used in codegen.
   private var sorterVariable: String = _
 
-  // The result rows come from the sort buffer, so this operator doesn't need 
to copy its result
-  // even if its child does.
-  override def needCopyResult: Boolean = false
-
-  // Sort operator always consumes all the input rows before outputting any 
result, so we don't need
-  // a stop check before sorting.
-  override def needStopCheck: Boolean = false
-
   override protected def doProduce(ctx: CodegenContext): String = {
     val needToSort =
       ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v 
= true;")
@@ -172,7 +164,7 @@ case class SortExec(
        |   $needToSort = false;
        | }
        |
-       | while ($sortedIterator.hasNext()) {
+       | while ($limitNotReachedCond $sortedIterator.hasNext()) {
        |   UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
        |   ${consume(ctx, null, outputRow)}
        |   if (shouldStop()) return;

http://git-wip-us.apache.org/repos/asf/spark/blob/e3133f4a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 1fc4de9..f5aee62 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -345,6 +345,61 @@ trait CodegenSupport extends SparkPlan {
    * don't require shouldStop() in the loop of producing rows.
    */
   def needStopCheck: Boolean = parent.needStopCheck
+
+  /**
+   * A sequence of checks which evaluate to true if the downstream Limit 
operators have not received
+   * enough records and reached the limit. If current node is a data producing 
node, it can leverage
+   * this information to stop producing data and complete the data flow 
earlier. Common data
+   * producing nodes are leaf nodes like Range and Scan, and blocking nodes 
like Sort and Aggregate.
+   * These checks should be put into the loop condition of the data producing 
loop.
+   */
+  def limitNotReachedChecks: Seq[String] = parent.limitNotReachedChecks
+
+  /**
+   * A helper method to generate the data producing loop condition according 
to the
+   * limit-not-reached checks.
+   */
+  final def limitNotReachedCond: String = {
+    // InputAdapter is also a leaf node.
+    val isLeafNode = children.isEmpty || this.isInstanceOf[InputAdapter]
+    if (!isLeafNode && !this.isInstanceOf[BlockingOperatorWithCodegen]) {
+      val errMsg = "Only leaf nodes and blocking nodes need to call 
'limitNotReachedCond' " +
+        "in its data producing loop."
+      if (Utils.isTesting) {
+        throw new IllegalStateException(errMsg)
+      } else {
+        logWarning(s"[BUG] $errMsg Please open a JIRA ticket to report it.")
+      }
+    }
+    if (parent.limitNotReachedChecks.isEmpty) {
+      ""
+    } else {
+      parent.limitNotReachedChecks.mkString("", " && ", " &&")
+    }
+  }
+}
+
+/**
+ * A special kind of operators which support whole stage codegen. Blocking 
means these operators
+ * will consume all the inputs first, before producing output. Typical 
blocking operators are
+ * sort and aggregate.
+ */
+trait BlockingOperatorWithCodegen extends CodegenSupport {
+
+  // Blocking operators usually have some kind of buffer to keep the data 
before producing them, so
+  // then don't to copy its result even if its child does.
+  override def needCopyResult: Boolean = false
+
+  // Blocking operators always consume all the input first, so its upstream 
operators don't need a
+  // stop check.
+  override def needStopCheck: Boolean = false
+
+  // Blocking operators need to consume all the inputs before producing any 
output. This means,
+  // Limit operator after this blocking operator will never reach its limit 
during the execution of
+  // this blocking operator's upstream operators. Here we override this method 
to return Nil, so
+  // that upstream operators will not generate useless conditions (which are 
always evaluated to
+  // false) for the Limit operators after this blocking operator.
+  override def limitNotReachedChecks: Seq[String] = Nil
 }
 
 
@@ -381,7 +436,7 @@ case class InputAdapter(child: SparkPlan) extends 
UnaryExecNode with CodegenSupp
       forceInline = true)
     val row = ctx.freshName("row")
     s"""
-       | while ($input.hasNext() && !stopEarly()) {
+       | while ($limitNotReachedCond $input.hasNext()) {
        |   InternalRow $row = (InternalRow) $input.next();
        |   ${consume(ctx, null, row).trim}
        |   if (shouldStop()) return;
@@ -677,6 +732,8 @@ case class WholeStageCodegenExec(child: SparkPlan)(val 
codegenStageId: Int)
 
   override def needStopCheck: Boolean = true
 
+  override def limitNotReachedChecks: Seq[String] = Nil
+
   override protected def otherCopyArgs: Seq[AnyRef] = 
Seq(codegenStageId.asInstanceOf[Integer])
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e3133f4a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 98adba5..6155ec9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -45,7 +45,7 @@ case class HashAggregateExec(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends UnaryExecNode with CodegenSupport {
+  extends UnaryExecNode with BlockingOperatorWithCodegen {
 
   private[this] val aggregateBufferAttributes = {
     aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -151,14 +151,6 @@ case class HashAggregateExec(
     child.asInstanceOf[CodegenSupport].inputRDDs()
   }
 
-  // The result rows come from the aggregate buffer, or a single row(no 
grouping keys), so this
-  // operator doesn't need to copy its result even if its child does.
-  override def needCopyResult: Boolean = false
-
-  // Aggregate operator always consumes all the input rows before outputting 
any result, so we
-  // don't need a stop check before aggregating.
-  override def needStopCheck: Boolean = false
-
   protected override def doProduce(ctx: CodegenContext): String = {
     if (groupingExpressions.isEmpty) {
       doProduceWithoutKeys(ctx)
@@ -705,13 +697,16 @@ case class HashAggregateExec(
 
     def outputFromRegularHashMap: String = {
       s"""
-         |while ($iterTerm.next()) {
+         |while ($limitNotReachedCond $iterTerm.next()) {
          |  UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
          |  UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
          |  $outputFunc($keyTerm, $bufferTerm);
-         |
          |  if (shouldStop()) return;
          |}
+         |$iterTerm.close();
+         |if ($sorterTerm == null) {
+         |  $hashMapTerm.free();
+         |}
        """.stripMargin
     }
 
@@ -728,11 +723,6 @@ case class HashAggregateExec(
      // output the result
      $outputFromFastHashMap
      $outputFromRegularHashMap
-
-     $iterTerm.close();
-     if ($sorterTerm == null) {
-       $hashMapTerm.free();
-     }
      """
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e3133f4a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 222a1b8..4cd2e78 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -378,7 +378,7 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
     val numOutput = metricTerm(ctx, "numOutputRows")
 
     val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
-    val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number")
+    val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")
 
     val value = ctx.freshName("value")
     val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
@@ -397,7 +397,7 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
     // within a batch, while the code in the outer loop is setting batch 
parameters and updating
     // the metrics.
 
-    // Once number == batchEnd, it's time to progress to the next batch.
+    // Once nextIndex == batchEnd, it's time to progress to the next batch.
     val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
 
     // How many values should still be generated by this range operator.
@@ -421,13 +421,13 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
         |
         |   $BigInt st = 
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
         |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
-        |     $number = Long.MAX_VALUE;
+        |     $nextIndex = Long.MAX_VALUE;
         |   } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
-        |     $number = Long.MIN_VALUE;
+        |     $nextIndex = Long.MIN_VALUE;
         |   } else {
-        |     $number = st.longValue();
+        |     $nextIndex = st.longValue();
         |   }
-        |   $batchEnd = $number;
+        |   $batchEnd = $nextIndex;
         |
         |   $BigInt end = 
index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
         |     .multiply(step).add(start);
@@ -440,7 +440,7 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
         |   }
         |
         |   $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
-        |     $BigInt.valueOf($number));
+        |     $BigInt.valueOf($nextIndex));
         |   $numElementsTodo  = startToEnd.divide(step).longValue();
         |   if ($numElementsTodo < 0) {
         |     $numElementsTodo = 0;
@@ -452,12 +452,42 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
 
     val localIdx = ctx.freshName("localIdx")
     val localEnd = ctx.freshName("localEnd")
-    val range = ctx.freshName("range")
     val shouldStop = if (parent.needStopCheck) {
-      s"if (shouldStop()) { $number = $value + ${step}L; return; }"
+      s"if (shouldStop()) { $nextIndex = $value + ${step}L; return; }"
     } else {
       "// shouldStop check is eliminated"
     }
+    val loopCondition = if (limitNotReachedChecks.isEmpty) {
+      "true"
+    } else {
+      limitNotReachedChecks.mkString(" && ")
+    }
+
+    // An overview of the Range processing.
+    //
+    // For each partition, the Range task needs to produce records from 
partition start(inclusive)
+    // to end(exclusive). For better performance, we separate the partition 
range into batches, and
+    // use 2 loops to produce data. The outer while loop is used to iterate 
batches, and the inner
+    // for loop is used to iterate records inside a batch.
+    //
+    // `nextIndex` tracks the index of the next record that is going to be 
consumed, initialized
+    // with partition start. `batchEnd` tracks the end index of the current 
batch, initialized
+    // with `nextIndex`. In the outer loop, we first check if `nextIndex == 
batchEnd`. If it's true,
+    // it means the current batch is fully consumed, and we will update 
`batchEnd` to process the
+    // next batch. If `batchEnd` reaches partition end, exit the outer loop. 
Finally we enter the
+    // inner loop. Note that, when we enter inner loop, `nextIndex` must be 
different from
+    // `batchEnd`, otherwise we already exit the outer loop.
+    //
+    // The inner loop iterates from 0 to `localEnd`, which is calculated by
+    // `(batchEnd - nextIndex) / step`. Since `batchEnd` is increased by 
`nextBatchTodo * step` in
+    // the outer loop, and initialized with `nextIndex`, so `batchEnd - 
nextIndex` is always
+    // divisible by `step`. The `nextIndex` is increased by `step` during each 
iteration, and ends
+    // up being equal to `batchEnd` when the inner loop finishes.
+    //
+    // The inner loop can be interrupted, if the query has produced at least 
one result row, so that
+    // we don't buffer too many result rows and waste memory. It's ok to 
interrupt the inner loop,
+    // because `nextIndex` will be updated before interrupting.
+
     s"""
       | // initialize Range
       | if (!$initTerm) {
@@ -465,33 +495,30 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
       |   $initRangeFuncName(partitionIndex);
       | }
       |
-      | while (true) {
-      |   long $range = $batchEnd - $number;
-      |   if ($range != 0L) {
-      |     int $localEnd = (int)($range / ${step}L);
-      |     for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
-      |       long $value = ((long)$localIdx * ${step}L) + $number;
-      |       ${consume(ctx, Seq(ev))}
-      |       $shouldStop
+      | while ($loopCondition) {
+      |   if ($nextIndex == $batchEnd) {
+      |     long $nextBatchTodo;
+      |     if ($numElementsTodo > ${batchSize}L) {
+      |       $nextBatchTodo = ${batchSize}L;
+      |       $numElementsTodo -= ${batchSize}L;
+      |     } else {
+      |       $nextBatchTodo = $numElementsTodo;
+      |       $numElementsTodo = 0;
+      |       if ($nextBatchTodo == 0) break;
       |     }
-      |     $number = $batchEnd;
+      |     $numOutput.add($nextBatchTodo);
+      |     $inputMetrics.incRecordsRead($nextBatchTodo);
+      |     $batchEnd += $nextBatchTodo * ${step}L;
       |   }
       |
-      |   $taskContext.killTaskIfInterrupted();
-      |
-      |   long $nextBatchTodo;
-      |   if ($numElementsTodo > ${batchSize}L) {
-      |     $nextBatchTodo = ${batchSize}L;
-      |     $numElementsTodo -= ${batchSize}L;
-      |   } else {
-      |     $nextBatchTodo = $numElementsTodo;
-      |     $numElementsTodo = 0;
-      |     if ($nextBatchTodo == 0) break;
+      |   int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L);
+      |   for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
+      |     long $value = ((long)$localIdx * ${step}L) + $nextIndex;
+      |     ${consume(ctx, Seq(ev))}
+      |     $shouldStop
       |   }
-      |   $numOutput.add($nextBatchTodo);
-      |   $inputMetrics.incRecordsRead($nextBatchTodo);
-      |
-      |   $batchEnd += $nextBatchTodo * ${step}L;
+      |   $nextIndex = $batchEnd;
+      |   $taskContext.killTaskIfInterrupted();
       | }
      """.stripMargin
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3133f4a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 66bcda8..9bfe1a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -46,6 +46,15 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) 
extends UnaryExecNode
   }
 }
 
+object BaseLimitExec {
+  private val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+  def newLimitCountTerm(): String = {
+    val id = curId.getAndIncrement()
+    s"_limit_counter_$id"
+  }
+}
+
 /**
  * Helper trait which defines methods that are shared by both
  * [[LocalLimitExec]] and [[GlobalLimitExec]].
@@ -66,27 +75,25 @@ trait BaseLimitExec extends UnaryExecNode with 
CodegenSupport {
   // to the parent operator.
   override def usedInputs: AttributeSet = AttributeSet.empty
 
+  private lazy val countTerm = BaseLimitExec.newLimitCountTerm()
+
+  override lazy val limitNotReachedChecks: Seq[String] = {
+    s"$countTerm < $limit" +: super.limitNotReachedChecks
+  }
+
   protected override def doProduce(ctx: CodegenContext): String = {
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    val stopEarly =
-      ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as 
stopEarly = false
-
-    ctx.addNewFunction("stopEarly", s"""
-      @Override
-      protected boolean stopEarly() {
-        return $stopEarly;
-      }
-    """, inlineToOuterClass = true)
-    val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // 
init as count = 0
+    // The counter name is already obtained by the upstream operators via 
`limitNotReachedChecks`.
+    // Here we have to inline it to not change its name. This is fine as we 
won't have many limit
+    // operators in one query.
+    ctx.addMutableState(CodeGenerator.JAVA_INT, countTerm, forceInline = true, 
useFreshName = false)
     s"""
        | if ($countTerm < $limit) {
        |   $countTerm += 1;
        |   ${consume(ctx, input)}
-       | } else {
-       |   $stopEarly = true;
        | }
      """.stripMargin
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e3133f4a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 085a445..81db3e1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -19,12 +19,15 @@ package org.apache.spark.sql.execution.metric
 
 import java.io.File
 
+import scala.reflect.{classTag, ClassTag}
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, 
WholeStageCodegenExec}
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
@@ -518,56 +521,80 @@ class SQLMetricsSuite extends SparkFunSuite with 
SQLMetricsTestUtils with Shared
     testMetricsDynamicPartition("parquet", "parquet", "t1")
   }
 
+  private def collectNodeWithinWholeStage[T <: SparkPlan : ClassTag](plan: 
SparkPlan): Seq[T] = {
+    val stages = plan.collect {
+      case w: WholeStageCodegenExec => w
+    }
+    assert(stages.length == 1, "The query plan should have one and only one 
whole-stage.")
+
+    val cls = classTag[T].runtimeClass
+    stages.head.collect {
+      case n if n.getClass == cls => n.asInstanceOf[T]
+    }
+  }
+
   test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input 
when not necessary") {
     def checkFilterAndRangeMetrics(
         df: DataFrame,
         filterNumOutputs: Int,
         rangeNumOutputs: Int): Unit = {
-      var filter: FilterExec = null
-      var range: RangeExec = null
-      val collectFilterAndRange: SparkPlan => Unit = {
-        case f: FilterExec =>
-          assert(filter == null, "the query should only have one Filter")
-          filter = f
-        case r: RangeExec =>
-          assert(range == null, "the query should only have one Range")
-          range = r
-        case _ =>
-      }
-      if (SQLConf.get.wholeStageEnabled) {
-        df.queryExecution.executedPlan.foreach {
-          case w: WholeStageCodegenExec =>
-            w.child.foreach(collectFilterAndRange)
-          case _ =>
-        }
-      } else {
-        df.queryExecution.executedPlan.foreach(collectFilterAndRange)
-      }
+      val plan = df.queryExecution.executedPlan
 
-      assert(filter != null && range != null, "the query doesn't have Filter 
and Range")
-      assert(filter.metrics("numOutputRows").value == filterNumOutputs)
-      assert(range.metrics("numOutputRows").value == rangeNumOutputs)
+      val filters = collectNodeWithinWholeStage[FilterExec](plan)
+      assert(filters.length == 1, "The query plan should have one and only one 
Filter")
+      assert(filters.head.metrics("numOutputRows").value == filterNumOutputs)
+
+      val ranges = collectNodeWithinWholeStage[RangeExec](plan)
+      assert(ranges.length == 1, "The query plan should have one and only one 
Range")
+      assert(ranges.head.metrics("numOutputRows").value == rangeNumOutputs)
     }
 
-    val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0)
-    val df2 = df.limit(2)
-    Seq(true, false).foreach { wholeStageEnabled =>
-      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> 
wholeStageEnabled.toString) {
-        df.collect()
-        checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, 
rangeNumOutputs = 3000)
-
-        df.queryExecution.executedPlan.foreach(_.resetMetrics())
-        // For each partition, we get 2 rows. Then the Filter should produce 2 
rows per-partition,
-        // and Range should produce 1000 rows (one batch) per-partition. 
Totally Filter produces
-        // 4 rows, and Range produces 2000 rows.
-        df.queryExecution.toRdd.mapPartitions(_.take(2)).collect()
-        checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 
2000)
-
-        // Top-most limit will call `CollectLimitExec.executeCollect`, which 
will only run the first
-        // task, so totally the Filter produces 2 rows, and Range produces 
1000 rows (one batch).
-        df2.collect()
-        checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs 
= 1000)
-      }
+    withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
+      val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0)
+      df.collect()
+      checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs 
= 3000)
+
+      df.queryExecution.executedPlan.foreach(_.resetMetrics())
+      // For each partition, we get 2 rows. Then the Filter should produce 2 
rows per-partition,
+      // and Range should produce 1000 rows (one batch) per-partition. Totally 
Filter produces
+      // 4 rows, and Range produces 2000 rows.
+      df.queryExecution.toRdd.mapPartitions(_.take(2)).collect()
+      checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 
2000)
+
+      // Top-most limit will call `CollectLimitExec.executeCollect`, which 
will only run the first
+      // task, so totally the Filter produces 2 rows, and Range produces 1000 
rows (one batch).
+      val df2 = df.limit(2)
+      df2.collect()
+      checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 
1000)
+    }
+  }
+
+  test("SPARK-25497: LIMIT within whole stage codegen should not consume all 
the inputs") {
+    withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
+      // A special query that only has one partition, so there is no shuffle 
and the entire query
+      // can be whole-stage-codegened.
+      val df = spark.range(0, 1500, 1, 
1).limit(10).groupBy('id).count().limit(1).filter('id >= 0)
+      df.collect()
+      val plan = df.queryExecution.executedPlan
+
+      val ranges = collectNodeWithinWholeStage[RangeExec](plan)
+      assert(ranges.length == 1, "The query plan should have one and only one 
Range")
+      // The Range should only produce the first batch, i.e. 1000 rows.
+      assert(ranges.head.metrics("numOutputRows").value == 1000)
+
+      val aggs = collectNodeWithinWholeStage[HashAggregateExec](plan)
+      assert(aggs.length == 2, "The query plan should have two and only two 
Aggregate")
+      val partialAgg = aggs.filter(_.aggregateExpressions.head.mode == 
Partial).head
+      // The partial aggregate should output 10 rows, because its input is 10 
rows.
+      assert(partialAgg.metrics("numOutputRows").value == 10)
+      val finalAgg = aggs.filter(_.aggregateExpressions.head.mode == 
Final).head
+      // The final aggregate should only produce 1 row, because the upstream 
limit only needs 1 row.
+      assert(finalAgg.metrics("numOutputRows").value == 1)
+
+      val filters = collectNodeWithinWholeStage[FilterExec](plan)
+      assert(filters.length == 1, "The query plan should have one and only one 
Filter")
+      // The final Filter should produce 1 rows, because the input is just one 
row.
+      assert(filters.head.metrics("numOutputRows").value == 1)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to