[SPARK-9630] [SQL] Clean up new aggregate operators (SPARK-9240 follow up) This is the followup of https://github.com/apache/spark/pull/7813. It renames `HybridUnsafeAggregationIterator` to `TungstenAggregationIterator` and makes it only work with `UnsafeRow`. Also, I add a `TungstenAggregate` that uses `TungstenAggregationIterator` and make `SortBasedAggregate` (renamed from `SortBasedAggregate`) only works with `SafeRow`.
Author: Yin Huai <[email protected]> Closes #7954 from yhuai/agg-followUp and squashes the following commits: 4d2f4fc [Yin Huai] Add comments and free map. 0d7ddb9 [Yin Huai] Add TungstenAggregationQueryWithControlledFallbackSuite to test fall back process. 91d69c2 [Yin Huai] Rename UnsafeHybridAggregationIterator to TungstenAggregateIteraotr and make it only work with UnsafeRow. (cherry picked from commit 3504bf3aa9f7b75c0985f04ce2944833d8c5b5bd) Signed-off-by: Reynold Xin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/272e8834 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/272e8834 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/272e8834 Branch: refs/heads/branch-1.5 Commit: 272e88342540328a24702f07a730b156657bd3be Parents: 9806872 Author: Yin Huai <[email protected]> Authored: Thu Aug 6 15:04:44 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Thu Aug 6 15:04:53 2015 -0700 ---------------------------------------------------------------------- .../expressions/aggregate/functions.scala | 14 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../sql/execution/UnsafeRowSerializer.scala | 20 +- .../sql/execution/aggregate/Aggregate.scala | 182 ----- .../aggregate/SortBasedAggregate.scala | 103 +++ .../SortBasedAggregationIterator.scala | 26 - .../execution/aggregate/TungstenAggregate.scala | 102 +++ .../aggregate/TungstenAggregationIterator.scala | 667 +++++++++++++++++++ .../UnsafeHybridAggregationIterator.scala | 372 ----------- .../spark/sql/execution/aggregate/utils.scala | 260 ++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../hive/execution/AggregationQuerySuite.scala | 104 ++- 12 files changed, 1192 insertions(+), 663 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 88fb516..a73024d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -31,8 +31,11 @@ case class Average(child: Expression) extends AlgebraicAggregate { override def dataType: DataType = resultType // Expected input data type. - // TODO: Once we remove the old code path, we can use our analyzer to cast NullType - // to the default data type of the NumericType. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) private val resultType = child.dataType match { @@ -256,12 +259,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate { override def dataType: DataType = resultType // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select sum(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) private val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) + // TODO: Remove this line once we remove the NullType from inputTypes. + case NullType => IntegerType case _ => child.dataType } http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a730ffb..c5aaebe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -191,8 +191,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // aggregate function to the corresponding attribute of the function. val aggregateFunctionMap = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction + val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute (aggregateFunction, agg.isDistinct) -> - Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction -> attribtue) }.toMap val (functionsWithDistinct, functionsWithoutDistinct) = http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 16498da..39f8f99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag @@ -58,11 +58,26 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + // When `out` is backed by ChainedBufferOutputStream, we will get an + // UnsupportedOperationException when we call dOut.writeInt because it internally calls + // ChainedBufferOutputStream's write(b: Int), which is not supported. + // To workaround this issue, we create an array for sorting the int value. + // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and + // run SparkSqlSerializer2SortMergeShuffleSuite. + private[this] var intBuffer: Array[Byte] = new Array[Byte](4) private[this] val dOut: DataOutputStream = new DataOutputStream(out) override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - dOut.writeInt(row.getSizeInBytes) + val size = row.getSizeInBytes + // This part is based on DataOutputStream's writeInt. + // It is for dOut.writeInt(row.getSizeInBytes). + intBuffer(0) = ((size >>> 24) & 0xFF).toByte + intBuffer(1) = ((size >>> 16) & 0xFF).toByte + intBuffer(2) = ((size >>> 8) & 0xFF).toByte + intBuffer(3) = ((size >>> 0) & 0xFF).toByte + dOut.write(intBuffer, 0, 4) + row.writeToStream(out, writeBuffer) this } @@ -90,6 +105,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null + intBuffer = null dOut.writeInt(EOF) dOut.close() } http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala deleted file mode 100644 index cf568dc..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.StructType - -/** - * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types - * of the grouping expressions and aggregate functions, it determines if it uses - * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to - * process input rows. - */ -case class Aggregate( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - private[this] val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - private[this] val hasNonAlgebricAggregateFunctions = - !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) - - // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of - // grouping key and aggregation buffer is supported; and (3) all - // aggregate functions are algebraic. - private[this] val supportsHybridIterator: Boolean = { - val aggregationBufferSchema: StructType = - StructType.fromAttributes( - allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) - val groupKeySchema: StructType = - StructType.fromAttributes(groupingExpressions.map(_.toAttribute)) - - val schemaSupportsUnsafe: Boolean = - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupKeySchema) - - // TODO: Use the hybrid iterator for non-algebric aggregate functions. - sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions - } - - // We need to use sorted input if we have grouping expressions, and - // we cannot use the hybrid iterator or the hybrid is disabled. - private[this] val requiresSortedInput: Boolean = { - groupingExpressions.nonEmpty && !supportsHybridIterator - } - - override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions - - // If result expressions' data types are all fixed length, we generate unsafe rows - // (We have this requirement instead of check the result of UnsafeProjection.canSupport - // is because we use a mutable projection to generate the result). - override def outputsUnsafeRows: Boolean = { - // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength) - // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix - // any issue we get. - false - } - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - if (requiresSortedInput) { - // TODO: We should not sort the input rows if they are just in reversed order. - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - } else { - Seq.fill(children.size)(Nil) - } - } - - override def outputOrdering: Seq[SortOrder] = { - if (requiresSortedInput) { - // It is possible that the child.outputOrdering starts with the required - // ordering expressions (e.g. we require [a] as the sort expression and the - // child's outputOrdering is [a, b]). We can only guarantee the output rows - // are sorted by values of groupingExpressions. - groupingExpressions.map(SortOrder(_, Ascending)) - } else { - Nil - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // Because the constructor of an aggregation iterator will read at least the first row, - // we need to get the value of iter.hasNext first. - val hasInput = iter.hasNext - val useHybridIterator = - hasInput && - supportsHybridIterator && - groupingExpressions.nonEmpty - if (useHybridIterator) { - UnsafeHybridAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _, - child.output, - iter, - outputsUnsafeRows) - } else { - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator[InternalRow]() - } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _ , - newProjection _, - child.output, - iter, - outputsUnsafeRows) - if (!hasInput && groupingExpressions.isEmpty) { - // There is no input and there is no grouping expressions. - // We need to output a single row as the output. - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) - } else { - outputIter - } - } - } - } - } - - override def simpleString: String = { - val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) { - classOf[UnsafeHybridAggregationIterator].getSimpleName - } else { - classOf[SortBasedAggregationIterator].getSimpleName - } - - s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}""" - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala new file mode 100644 index 0000000..ad428ad --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType + +case class SortBasedAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = false + + override def canProcessUnsafeRows: Boolean = false + + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + groupingExpressions.map(SortOrder(_, Ascending)) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[InternalRow]() + } else { + val outputIter = SortBasedAggregationIterator.createFromInputIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection _, + newProjection _, + child.output, + iter, + outputsUnsafeRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 40f6bff..67ebafd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -204,31 +204,5 @@ object SortBasedAggregationIterator { newMutableProjection, outputsUnsafeRows) } - - def createFromKVIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { - new SortBasedAggregationIterator( - groupingKeyAttributes, - valueAttributes, - inputKVIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } // scalastyle:on } http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala new file mode 100644 index 0000000..5a0b4d4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} + +case class TungstenAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression2], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + // This is for testing. We force TungstenAggregationIterator to fall back to sort-based + // aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[Int] = { + sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => Some(fallbackStartsAt.toInt) + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] + } else { + val aggregationIterator = + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + iter.asInstanceOf[Iterator[UnsafeRow]], + testFallbackStartsAt) + + if (!hasInput && groupingExpressions.isEmpty) { + Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) + } else { + aggregationIterator + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + + testFallbackStartsAt match { + case None => s"TungstenAggregate ${groupingExpressions} ${allAggregateExpressions}" + case Some(fallbackStartsAt) => + s"TungstenAggregateWithControlledFallback ${groupingExpressions} " + + s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt" + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala new file mode 100644 index 0000000..b9d44aa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.types.StructType + +/** + * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. + * + * This iterator first uses hash-based aggregation to process input rows. It uses + * a hash map to store groups and their corresponding aggregation buffers. If we + * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]], + * it switches to sort-based aggregation. The process of the switch has the following step: + * - Step 1: Sort all entries of the hash map based on values of grouping expressions and + * spill them to disk. + * - Step 2: Create a external sorter based on the spilled sorted map entries. + * - Step 3: Redirect all input rows to the external sorter. + * - Step 4: Get a sorted [[KVIterator]] from the external sorter. + * - Step 5: Initialize sort-based aggregation. + * Then, this iterator works in the way of sort-based aggregation. + * + * The code of this class is organized as follows: + * - Part 1: Initializing aggregate functions. + * - Part 2: Methods and fields used by setting aggregation buffer values, + * processing input rows from inputIter, and generating output + * rows. + * - Part 3: Methods and fields used by hash-based aggregation. + * - Part 4: The function used to switch this iterator from hash-based + * aggregation to sort-based aggregation. + * - Part 5: Methods and fields used by sort-based aggregation. + * - Part 6: Loads input and process input rows. + * - Part 7: Public methods of this iterator. + * - Part 8: A utility function used to generate a result when there is no + * input and there is no grouping expression. + * + * @param groupingExpressions + * expressions for grouping keys + * @param nonCompleteAggregateExpressions + * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. + * @param completeAggregateExpressions + * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * @param initialInputBufferOffset + * If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]]. + * The input rows have the format of `grouping keys + aggregation buffer`. + * This offset indicates the starting position of aggregation buffer in a input row. + * @param resultExpressions + * expressions for generating output rows. + * @param newMutableProjection + * the function used to create mutable projections. + * @param originalInputAttributes + * attributes of representing input rows from `inputIter`. + * @param inputIter + * the iterator containing input [[UnsafeRow]]s. + */ +class TungstenAggregationIterator( + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression2], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + originalInputAttributes: Seq[Attribute], + inputIter: Iterator[UnsafeRow], + testFallbackStartsAt: Option[Int]) + extends Iterator[UnsafeRow] with Logging { + + /////////////////////////////////////////////////////////////////////////// + // Part 1: Initializing aggregate functions. + /////////////////////////////////////////////////////////////////////////// + + // A Seq containing all AggregateExpressions. + // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final + // are at the beginning of the allAggregateExpressions. + private[this] val allAggregateExpressions: Seq[AggregateExpression2] = + nonCompleteAggregateExpressions ++ completeAggregateExpressions + + // Check to make sure we do not have more than three modes in our AggregateExpressions. + // If we have, users are hitting a bug and we throw an IllegalStateException. + if (allAggregateExpressions.map(_.mode).distinct.length > 2) { + throw new IllegalStateException( + s"$allAggregateExpressions should have no more than 2 kinds of modes.") + } + + // + // The modes of AggregateExpressions. Right now, we can handle the following mode: + // - Partial-only: + // All AggregateExpressions have the mode of Partial. + // For this case, aggregationMode is (Some(Partial), None). + // - PartialMerge-only: + // All AggregateExpressions have the mode of PartialMerge). + // For this case, aggregationMode is (Some(PartialMerge), None). + // - Final-only: + // All AggregateExpressions have the mode of Final. + // For this case, aggregationMode is (Some(Final), None). + // - Final-Complete: + // Some AggregateExpressions have the mode of Final and + // others have the mode of Complete. For this case, + // aggregationMode is (Some(Final), Some(Complete)). + // - Complete-only: + // nonCompleteAggregateExpressions is empty and we have AggregateExpressions + // with mode Complete in completeAggregateExpressions. For this case, + // aggregationMode is (None, Some(Complete)). + // - Grouping-only: + // There is no AggregateExpression. For this case, AggregationMode is (None,None). + // + private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { + nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> + completeAggregateExpressions.map(_.mode).distinct.headOption + } + + // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates. + // If there is any functions that is not an AlgebraicAggregate, we throw an + // IllegalStateException. + private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = { + if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) { + throw new IllegalStateException( + "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.") + } + + allAggregateExpressions + .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate]) + .toArray + } + + /////////////////////////////////////////////////////////////////////////// + // Part 2: Methods and fields used by setting aggregation buffer values, + // processing input rows from inputIter, and generating output + // rows. + /////////////////////////////////////////////////////////////////////////// + + // The projection used to initialize buffer values. + private[this] val algebraicInitialProjection: MutableProjection = { + val initExpressions = allAggregateFunctions.flatMap(_.initialValues) + newMutableProjection(initExpressions, Nil)() + } + + // Creates a new aggregation buffer and initializes buffer values. + // This functions should be only called at most three times (when we create the hash map, + // when we switch to sort-based aggregation, and when we create the re-used buffer for + // sort-based aggregation). + private def createNewAggregationBuffer(): UnsafeRow = { + val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + val buffer = unsafeProjection.apply(genericMutableBuffer) + algebraicInitialProjection.target(buffer)(EmptyRow) + buffer + } + + // Creates a function used to process a row based on the given inputAttributes. + private def generateProcessRow( + inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = { + + val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes) + val inputSchema = StructType.fromAttributes(inputAttributes) + val unsafeRowJoiner = + GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema) + + aggregationMode match { + // Partial-only + case (Some(Partial), None) => + val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions) + val algebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + algebraicUpdateProjection.target(currentBuffer) + algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // PartialMerge-only or Final-only + case (Some(PartialMerge), None) | (Some(Final), None) => + val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions) + // This projection is used to merge buffer values for all AlgebraicAggregates. + val algebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(currentBuffer) + algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // Final-Complete + case (Some(Final), Some(Complete)) => + val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + val completeAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + val mergeExpressions = + nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions + val finalAlgebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferAttributes ++ inputAttributes)() + + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + val updateExpressions = + finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions) + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + val input = unsafeRowJoiner.join(currentBuffer, row) + // For all aggregate functions with mode Complete, update the given currentBuffer. + completeAlgebraicUpdateProjection.target(currentBuffer)(input) + + // For all aggregate functions with mode Final, merge buffer values in row to + // currentBuffer. + finalAlgebraicMergeProjection.target(currentBuffer)(input) + } + + // Complete-only + case (None, Some(Complete)) => + val completeAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + + val updateExpressions = + completeAggregateFunctions.flatMap(_.updateExpressions) + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + completeAlgebraicUpdateProjection.target(currentBuffer) + // For all aggregate functions with mode Complete, update the given currentBuffer. + completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // Grouping only. + case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {} + + case other => + throw new IllegalStateException( + s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + } + } + + // Creates a function used to generate output rows. + private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + aggregationMode match { + // Partial-only or PartialMerge-only: every output row is basically the values of + // the grouping expressions and the corresponding aggregation buffer. + case (Some(Partial), None) | (Some(PartialMerge), None) => + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer) + } + + // Final-only, Complete-only and Final-Complete: a output row is generated based on + // resultExpressions. + case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer)) + } + + // Grouping-only: a output row is generated from values of grouping expressions. + case (None, None) => + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + resultProjection(currentGroupingKey) + } + + case other => + throw new IllegalStateException( + s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + } + } + + // An UnsafeProjection used to extract grouping keys from the input rows. + private[this] val groupProjection = + UnsafeProjection.create(groupingExpressions, originalInputAttributes) + + // A function used to process a input row. Its first argument is the aggregation buffer + // and the second argument is the input row. + private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit = + generateProcessRow(originalInputAttributes) + + // A function used to generate output rows based on the grouping keys (first argument) + // and the corresponding aggregation buffer (second argument). + private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = + generateResultProjection() + + // An aggregation buffer containing initial buffer values. It is used to + // initialize other aggregation buffers. + private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + /////////////////////////////////////////////////////////////////////////// + // Part 3: Methods and fields used by hash-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // This is the hash map used for hash-based aggregation. It is backed by an + // UnsafeFixedWidthAggregationMap and it is used to store + // all groups and their corresponding aggregation buffers for hash-based aggregation. + private[this] val hashMap = new UnsafeFixedWidthAggregationMap( + initialAggregationBuffer, + StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), + StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), + TaskContext.get.taskMemoryManager(), + SparkEnv.get.shuffleMemoryManager, + 1024 * 16, // initial capacity + SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), + false // disable tracking of performance metrics + ) + + // The function used to read and process input rows. When processing input rows, + // it first uses hash-based aggregation by putting groups and their buffers in + // hashMap. If we could not allocate more memory for the map, we switch to + // sort-based aggregation (by calling switchToSortBasedAggregation). + private def processInputs(): Unit = { + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey) + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + } + } + + // This function is only used for testing. It basically the same as processInputs except + // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have + // been processed. + private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + var i = 0 + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = if (i < fallbackStartsAt) { + hashMap.getAggregationBuffer(groupingKey) + } else { + null + } + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + i += 1 + } + } + + // The iterator created from hashMap. It is used to generate output rows when we + // are using hash-based aggregation. + private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null + + // Indicates if aggregationBufferMapIterator still has key-value pairs. + private[this] var mapIteratorHasNext: Boolean = false + + /////////////////////////////////////////////////////////////////////////// + // Part 4: The function used to switch this iterator from hash-based + // aggregation to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { + logInfo("falling back to sort based aggregation.") + // Step 1: Get the ExternalSorter containing sorted entries of the map. + val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter() + + // Step 2: Free the memory used by the map. + hashMap.free() + + // Step 3: If we have aggregate function with mode Partial or Complete, + // we need to process input rows to get aggregation buffer. + // So, later in the sort-based aggregation iterator, we can do merge. + // If aggregate functions are with mode Final and PartialMerge, + // we just need to project the aggregation buffer from an input row. + val needsProcess = aggregationMode match { + case (Some(Partial), None) => true + case (None, Some(Complete)) => true + case (Some(Final), Some(Complete)) => true + case _ => false + } + + if (needsProcess) { + // First, we create a buffer. + val buffer = createNewAggregationBuffer() + + // Process firstKey and firstInput. + // Initialize buffer. + buffer.copyFrom(initialAggregationBuffer) + processRow(buffer, firstInput) + externalSorter.insertKV(firstKey, buffer) + + // Process the rest of input rows. + while (inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + buffer.copyFrom(initialAggregationBuffer) + processRow(buffer, newInput) + externalSorter.insertKV(groupingKey, buffer) + } + } else { + // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. + // We need to project the aggregation buffer part from an input row. + val buffer = createNewAggregationBuffer() + // The originalInputAttributes are using cloneBufferAttributes. So, we need to use + // allAggregateFunctions.flatMap(_.cloneBufferAttributes). + val bufferExtractor = newMutableProjection( + allAggregateFunctions.flatMap(_.cloneBufferAttributes), + originalInputAttributes)() + bufferExtractor.target(buffer) + + // Insert firstKey and its buffer. + bufferExtractor(firstInput) + externalSorter.insertKV(firstKey, buffer) + + // Insert the rest of input rows. + while (inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + bufferExtractor(newInput) + externalSorter.insertKV(groupingKey, buffer) + } + } + + // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. + val newAggregationMode = aggregationMode match { + case (Some(Partial), None) => (Some(PartialMerge), None) + case (None, Some(Complete)) => (Some(Final), None) + case (Some(Final), Some(Complete)) => (Some(Final), None) + case other => other + } + aggregationMode = newAggregationMode + + // Basically the value of the KVIterator returned by externalSorter + // will just aggregation buffer. At here, we use cloneBufferAttributes. + val newInputAttributes: Seq[Attribute] = + allAggregateFunctions.flatMap(_.cloneBufferAttributes) + + // Set up new processRow and generateOutput. + processRow = generateProcessRow(newInputAttributes) + generateOutput = generateResultProjection() + + // Step 5: Get the sorted iterator from the externalSorter. + sortedKVIterator = externalSorter.sortedIterator() + + // Step 6: Pre-load the first key-value pair from the sorted iterator to make + // hasNext idempotent. + sortedInputHasNewGroup = sortedKVIterator.next() + + // Copy the first key and value (aggregation buffer). + if (sortedInputHasNewGroup) { + val key = sortedKVIterator.getKey + val value = sortedKVIterator.getValue + nextGroupingKey = key.copy() + currentGroupingKey = key.copy() + firstRowInNextGroup = value.copy() + } + + // Step 7: set sortBased to true. + sortBased = true + } + + /////////////////////////////////////////////////////////////////////////// + // Part 5: Methods and fields used by sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Indicates if we are using sort-based aggregation. Because we first try to use + // hash-based aggregation, its initial value is false. + private[this] var sortBased: Boolean = false + + // The KVIterator containing input rows for the sort-based aggregation. It will be + // set in switchToSortBasedAggregation when we switch to sort-based aggregation. + private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null + + // The grouping key of the current group. + private[this] var currentGroupingKey: UnsafeRow = null + + // The grouping key of next group. + private[this] var nextGroupingKey: UnsafeRow = null + + // The first row of next group. + private[this] var firstRowInNextGroup: UnsafeRow = null + + // Indicates if we has new group of rows from the sorted input iterator. + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + // Processes rows in the current group. It will stop when it find a new group. + private def processCurrentSortedGroup(): Unit = { + // First, we need to copy nextGroupingKey to currentGroupingKey. + currentGroupingKey.copyFrom(nextGroupingKey) + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + // Pre-load the first key-value pair to make the condition of the while loop + // has no action (we do not trigger loading a new key-value pair + // when we evaluate the condition). + var hasNext = sortedKVIterator.next() + while (!findNextPartition && hasNext) { + // Get the grouping key and value (aggregation buffer). + val groupingKey = sortedKVIterator.getKey + val inputAggregationBuffer = sortedKVIterator.getValue + + // Check if the current row belongs the current input row. + if (currentGroupingKey.equals(groupingKey)) { + processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + + hasNext = sortedKVIterator.next() + } else { + // We find a new group. + findNextPartition = true + // copyFrom will fail when + nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() + firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() + + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the sortedKVIterator. + if (!findNextPartition) { + sortedInputHasNewGroup = false + sortedKVIterator.close() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 6: Loads input rows and setup aggregationBufferMapIterator if we + // have not switched to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Starts to process input rows. + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } + + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Par 7: Iterator's public methods. + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = { + (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext) + } + + override final def next(): UnsafeRow = { + if (hasNext) { + if (sortBased) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + + outputRow + } else { + // We did not fall back to sort-based aggregation. + val result = + generateOutput( + aggregationBufferMapIterator.getKey, + aggregationBufferMapIterator.getValue) + + // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext + // idempotent. + mapIteratorHasNext = aggregationBufferMapIterator.next() + + if (!mapIteratorHasNext) { + // If there is no input from aggregationBufferMapIterator, we copy current result. + val resultCopy = result.copy() + // Then, we free the map. + hashMap.free() + + resultCopy + } else { + result + } + } + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 8: A utility function used to generate a output row when there is no + // input and there is no grouping expression. + /////////////////////////////////////////////////////////////////////////// + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + if (groupingExpressions.isEmpty) { + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + // We create a output row and copy it. So, we can free the map. + val resultCopy = + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() + hashMap.free() + resultCopy + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala deleted file mode 100644 index b465787..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} -import org.apache.spark.sql.types.StructType - -/** - * An iterator used to evaluate [[AggregateFunction2]]. - * It first tries to use in-memory hash-based aggregation. If we cannot allocate more - * space for the hash map, we spill the sorted map entries, free the map, and then - * switch to sort-based aggregation. - */ -class UnsafeHybridAggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[UnsafeRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends AggregationIterator( - groupingKeyAttributes, - valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - require(groupingKeyAttributes.nonEmpty) - - /////////////////////////////////////////////////////////////////////////// - // Unsafe Aggregation buffers - /////////////////////////////////////////////////////////////////////////// - - // This is the Unsafe Aggregation Map used to store all buffers. - private[this] val buffers = new UnsafeFixedWidthAggregationMap( - newBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), - StructType.fromAttributes(groupingKeyAttributes), - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - 1024 * 16, // initial capacity - SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), - false // disable tracking of performance metrics - ) - - override protected def newBuffer: UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - val bufferRowSize: Int = bufferSchema.length - - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val unsafeProjection = - UnsafeProjection.create(bufferSchema.map(_.dataType)) - val buffer = unsafeProjection.apply(genericMutableBuffer) - initializeBuffer(buffer) - buffer - } - - /////////////////////////////////////////////////////////////////////////// - // Methods and variables related to switching to sort-based aggregation - /////////////////////////////////////////////////////////////////////////// - private[this] var sortBased = false - - private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _ - - // The value part of the input KV iterator is used to store original input values of - // aggregate functions, we need to convert them to aggregation buffers. - private def processOriginalInput( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val buffer: UnsafeRow = newBuffer - - override def next(): Boolean = { - initializeBuffer(buffer) - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - processRow(buffer, firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - val value = inputKVIterator.getValue() - processRow(buffer, value) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - buffer - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer. - // We need to project the aggregation buffer out. - private def projectInputBufferToUnsafe( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - - private[this] val value: UnsafeRow = { - val genericMutableRow = new GenericMutableRow(bufferSchema.length) - UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow) - } - - private[this] val projectInputBuffer = { - newMutableProjection(bufferSchema, valueAttributes)().target(value) - } - - override def next(): Boolean = { - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - projectInputBuffer(firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - projectInputBuffer(inputKVIterator.getValue()) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - value - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - /** - * We need to fall back to sort based aggregation because we do not have enough memory - * for our in-memory hash map (i.e. `buffers`). - */ - private def switchToSortBasedAggregation( - currentGroupingKey: UnsafeRow, - currentRow: InternalRow): Unit = { - logInfo("falling back to sort based aggregation.") - - // Step 1: Get the ExternalSorter containing entries of the map. - val externalSorter = buffers.destructAndCreateExternalSorter() - - // Step 2: Free the memory used by the map. - buffers.free() - - // Step 3: If we have aggregate function with mode Partial or Complete, - // we need to process them to get aggregation buffer. - // So, later in the sort-based aggregation iterator, we can do merge. - // If aggregate functions are with mode Final and PartialMerge, - // we just need to project the aggregation buffer from the input. - val needsProcess = aggregationMode match { - case (Some(Partial), None) => true - case (None, Some(Complete)) => true - case (Some(Final), Some(Complete)) => true - case _ => false - } - - val processedIterator = if (needsProcess) { - processOriginalInput(currentGroupingKey, currentRow) - } else { - // The input value's format is groupingExprs + buffer. - // We need to project the buffer part out. - projectInputBufferToUnsafe(currentGroupingKey, currentRow) - } - - // Step 4: Redirect processedIterator to externalSorter. - while (processedIterator.next()) { - externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue()) - } - - // Step 5: Get the sorted iterator from the externalSorter. - val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator() - - // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator. - // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator - // will be PartialMerge. For a aggregate function with mode Complete, - // its mode in the SortBasedAggregationIterator will be Final. - val newNonCompleteAggregateExpressions = allAggregateExpressions.map { - case AggregateExpression2(func, Partial, isDistinct) => - AggregateExpression2(func, PartialMerge, isDistinct) - case AggregateExpression2(func, Complete, isDistinct) => - AggregateExpression2(func, Final, isDistinct) - case other => other - } - val newNonCompleteAggregateAttributes = - nonCompleteAggregateAttributes ++ completeAggregateAttributes - - val newValueAttributes = - allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) - - sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator( - groupingKeyAttributes = groupingKeyAttributes, - valueAttributes = newValueAttributes, - inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]], - nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions, - nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - newMutableProjection = newMutableProjection, - outputsUnsafeRows = outputsUnsafeRows) - } - - /////////////////////////////////////////////////////////////////////////// - // Methods used to initialize this iterator. - /////////////////////////////////////////////////////////////////////////// - - /** Starts to read input rows and falls back to sort-based aggregation if necessary. */ - protected def initialize(): Unit = { - var hasNext = inputKVIterator.next() - while (!sortBased && hasNext) { - val groupingKey = inputKVIterator.getKey() - val currentRow = inputKVIterator.getValue() - val buffer = buffers.getAggregationBuffer(groupingKey) - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, currentRow) - sortBased = true - } else { - processRow(buffer, currentRow) - hasNext = inputKVIterator.next() - } - } - } - - // This is the starting point of this iterator. - initialize() - - // Creates the iterator for the Hash Aggregation Map after we have populated - // contents of that map. - private[this] val aggregationBufferMapIterator = buffers.iterator() - - private[this] var _mapIteratorHasNext = false - - // Pre-load the first key-value pair from the map to make hasNext idempotent. - if (!sortBased) { - _mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!_mapIteratorHasNext) { - buffers.free() - } - } - - /////////////////////////////////////////////////////////////////////////// - // Iterator's public methods - /////////////////////////////////////////////////////////////////////////// - - override final def hasNext: Boolean = { - (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext) - } - - - override final def next(): InternalRow = { - if (hasNext) { - if (sortBased) { - sortBasedAggregationIterator.next() - } else { - // We did not fall back to the sort-based aggregation. - val result = - generateOutput( - aggregationBufferMapIterator.getKey, - aggregationBufferMapIterator.getValue) - // Pre-load next key-value pair form aggregationBufferMapIterator. - _mapIteratorHasNext = aggregationBufferMapIterator.next() - - if (!_mapIteratorHasNext) { - val resultCopy = result.copy() - buffers.free() - resultCopy - } else { - result - } - } - } else { - // no more result - throw new NoSuchElementException - } - } -} - -object UnsafeHybridAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { - new UnsafeHybridAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter), - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } - // scalastyle:on -} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
