Github user yhuai commented on a diff in the pull request:
https://github.com/apache/spark/pull/9038#discussion_r41661916
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
---
@@ -134,19 +137,73 @@ class TungstenAggregationIterator(
completeAggregateExpressions.map(_.mode).distinct.headOption
}
- // All aggregate functions. TungstenAggregationIterator only handles
expression-based aggregate.
- // If there is any functions that is an ImperativeAggregateFunction, we
throw an
- // IllegalStateException.
- private[this] val allAggregateFunctions: Array[DeclarativeAggregate] = {
- if (!allAggregateExpressions.forall(
- _.aggregateFunction.isInstanceOf[DeclarativeAggregate])) {
- throw new IllegalStateException(
- "Only ExpressionAggregateFunctions should be passed in
TungstenAggregationIterator.")
+ // Initialize all AggregateFunctions by binding references, if necessary,
+ // and setting inputBufferOffset and mutableBufferOffset.
+ private def initializeAllAggregateFunctions(
+ startingInputBufferOffset: Int): Array[AggregateFunction2] = {
+ var mutableBufferOffset = 0
+ var inputBufferOffset: Int = startingInputBufferOffset
+ val functions = new
Array[AggregateFunction2](allAggregateExpressions.length)
+ var i = 0
+ while (i < allAggregateExpressions.length) {
+ val func = allAggregateExpressions(i).aggregateFunction
+ val aggregateExpressionIsNonComplete = i <
nonCompleteAggregateExpressions.length
+ // We need to use this mode instead of func.mode in order to handle
aggregation mode switching
+ // when switching to sort-based aggregation:
+ val mode =
+ if (aggregateExpressionIsNonComplete) aggregationMode._1 else
aggregationMode._2
+ val funcWithBoundReferences = mode match {
+ case Some(Partial) | Some(Complete) if
func.isInstanceOf[ImperativeAggregate] =>
+ // We need to create BoundReferences if the function is not an
+ // expression-based aggregate function (it does not support
code-gen) and the mode of
+ // this function is Partial or Complete because we will call
eval of this
+ // function's children in the update method of this aggregate
function.
+ // Those eval calls require BoundReferences to work.
+ BindReferences.bindReference(func, originalInputAttributes)
+ case _ =>
+ // We only need to set inputBufferOffset for aggregate functions
with mode
+ // PartialMerge and Final.
+ func match {
+ case function: ImperativeAggregate =>
+ function.withNewInputAggBufferOffset(inputBufferOffset)
+ case _ =>
+ }
+ inputBufferOffset += func.aggBufferSchema.length
+ func
+ }
+ // Set mutableBufferOffset for this function. It is important that
setting
+ // mutableBufferOffset happens after all potential bindReference
operations
+ // because bindReference will create a new instance of the function.
+ funcWithBoundReferences match {
+ case function: ImperativeAggregate =>
+ function.withNewMutableAggBufferOffset(mutableBufferOffset)
+ case _ =>
+ }
+ mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length
+ functions(i) = funcWithBoundReferences
+ i += 1
}
+ functions
+ }
+
+ private[this] var allAggregateFunctions: Array[AggregateFunction2] =
--- End diff --
Since it will be a `var`, let's double check places where we use it and
make sure they will be updated after we change it. I did a quick search, seems
`allImperativeAggregateFunctionPositions` and `expressionAggInitialProjection`
are the only two `val`s that will not be updated after we change
`allAggregateFunctions`. But, I think we are fine because these two will not be
affected by the change of `startingInputBufferOffset`.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]