Repository: spark
Updated Branches:
refs/heads/branch-2.0 53cd99f65 -> 080ac37fb
[SPARK-18528][SQL] Fix a bug to initialise an iterator of aggregation buffer
## What changes were proposed in this pull request?
This pr is to fix an `NullPointerException` issue caused by a following `limit
+ aggregate` query;
```
scala> val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
scala> df.limit(2).groupBy("id").count().show
WARN TaskSetManager: Lost task 0.0 in stage 9.0 (TID 8204,
lvsp20hdn012.stubprod.com): java.lang.NullPointerException
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.agg_doAggregateWithKeys$(Unknown
Source)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown
Source)
```
The root culprit is that
[`$doAgg()`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L596)
skips an initialization of [the buffer
iterator](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L603);
`BaseLimitExec` sets `stopEarly=true` and `$doAgg()` exits in the middle
without the initialization.
## How was this patch tested?
Added a test to check if no exception happens for limit + aggregates in
`DataFrameAggregateSuite.scala`.
Author: Takeshi YAMAMURO <[email protected]>
Closes #15980 from maropu/SPARK-18528.
(cherry picked from commit b41ec997786e2be42a8a2a182212a610d08b221b)
Signed-off-by: Herman van Hovell <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/080ac37f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/080ac37f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/080ac37f
Branch: refs/heads/branch-2.0
Commit: 080ac37fb5ff5f6b8781863866e8099eb9be4dba
Parents: 53cd99f
Author: Takeshi YAMAMURO <[email protected]>
Authored: Thu Dec 22 01:53:33 2016 +0100
Committer: Herman van Hovell <[email protected]>
Committed: Thu Dec 22 01:54:00 2016 +0100
----------------------------------------------------------------------
.../apache/spark/sql/execution/BufferedRowIterator.java | 10 ++++++++++
.../spark/sql/execution/WholeStageCodegenExec.scala | 2 +-
.../main/scala/org/apache/spark/sql/execution/limit.scala | 6 +++---
.../org/apache/spark/sql/DataFrameAggregateSuite.scala | 8 ++++++++
4 files changed, 22 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/080ac37f/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
index 086547c..730a4ae 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -70,6 +70,16 @@ public abstract class BufferedRowIterator {
}
/**
+ * Returns whether this iterator should stop fetching next row from
[[CodegenSupport#inputRDDs]].
+ *
+ * If it returns true, the caller should exit the loop that [[InputAdapter]]
generates.
+ * This interface is mainly used to limit the number of input rows.
+ */
+ protected boolean stopEarly() {
+ return false;
+ }
+
+ /**
* Returns whether `processNext()` should stop processing next row from
`input` or not.
*
* If it returns true, the caller should exit the loop (return from
processNext()).
http://git-wip-us.apache.org/repos/asf/spark/blob/080ac37f/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index fb57ed7..697db39 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -239,7 +239,7 @@ case class InputAdapter(child: SparkPlan) extends
UnaryExecNode with CodegenSupp
ctx.addMutableState("scala.collection.Iterator", input, s"$input =
inputs[0];")
val row = ctx.freshName("row")
s"""
- | while ($input.hasNext()) {
+ | while ($input.hasNext() && !stopEarly()) {
| InternalRow $row = (InternalRow) $input.next();
| ${consume(ctx, null, row).trim}
| if (shouldStop()) return;
http://git-wip-us.apache.org/repos/asf/spark/blob/080ac37f/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 86a8770..14c6b6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -72,10 +72,10 @@ trait BaseLimitExec extends UnaryExecNode with
CodegenSupport {
val stopEarly = ctx.freshName("stopEarly")
ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
- ctx.addNewFunction("shouldStop", s"""
+ ctx.addNewFunction("stopEarly", s"""
@Override
- protected boolean shouldStop() {
- return !currentRows.isEmpty() || $stopEarly;
+ protected boolean stopEarly() {
+ return $stopEarly;
}
""")
val countTerm = ctx.freshName("count")
http://git-wip-us.apache.org/repos/asf/spark/blob/080ac37f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 3454caf..94da29c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -505,4 +505,12 @@ class DataFrameAggregateSuite extends QueryTest with
SharedSQLContext {
df.groupBy($"x").agg(countDistinct($"y"),
sort_array(collect_list($"z"))),
Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
}
+
+ test("SPARK-18004 limit + aggregates") {
+ val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
+ val limit2Df = df.limit(2)
+ checkAnswer(
+ limit2Df.groupBy("id").count().select($"id"),
+ limit2Df.select($"id"))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]