[SPARK-14255][SQL] Streaming Aggregation

This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. 
 In order to implement this feature, the planning of aggregation has augmented 
with a new `StatefulAggregationStrategy`.  Unlike batch aggregation, 
stateful-aggregation uses the `StateStore` (introduced in #11645) to persist 
the results of partial aggregation across different invocations.  The resulting 
physical plan performs the aggregation using the following progression:
   - Partial Aggregation
   - Shuffle
   - Partial Merge (now there is at most 1 tuple per group)
   - StateStoreRestore (now there is 1 tuple from this batch + optionally one 
from the previous)
   - Partial Merge (now there is at most 1 tuple per group)
   - StateStoreSave (saves the tuple for the next batch)
   - Complete (output the current result of the aggregation)

The following refactoring was also performed to allow us to plug into existing 
code:
 - The get/put implementation is taken from #12013
 - The logic for breaking down and de-duping the physical execution of 
aggregation has been move into a new pattern `PhysicalAggregation`
 - The `AttributeReference` used to identify the result of an 
`AggregateFunction` as been moved into the `AggregateExpression` container.  
This change moves the reference into the same object as the other intermediate 
references used in aggregation and eliminates the need to pass around a 
`Map[(AggregateFunction, Boolean), Attribute]`.  Further clean up (using a 
different aggregation container for logical/physical plans) is deferred to a 
followup.
 - Some planning logic is moved from the `SessionState` into the 
`QueryExecution` to make it easier to override in the streaming case.
 - The ability to write a `StreamTest` that checks only the output of the last 
batch has been added to simulate the future addition of output modes.

Author: Michael Armbrust <[email protected]>

Closes #12048 from marmbrus/statefulAgg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0fc4aaa7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0fc4aaa7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0fc4aaa7

Branch: refs/heads/master
Commit: 0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c
Parents: 0b7d496
Author: Michael Armbrust <[email protected]>
Authored: Fri Apr 1 15:15:16 2016 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Fri Apr 1 15:15:16 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   9 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala   |   2 +-
 .../spark/sql/catalyst/errors/package.scala     |   7 +-
 .../expressions/aggregate/interfaces.scala      |  37 ++++-
 .../catalyst/expressions/namedExpressions.scala |   2 +-
 .../sql/catalyst/optimizer/Optimizer.scala      |  14 +-
 .../spark/sql/catalyst/planning/patterns.scala  |  73 +++++++++
 .../spark/sql/catalyst/plans/PlanTest.scala     |   3 +
 .../spark/sql/execution/QueryExecution.scala    |  24 ++-
 .../apache/spark/sql/execution/SparkPlan.scala  |   7 +
 .../spark/sql/execution/SparkPlanner.scala      |   4 +-
 .../spark/sql/execution/SparkStrategies.scala   |  92 ++++-------
 .../org/apache/spark/sql/execution/Window.scala |   2 +-
 .../aggregate/TungstenAggregationIterator.scala |   4 +-
 .../spark/sql/execution/aggregate/utils.scala   | 121 ++++++++++++---
 .../streaming/IncrementalExecution.scala        |  72 +++++++++
 .../execution/streaming/StatefulAggregate.scala | 119 +++++++++++++++
 .../execution/streaming/StreamExecution.scala   |  12 +-
 .../spark/sql/execution/streaming/memory.scala  |   4 +-
 .../state/HDFSBackedStateStoreProvider.scala    |  36 +++--
 .../execution/streaming/state/StateStore.scala  |  19 ++-
 .../streaming/state/StateStoreConf.scala        |   4 +-
 .../streaming/state/StateStoreRDD.scala         |  17 +--
 .../sql/execution/streaming/state/package.scala |  21 ++-
 .../apache/spark/sql/execution/subquery.scala   |  11 +-
 .../spark/sql/expressions/Aggregator.scala      |   4 +-
 .../spark/sql/internal/SessionState.scala       |  16 +-
 .../scala/org/apache/spark/sql/StreamTest.scala |  36 +++--
 .../spark/sql/execution/SparkPlanTest.scala     |  10 +-
 .../streaming/state/StateStoreRDDSuite.scala    | 152 +++++++++++--------
 .../streaming/state/StateStoreSuite.scala       |  61 ++++----
 .../streaming/StreamingAggregationSuite.scala   | 132 ++++++++++++++++
 .../spark/sql/hive/HiveSessionState.scala       |   5 +-
 33 files changed, 827 insertions(+), 305 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d82ee3a..05e2b9a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -336,6 +336,11 @@ class Analyzer(
                 Last(ifExpr(expr), Literal(true))
               case a: AggregateFunction =>
                 a.withNewChildren(a.children.map(ifExpr))
+            }.transform {
+              // We are duplicating aggregates that are now computing a 
different value for each
+              // pivot value.
+              // TODO: Don't construct the physical container until after 
analysis.
+              case ae: AggregateExpression => ae.copy(resultId = 
NamedExpression.newExprId)
             }
             if (filteredAggregate.fastEquals(aggregate)) {
               throw new AnalysisException(
@@ -1153,11 +1158,11 @@ class Analyzer(
 
           // Extract Windowed AggregateExpression
           case we @ WindowExpression(
-              AggregateExpression(function, mode, isDistinct),
+              ae @ AggregateExpression(function, _, _, _),
               spec: WindowSpecDefinition) =>
             val newChildren = function.children.map(extractExpr)
             val newFunction = 
function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
-            val newAgg = AggregateExpression(newFunction, mode, isDistinct)
+            val newAgg = ae.copy(aggregateFunction = newFunction)
             seenWindowAggregates += newAgg
             WindowExpression(newAgg, spec)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1d1e892..4880502 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -76,7 +76,7 @@ trait CheckAnalysis {
           case g: GroupingID =>
             failAnalysis(s"grouping_id() can only be used with 
GroupingSets/Cube/Rollup")
 
-          case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
+          case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
             failAnalysis(s"Distinct window functions are not supported: $w")
 
           case w @ WindowExpression(_: OffsetWindowFunction, 
WindowSpecDefinition(_, order,

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
index 0d44d1d..0420b4b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
@@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
 package object errors {
 
   class TreeNodeException[TreeType <: TreeNode[_]](
-      tree: TreeType, msg: String, cause: Throwable)
+      @transient val tree: TreeType,
+      msg: String,
+      cause: Throwable)
     extends Exception(msg, cause) {
 
+    val treeString = tree.toString
+
     // Yes, this is the same as a default parameter, but... those don't seem 
to work with SBT
     // external project dependencies for some reason.
     def this(tree: TreeType, msg: String) = this(tree, msg, null)
 
     override def getMessage: String = {
-      val treeString = tree.toString
       s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " 
"}$tree"
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index ff3064a..d31ccf9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.types._
@@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with 
Unevaluable {
   override def children: Seq[Expression] = Nil
 }
 
+object AggregateExpression {
+  def apply(
+      aggregateFunction: AggregateFunction,
+      mode: AggregateMode,
+      isDistinct: Boolean): AggregateExpression = {
+    AggregateExpression(
+      aggregateFunction,
+      mode,
+      isDistinct,
+      NamedExpression.newExprId)
+  }
+}
+
 /**
  * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a 
field
  * (`isDistinct`) indicating if DISTINCT keyword is specified for this 
function.
@@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with 
Unevaluable {
 private[sql] case class AggregateExpression(
     aggregateFunction: AggregateFunction,
     mode: AggregateMode,
-    isDistinct: Boolean)
+    isDistinct: Boolean,
+    resultId: ExprId)
   extends Expression
   with Unevaluable {
 
+  lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) {
+    AttributeReference(
+      aggregateFunction.toString,
+      aggregateFunction.dataType,
+      aggregateFunction.nullable)(exprId = resultId)
+  } else {
+    // This is a bit of a hack.  Really we should not be constructing this 
container and reasoning
+    // about datatypes / aggregation mode until after we have finished 
analysis and made it to
+    // planning.
+    UnresolvedAttribute(aggregateFunction.toString)
+  }
+
+  // We compute the same thing regardless of our final result.
+  override lazy val canonicalized: Expression =
+    AggregateExpression(
+      aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
+      mode,
+      isDistinct,
+      ExprId(0))
+
   override def children: Seq[Expression] = aggregateFunction :: Nil
   override def dataType: DataType = aggregateFunction.dataType
   override def foldable: Boolean = false

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 262582c..2307122 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -329,7 +329,7 @@ case class PrettyAttribute(
   override def withName(newName: String): Attribute = throw new 
UnsupportedOperationException
   override def qualifier: Option[String] = throw new 
UnsupportedOperationException
   override def exprId: ExprId = throw new UnsupportedOperationException
-  override def nullable: Boolean = throw new UnsupportedOperationException
+  override def nullable: Boolean = true
 }
 
 object VirtualColumn {

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a7a948e..326933e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] {
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
-      case e @ AggregateExpression(Count(exprs), _, _) if 
!exprs.exists(nonNullLiteral) =>
+      case e @ AggregateExpression(Count(exprs), _, _, _) if 
!exprs.exists(nonNullLiteral) =>
         Cast(Literal(0L), e.dataType)
       case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
       case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] {
         Literal.create(null, e.dataType)
       case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
       case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
-      case e @ AggregateExpression(Count(exprs), mode, false) if 
!exprs.exists(_.nullable) =>
+      case ae @ AggregateExpression(Count(exprs), _, false, _) if 
!exprs.exists(_.nullable) =>
         // This rule should be only triggered when isDistinct field is false.
-        AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
+        ae.copy(aggregateFunction = Count(Literal(1)))
 
       // For Coalesce, remove null literals.
       case e @ Coalesce(children) =>
@@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
   private val MAX_DOUBLE_DIGITS = 15
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-    case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), 
mode, isDistinct)
+    case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, 
scale)), _, _, _)
       if prec + 10 <= MAX_LONG_DIGITS =>
-      MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, 
isDistinct), prec + 10, scale)
+      MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 
10, scale)
 
-    case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), 
mode, isDistinct)
+    case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, 
scale)), _, _, _)
       if prec + 4 <= MAX_DOUBLE_DIGITS =>
-      val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, 
isDistinct)
+      val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
       Cast(
         Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
         DecimalType(prec + 4, scale + 4))

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 9c92707..28d2c44 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types.IntegerType
@@ -216,3 +217,75 @@ object IntegerIndex {
     case _ => None
   }
 }
+
+/**
+ * An extractor used when planning the physical execution of an aggregation. 
Compared with a logical
+ * aggregation, the following transformations are performed:
+ *  - Unnamed grouping expressions are named so that they can be referred to 
across phases of
+ *    aggregation
+ *  - Aggregations that appear multiple times are deduplicated.
+ *  - The compution of the aggregations themselves is separated from the final 
result. For example,
+ *    the `count` in `count + 1` will be split into an [[AggregateExpression]] 
and a final
+ *    computation that computes `count.resultAttribute + 1`.
+ */
+object PhysicalAggregation {
+  // groupingExpressions, aggregateExpressions, resultExpressions, child
+  type ReturnType =
+    (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], 
LogicalPlan)
+
+  def unapply(a: Any): Option[ReturnType] = a match {
+    case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+      // A single aggregate expression might appear multiple times in 
resultExpressions.
+      // In order to avoid evaluating an individual aggregate function 
multiple times, we'll
+      // build a set of the distinct aggregate expressions and build a 
function which can
+      // be used to re-write expressions so that they reference the single 
copy of the
+      // aggregate function which actually gets computed.
+      val aggregateExpressions = resultExpressions.flatMap { expr =>
+        expr.collect {
+          case agg: AggregateExpression => agg
+        }
+      }.distinct
+
+      val namedGroupingExpressions = groupingExpressions.map {
+        case ne: NamedExpression => ne -> ne
+        // If the expression is not a NamedExpressions, we add an alias.
+        // So, when we generate the result of the operator, the Aggregate 
Operator
+        // can directly get the Seq of attributes representing the grouping 
expressions.
+        case other =>
+          val withAlias = Alias(other, other.toString)()
+          other -> withAlias
+      }
+      val groupExpressionMap = namedGroupingExpressions.toMap
+
+      // The original `resultExpressions` are a set of expressions which may 
reference
+      // aggregate expressions, grouping column values, and constants. When 
aggregate operator
+      // emits output rows, we will use `resultExpressions` to generate an 
output projection
+      // which takes the grouping columns and final aggregate result buffer as 
input.
+      // Thus, we must re-write the result expressions so that their 
attributes match up with
+      // the attributes of the final result projection's input row:
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transformDown {
+          case ae: AggregateExpression =>
+            // The final aggregation buffer's attributes will be 
`finalAggregationAttributes`,
+            // so replace each aggregate expression by its corresponding 
attribute in the set:
+            ae.resultAttribute
+          case expression =>
+            // Since we're using `namedGroupingAttributes` to extract the 
grouping key
+            // columns, we need to replace grouping key expressions with their 
corresponding
+            // attributes. We do not rely on the equality check at here since 
attributes may
+            // differ cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => 
ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+
+      Some((
+        namedGroupingExpressions.map(_._2),
+        aggregateExpressions,
+        rewrittenResultExpressions,
+        child))
+
+    case _ => None
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index aa5d433..7191936 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
OneRowRelation, Sample}
 import org.apache.spark.sql.catalyst.util._
 
@@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with 
PredicateHelper {
         AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
       case a: Alias =>
         Alias(a.child, a.name)(exprId = ExprId(0))
+      case ae: AggregateExpression =>
+        ae.copy(resultId = ExprId(0))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 912b84a..4843553 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{AnalysisException, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ReuseExchange}
 
 /**
  * The primary workflow for executing relational queries using Spark.  
Designed to allow easy
@@ -31,6 +33,9 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
  */
 class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
 
+  // TODO: Move the planner an optimizer into here from SessionState.
+  protected def planner = sqlContext.sessionState.planner
+
   def assertAnalyzed(): Unit = try 
sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch {
     case e: AnalysisException =>
       val ae = new AnalysisException(e.message, e.line, e.startPosition, 
Some(analyzed))
@@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val 
logical: LogicalPlan) {
 
   lazy val sparkPlan: SparkPlan = {
     SQLContext.setActive(sqlContext)
-    sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
+    planner.plan(ReturnAnswer(optimizedPlan)).next()
   }
 
   // executedPlan should not be used to initialize any SparkPlan. It should be
   // only used for execution.
-  lazy val executedPlan: SparkPlan = 
sqlContext.sessionState.prepareForExecution.execute(sparkPlan)
+  lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
 
   /** Internal version of the RDD. Avoids copies and has no schema */
   lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
 
+  /**
+   * Prepares a planned [[SparkPlan]] for execution by inserting shuffle 
operations and internal
+   * row format conversions as needed.
+   */
+  protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
+    preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
+  }
+
+  /** A sequence of rules that will be applied in order to the physical plan 
before execution. */
+  protected def preparations: Seq[Rule[SparkPlan]] = Seq(
+    PlanSubqueries(sqlContext),
+    EnsureRequirements(sqlContext.conf),
+    CollapseCodegenStages(sqlContext.conf),
+    ReuseExchange(sqlContext.conf))
+
   protected def stringOrError[A](f: => A): String =
     try f.toString catch { case e: Throwable => e.toString }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 010ed7f..b1b3d4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan {
   override def producedAttributes: AttributeSet = outputSet
 }
 
+object UnaryNode {
+  def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match {
+    case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head))
+    case _ => None
+  }
+}
+
 private[sql] trait UnaryNode extends SparkPlan {
   def child: SparkPlan
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 9da2c74..ac8072f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -26,13 +26,13 @@ import org.apache.spark.sql.internal.SQLConf
 class SparkPlanner(
     val sparkContext: SparkContext,
     val conf: SQLConf,
-    val experimentalMethods: ExperimentalMethods)
+    val extraStrategies: Seq[Strategy])
   extends SparkStrategies {
 
   def numPartitions: Int = conf.numShufflePartitions
 
   def strategies: Seq[Strategy] =
-    experimentalMethods.extraStrategies ++ (
+      extraStrategies ++ (
       FileSourceStrategy ::
       DataSourceStrategy ::
       DDLStrategy ::

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7a2e2b7..5bcc172 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.Strategy
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -204,28 +203,32 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
   }
 
   /**
+   * Used to plan aggregation queries that are computed incrementally as part 
of a
+   * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected 
into the planner
+   * on-demand, only when planning in a 
[[org.apache.spark.sql.execution.streaming.StreamExecution]]
+   */
+  object StatefulAggregationStrategy extends Strategy {
+    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case PhysicalAggregation(
+        namedGroupingExpressions, aggregateExpressions, 
rewrittenResultExpressions, child) =>
+
+        aggregate.Utils.planStreamingAggregation(
+          namedGroupingExpressions,
+          aggregateExpressions,
+          rewrittenResultExpressions,
+          planLater(child))
+
+      case _ => Nil
+    }
+  }
+
+  /**
    * Used to plan the aggregate operator for expressions based on the 
AggregateFunction2 interface.
    */
   object Aggregation extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
-        // A single aggregate expression might appear multiple times in 
resultExpressions.
-        // In order to avoid evaluating an individual aggregate function 
multiple times, we'll
-        // build a set of the distinct aggregate expressions and build a 
function which can
-        // be used to re-write expressions so that they reference the single 
copy of the
-        // aggregate function which actually gets computed.
-        val aggregateExpressions = resultExpressions.flatMap { expr =>
-          expr.collect {
-            case agg: AggregateExpression => agg
-          }
-        }.distinct
-        // For those distinct aggregate expressions, we create a map from the
-        // aggregate function to the corresponding attribute of the function.
-        val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
-          val aggregateFunction = agg.aggregateFunction
-          val attribute = Alias(aggregateFunction, 
aggregateFunction.toString)().toAttribute
-          (aggregateFunction, agg.isDistinct) -> attribute
-        }.toMap
+      case PhysicalAggregation(
+          groupingExpressions, aggregateExpressions, resultExpressions, child) 
=>
 
         val (functionsWithDistinct, functionsWithoutDistinct) =
           aggregateExpressions.partition(_.isDistinct)
@@ -233,41 +236,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           // This is a sanity check. We should not reach here when we have 
multiple distinct
           // column sets. Our MultipleDistinctRewriter should take care this 
case.
           sys.error("You hit a query analyzer bug. Please report your query to 
" +
-            "Spark user mailing list.")
-        }
-
-        val namedGroupingExpressions = groupingExpressions.map {
-          case ne: NamedExpression => ne -> ne
-          // If the expression is not a NamedExpressions, we add an alias.
-          // So, when we generate the result of the operator, the Aggregate 
Operator
-          // can directly get the Seq of attributes representing the grouping 
expressions.
-          case other =>
-            val withAlias = Alias(other, other.toString)()
-            other -> withAlias
-        }
-        val groupExpressionMap = namedGroupingExpressions.toMap
-
-        // The original `resultExpressions` are a set of expressions which may 
reference
-        // aggregate expressions, grouping column values, and constants. When 
aggregate operator
-        // emits output rows, we will use `resultExpressions` to generate an 
output projection
-        // which takes the grouping columns and final aggregate result buffer 
as input.
-        // Thus, we must re-write the result expressions so that their 
attributes match up with
-        // the attributes of the final result projection's input row:
-        val rewrittenResultExpressions = resultExpressions.map { expr =>
-          expr.transformDown {
-            case AggregateExpression(aggregateFunction, _, isDistinct) =>
-              // The final aggregation buffer's attributes will be 
`finalAggregationAttributes`,
-              // so replace each aggregate expression by its corresponding 
attribute in the set:
-              aggregateFunctionToAttribute(aggregateFunction, isDistinct)
-            case expression =>
-              // Since we're using `namedGroupingAttributes` to extract the 
grouping key
-              // columns, we need to replace grouping key expressions with 
their corresponding
-              // attributes. We do not rely on the equality check at here 
since attributes may
-              // differ cosmetically. Instead, we use semanticEquals.
-              groupExpressionMap.collectFirst {
-                case (expr, ne) if expr semanticEquals expression => 
ne.toAttribute
-              }.getOrElse(expression)
-          }.asInstanceOf[NamedExpression]
+              "Spark user mailing list.")
         }
 
         val aggregateOperator =
@@ -277,26 +246,23 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
                 "aggregate functions which don't support partial aggregation.")
             } else {
               aggregate.Utils.planAggregateWithoutPartial(
-                namedGroupingExpressions.map(_._2),
+                groupingExpressions,
                 aggregateExpressions,
-                aggregateFunctionToAttribute,
-                rewrittenResultExpressions,
+                resultExpressions,
                 planLater(child))
             }
           } else if (functionsWithDistinct.isEmpty) {
             aggregate.Utils.planAggregateWithoutDistinct(
-              namedGroupingExpressions.map(_._2),
+              groupingExpressions,
               aggregateExpressions,
-              aggregateFunctionToAttribute,
-              rewrittenResultExpressions,
+              resultExpressions,
               planLater(child))
           } else {
             aggregate.Utils.planAggregateWithOneDistinct(
-              namedGroupingExpressions.map(_._2),
+              groupingExpressions,
               functionsWithDistinct,
               functionsWithoutDistinct,
-              aggregateFunctionToAttribute,
-              rewrittenResultExpressions,
+              resultExpressions,
               planLater(child))
           }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 270c09a..7acf020 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -177,7 +177,7 @@ case class Window(
         case e @ WindowExpression(function, spec) =>
           val frame = 
spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
           function match {
-            case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, 
e, f)
+            case AggregateExpression(f, _, _, _) => collect("AGGREGATE", 
frame, e, f)
             case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, 
f)
             case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
             case f => sys.error(s"Unsupported window function: $f")

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 213bca9..ce504e2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -242,9 +242,9 @@ class TungstenAggregationIterator(
     // Basically the value of the KVIterator returned by externalSorter
     // will be just aggregation buffer, so we rewrite the aggregateExpressions 
to reflect it.
     val newExpressions = aggregateExpressions.map {
-      case agg @ AggregateExpression(_, Partial, _) =>
+      case agg @ AggregateExpression(_, Partial, _, _) =>
         agg.copy(mode = PartialMerge)
-      case agg @ AggregateExpression(_, Complete, _) =>
+      case agg @ AggregateExpression(_, Complete, _, _) =>
         agg.copy(mode = Final)
       case other => other
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 1e113cc..4682949 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.streaming.{StateStoreRestore, 
StateStoreSave}
 
 /**
  * Utility functions used by the query planner to convert our plan to new 
aggregation code path.
@@ -29,15 +30,11 @@ object Utils {
   def planAggregateWithoutPartial(
       groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
-      aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), 
Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
     val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = 
Complete))
-    val completeAggregateAttributes = completeAggregateExpressions.map {
-      expr => aggregateFunctionToAttribute(expr.aggregateFunction, 
expr.isDistinct)
-    }
-
+    val completeAggregateAttributes = 
completeAggregateExpressions.map(_.resultAttribute)
     SortBasedAggregate(
       requiredChildDistributionExpressions = Some(groupingExpressions),
       groupingExpressions = groupingExpressions,
@@ -83,7 +80,6 @@ object Utils {
   def planAggregateWithoutDistinct(
       groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
-      aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), 
Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
     // Check if we can use TungstenAggregate.
@@ -111,9 +107,7 @@ object Utils {
     val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = 
Final))
     // The attributes of the final aggregation buffer, which is presented as 
input to the result
     // projection:
-    val finalAggregateAttributes = finalAggregateExpressions.map {
-      expr => aggregateFunctionToAttribute(expr.aggregateFunction, 
expr.isDistinct)
-    }
+    val finalAggregateAttributes = 
finalAggregateExpressions.map(_.resultAttribute)
 
     val finalAggregate = createAggregate(
         requiredChildDistributionExpressions = Some(groupingAttributes),
@@ -131,7 +125,6 @@ object Utils {
       groupingExpressions: Seq[NamedExpression],
       functionsWithDistinct: Seq[AggregateExpression],
       functionsWithoutDistinct: Seq[AggregateExpression],
-      aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), 
Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
@@ -151,9 +144,7 @@ object Utils {
     // 1. Create an Aggregate Operator for partial aggregations.
     val partialAggregate: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
Partial))
-      val aggregateAttributes = aggregateExpressions.map {
-        expr => aggregateFunctionToAttribute(expr.aggregateFunction, 
expr.isDistinct)
-      }
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
       // We will group by the original grouping expression, plus an additional 
expression for the
       // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, 
the grouping
       // expressions will be [key, value].
@@ -169,9 +160,7 @@ object Utils {
     // 2. Create an Aggregate Operator for partial merge aggregations.
     val partialMergeAggregate: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
PartialMerge))
-      val aggregateAttributes = aggregateExpressions.map {
-        expr => aggregateFunctionToAttribute(expr.aggregateFunction, 
expr.isDistinct)
-      }
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
       createAggregate(
         requiredChildDistributionExpressions =
           Some(groupingAttributes ++ distinctAttributes),
@@ -190,7 +179,7 @@ object Utils {
       // Children of an AggregateFunction with DISTINCT keyword has already
       // been evaluated. At here, we need to replace original children
       // to AttributeReferences.
-      case agg @ AggregateExpression(aggregateFunction, mode, true) =>
+      case agg @ AggregateExpression(aggregateFunction, mode, true, _) =>
         aggregateFunction.transformDown(distinctColumnAttributeLookup)
           .asInstanceOf[AggregateFunction]
     }
@@ -199,9 +188,7 @@ object Utils {
       val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode 
= PartialMerge))
       // The attributes of the final aggregation buffer, which is presented as 
input to the result
       // projection:
-      val mergeAggregateAttributes = mergeAggregateExpressions.map {
-        expr => aggregateFunctionToAttribute(expr.aggregateFunction, 
expr.isDistinct)
-      }
+      val mergeAggregateAttributes = 
mergeAggregateExpressions.map(_.resultAttribute)
       val (distinctAggregateExpressions, distinctAggregateAttributes) =
         rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
           // We rewrite the aggregate function to a non-distinct aggregation 
because
@@ -211,7 +198,7 @@ object Utils {
           val expr = AggregateExpression(func, Partial, isDistinct = true)
           // Use original AggregationFunction to lookup attributes, which is 
used to build
           // aggregateFunctionToAttribute
-          val attr = 
aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+          val attr = functionsWithDistinct(i).resultAttribute
           (expr, attr)
       }.unzip
 
@@ -232,9 +219,7 @@ object Utils {
       val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode 
= Final))
       // The attributes of the final aggregation buffer, which is presented as 
input to the result
       // projection:
-      val finalAggregateAttributes = finalAggregateExpressions.map {
-        expr => aggregateFunctionToAttribute(expr.aggregateFunction, 
expr.isDistinct)
-      }
+      val finalAggregateAttributes = 
finalAggregateExpressions.map(_.resultAttribute)
 
       val (distinctAggregateExpressions, distinctAggregateAttributes) =
         rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
@@ -245,7 +230,7 @@ object Utils {
           val expr = AggregateExpression(func, Final, isDistinct = true)
           // Use original AggregationFunction to lookup attributes, which is 
used to build
           // aggregateFunctionToAttribute
-          val attr = 
aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+          val attr = functionsWithDistinct(i).resultAttribute
           (expr, attr)
       }.unzip
 
@@ -261,4 +246,90 @@ object Utils {
 
     finalAndCompleteAggregate :: Nil
   }
+
+  /**
+   * Plans a streaming aggregation using the following progression:
+   *  - Partial Aggregation
+   *  - Shuffle
+   *  - Partial Merge (now there is at most 1 tuple per group)
+   *  - StateStoreRestore (now there is 1 tuple from this batch + optionally 
one from the previous)
+   *  - PartialMerge (now there is at most 1 tuple per group)
+   *  - StateStoreSave (saves the tuple for the next batch)
+   *  - Complete (output the current result of the aggregation)
+   */
+  def planStreamingAggregation(
+      groupingExpressions: Seq[NamedExpression],
+      functionsWithoutDistinct: Seq[AggregateExpression],
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): Seq[SparkPlan] = {
+
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
+
+    val partialAggregate: SparkPlan = {
+      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
Partial))
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+      // We will group by the original grouping expression, plus an additional 
expression for the
+      // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, 
the grouping
+      // expressions will be [key, value].
+      createAggregate(
+        groupingExpressions = groupingExpressions,
+        aggregateExpressions = aggregateExpressions,
+        aggregateAttributes = aggregateAttributes,
+        resultExpressions = groupingAttributes ++
+            
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+        child = child)
+    }
+
+    val partialMerged1: SparkPlan = {
+      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
PartialMerge))
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+      createAggregate(
+        requiredChildDistributionExpressions =
+            Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
+        aggregateExpressions = aggregateExpressions,
+        aggregateAttributes = aggregateAttributes,
+        initialInputBufferOffset = groupingAttributes.length,
+        resultExpressions = groupingAttributes ++
+            
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+        child = partialAggregate)
+    }
+
+    val restored = StateStoreRestore(groupingAttributes, None, partialMerged1)
+
+    val partialMerged2: SparkPlan = {
+      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
PartialMerge))
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+      createAggregate(
+        requiredChildDistributionExpressions =
+            Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
+        aggregateExpressions = aggregateExpressions,
+        aggregateAttributes = aggregateAttributes,
+        initialInputBufferOffset = groupingAttributes.length,
+        resultExpressions = groupingAttributes ++
+            
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+        child = restored)
+    }
+
+    val saved = StateStoreSave(groupingAttributes, None, partialMerged2)
+
+    val finalAndCompleteAggregate: SparkPlan = {
+      val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode 
= Final))
+      // The attributes of the final aggregation buffer, which is presented as 
input to the result
+      // projection:
+      val finalAggregateAttributes = 
finalAggregateExpressions.map(_.resultAttribute)
+
+      createAggregate(
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
+        aggregateExpressions = finalAggregateExpressions,
+        aggregateAttributes = finalAggregateAttributes,
+        initialInputBufferOffset = groupingAttributes.length,
+        resultExpressions = resultExpressions,
+        child = saved)
+    }
+
+    finalAndCompleteAggregate :: Nil
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
new file mode 100644
index 0000000..aaced49
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -0,0 +1,72 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, 
SparkPlanner, UnaryNode}
+
+/**
+ * A variant of [[QueryExecution]] that allows the execution of the given 
[[LogicalPlan]]
+ * plan incrementally. Possibly preserving state in between each execution.
+ */
+class IncrementalExecution(
+    ctx: SQLContext,
+    logicalPlan: LogicalPlan,
+    checkpointLocation: String,
+    currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) {
+
+  // TODO: make this always part of planning.
+  val stateStrategy = 
sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil
+
+  // Modified planner with stateful operations.
+  override def planner: SparkPlanner =
+    new SparkPlanner(
+      sqlContext.sparkContext,
+      sqlContext.conf,
+      stateStrategy)
+
+  /**
+   * Records the current id for a given stateful operator in the query plan as 
the `state`
+   * preperation walks the query plan.
+   */
+  private var operatorId = 0
+
+  /** Locates save/restore pairs surrounding aggregation. */
+  val state = new Rule[SparkPlan] {
+    override def apply(plan: SparkPlan): SparkPlan = plan transform {
+      case StateStoreSave(keys, None,
+             UnaryNode(agg,
+               StateStoreRestore(keys2, None, child))) =>
+        val stateId = OperatorStateId(checkpointLocation, operatorId, 
currentBatchId - 1)
+        operatorId += 1
+
+        StateStoreSave(
+          keys,
+          Some(stateId),
+          agg.withNewChildren(
+            StateStoreRestore(
+              keys,
+              Some(stateId),
+              child) :: Nil))
+    }
+  }
+
+  override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
new file mode 100644
index 0000000..5957747
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.execution.SparkPlan
+
+/** Used to identify the state store for a given operator. */
+case class OperatorStateId(
+    checkpointLocation: String,
+    operatorId: Long,
+    batchId: Long)
+
+/**
+ * An operator that saves or restores state from the [[StateStore]].  The 
[[OperatorStateId]] should
+ * be filled in by `prepareForExecution` in [[IncrementalExecution]].
+ */
+trait StatefulOperator extends SparkPlan {
+  def stateId: Option[OperatorStateId]
+
+  protected def getStateId: OperatorStateId = attachTree(this) {
+    stateId.getOrElse {
+      throw new IllegalStateException("State location not present for 
execution")
+    }
+  }
+}
+
+/**
+ * For each input tuple, the key is calculated and the value from the 
[[StateStore]] is added
+ * to the stream (in addition to the input tuple) if present.
+ */
+case class StateStoreRestore(
+    keyExpressions: Seq[Attribute],
+    stateId: Option[OperatorStateId],
+    child: SparkPlan) extends execution.UnaryNode with StatefulOperator {
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitionsWithStateStore(
+      getStateId.checkpointLocation,
+      operatorId = getStateId.operatorId,
+      storeVersion = getStateId.batchId,
+      keyExpressions.toStructType,
+      child.output.toStructType,
+      new StateStoreConf(sqlContext.conf),
+      Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+        val getKey = GenerateUnsafeProjection.generate(keyExpressions, 
child.output)
+        iter.flatMap { row =>
+          val key = getKey(row)
+          val savedState = store.get(key)
+          row +: savedState.toSeq
+        }
+    }
+  }
+  override def output: Seq[Attribute] = child.output
+}
+
+/**
+ * For each input tuple, the key is calculated and the tuple is `put` into the 
[[StateStore]].
+ */
+case class StateStoreSave(
+    keyExpressions: Seq[Attribute],
+    stateId: Option[OperatorStateId],
+    child: SparkPlan) extends execution.UnaryNode with StatefulOperator {
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitionsWithStateStore(
+      getStateId.checkpointLocation,
+      operatorId = getStateId.operatorId,
+      storeVersion = getStateId.batchId,
+      keyExpressions.toStructType,
+      child.output.toStructType,
+      new StateStoreConf(sqlContext.conf),
+      Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+        new Iterator[InternalRow] {
+          private[this] val baseIterator = iter
+          private[this] val getKey = 
GenerateUnsafeProjection.generate(keyExpressions, child.output)
+
+          override def hasNext: Boolean = {
+            if (!baseIterator.hasNext) {
+              store.commit()
+              false
+            } else {
+              true
+            }
+          }
+
+          override def next(): InternalRow = {
+            val row = baseIterator.next().asInstanceOf[UnsafeRow]
+            val key = getKey(row)
+            store.put(key.copy(), row.copy())
+            row
+          }
+        }
+    }
+  }
+
+  override def output: Seq[Attribute] = child.output
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index c4e410d..511e30c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.util._
@@ -272,6 +273,8 @@ class StreamExecution(
   private def runBatch(): Unit = {
     val startTime = System.nanoTime()
 
+    // TODO: Move this to IncrementalExecution.
+
     // Request unprocessed data from all sources.
     val newData = availableOffsets.flatMap {
       case (source, available) if committedOffsets.get(source).map(_ < 
available).getOrElse(true) =>
@@ -305,13 +308,14 @@ class StreamExecution(
     }
 
     val optimizerStart = System.nanoTime()
-
-    lastExecution = new QueryExecution(sqlContext, newPlan)
-    val executedPlan = lastExecution.executedPlan
+    lastExecution =
+        new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), 
currentBatchId)
+    lastExecution.executedPlan
     val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000
     logDebug(s"Optimized batch in ${optimizerTime}ms")
 
-    val nextBatch = Dataset.ofRows(sqlContext, newPlan)
+    val nextBatch =
+      new Dataset(sqlContext, lastExecution, 
RowEncoder(lastExecution.analyzed.schema))
     sink.addBatch(currentBatchId - 1, nextBatch)
 
     awaitBatchLock.synchronized {

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 0f91e59..7d97f81 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -108,7 +108,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
  * A sink that stores the results in memory. This [[Sink]] is primarily 
intended for use in unit
  * tests and does not provide durability.
  */
-class MemorySink(schema: StructType) extends Sink with Logging {
+class MemorySink(val schema: StructType) extends Sink with Logging {
   /** An order list of batches that have been written to this [[Sink]]. */
   private val batches = new ArrayBuffer[Array[Row]]()
 
@@ -117,6 +117,8 @@ class MemorySink(schema: StructType) extends Sink with 
Logging {
     batches.flatten
   }
 
+  def lastBatch: Seq[Row] = batches.last
+
   def toDebugString: String = synchronized {
     batches.zipWithIndex.map { case (b, i) =>
       val dataStr = try b.mkString(" ") catch {

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index ee015ba..998eb82 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -81,7 +81,7 @@ private[state] class HDFSBackedStateStoreProvider(
     trait STATE
     case object UPDATING extends STATE
     case object COMMITTED extends STATE
-    case object CANCELLED extends STATE
+    case object ABORTED extends STATE
 
     private val newVersion = version + 1
     private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
@@ -94,15 +94,14 @@ private[state] class HDFSBackedStateStoreProvider(
 
     override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
 
-    /**
-     * Update the value of a key using the value generated by the update 
function.
-     * @note Do not mutate the retrieved value row as it will unexpectedly 
affect the previous
-     *       versions of the store data.
-     */
-    override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => 
UnsafeRow): Unit = {
-      verify(state == UPDATING, "Cannot update after already committed or 
cancelled")
-      val oldValueOption = Option(mapToUpdate.get(key))
-      val value = updateFunc(oldValueOption)
+    override def get(key: UnsafeRow): Option[UnsafeRow] = {
+      Option(mapToUpdate.get(key))
+    }
+
+    override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+      verify(state == UPDATING, "Cannot remove after already committed or 
cancelled")
+
+      val isNewKey = !mapToUpdate.containsKey(key)
       mapToUpdate.put(key, value)
 
       Option(allUpdates.get(key)) match {
@@ -115,8 +114,7 @@ private[state] class HDFSBackedStateStoreProvider(
         case None =>
           // There was no prior update, so mark this as added or updated 
according to its presence
           // in previous version.
-          val update =
-            if (oldValueOption.nonEmpty) ValueUpdated(key, value) else 
ValueAdded(key, value)
+          val update = if (isNewKey) ValueAdded(key, value) else 
ValueUpdated(key, value)
           allUpdates.put(key, update)
       }
       writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
@@ -148,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider(
 
     /** Commit all the updates that have been made to the store, and return 
the new version. */
     override def commit(): Long = {
-      verify(state == UPDATING, "Cannot commit again after already committed 
or cancelled")
+      verify(state == UPDATING, "Cannot commit after already committed or 
cancelled")
 
       try {
         finalizeDeltaFile(tempDeltaFileStream)
@@ -164,8 +162,8 @@ private[state] class HDFSBackedStateStoreProvider(
     }
 
     /** Cancel all the updates made on this store. This store will not be 
usable any more. */
-    override def cancel(): Unit = {
-      state = CANCELLED
+    override def abort(): Unit = {
+      state = ABORTED
       if (tempDeltaFileStream != null) {
         tempDeltaFileStream.close()
       }
@@ -176,8 +174,8 @@ private[state] class HDFSBackedStateStoreProvider(
     }
 
     /**
-     * Get an iterator of all the store data. This can be called only after 
committing the
-     * updates.
+     * Get an iterator of all the store data.
+     * This can be called only after committing all the updates made in the 
current thread.
      */
     override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
       verify(state == COMMITTED, "Cannot get iterator of store data before 
comitting")
@@ -186,7 +184,7 @@ private[state] class HDFSBackedStateStoreProvider(
 
     /**
      * Get an iterator of all the updates made to the store in the current 
version.
-     * This can be called only after committing the updates.
+     * This can be called only after committing all the updates made in the 
current thread.
      */
     override def updates(): Iterator[StoreUpdate] = {
       verify(state == COMMITTED, "Cannot get iterator of updates before 
committing")
@@ -196,7 +194,7 @@ private[state] class HDFSBackedStateStoreProvider(
     /**
      * Whether all updates have been committed
      */
-    override def hasCommitted: Boolean = {
+    override private[state] def hasCommitted: Boolean = {
       state == COMMITTED
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index ca5c864..d60e618 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -47,12 +47,11 @@ trait StateStore {
   /** Version of the data in this store before committing updates. */
   def version: Long
 
-  /**
-   * Update the value of a key using the value generated by the update 
function.
-   * @note Do not mutate the retrieved value row as it will unexpectedly 
affect the previous
-   *       versions of the store data.
-   */
-  def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit
+  /** Get the current value of a key. */
+  def get(key: UnsafeRow): Option[UnsafeRow]
+
+  /** Put a new value for a key. */
+  def put(key: UnsafeRow, value: UnsafeRow)
 
   /**
    * Remove keys that match the following condition.
@@ -65,24 +64,24 @@ trait StateStore {
   def commit(): Long
 
   /** Cancel all the updates that have been made to the store. */
-  def cancel(): Unit
+  def abort(): Unit
 
   /**
    * Iterator of store data after a set of updates have been committed.
-   * This can be called only after commitUpdates() has been called in the 
current thread.
+   * This can be called only after committing all the updates made in the 
current thread.
    */
   def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
 
   /**
    * Iterator of the updates that have been committed.
-   * This can be called only after commitUpdates() has been called in the 
current thread.
+   * This can be called only after committing all the updates made in the 
current thread.
    */
   def updates(): Iterator[StoreUpdate]
 
   /**
    * Whether all updates have been committed
    */
-  def hasCommitted: Boolean
+  private[state] def hasCommitted: Boolean
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index cca22a0..f0f1f3a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
 import org.apache.spark.sql.internal.SQLConf
 
 /** A class that contains configuration parameters for [[StateStore]]s. */
-private[state] class StateStoreConf(@transient private val conf: SQLConf) 
extends Serializable {
+private[streaming] class StateStoreConf(@transient private val conf: SQLConf) 
extends Serializable {
 
   def this() = this(new SQLConf)
 
@@ -31,7 +31,7 @@ private[state] class StateStoreConf(@transient private val 
conf: SQLConf) extend
   val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN)
 }
 
-private[state] object StateStoreConf {
+private[streaming] object StateStoreConf {
   val empty = new StateStoreConf()
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index 3318660..df3d82c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
 
   override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = 
{
     var store: StateStore = null
-
-    Utils.tryWithSafeFinally {
-      val storeId = StateStoreId(checkpointLocation, operatorId, 
partition.index)
-      store = StateStore.get(
-        storeId, keySchema, valueSchema, storeVersion, storeConf, 
confBroadcast.value.value)
-      val inputIter = dataRDD.iterator(partition, ctxt)
-      val outputIter = storeUpdateFunction(store, inputIter)
-      assert(store.hasCommitted)
-      outputIter
-    } {
-      if (store != null) store.cancel()
-    }
+    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    store = StateStore.get(
+      storeId, keySchema, valueSchema, storeVersion, storeConf, 
confBroadcast.value.value)
+    val inputIter = dataRDD.iterator(partition, ctxt)
+    storeUpdateFunction(store, inputIter)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index b249e37..9b6d091 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -28,37 +28,36 @@ package object state {
   implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {
 
     /** Map each partition of a RDD along with data in a [[StateStore]]. */
-    def mapPartitionWithStateStore[U: ClassTag](
-        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+    def mapPartitionsWithStateStore[U: ClassTag](
+        sqlContext: SQLContext,
         checkpointLocation: String,
         operatorId: Long,
         storeVersion: Long,
         keySchema: StructType,
-        valueSchema: StructType
-      )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = {
+        valueSchema: StructType)(
+        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): 
StateStoreRDD[T, U] = {
 
-      mapPartitionWithStateStore(
-        storeUpdateFunction,
+      mapPartitionsWithStateStore(
         checkpointLocation,
         operatorId,
         storeVersion,
         keySchema,
         valueSchema,
         new StateStoreConf(sqlContext.conf),
-        Some(sqlContext.streams.stateStoreCoordinator))
+        Some(sqlContext.streams.stateStoreCoordinator))(
+        storeUpdateFunction)
     }
 
     /** Map each partition of a RDD along with data in a [[StateStore]]. */
-    private[state] def mapPartitionWithStateStore[U: ClassTag](
-        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+    private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
         checkpointLocation: String,
         operatorId: Long,
         storeVersion: Long,
         keySchema: StructType,
         valueSchema: StructType,
         storeConf: StateStoreConf,
-        storeCoordinator: Option[StateStoreCoordinatorRef]
-      ): StateStoreRDD[T, U] = {
+        storeCoordinator: Option[StateStoreCoordinatorRef])(
+        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): 
StateStoreRDD[T, U] = {
       val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
       new StateStoreRDD(
         dataRDD,

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 0d58070..4b3091b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -17,12 +17,12 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.{expressions, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, 
SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.internal.SessionState
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -60,14 +60,13 @@ case class ScalarSubquery(
 }
 
 /**
- * Convert the subquery from logical plan into executed plan.
+ * Plans scalar subqueries from that are present in the given [[SparkPlan]].
  */
-case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] {
+case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
   def apply(plan: SparkPlan): SparkPlan = {
     plan.transformAllExpressions {
       case subquery: expressions.ScalarSubquery =>
-        val sparkPlan = 
sessionState.planner.plan(ReturnAnswer(subquery.query)).next()
-        val executedPlan = sessionState.prepareForExecution.execute(sparkPlan)
+        val executedPlan = new QueryExecution(sqlContext, 
subquery.plan).executedPlan
         ScalarSubquery(executedPlan, subquery.exprId)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 844f305..9cb356f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -84,10 +84,10 @@ abstract class Aggregator[-I, B, O] extends Serializable {
       implicit bEncoder: Encoder[B],
       cEncoder: Encoder[O]): TypedColumn[I, O] = {
     val expr =
-      new AggregateExpression(
+      AggregateExpression(
         TypedAggregateExpression(this),
         Complete,
-        false)
+        isDistinct = false)
 
     new TypedColumn[I, O](expr, encoderFor[O])
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index f7fdfac..cd3d254 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -86,20 +86,8 @@ private[sql] class SessionState(ctx: SQLContext) {
   /**
    * Planner that converts optimized logical plans to physical plans.
    */
-  lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, 
experimentalMethods)
-
-  /**
-   * Prepares a planned [[SparkPlan]] for execution by inserting shuffle 
operations and internal
-   * row format conversions as needed.
-   */
-  lazy val prepareForExecution = new RuleExecutor[SparkPlan] {
-    override val batches: Seq[Batch] = Seq(
-      Batch("Subquery", Once, PlanSubqueries(SessionState.this)),
-      Batch("Add exchange", Once, EnsureRequirements(conf)),
-      Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)),
-      Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf))
-    )
-  }
+  def planner: SparkPlanner =
+    new SparkPlanner(ctx.sparkContext, conf, 
experimentalMethods.extraStrategies)
 
   /**
    * An interface to register custom 
[[org.apache.spark.sql.util.QueryExecutionListener]]s

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index b5be7ef..550c3c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts {
     def apply[A : Encoder](data: A*): CheckAnswerRows = {
       val encoder = encoderFor[A]
       val toExternalRow = RowEncoder(encoder.schema)
-      CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))))
+      CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), 
false)
     }
 
-    def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows)
+    def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false)
   }
 
-  case class CheckAnswerRows(expectedAnswer: Seq[Row])
+  /**
+   * Checks to make sure that the current data stored in the sink matches the 
`expectedAnswer`.
+   * This operation automatically blocks until all added data has been 
processed.
+   */
+  object CheckLastBatch {
+    def apply[A : Encoder](data: A*): CheckAnswerRows = {
+      val encoder = encoderFor[A]
+      val toExternalRow = RowEncoder(encoder.schema)
+      CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), 
true)
+    }
+
+    def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true)
+  }
+
+  case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean)
       extends StreamAction with StreamMustBeRunning {
-    override def toString: String = s"CheckAnswer: 
${expectedAnswer.mkString(",")}"
+    override def toString: String = s"$operatorName: 
${expectedAnswer.mkString(",")}"
+    private def operatorName = if (lastOnly) "CheckLastBatch" else 
"CheckAnswer"
   }
 
   /** Stops the stream.  It must currently be running. */
@@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts {
          """.stripMargin
 
     def verify(condition: => Boolean, message: String): Unit = {
-      try {
-        Assertions.assert(condition)
-      } catch {
-        case NonFatal(e) =>
-          failTest(message, e)
+      if (!condition) {
+        failTest(message)
       }
     }
 
@@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts {
           case a: AddData =>
             awaiting.put(a.source, a.addData())
 
-          case CheckAnswerRows(expectedAnswer) =>
+          case CheckAnswerRows(expectedAnswer, lastOnly) =>
             verify(currentStream != null, "stream not running")
 
             // Block until all data added has been processed
@@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts {
               }
             }
 
-            val allData = try sink.allData catch {
+            val sparkAnswer = try if (lastOnly) sink.lastBatch else 
sink.allData catch {
               case e: Exception =>
                 failTest("Exception while getting data from sink", e)
             }
 
-            QueryTest.sameRows(expectedAnswer, allData).foreach {
+            QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach {
               error => failTest(error)
             }
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index ed0d3f5..3831874 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -231,10 +231,8 @@ object SparkPlanTest {
   }
 
   private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): 
Seq[Row] = {
-    // A very simple resolver to make writing tests easier. In contrast to the 
real resolver
-    // this is always case sensitive and does not try to handle scoping or 
complex type resolution.
-    val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute(
-      outputPlan transform {
+    val execution = new QueryExecution(sqlContext, null) {
+      override lazy val sparkPlan: SparkPlan = outputPlan transform {
         case plan: SparkPlan =>
           val inputMap = plan.children.flatMap(_.output).map(a => (a.name, 
a)).toMap
           plan transformExpressions {
@@ -243,8 +241,8 @@ object SparkPlanTest {
                 sys.error(s"Invalid Test: Cannot resolve $u given input 
$inputMap"))
           }
       }
-    )
-    resolvedPlan.executeCollectPublic().toSeq
+    }
+    execution.executedPlan.executeCollectPublic().toSeq
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to