Github user yhuai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/10228#discussion_r47440000
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
 ---
    @@ -97,378 +89,62 @@ class TungstenAggregationIterator(
         numOutputRows: LongSQLMetric,
         dataSize: LongSQLMetric,
         spillSize: LongSQLMetric)
    -  extends Iterator[UnsafeRow] with Logging {
    +  extends AggregationIterator(
    +    groupingExpressions,
    +    originalInputAttributes,
    +    aggregateExpressions,
    +    aggregateAttributes,
    +    initialInputBufferOffset,
    +    resultExpressions,
    +    newMutableProjection) with Logging {
     
       
///////////////////////////////////////////////////////////////////////////
       // Part 1: Initializing aggregate functions.
       
///////////////////////////////////////////////////////////////////////////
     
    -  // A Seq containing all AggregateExpressions.
    -  // It is important that all AggregateExpressions with the mode Partial, 
PartialMerge or Final
    -  // are at the beginning of the allAggregateExpressions.
    -  private[this] val allAggregateExpressions: Seq[AggregateExpression] =
    -    nonCompleteAggregateExpressions ++ completeAggregateExpressions
    -
    -  // Check to make sure we do not have more than three modes in our 
AggregateExpressions.
    -  // If we have, users are hitting a bug and we throw an 
IllegalStateException.
    -  if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
    -    throw new IllegalStateException(
    -      s"$allAggregateExpressions should have no more than 2 kinds of 
modes.")
    -  }
    -
       // Remember spill data size of this task before execute this operator so 
that we can
       // figure out how many bytes we spilled for this operator.
       private val spillSizeBefore = 
TaskContext.get().taskMetrics().memoryBytesSpilled
     
    -  //
    -  // The modes of AggregateExpressions. Right now, we can handle the 
following mode:
    -  //  - Partial-only:
    -  //      All AggregateExpressions have the mode of Partial.
    -  //      For this case, aggregationMode is (Some(Partial), None).
    -  //  - PartialMerge-only:
    -  //      All AggregateExpressions have the mode of PartialMerge).
    -  //      For this case, aggregationMode is (Some(PartialMerge), None).
    -  //  - Final-only:
    -  //      All AggregateExpressions have the mode of Final.
    -  //      For this case, aggregationMode is (Some(Final), None).
    -  //  - Final-Complete:
    -  //      Some AggregateExpressions have the mode of Final and
    -  //      others have the mode of Complete. For this case,
    -  //      aggregationMode is (Some(Final), Some(Complete)).
    -  //  - Complete-only:
    -  //      nonCompleteAggregateExpressions is empty and we have 
AggregateExpressions
    -  //      with mode Complete in completeAggregateExpressions. For this 
case,
    -  //      aggregationMode is (None, Some(Complete)).
    -  //  - Grouping-only:
    -  //      There is no AggregateExpression. For this case, AggregationMode 
is (None,None).
    -  //
    -  private[this] var aggregationMode: (Option[AggregateMode], 
Option[AggregateMode]) = {
    -    nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
    -      completeAggregateExpressions.map(_.mode).distinct.headOption
    -  }
    -
    -  // Initialize all AggregateFunctions by binding references, if necessary,
    -  // and setting inputBufferOffset and mutableBufferOffset.
    -  private def initializeAllAggregateFunctions(
    -      startingInputBufferOffset: Int): Array[AggregateFunction] = {
    -    var mutableBufferOffset = 0
    -    var inputBufferOffset: Int = startingInputBufferOffset
    -    val functions = new 
Array[AggregateFunction](allAggregateExpressions.length)
    -    var i = 0
    -    while (i < allAggregateExpressions.length) {
    -      val func = allAggregateExpressions(i).aggregateFunction
    -      val aggregateExpressionIsNonComplete = i < 
nonCompleteAggregateExpressions.length
    -      // We need to use this mode instead of func.mode in order to handle 
aggregation mode switching
    -      // when switching to sort-based aggregation:
    -      val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 
else aggregationMode._2
    -      val funcWithBoundReferences = mode match {
    -        case Some(Partial) | Some(Complete) if 
func.isInstanceOf[ImperativeAggregate] =>
    -          // We need to create BoundReferences if the function is not an
    -          // expression-based aggregate function (it does not support 
code-gen) and the mode of
    -          // this function is Partial or Complete because we will call 
eval of this
    -          // function's children in the update method of this aggregate 
function.
    -          // Those eval calls require BoundReferences to work.
    -          BindReferences.bindReference(func, originalInputAttributes)
    -        case _ =>
    -          // We only need to set inputBufferOffset for aggregate functions 
with mode
    -          // PartialMerge and Final.
    -          val updatedFunc = func match {
    -            case function: ImperativeAggregate =>
    -              function.withNewInputAggBufferOffset(inputBufferOffset)
    -            case function => function
    -          }
    -          inputBufferOffset += func.aggBufferSchema.length
    -          updatedFunc
    -      }
    -      val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
    -        case function: ImperativeAggregate =>
    -          // Set mutableBufferOffset for this function. It is important 
that setting
    -          // mutableBufferOffset happens after all potential bindReference 
operations
    -          // because bindReference will create a new instance of the 
function.
    -          function.withNewMutableAggBufferOffset(mutableBufferOffset)
    -        case function => function
    -      }
    -      mutableBufferOffset += 
funcWithUpdatedAggBufferOffset.aggBufferSchema.length
    -      functions(i) = funcWithUpdatedAggBufferOffset
    -      i += 1
    -    }
    -    functions
    -  }
    -
    -  private[this] var allAggregateFunctions: Array[AggregateFunction] =
    -    initializeAllAggregateFunctions(initialInputBufferOffset)
    -
    -  // Positions of those imperative aggregate functions in 
