Github user yhuai commented on a diff in the pull request:
https://github.com/apache/spark/pull/7954#discussion_r36440415
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
---
@@ -0,0 +1,663 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.execution.{UnsafeKVExternalSorter,
UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An iterator used to evaluate aggregate functions. It operates on
[[UnsafeRow]]s.
+ *
+ * This iterator first uses hash-based aggregation to process input rows.
It uses
+ * a hash map to store groups and their corresponding aggregation buffers.
If we
+ * this map cannot allocate memory from
[[org.apache.spark.shuffle.ShuffleMemoryManager]],
+ * it switches to sort-based aggregation. The process of the switch has
the following step:
+ * - Step 1: Sort all entries of the hash map based on values of grouping
expressions and
+ * spill them to disk.
+ * - Step 2: Create a external sorter based on the spilled sorted map
entries.
+ * - Step 3: Redirect all input rows to the external sorter.
+ * - Step 4: Get a sorted [[KVIterator]] from the external sorter.
+ * - Step 5: Initialize sort-based aggregation.
+ * Then, this iterator works in the way of sort-based aggregation.
+ *
+ * The code of this class is organized as follows:
+ * - Part 1: Initializing aggregate functions.
+ * - Part 2: Methods and fields used by setting aggregation buffer values,
+ * processing input rows from inputIter, and generating output
+ * rows.
+ * - Part 3: Methods and fields used by hash-based aggregation.
+ * - Part 4: The function used to switch this iterator from hash-based
+ * aggregation to sort-based aggregation.
+ * - Part 5: Methods and fields used by sort-based aggregation.
+ * - Part 6: Loads input and process input rows.
+ * - Part 7: Public methods of this iterator.
+ * - Part 8: A utility function used to generate a result when there is no
+ * input and there is no grouping expression.
+ *
+ * @param groupingExpressions
+ * expressions for grouping keys
+ * @param nonCompleteAggregateExpressions
+ * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode
[[Partial]],
+ * [[PartialMerge]], or [[Final]].
+ * @param completeAggregateExpressions
+ * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode
[[Complete]].
+ * @param initialInputBufferOffset
+ * If this iterator is used to handle functions with mode
[[PartialMerge]] or [[Final]].
+ * The input rows have the format of `grouping keys + aggregation
buffer`.
+ * This offset indicates the starting position of aggregation buffer in
a input row.
+ * @param resultExpressions
+ * expressions for generating output rows.
+ * @param newMutableProjection
+ * the function used to create mutable projections.
+ * @param originalInputAttributes
+ * attributes of representing input rows from `inputIter`.
+ * @param inputIter
+ * the iterator containing input [[UnsafeRow]]s.
+ */
+class TungstenAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() =>
MutableProjection),
+ originalInputAttributes: Seq[Attribute],
+ inputIter: Iterator[UnsafeRow],
+ testFallbackStartsAt: Option[Int])
+ extends Iterator[UnsafeRow] with Logging {
+
+
///////////////////////////////////////////////////////////////////////////
+ // Part 1: Initializing aggregate functions.
+
///////////////////////////////////////////////////////////////////////////
+
+ // A Seq containing all AggregateExpressions.
+ // It is important that all AggregateExpressions with the mode Partial,
PartialMerge or Final
+ // are at the beginning of the allAggregateExpressions.
+ private[this] val allAggregateExpressions: Seq[AggregateExpression2] =
+ nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ // Check to make sure we do not have more than three modes in our
AggregateExpressions.
+ // If we have, users are hitting a bug and we throw an
IllegalStateException.
+ if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
+ throw new IllegalStateException(
+ s"$allAggregateExpressions should have no more than 2 kinds of
modes.")
+ }
+
+ //
+ // The modes of AggregateExpressions. Right now, we can handle the
following mode:
+ // - Partial-only:
+ // All AggregateExpressions have the mode of Partial.
+ // For this case, aggregationMode is (Some(Partial), None).
+ // - PartialMerge-only:
+ // All AggregateExpressions have the mode of PartialMerge).
+ // For this case, aggregationMode is (Some(PartialMerge), None).
+ // - Final-only:
+ // All AggregateExpressions have the mode of Final.
+ // For this case, aggregationMode is (Some(Final), None).
+ // - Final-Complete:
+ // Some AggregateExpressions have the mode of Final and
+ // others have the mode of Complete. For this case,
+ // aggregationMode is (Some(Final), Some(Complete)).
+ // - Complete-only:
+ // nonCompleteAggregateExpressions is empty and we have
AggregateExpressions
+ // with mode Complete in completeAggregateExpressions. For this
case,
+ // aggregationMode is (None, Some(Complete)).
+ // - Grouping-only:
+ // There is no AggregateExpression. For this case, AggregationMode
is (None,None).
+ //
+ private[this] var aggregationMode: (Option[AggregateMode],
Option[AggregateMode]) = {
+ nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
+ completeAggregateExpressions.map(_.mode).distinct.headOption
+ }
+
+ // All aggregate functions. TungstenAggregationIterator only handles
AlgebraicAggregates.
+ // If there is any functions that is not an AlgebraicAggregate, we throw
an
+ // IllegalStateException.
+ private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = {
+ if
(!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]))
{
+ throw new IllegalStateException(
+ "Only AlgebraicAggregates should be passed in
TungstenAggregationIterator.")
+ }
+
+ allAggregateExpressions
+ .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate])
+ .toArray
+ }
+
+
///////////////////////////////////////////////////////////////////////////
+ // Part 2: Methods and fields used by setting aggregation buffer values,
+ // processing input rows from inputIter, and generating output
+ // rows.
+
///////////////////////////////////////////////////////////////////////////
+
+ // The projection used to initialize buffer values.
+ private[this] val algebraicInitialProjection: MutableProjection = {
+ val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
+ newMutableProjection(initExpressions, Nil)()
+ }
+
+ // Creates a new aggregation buffer and initializes buffer values.
+ // This functions should be only called at most three times (when we
create the hash map,
+ // when we switch to sort-based aggregation, and when we create the
re-used buffer for
+ // sort-based aggregation).
+ private def createNewBuffer(): UnsafeRow = {
--- End diff --
Done.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]