[SPARK-14255][SQL] Streaming Aggregation This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression: - Partial Aggregation - Shuffle - Partial Merge (now there is at most 1 tuple per group) - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) - Partial Merge (now there is at most 1 tuple per group) - StateStoreSave (saves the tuple for the next batch) - Complete (output the current result of the aggregation)
The following refactoring was also performed to allow us to plug into existing code: - The get/put implementation is taken from #12013 - The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation` - The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup. - Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case. - The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes. Author: Michael Armbrust <[email protected]> Closes #12048 from marmbrus/statefulAgg. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0fc4aaa7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0fc4aaa7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0fc4aaa7 Branch: refs/heads/master Commit: 0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c Parents: 0b7d496 Author: Michael Armbrust <[email protected]> Authored: Fri Apr 1 15:15:16 2016 -0700 Committer: Michael Armbrust <[email protected]> Committed: Fri Apr 1 15:15:16 2016 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 9 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../spark/sql/catalyst/errors/package.scala | 7 +- .../expressions/aggregate/interfaces.scala | 37 ++++- .../catalyst/expressions/namedExpressions.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 14 +- .../spark/sql/catalyst/planning/patterns.scala | 73 +++++++++ .../spark/sql/catalyst/plans/PlanTest.scala | 3 + .../spark/sql/execution/QueryExecution.scala | 24 ++- .../apache/spark/sql/execution/SparkPlan.scala | 7 + .../spark/sql/execution/SparkPlanner.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 92 ++++------- .../org/apache/spark/sql/execution/Window.scala | 2 +- .../aggregate/TungstenAggregationIterator.scala | 4 +- .../spark/sql/execution/aggregate/utils.scala | 121 ++++++++++++--- .../streaming/IncrementalExecution.scala | 72 +++++++++ .../execution/streaming/StatefulAggregate.scala | 119 +++++++++++++++ .../execution/streaming/StreamExecution.scala | 12 +- .../spark/sql/execution/streaming/memory.scala | 4 +- .../state/HDFSBackedStateStoreProvider.scala | 36 +++-- .../execution/streaming/state/StateStore.scala | 19 ++- .../streaming/state/StateStoreConf.scala | 4 +- .../streaming/state/StateStoreRDD.scala | 17 +-- .../sql/execution/streaming/state/package.scala | 21 ++- .../apache/spark/sql/execution/subquery.scala | 11 +- .../spark/sql/expressions/Aggregator.scala | 4 +- .../spark/sql/internal/SessionState.scala | 16 +- .../scala/org/apache/spark/sql/StreamTest.scala | 36 +++-- .../spark/sql/execution/SparkPlanTest.scala | 10 +- .../streaming/state/StateStoreRDDSuite.scala | 152 +++++++++++-------- .../streaming/state/StateStoreSuite.scala | 61 ++++---- .../streaming/StreamingAggregationSuite.scala | 132 ++++++++++++++++ .../spark/sql/hive/HiveSessionState.scala | 5 +- 33 files changed, 827 insertions(+), 305 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d82ee3a..05e2b9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -336,6 +336,11 @@ class Analyzer( Last(ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) + }.transform { + // We are duplicating aggregates that are now computing a different value for each + // pivot value. + // TODO: Don't construct the physical container until after analysis. + case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } if (filteredAggregate.fastEquals(aggregate)) { throw new AnalysisException( @@ -1153,11 +1158,11 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression( - AggregateExpression(function, mode, isDistinct), + ae @ AggregateExpression(function, _, _, _), spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - val newAgg = AggregateExpression(newFunction, mode, isDistinct) + val newAgg = ae.copy(aggregateFunction = newFunction) seenWindowAggregates += newAgg WindowExpression(newAgg, spec) http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1d1e892..4880502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -76,7 +76,7 @@ trait CheckAnalysis { case g: GroupingID => failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") - case w @ WindowExpression(AggregateExpression(_, _, true), _) => + case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0d44d1d..0420b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode package object errors { class TreeNodeException[TreeType <: TreeNode[_]]( - tree: TreeType, msg: String, cause: Throwable) + @transient val tree: TreeType, + msg: String, + cause: Throwable) extends Exception(msg, cause) { + val treeString = tree.toString + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT // external project dependencies for some reason. def this(tree: TreeType, msg: String) = this(tree, msg, null) override def getMessage: String = { - val treeString = tree.toString s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" } } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ff3064a..d31ccf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ @@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable { override def children: Seq[Expression] = Nil } +object AggregateExpression { + def apply( + aggregateFunction: AggregateFunction, + mode: AggregateMode, + isDistinct: Boolean): AggregateExpression = { + AggregateExpression( + aggregateFunction, + mode, + isDistinct, + NamedExpression.newExprId) + } +} + /** * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. @@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable { private[sql] case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) + isDistinct: Boolean, + resultId: ExprId) extends Expression with Unevaluable { + lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) { + AttributeReference( + aggregateFunction.toString, + aggregateFunction.dataType, + aggregateFunction.nullable)(exprId = resultId) + } else { + // This is a bit of a hack. Really we should not be constructing this container and reasoning + // about datatypes / aggregation mode until after we have finished analysis and made it to + // planning. + UnresolvedAttribute(aggregateFunction.toString) + } + + // We compute the same thing regardless of our final result. + override lazy val canonicalized: Expression = + AggregateExpression( + aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + mode, + isDistinct, + ExprId(0)) + override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType override def foldable: Boolean = false http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 262582c..2307122 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -329,7 +329,7 @@ case class PrettyAttribute( override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def nullable: Boolean = throw new UnsupportedOperationException + override def nullable: Boolean = true } object VirtualColumn { http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a7a948e..326933e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => + case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. - AggregateExpression(Count(Literal(1)), mode, isDistinct = false) + ae.copy(aggregateFunction = Count(Literal(1))) // For Coalesce, remove null literals. case e @ Coalesce(children) => @@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c92707..28d2c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -216,3 +217,75 @@ object IntegerIndex { case _ => None } } + +/** + * An extractor used when planning the physical execution of an aggregation. Compared with a logical + * aggregation, the following transformations are performed: + * - Unnamed grouping expressions are named so that they can be referred to across phases of + * aggregation + * - Aggregations that appear multiple times are deduplicated. + * - The compution of the aggregations themselves is separated from the final result. For example, + * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final + * computation that computes `count.resultAttribute + 1`. + */ +object PhysicalAggregation { + // groupingExpressions, aggregateExpressions, resultExpressions, child + type ReturnType = + (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case ae: AggregateExpression => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + ae.resultAttribute + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + Some(( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + rewrittenResultExpressions, + child)) + + case _ => None + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index aa5d433..7191936 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ @@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 912b84a..4843553 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { + // TODO: Move the planner an optimizer into here from SessionState. + protected def planner = sqlContext.sessionState.planner + def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) @@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next() + planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + /** + * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal + * row format conversions as needed. + */ + protected def prepareForExecution(plan: SparkPlan): SparkPlan = { + preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } + + /** A sequence of rules that will be applied in order to the physical plan before execution. */ + protected def preparations: Seq[Rule[SparkPlan]] = Seq( + PlanSubqueries(sqlContext), + EnsureRequirements(sqlContext.conf), + CollapseCodegenStages(sqlContext.conf), + ReuseExchange(sqlContext.conf)) + protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 010ed7f..b1b3d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan { override def producedAttributes: AttributeSet = outputSet } +object UnaryNode { + def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match { + case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head)) + case _ => None + } +} + private[sql] trait UnaryNode extends SparkPlan { def child: SparkPlan http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 9da2c74..ac8072f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -26,13 +26,13 @@ import org.apache.spark.sql.internal.SQLConf class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, - val experimentalMethods: ExperimentalMethods) + val extraStrategies: Seq[Strategy]) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - experimentalMethods.extraStrategies ++ ( + extraStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7a2e2b7..5bcc172 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -204,28 +203,32 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** + * Used to plan aggregation queries that are computed incrementally as part of a + * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner + * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] + */ + object StatefulAggregationStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalAggregation( + namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + + aggregate.Utils.planStreamingAggregation( + namedGroupingExpressions, + aggregateExpressions, + rewrittenResultExpressions, + planLater(child)) + + case _ => Nil + } + } + + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, resultExpressions, child) => - // A single aggregate expression might appear multiple times in resultExpressions. - // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression => agg - } - }.distinct - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionToAttribute = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> attribute - }.toMap + case PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) => val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -233,41 +236,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our MultipleDistinctRewriter should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + - "Spark user mailing list.") - } - - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - - // The original `resultExpressions` are a set of expressions which may reference - // aggregate expressions, grouping column values, and constants. When aggregate operator - // emits output rows, we will use `resultExpressions` to generate an output projection - // which takes the grouping columns and final aggregate result buffer as input. - // Thus, we must re-write the result expressions so that their attributes match up with - // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case AggregateExpression(aggregateFunction, _, isDistinct) => - // The final aggregation buffer's attributes will be `finalAggregationAttributes`, - // so replace each aggregate expression by its corresponding attribute in the set: - aggregateFunctionToAttribute(aggregateFunction, isDistinct) - case expression => - // Since we're using `namedGroupingAttributes` to extract the grouping key - // columns, we need to replace grouping key expressions with their corresponding - // attributes. We do not rely on the equality check at here since attributes may - // differ cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + "Spark user mailing list.") } val aggregateOperator = @@ -277,26 +246,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "aggregate functions which don't support partial aggregation.") } else { aggregate.Utils.planAggregateWithoutPartial( - namedGroupingExpressions.map(_._2), + groupingExpressions, aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { aggregate.Utils.planAggregateWithoutDistinct( - namedGroupingExpressions.map(_._2), + groupingExpressions, aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } else { aggregate.Utils.planAggregateWithOneDistinct( - namedGroupingExpressions.map(_._2), + groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 270c09a..7acf020 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -177,7 +177,7 @@ case class Window( case e @ WindowExpression(function, spec) => val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] function match { - case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f) + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) case f => sys.error(s"Unsupported window function: $f") http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 213bca9..ce504e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -242,9 +242,9 @@ class TungstenAggregationIterator( // Basically the value of the KVIterator returned by externalSorter // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. val newExpressions = aggregateExpressions.map { - case agg @ AggregateExpression(_, Partial, _) => + case agg @ AggregateExpression(_, Partial, _, _) => agg.copy(mode = PartialMerge) - case agg @ AggregateExpression(_, Complete, _) => + case agg @ AggregateExpression(_, Complete, _, _) => agg.copy(mode = Final) case other => other } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 1e113cc..4682949 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -29,15 +30,11 @@ object Utils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } - + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingExpressions), groupingExpressions = groupingExpressions, @@ -83,7 +80,6 @@ object Utils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. @@ -111,9 +107,7 @@ object Utils { val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), @@ -131,7 +125,6 @@ object Utils { groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], functionsWithoutDistinct: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -151,9 +144,7 @@ object Utils { // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val aggregateAttributes = aggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. @@ -169,9 +160,7 @@ object Utils { // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val aggregateAttributes = aggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), @@ -190,7 +179,7 @@ object Utils { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true) => + case agg @ AggregateExpression(aggregateFunction, mode, true, _) => aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] } @@ -199,9 +188,7 @@ object Utils { val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val mergeAggregateAttributes = mergeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) val (distinctAggregateExpressions, distinctAggregateAttributes) = rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because @@ -211,7 +198,7 @@ object Utils { val expr = AggregateExpression(func, Partial, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute - val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip @@ -232,9 +219,7 @@ object Utils { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) val (distinctAggregateExpressions, distinctAggregateAttributes) = rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => @@ -245,7 +230,7 @@ object Utils { val expr = AggregateExpression(func, Final, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute - val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip @@ -261,4 +246,90 @@ object Utils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming aggregation using the following progression: + * - Partial Aggregation + * - Shuffle + * - Partial Merge (now there is at most 1 tuple per group) + * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) + * - PartialMerge (now there is at most 1 tuple per group) + * - StateStoreSave (saves the tuple for the next batch) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregation( + groupingExpressions: Seq[NamedExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + val restored = StateStoreRestore(groupingAttributes, None, partialMerged1) + + val partialMerged2: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored) + } + + val saved = StateStoreSave(groupingAttributes, None, partialMerged2) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala new file mode 100644 index 0000000..aaced49 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode} + +/** + * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] + * plan incrementally. Possibly preserving state in between each execution. + */ +class IncrementalExecution( + ctx: SQLContext, + logicalPlan: LogicalPlan, + checkpointLocation: String, + currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) { + + // TODO: make this always part of planning. + val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil + + // Modified planner with stateful operations. + override def planner: SparkPlanner = + new SparkPlanner( + sqlContext.sparkContext, + sqlContext.conf, + stateStrategy) + + /** + * Records the current id for a given stateful operator in the query plan as the `state` + * preperation walks the query plan. + */ + private var operatorId = 0 + + /** Locates save/restore pairs surrounding aggregation. */ + val state = new Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan transform { + case StateStoreSave(keys, None, + UnaryNode(agg, + StateStoreRestore(keys2, None, child))) => + val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1) + operatorId += 1 + + StateStoreSave( + keys, + Some(stateId), + agg.withNewChildren( + StateStoreRestore( + keys, + Some(stateId), + child) :: Nil)) + } + } + + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations +} http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala new file mode 100644 index 0000000..5957747 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.SparkPlan + +/** Used to identify the state store for a given operator. */ +case class OperatorStateId( + checkpointLocation: String, + operatorId: Long, + batchId: Long) + +/** + * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should + * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + */ +trait StatefulOperator extends SparkPlan { + def stateId: Option[OperatorStateId] + + protected def getStateId: OperatorStateId = attachTree(this) { + stateId.getOrElse { + throw new IllegalStateException("State location not present for execution") + } + } +} + +/** + * For each input tuple, the key is calculated and the value from the [[StateStore]] is added + * to the stream (in addition to the input tuple) if present. + */ +case class StateStoreRestore( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + row +: savedState.toSeq + } + } + } + override def output: Seq[Attribute] = child.output +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class StateStoreSave( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + new Iterator[InternalRow] { + private[this] val baseIterator = iter + private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + row + } + } + } + } + + override def output: Seq[Attribute] = child.output +} http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index c4e410d..511e30c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -272,6 +273,8 @@ class StreamExecution( private def runBatch(): Unit = { val startTime = System.nanoTime() + // TODO: Move this to IncrementalExecution. + // Request unprocessed data from all sources. val newData = availableOffsets.flatMap { case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => @@ -305,13 +308,14 @@ class StreamExecution( } val optimizerStart = System.nanoTime() - - lastExecution = new QueryExecution(sqlContext, newPlan) - val executedPlan = lastExecution.executedPlan + lastExecution = + new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId) + lastExecution.executedPlan val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 logDebug(s"Optimized batch in ${optimizerTime}ms") - val nextBatch = Dataset.ofRows(sqlContext, newPlan) + val nextBatch = + new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema)) sink.addBatch(currentBatchId - 1, nextBatch) awaitBatchLock.synchronized { http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 0f91e59..7d97f81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -108,7 +108,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(schema: StructType) extends Sink with Logging { +class MemorySink(val schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ private val batches = new ArrayBuffer[Array[Row]]() @@ -117,6 +117,8 @@ class MemorySink(schema: StructType) extends Sink with Logging { batches.flatten } + def lastBatch: Seq[Row] = batches.last + def toDebugString: String = synchronized { batches.zipWithIndex.map { case (b, i) => val dataStr = try b.mkString(" ") catch { http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ee015ba..998eb82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -81,7 +81,7 @@ private[state] class HDFSBackedStateStoreProvider( trait STATE case object UPDATING extends STATE case object COMMITTED extends STATE - case object CANCELLED extends STATE + case object ABORTED extends STATE private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") @@ -94,15 +94,14 @@ private[state] class HDFSBackedStateStoreProvider( override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id - /** - * Update the value of a key using the value generated by the update function. - * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous - * versions of the store data. - */ - override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot update after already committed or cancelled") - val oldValueOption = Option(mapToUpdate.get(key)) - val value = updateFunc(oldValueOption) + override def get(key: UnsafeRow): Option[UnsafeRow] = { + Option(mapToUpdate.get(key)) + } + + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or cancelled") + + val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) Option(allUpdates.get(key)) match { @@ -115,8 +114,7 @@ private[state] class HDFSBackedStateStoreProvider( case None => // There was no prior update, so mark this as added or updated according to its presence // in previous version. - val update = - if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) + val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value) allUpdates.put(key, update) } writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) @@ -148,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { - verify(state == UPDATING, "Cannot commit again after already committed or cancelled") + verify(state == UPDATING, "Cannot commit after already committed or cancelled") try { finalizeDeltaFile(tempDeltaFileStream) @@ -164,8 +162,8 @@ private[state] class HDFSBackedStateStoreProvider( } /** Cancel all the updates made on this store. This store will not be usable any more. */ - override def cancel(): Unit = { - state = CANCELLED + override def abort(): Unit = { + state = ABORTED if (tempDeltaFileStream != null) { tempDeltaFileStream.close() } @@ -176,8 +174,8 @@ private[state] class HDFSBackedStateStoreProvider( } /** - * Get an iterator of all the store data. This can be called only after committing the - * updates. + * Get an iterator of all the store data. + * This can be called only after committing all the updates made in the current thread. */ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { verify(state == COMMITTED, "Cannot get iterator of store data before comitting") @@ -186,7 +184,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Get an iterator of all the updates made to the store in the current version. - * This can be called only after committing the updates. + * This can be called only after committing all the updates made in the current thread. */ override def updates(): Iterator[StoreUpdate] = { verify(state == COMMITTED, "Cannot get iterator of updates before committing") @@ -196,7 +194,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Whether all updates have been committed */ - override def hasCommitted: Boolean = { + override private[state] def hasCommitted: Boolean = { state == COMMITTED } } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ca5c864..d60e618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -47,12 +47,11 @@ trait StateStore { /** Version of the data in this store before committing updates. */ def version: Long - /** - * Update the value of a key using the value generated by the update function. - * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous - * versions of the store data. - */ - def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit + /** Get the current value of a key. */ + def get(key: UnsafeRow): Option[UnsafeRow] + + /** Put a new value for a key. */ + def put(key: UnsafeRow, value: UnsafeRow) /** * Remove keys that match the following condition. @@ -65,24 +64,24 @@ trait StateStore { def commit(): Long /** Cancel all the updates that have been made to the store. */ - def cancel(): Unit + def abort(): Unit /** * Iterator of store data after a set of updates have been committed. - * This can be called only after commitUpdates() has been called in the current thread. + * This can be called only after committing all the updates made in the current thread. */ def iterator(): Iterator[(UnsafeRow, UnsafeRow)] /** * Iterator of the updates that have been committed. - * This can be called only after commitUpdates() has been called in the current thread. + * This can be called only after committing all the updates made in the current thread. */ def updates(): Iterator[StoreUpdate] /** * Whether all updates have been committed */ - def hasCommitted: Boolean + private[state] def hasCommitted: Boolean } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index cca22a0..f0f1f3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { def this() = this(new SQLConf) @@ -31,7 +31,7 @@ private[state] class StateStoreConf(@transient private val conf: SQLConf) extend val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) } -private[state] object StateStoreConf { +private[streaming] object StateStoreConf { val empty = new StateStoreConf() } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 3318660..df3d82c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - - Utils.tryWithSafeFinally { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) - store = StateStore.get( - storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) - val inputIter = dataRDD.iterator(partition, ctxt) - val outputIter = storeUpdateFunction(store, inputIter) - assert(store.hasCommitted) - outputIter - } { - if (store != null) store.cancel() - } + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + store = StateStore.get( + storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + val inputIter = dataRDD.iterator(partition, ctxt) + storeUpdateFunction(store, inputIter) } } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b249e37..9b6d091 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -28,37 +28,36 @@ package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { /** Map each partition of a RDD along with data in a [[StateStore]]. */ - def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + def mapPartitionsWithStateStore[U: ClassTag]( + sqlContext: SQLContext, checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, - valueSchema: StructType - )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + valueSchema: StructType)( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { - mapPartitionWithStateStore( - storeUpdateFunction, + mapPartitionsWithStateStore( checkpointLocation, operatorId, storeVersion, keySchema, valueSchema, new StateStoreConf(sqlContext.conf), - Some(sqlContext.streams.stateStoreCoordinator)) + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) } /** Map each partition of a RDD along with data in a [[StateStore]]. */ - private[state] def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, valueSchema: StructType, storeConf: StateStoreConf, - storeCoordinator: Option[StateStoreCoordinatorRef] - ): StateStoreRDD[T, U] = { + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( dataRDD, http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 0d58070..4b3091b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.DataType /** @@ -60,14 +60,13 @@ case class ScalarSubquery( } /** - * Convert the subquery from logical plan into executed plan. + * Plans scalar subqueries from that are present in the given [[SparkPlan]]. */ -case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] { +case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next() - val executedPlan = sessionState.prepareForExecution.execute(sparkPlan) + val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan ScalarSubquery(executedPlan, subquery.exprId) } } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 844f305..9cb356f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -84,10 +84,10 @@ abstract class Aggregator[-I, B, O] extends Serializable { implicit bEncoder: Encoder[B], cEncoder: Encoder[O]): TypedColumn[I, O] = { val expr = - new AggregateExpression( + AggregateExpression( TypedAggregateExpression(this), Complete, - false) + isDistinct = false) new TypedColumn[I, O](expr, encoderFor[O]) } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index f7fdfac..cd3d254 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -86,20 +86,8 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Planner that converts optimized logical plans to physical plans. */ - lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) - - /** - * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal - * row format conversions as needed. - */ - lazy val prepareForExecution = new RuleExecutor[SparkPlan] { - override val batches: Seq[Batch] = Seq( - Batch("Subquery", Once, PlanSubqueries(SessionState.this)), - Batch("Add exchange", Once, EnsureRequirements(conf)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)), - Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf)) - ) - } + def planner: SparkPlanner = + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index b5be7ef..550c3c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] val toExternalRow = RowEncoder(encoder.schema) - CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d)))) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) } - def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows) + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false) } - case class CheckAnswerRows(expectedAnswer: Seq[Row]) + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks until all added data has been processed. + */ + object CheckLastBatch { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } /** Stops the stream. It must currently be running. */ @@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts { """.stripMargin def verify(condition: => Boolean, message: String): Unit = { - try { - Assertions.assert(condition) - } catch { - case NonFatal(e) => - failTest(message, e) + if (!condition) { + failTest(message) } } @@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts { case a: AddData => awaiting.put(a.source, a.addData()) - case CheckAnswerRows(expectedAnswer) => + case CheckAnswerRows(expectedAnswer, lastOnly) => verify(currentStream != null, "stream not running") // Block until all data added has been processed @@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts { } } - val allData = try sink.allData catch { + val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } - QueryTest.sameRows(expectedAnswer, allData).foreach { + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } } http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index ed0d3f5..3831874 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -231,10 +231,8 @@ object SparkPlanTest { } private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( - outputPlan transform { + val execution = new QueryExecution(sqlContext, null) { + override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan transformExpressions { @@ -243,8 +241,8 @@ object SparkPlanTest { sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } - ) - resolvedPlan.executeCollectPublic().toSeq + } + execution.executedPlan.executeCollectPublic().toSeq } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