allAggregateFunctions.
    -  // For example, say that we have func1, func2, func3, func4 in 
aggregateFunctions, and
    -  // func2 and func3 are imperative aggregate functions. Then
    -  // allImperativeAggregateFunctionPositions will be [1, 2]. Note that 
this does not need to be
    -  // updated when falling back to sort-based aggregation because the 
positions of the aggregate
    -  // functions do not change in that case.
    -  private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
    -    val positions = new ArrayBuffer[Int]()
    -    var i = 0
    -    while (i < allAggregateFunctions.length) {
    -      allAggregateFunctions(i) match {
    -        case agg: DeclarativeAggregate =>
    -        case _ => positions += i
    -      }
    -      i += 1
    -    }
    -    positions.toArray
    -  }
    -
       
///////////////////////////////////////////////////////////////////////////
       // Part 2: Methods and fields used by setting aggregation buffer values,
       //         processing input rows from inputIter, and generating output
       //         rows.
       
///////////////////////////////////////////////////////////////////////////
     
    -  // The projection used to initialize buffer values for all 
expression-based aggregates.
    -  // Note that this projection does not need to be updated when switching 
to sort-based aggregation
    -  // because the schema of empty aggregation buffers does not change in 
that case.
    -  private[this] val expressionAggInitialProjection: MutableProjection = {
    -    val initExpressions = allAggregateFunctions.flatMap {
    -      case ae: DeclarativeAggregate => ae.initialValues
    -      // For the positions corresponding to imperative aggregate 
functions, we'll use special
    -      // no-op expressions which are ignored during projection 
code-generation.
    -      case i: ImperativeAggregate => 
Seq.fill(i.aggBufferAttributes.length)(NoOp)
    -    }
    -    newMutableProjection(initExpressions, Nil)()
    -  }
    -
       // Creates a new aggregation buffer and initializes buffer values.
    -  // This function should be only called at most three times (when we 
create the hash map,
    -  // when we switch to sort-based aggregation, and when we create the 
re-used buffer for
    -  // sort-based aggregation).
    +  // This function should be only called at most two times (when we create 
the hash map,
    +  // and when we create the re-used buffer for sort-based aggregation).
       private def createNewAggregationBuffer(): UnsafeRow = {
    -    val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
    +    val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes)
         val buffer: UnsafeRow = 
UnsafeProjection.create(bufferSchema.map(_.dataType))
           .apply(new GenericMutableRow(bufferSchema.length))
         // Initialize declarative aggregates' buffer values
         expressionAggInitialProjection.target(buffer)(EmptyRow)
         // Initialize imperative aggregates' buffer values
    -    allAggregateFunctions.collect { case f: ImperativeAggregate => f 
}.foreach(_.initialize(buffer))
    +    aggregateFunctions.collect { case f: ImperativeAggregate => f 
}.foreach(_.initialize(buffer))
         buffer
       }
     
    -  // Creates a function used to process a row based on the given 
