[SPARK-15728][SQL] Rename aggregate operators: HashAggregate and SortAggregate
## What changes were proposed in this pull request? We currently have two physical aggregate operators: TungstenAggregate and SortBasedAggregate. These names don't make a lot of sense from an end-user point of view. This patch renames them HashAggregate and SortAggregate. ## How was this patch tested? Updated test cases. Author: Reynold Xin <[email protected]> Closes #13465 from rxin/SPARK-15728. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8900c8d8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8900c8d8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8900c8d8 Branch: refs/heads/master Commit: 8900c8d8ff1614b5ec5a2ce213832fa13462b4d4 Parents: 09b3c56 Author: Reynold Xin <[email protected]> Authored: Thu Jun 2 12:34:51 2016 -0700 Committer: Reynold Xin <[email protected]> Committed: Thu Jun 2 12:34:51 2016 -0700 ---------------------------------------------------------------------- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../sql/execution/WholeStageCodegenExec.scala | 4 +- .../sql/execution/aggregate/AggUtils.scala | 337 ++++++++ .../execution/aggregate/HashAggregateExec.scala | 788 +++++++++++++++++++ .../execution/aggregate/SortAggregateExec.scala | 114 +++ .../aggregate/SortBasedAggregateExec.scala | 111 --- .../execution/aggregate/TungstenAggregate.scala | 785 ------------------ .../spark/sql/execution/aggregate/utils.scala | 337 -------- .../org/apache/spark/sql/DataFrameSuite.scala | 6 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../sql/execution/WholeStageCodegenSuite.scala | 8 +- .../sql/execution/metric/SQLMetricsSuite.scala | 14 +- 12 files changed, 1258 insertions(+), 1256 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8900c8d8/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7e3e45e..9610506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -210,7 +210,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - aggregate.Utils.planStreamingAggregation( + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, @@ -243,20 +243,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sys.error("Distinct columns cannot exist in Aggregate operator containing " + "aggregate functions which don't support partial aggregation.") } else { - aggregate.Utils.planAggregateWithoutPartial( + aggregate.AggUtils.planAggregateWithoutPartial( groupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { - aggregate.Utils.planAggregateWithoutDistinct( + aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } else { - aggregate.Utils.planAggregateWithOneDistinct( + aggregate.AggUtils.planAggregateWithOneDistinct( groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, http://git-wip-us.apache.org/repos/asf/spark/blob/8900c8d8/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index cd9ba7c..d3e8d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf @@ -37,7 +37,7 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { - case _: TungstenAggregate => "agg" + case _: HashAggregateExec => "agg" case _: BroadcastHashJoinExec => "bhj" case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" http://git-wip-us.apache.org/repos/asf/spark/blob/8900c8d8/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala new file mode 100644 index 0000000..a9ec0c8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object AggUtils { + + def planAggregateWithoutPartial( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) + SortAggregateExec( + requiredChildDistributionExpressions = Some(groupingExpressions), + groupingExpressions = groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = 0, + resultExpressions = resultExpressions, + child = child + ) :: Nil + } + + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val useHash = HashAggregateExec.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (useHash) { + HashAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + + def planAggregateWithoutDistinct( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + // Check if we can use TungstenAggregate. + + // 1. Create an Aggregate Operator for partial aggregations. + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialResultExpressions = + groupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + + val partialAggregate = createAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = groupingExpressions, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, + initialInputBufferOffset = 0, + resultExpressions = partialResultExpressions, + child = child) + + // 2. Create an Aggregate Operator for final aggregations. + val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + val finalAggregate = createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, + child = partialAggregate) + + finalAggregate :: Nil + } + + def planAggregateWithOneDistinct( + groupingExpressions: Seq[NamedExpression], + functionsWithDistinct: Seq[AggregateExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one + // DISTINCT aggregate function, all of those functions will have the same column expressions. + // For example, it would be valid for functionsWithDistinct to be + // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is + // disallowed because those two distinct aggregates have different column expressions. + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctExpressions = distinctExpressions.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + // 1. Create an Aggregate Operator for partial aggregations. + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + // 2. Create an Aggregate Operator for partial merge aggregations. + val partialMergeAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes ++ distinctAttributes), + groupingExpressions = groupingAttributes ++ distinctAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + // 3. Create an Aggregate operator for partial aggregation (for distinct) + val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap + val rewrittenDistinctFunctions = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression(aggregateFunction, mode, true, _) => + aggregateFunction.transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] + } + + val partialDistinctAggregate: SparkPlan = { + val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Partial, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = functionsWithDistinct(i).resultAttribute + (expr, attr) + }.unzip + + val partialAggregateResult = groupingAttributes ++ + mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = partialAggregateResult, + child = partialMergeAggregate) + } + + // 4. Create an Aggregate Operator for the final aggregation. + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Final, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = functionsWithDistinct(i).resultAttribute + (expr, attr) + }.unzip + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = partialDistinctAggregate) + } + + finalAndCompleteAggregate :: Nil + } + + /** + * Plans a streaming aggregation using the following progression: + * - Partial Aggregation + * - Shuffle + * - Partial Merge (now there is at most 1 tuple per group) + * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) + * - PartialMerge (now there is at most 1 tuple per group) + * - StateStoreSave (saves the tuple for the next batch) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregation( + groupingExpressions: Seq[NamedExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) + + val partialMerged2: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored) + } + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = StateStoreSaveExec( + groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/8900c8d8/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala new file mode 100644 index 0000000..fad81b5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -0,0 +1,788 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.unsafe.KVIterator + +/** + * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. + */ +case class HashAggregateExec( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode with CodegenSupport { + + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) + + override lazy val allAttributes: Seq[Attribute] = + child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + + override private[sql] lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time")) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash + // map and/or the sort-based aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[(Int, Int)] = { + sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => + val splits = fallbackStartsAt.split(",").map(_.trim) + Some((splits.head.toInt, splits.last.toInt)) + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numOutputRows = longMetric("numOutputRows") + val peakMemory = longMetric("peakMemory") + val spillSize = longMetric("spillSize") + + child.execute().mapPartitions { iter => + + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty + } else { + val aggregationIterator = + new TungstenAggregationIterator( + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), + child.output, + iter, + testFallbackStartsAt, + numOutputRows, + peakMemory, + spillSize) + if (!hasInput && groupingExpressions.isEmpty) { + numOutputRows += 1 + Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) + } else { + aggregationIterator + } + } + } + } + + // all the mode of aggregate expressions + private val modes = aggregateExpressions.map(_.mode).distinct + + override def usedInputs: AttributeSet = inputSet + + override def supportCodegen: Boolean = { + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + protected override def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, input) + } else { + doConsumeWithKeys(ctx, input) + } + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private def doProduceWithoutKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // generate variables for aggregation buffer + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + bufVars = initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") + // The initial expression should not access any column + val ev = e.genCode(ctx) + val initVars = s""" + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + val initBufVar = evaluateVariables(bufVars) + + // generate variables for output + val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) + } + val evaluateAggResults = evaluateVariables(aggResults) + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) + } + (resultVars, s""" + |$evaluateAggResults + |${evaluateVariables(resultVars)} + """.stripMargin) + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // output the aggregate buffer directly + (bufVars, "") + } else { + // no aggregate function, the result should be literals + val resultVars = resultExpressions.map(_.genCode(ctx)) + (resultVars, evaluateVariables(resultVars)) + } + + val doAgg = ctx.freshName("doAggregateWithoutKey") + ctx.addNewFunction(doAgg, + s""" + | private void $doAgg() throws java.io.IOException { + | // initialize aggregation buffer + | $initBufVar + | + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin) + + val numOutput = metricTerm(ctx, "numOutputRows") + val aggTime = metricTerm(ctx, "aggTime") + val beforeAgg = ctx.freshName("beforeAgg") + s""" + | while (!$initAgg) { + | $initAgg = true; + | long $beforeAgg = System.nanoTime(); + | $doAgg(); + | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); + | + | // output the result + | ${genResult.trim} + | + | $numOutput.add(1); + | ${consume(ctx, resultVars).trim} + | } + """.stripMargin + } + + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // only have DeclarativeAggregate + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + ctx.currentVars = bufVars ++ input + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + // aggregate buffer should be updated atomic + val updates = aggVals.zipWithIndex.map { case (ev, i) => + s""" + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin + } + s""" + | // do aggregate + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(aggVals)} + | // update aggregation buffer + | ${updates.mkString("\n").trim} + """.stripMargin + } + + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + + // The name for Vectorized HashMap + private var vectorizedHashMapTerm: String = _ + private var isVectorizedHashMapEnabled: Boolean = _ + + // The name for UnsafeRow HashMap + private var hashMapTerm: String = _ + private var sorterTerm: String = _ + + /** + * This is called by generated Java class, should be public. + */ + def createHashMap(): UnsafeFixedWidthAggregationMap = { + // create initialized aggregate buffer + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + + // create hashMap + new UnsafeFixedWidthAggregationMap( + initialBuffer, + bufferSchema, + groupingKeySchema, + TaskContext.get().taskMemoryManager(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes, + false // disable tracking of performance metrics + ) + } + + /** + * This is called by generated Java class, should be public. + */ + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + /** + * Called by generated Java class to finish the aggregate and return a KVIterator. + */ + def finishAggregate( + hashMap: UnsafeFixedWidthAggregationMap, + sorter: UnsafeKVExternalSorter, + peakMemory: SQLMetric, + spillSize: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { + + // update peak execution memory + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val maxMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() + peakMemory.add(maxMemory) + metrics.incPeakExecutionMemory(maxMemory) + + if (sorter == null) { + // not spilled + return hashMap.iterator() + } + + // merge the final hashMap into sorter + sorter.merge(hashMap.destructAndCreateExternalSorter()) + hashMap.free() + val sortedIter = sorter.sortedIterator() + + // Create a KVIterator based on the sorted iterator. + new KVIterator[UnsafeRow, UnsafeRow] { + + // Create a MutableProjection to merge the rows of same key together + val mergeExpr = declFunctions.flatMap(_.mergeExpressions) + val mergeProjection = newMutableProjection( + mergeExpr, + aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + subexpressionEliminationEnabled) + val joinedRow = new JoinedRow() + + var currentKey: UnsafeRow = null + var currentRow: UnsafeRow = null + var nextKey: UnsafeRow = if (sortedIter.next()) { + sortedIter.getKey + } else { + null + } + + override def next(): Boolean = { + if (nextKey != null) { + currentKey = nextKey.copy() + currentRow = sortedIter.getValue.copy() + nextKey = null + // use the first row as aggregate buffer + mergeProjection.target(currentRow) + + // merge the following rows with same key together + var findNextGroup = false + while (!findNextGroup && sortedIter.next()) { + val key = sortedIter.getKey + if (currentKey.equals(key)) { + mergeProjection(joinedRow(currentRow, sortedIter.getValue)) + } else { + // We find a new group. + findNextGroup = true + nextKey = key + } + } + + true + } else { + spillSize.add(sorter.getSpillSize) + false + } + } + + override def getKey: UnsafeRow = currentKey + override def getValue: UnsafeRow = currentRow + override def close(): Unit = { + sortedIter.close() + } + } + } + + /** + * Generate the code for output. + */ + private def generateResultCode( + ctx: CodegenContext, + keyTerm: String, + bufferTerm: String, + plan: String): String = { + if (modes.contains(Final) || modes.contains(Complete)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).genCode(ctx) + } + val evaluateKeyVars = evaluateVariables(keyVars) + ctx.INPUT_ROW = bufferTerm + val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).genCode(ctx) + } + val evaluateBufferVars = evaluateVariables(bufferVars) + // evaluate the aggregation result + ctx.currentVars = bufferVars + val aggResults = declFunctions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) + } + val evaluateAggResults = evaluateVariables(aggResults) + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).genCode(ctx) + } + s""" + $evaluateKeyVars + $evaluateBufferVars + $evaluateAggResults + ${consume(ctx, resultVars)} + """ + + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val joinerTerm = ctx.freshName("unsafeRowJoiner") + ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + s"$joinerTerm = $plan.createUnsafeJoiner();") + val resultRow = ctx.freshName("resultRow") + s""" + UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + ${consume(ctx, null, resultRow)} + """ + + } else { + // generate result based on grouping key + ctx.INPUT_ROW = keyTerm + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).genCode(ctx) + } + consume(ctx, eval) + } + } + + /** + * Using the vectorized hash map in TungstenAggregate is currently supported for all primitive + * data types during partial aggregation. However, we currently only enable the hash map for a + * subset of cases that've been verified to show performance improvements on our benchmarks + * subject to an internal conf that sets an upper limit on the maximum length of the aggregate + * key/value schema. + * + * This list of supported use-cases should be expanded over time. + */ + private def enableVectorizedHashMap(ctx: CodegenContext): Boolean = { + val schemaLength = (groupingKeySchema ++ bufferSchema).length + val isSupported = + (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || + f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && + bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) + + // We do not support byte array based decimal type for aggregate values as + // ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place + // updates. Due to this, appending the byte array in the vectorized hash map can turn out to be + // quite inefficient and can potentially OOM the executor. + val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType]) + .forall(!DecimalType.isByteArrayDecimalType(_)) + + isSupported && isNotByteArrayDecimalType && + schemaLength <= sqlContext.conf.vectorizedAggregateMapMaxColumns + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + isVectorizedHashMapEnabled = enableVectorizedHashMap(ctx) + vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap") + val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap") + val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, aggregateExpressions, + vectorizedHashMapClassName, groupingKeySchema, bufferSchema) + // Create a name for iterator from vectorized HashMap + val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter") + if (isVectorizedHashMapEnabled) { + ctx.addMutableState(vectorizedHashMapClassName, vectorizedHashMapTerm, + s"$vectorizedHashMapTerm = new $vectorizedHashMapClassName();") + ctx.addMutableState( + "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>", + iterTermForVectorizedHashMap, "") + } + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, "") + sorterTerm = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + + val doAgg = ctx.freshName("doAggregateWithKeys") + val peakMemory = metricTerm(ctx, "peakMemory") + val spillSize = metricTerm(ctx, "spillSize") + ctx.addNewFunction(doAgg, + s""" + ${if (isVectorizedHashMapEnabled) vectorizedHashMapGenerator.generate() else ""} + private void $doAgg() throws java.io.IOException { + $hashMapTerm = $thisPlan.createHashMap(); + ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + + ${if (isVectorizedHashMapEnabled) { + s"$iterTermForVectorizedHashMap = $vectorizedHashMapTerm.rowIterator();"} else ""} + + $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize); + } + """) + + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + val numOutput = metricTerm(ctx, "numOutputRows") + + // The child could change `copyResult` to true, but we had already consumed all the rows, + // so `copyResult` should be reset to `false`. + ctx.copyResult = false + + // Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow + def outputFromGeneratedMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + val row = ctx.freshName("vectorizedHashMapRow") + ctx.currentVars = null + ctx.INPUT_ROW = row + var schema: StructType = groupingKeySchema + bufferSchema.foreach(i => schema = schema.add(i)) + val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex + .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }) + Option( + s""" + | while ($iterTermForVectorizedHashMap.hasNext()) { + | $numOutput.add(1); + | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = + | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) + | $iterTermForVectorizedHashMap.next(); + | ${generateRow.code} + | ${consume(ctx, Seq.empty, {generateRow.value})} + | + | if (shouldStop()) return; + | } + | + | $vectorizedHashMapTerm.close(); + """.stripMargin) + } else None + } + + val aggTime = metricTerm(ctx, "aggTime") + val beforeAgg = ctx.freshName("beforeAgg") + s""" + if (!$initAgg) { + $initAgg = true; + long $beforeAgg = System.nanoTime(); + $doAgg(); + $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); + } + + // output the result + ${outputFromGeneratedMap.getOrElse("")} + + while ($iterTerm.next()) { + $numOutput.add(1); + UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + $outputCode + + if (shouldStop()) return; + } + + $iterTerm.close(); + if ($sorterTerm == null) { + $hashMapTerm.free(); + } + """ + } + + private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + + // create grouping key + ctx.currentVars = input + // make sure that the generated code will not be splitted as multiple functions + ctx.INPUT_ROW = null + val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val vectorizedRowKeys = ctx.generateExpressions( + groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val unsafeRowKeys = unsafeRowKeyCode.value + val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") + val vectorizedRowBuffer = ctx.freshName("vectorizedAggBuffer") + + // only have DeclarativeAggregate + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + + // generate hash code for key + val hashExpr = Murmur3Hash(groupingExpressions, 42) + ctx.currentVars = input + val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) + + val inputAttr = aggregateBufferAttributes ++ child.output + ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input + + val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, + incCounter) = if (testFallbackStartsAt.isDefined) { + val countTerm = ctx.freshName("fallbackCounter") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + (s"$countTerm < ${testFallbackStartsAt.get._1}", + s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") + } else { + ("true", "true", "", "") + } + + // We first generate code to probe and update the vectorized hash map. If the probe is + // successful the corresponding vectorized row buffer will hold the mutable row + val findOrInsertInVectorizedHashMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + Option( + s""" + |if ($checkFallbackForGeneratedHashMap) { + | ${vectorizedRowKeys.map(_.code).mkString("\n")} + | if (${vectorizedRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $vectorizedRowBuffer = $vectorizedHashMapTerm.findOrInsert( + | ${vectorizedRowKeys.map(_.value).mkString(", ")}); + | } + |} + """.stripMargin) + } else { + None + } + } + + val updateRowInVectorizedHashMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + ctx.INPUT_ROW = vectorizedRowBuffer + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val vectorizedRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable, + isVectorized = true) + } + Option( + s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate function + |${evaluateVariables(vectorizedRowEvals)} + |// update vectorized row + |${updateVectorizedRow.mkString("\n").trim} + """.stripMargin) + } else None + } + + // Next, we generate code to probe and update the unsafe row hash map. + val findOrInsertInUnsafeRowMap: String = { + s""" + | if ($vectorizedRowBuffer == null) { + | // generate grouping key + | ${unsafeRowKeyCode.code.trim} + | ${hashEval.code.trim} + | if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + | } + | if ($unsafeRowBuffer == null) { + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $resetCounter + | // the hash map had be spilled, it should have enough memory now, + | // try to allocate buffer again. + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + | if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new OutOfMemoryError("No enough memory for aggregation"); + | } + | } + | } + """.stripMargin + } + + val updateRowInUnsafeRowMap: String = { + ctx.INPUT_ROW = unsafeRowBuffer + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + } + s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate function + |${evaluateVariables(unsafeRowBufferEvals)} + |// update unsafe row buffer + |${updateUnsafeRowBuffer.mkString("\n").trim} + """.stripMargin + } + + + // We try to do hash map based in-memory aggregation first. If there is not enough memory (the + // hash map will return null for new key), we spill the hash map to disk to free memory, then + // continue to do in-memory aggregation and spilling until all the rows had been processed. + // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. + s""" + UnsafeRow $unsafeRowBuffer = null; + org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $vectorizedRowBuffer = null; + + ${findOrInsertInVectorizedHashMap.getOrElse("")} + + $findOrInsertInUnsafeRowMap + + $incCounter + + if ($vectorizedRowBuffer != null) { + // update vectorized row + ${updateRowInVectorizedHashMap.getOrElse("")} + } else { + // update unsafe row + $updateRowInUnsafeRowMap + } + """ + } + + override def simpleString: String = { + val allAggregateExpressions = aggregateExpressions + + testFallbackStartsAt match { + case None => + val keyString = groupingExpressions.mkString("[", ",", "]") + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"HashAggregate(key=$keyString, functions=$functionString, output=$outputString)" + case Some(fallbackStartsAt) => + s"HashAggregateWithControlledFallback $groupingExpressions " + + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" + } + } +} + +object HashAggregateExec { + def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/8900c8d8/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala new file mode 100644 index 0000000..9e48ff8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Sort-based aggregate operator. + */ +case class SortAggregateExec( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override private[sql] lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + groupingExpressions.map(SortOrder(_, Ascending)) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsInternal { iter => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[UnsafeRow]() + } else { + val outputIter = new SortBasedAggregationIterator( + groupingExpressions, + child.output, + iter, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), + numOutputRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + numOutputRows += 1 + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = aggregateExpressions + + val keyString = groupingExpressions.mkString("[", ",", "]") + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/8900c8d8/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala deleted file mode 100644 index af1fb4c..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.metric.SQLMetrics - -case class SortBasedAggregateExec( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], - aggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryExecNode { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - } - - override def outputOrdering: Seq[SortOrder] = { - groupingExpressions.map(SortOrder(_, Ascending)) - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => - // Because the constructor of an aggregation iterator will read at least the first row, - // we need to get the value of iter.hasNext first. - val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator[UnsafeRow]() - } else { - val outputIter = new SortBasedAggregationIterator( - groupingExpressions, - child.output, - iter, - aggregateExpressions, - aggregateAttributes, - initialInputBufferOffset, - resultExpressions, - (expressions, inputSchema) => - newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), - numOutputRows) - if (!hasInput && groupingExpressions.isEmpty) { - // There is no input and there is no grouping expressions. - // We need to output a single row as the output. - numOutputRows += 1 - Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) - } else { - outputIter - } - } - } - } - - override def simpleString: String = { - val allAggregateExpressions = aggregateExpressions - - val keyString = groupingExpressions.mkString("[", ",", "]") - val functionString = allAggregateExpressions.mkString("[", ",", "]") - val outputString = output.mkString("[", ",", "]") - s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" - } -} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