inputAttributes.
    -  private def generateProcessRow(
    -      inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = 
{
    -
    -    val aggregationBufferAttributes = 
allAggregateFunctions.flatMap(_.aggBufferAttributes)
    -    val joinedRow = new JoinedRow()
    -
    -    aggregationMode match {
    -      // Partial-only
    -      case (Some(Partial), None) =>
    -        val updateExpressions = allAggregateFunctions.flatMap {
    -          case ae: DeclarativeAggregate => ae.updateExpressions
    -          case agg: AggregateFunction => 
Seq.fill(agg.aggBufferAttributes.length)(NoOp)
    -        }
    -        val imperativeAggregateFunctions: Array[ImperativeAggregate] =
    -          allAggregateFunctions.collect { case func: ImperativeAggregate 
=> func}
    -        val expressionAggUpdateProjection =
    -          newMutableProjection(updateExpressions, 
aggregationBufferAttributes ++ inputAttributes)()
    -
    -        (currentBuffer: UnsafeRow, row: InternalRow) => {
    -          expressionAggUpdateProjection.target(currentBuffer)
    -          // Process all expression-based aggregate functions.
    -          expressionAggUpdateProjection(joinedRow(currentBuffer, row))
    -          // Process all imperative aggregate functions
    -          var i = 0
    -          while (i < imperativeAggregateFunctions.length) {
    -            imperativeAggregateFunctions(i).update(currentBuffer, row)
    -            i += 1
    -          }
    -        }
    -
    -      // PartialMerge-only or Final-only
    -      case (Some(PartialMerge), None) | (Some(Final), None) =>
    -        val mergeExpressions = allAggregateFunctions.flatMap {
    -          case ae: DeclarativeAggregate => ae.mergeExpressions
    -          case agg: AggregateFunction => 
Seq.fill(agg.aggBufferAttributes.length)(NoOp)
    -        }
    -        val imperativeAggregateFunctions: Array[ImperativeAggregate] =
    -          allAggregateFunctions.collect { case func: ImperativeAggregate 
=> func}
    -        // This projection is used to merge buffer values for all 
expression-based aggregates.
    -        val expressionAggMergeProjection =
    -          newMutableProjection(mergeExpressions, 
aggregationBufferAttributes ++ inputAttributes)()
    -
    -        (currentBuffer: UnsafeRow, row: InternalRow) => {
    -          // Process all expression-based aggregate functions.
    -          
expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, 
row))
    -          // Process all imperative aggregate functions.
    -          var i = 0
    -          while (i < imperativeAggregateFunctions.length) {
    -            imperativeAggregateFunctions(i).merge(currentBuffer, row)
    -            i += 1
    -          }
    -        }
    -
    -      // Final-Complete
    -      case (Some(Final), Some(Complete)) =>
    -        val completeAggregateFunctions: Array[AggregateFunction] =
    -          
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
    -        val completeImperativeAggregateFunctions: 
Array[ImperativeAggregate] =
    -          completeAggregateFunctions.collect { case func: 
ImperativeAggregate => func }
    -        val nonCompleteAggregateFunctions: Array[AggregateFunction] =
    -          
allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
    -        val nonCompleteImperativeAggregateFunctions: 
Array[ImperativeAggregate] =
    -          nonCompleteAggregateFunctions.collect { case func: 
ImperativeAggregate => func }
    -
    -        val completeOffsetExpressions =
    -          
Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
    -        val mergeExpressions =
    -          nonCompleteAggregateFunctions.flatMap {
    -            case ae: DeclarativeAggregate => ae.mergeExpressions
    -            case agg: AggregateFunction => 
Seq.fill(agg.aggBufferAttributes.length)(NoOp)
    -          } ++ completeOffsetExpressions
    -        val finalMergeProjection =
    -          newMutableProjection(mergeExpressions, 
aggregationBufferAttributes ++ inputAttributes)()
    -
    -        // We do not touch buffer values of aggregate functions with the 
Final mode.
    -        val finalOffsetExpressions =
    -          
Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
    -        val updateExpressions = finalOffsetExpressions ++ 
completeAggregateFunctions.flatMap {
    -          case ae: DeclarativeAggregate => ae.updateExpressions
    -          case agg: AggregateFunction => 
Seq.fill(agg.aggBufferAttributes.length)(NoOp)
    -        }
    -        val completeUpdateProjection =
    -          newMutableProjection(updateExpressions, 
aggregationBufferAttributes ++ inputAttributes)()
    -
    -        (currentBuffer: UnsafeRow, row: InternalRow) => {
    -          val input = joinedRow(currentBuffer, row)
    -          // For all aggregate functions with mode Complete, update 
buffers.
    -          completeUpdateProjection.target(currentBuffer)(input)
    -          var i = 0
    -          while (i < completeImperativeAggregateFunctions.length) {
    -            completeImperativeAggregateFunctions(i).update(currentBuffer, 
row)
    -            i += 1
    -          }
    -
    -          // For all aggregate functions with mode Final, merge buffer 
values in row to
    -          // currentBuffer.
    -          finalMergeProjection.target(currentBuffer)(input)
    -          i = 0
    -          while (i < nonCompleteImperativeAggregateFunctions.length) {
    -            
nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
    -            i += 1
    -          }
    -        }
    -
    -      // Complete-only
    -      case (None, Some(Complete)) =>
    -        val completeAggregateFunctions: Array[AggregateFunction] =
    -          
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
    -        // All imperative aggregate functions with mode Complete.
    -        val completeImperativeAggregateFunctions: 
Array[ImperativeAggregate] =
    -          completeAggregateFunctions.collect { case func: 
ImperativeAggregate => func }
    -
    -        val updateExpressions = completeAggregateFunctions.flatMap {
    -          case ae: DeclarativeAggregate => ae.updateExpressions
    -          case agg: AggregateFunction => 
Seq.fill(agg.aggBufferAttributes.length)(NoOp)
    -        }
    -        val completeExpressionAggUpdateProjection =
    -          newMutableProjection(updateExpressions, 
aggregationBufferAttributes ++ inputAttributes)()
    -
    -        (currentBuffer: UnsafeRow, row: InternalRow) => {
    -          // For all aggregate functions with mode Complete, update 
buffers.
    -          
completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer,
 row))
    -          var i = 0
    -          while (i < completeImperativeAggregateFunctions.length) {
    -            completeImperativeAggregateFunctions(i).update(currentBuffer, 
row)
    -            i += 1
    -          }
    -        }
    -
    -      // Grouping only.
    -      case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => 
{}
    -
    -      case other =>
    -        throw new IllegalStateException(
    -          s"${aggregationMode} should not be passed into 
TungstenAggregationIterator.")
    -    }
    -  }
    -
       // Creates a function used to generate output rows.
    -  private def generateResultProjection(): (UnsafeRow, UnsafeRow) => 
UnsafeRow = {
    -
    -    val groupingAttributes = groupingExpressions.map(_.toAttribute)
    -    val bufferAttributes = 
allAggregateFunctions.flatMap(_.aggBufferAttributes)
    -
    -    aggregationMode match {
    -      // Partial-only or PartialMerge-only: every output row is basically 
the values of
    -      // the grouping expressions and the corresponding aggregation buffer.
    -      case (Some(Partial), None) | (Some(PartialMerge), None) =>
    -        val groupingKeySchema = 
StructType.fromAttributes(groupingAttributes)
    -        val bufferSchema = StructType.fromAttributes(bufferAttributes)
    -        val unsafeRowJoiner = 
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
    -
    -        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
    -          unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
    -        }
    -
    -      // Final-only, Complete-only and Final-Complete: a output row is 
generated based on
    -      // resultExpressions.
    -      case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
    -        val joinedRow = new JoinedRow()
    -        val evalExpressions = allAggregateFunctions.map {
    -          case ae: DeclarativeAggregate => ae.evaluateExpression
    -          case agg: AggregateFunction => NoOp
    -        }
    -        val expressionAggEvalProjection = 
newMutableProjection(evalExpressions, bufferAttributes)()
    -        // These are the attributes of the row produced by 
`expressionAggEvalProjection`
    -        val aggregateResultSchema = nonCompleteAggregateAttributes ++ 
completeAggregateAttributes
    -        val aggregateResult = new 
SpecificMutableRow(aggregateResultSchema.map(_.dataType))
    -        expressionAggEvalProjection.target(aggregateResult)
    -        val resultProjection =
    -          UnsafeProjection.create(resultExpressions, groupingAttributes ++ 
aggregateResultSchema)
    -
    -        val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
    -          allAggregateFunctions.collect { case func: ImperativeAggregate 
=> func}
    -
    -        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
    -          // Generate results for all expression-based aggregate functions.
    -          expressionAggEvalProjection(currentBuffer)
    -          // Generate results for all imperative aggregate functions.
    -          var i = 0
    -          while (i < allImperativeAggregateFunctions.length) {
    -            aggregateResult.update(
    -              allImperativeAggregateFunctionPositions(i),
    -              allImperativeAggregateFunctions(i).eval(currentBuffer))
    -            i += 1
    -          }
    -          resultProjection(joinedRow(currentGroupingKey, aggregateResult))
    -        }
    -
    -      // Grouping-only: a output row is generated from values of grouping 
expressions.
    -      case (None, None) =>
    -        val resultProjection = UnsafeProjection.create(resultExpressions, 
groupingAttributes)
    -
    -        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
    -          resultProjection(currentGroupingKey)
    -        }
    -
    -      case other =>
    -        throw new IllegalStateException(
    -          s"${aggregationMode} should not be passed into 
TungstenAggregationIterator.")
    +  override def generateResultProjection(): (UnsafeRow, MutableRow) => 
UnsafeRow = {
    --- End diff --
    
    `override protected`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to