Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7813#discussion_r36057881
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
 ---
    @@ -0,0 +1,490 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.aggregate
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.expressions._
    +import org.apache.spark.sql.catalyst.expressions.aggregate._
    +import org.apache.spark.unsafe.KVIterator
    +
    +import scala.collection.mutable.ArrayBuffer
    +
    +/**
    + * The base class of [[SortBasedAggregationIterator]] and 
[[UnsafeHybridAggregationIterator]].
    + * It mainly contains two parts:
    + * 1. It initializes aggregate functions.
    + * 2. It creates two functions, `processRow` and `generateOutput` based on 
[[AggregateMode]] of
    + *    its aggregate functions. `processRow` is the function to handle an 
input. `generateOutput`
    + *    is used to generate result.
    + */
    +abstract class AggregationIterator(
    +    groupingKeyAttributes: Seq[Attribute],
    +    valueAttributes: Seq[Attribute],
    +    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
    +    nonCompleteAggregateAttributes: Seq[Attribute],
    +    completeAggregateExpressions: Seq[AggregateExpression2],
    +    completeAggregateAttributes: Seq[Attribute],
    +    initialInputBufferOffset: Int,
    +    resultExpressions: Seq[NamedExpression],
    +    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
    +    outputsUnsafeRows: Boolean)
    +  extends Iterator[InternalRow] with Logging {
    +
    +  
///////////////////////////////////////////////////////////////////////////
    +  // Initializing functions.
    +  
///////////////////////////////////////////////////////////////////////////
    +
    +  // An Seq of all AggregateExpressions.
    +  // It is important that all AggregateExpressions with the mode Partial, 
PartialMerge or Final
    +  // are at the beginning of the allAggregateExpressions.
    +  protected val allAggregateExpressions =
    +    nonCompleteAggregateExpressions ++ completeAggregateExpressions
    +
    +  require(
    +    allAggregateExpressions.map(_.mode).distinct.length <= 2,
    +    s"$allAggregateExpressions are not supported becuase they have more 
than 2 distinct modes.")
    +
    +  /**
    +   * The distinct modes of AggregateExpressions. Right now, we can handle 
the following mode:
    +   *  - Partial-only: all AggregateExpressions have the mode of Partial;
    +   *  - PartialMerge-only: all AggregateExpressions have the mode of 
PartialMerge);
    +   *  - Final-only: all AggregateExpressions have the mode of Final;
    +   *  - Final-Complete: some AggregateExpressions have the mode of Final 
and
    +   *    others have the mode of Complete;
    +   *  - Complete-only: nonCompleteAggregateExpressions is empty and we 
have AggregateExpressions
    +   *    with mode Complete in completeAggregateExpressions; and
    +   *  - Grouping-only: there is no AggregateExpression.
    +   */
    +  protected val aggregationMode: (Option[AggregateMode], 
Option[AggregateMode]) =
    +    nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
    +      completeAggregateExpressions.map(_.mode).distinct.headOption
    +
    +  // Initialize all AggregateFunctions by binding references if necessary,
    +  // and set inputBufferOffset and mutableBufferOffset.
    +  protected val allAggregateFunctions: Array[AggregateFunction2] = {
    +    var mutableBufferOffset = 0
    +    var inputBufferOffset: Int = initialInputBufferOffset
    +    val functions = new 
Array[AggregateFunction2](allAggregateExpressions.length)
    +    var i = 0
    +    while (i < allAggregateExpressions.length) {
    +      val func = allAggregateExpressions(i).aggregateFunction
    +      val funcWithBoundReferences = allAggregateExpressions(i).mode match {
    +        case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] 
=>
    +          // We need to create BoundReferences if the function is not an
    +          // AlgebraicAggregate (it does not support code-gen) and the 
mode of
    +          // this function is Partial or Complete because we will call 
eval of this
    +          // function's children in the update method of this aggregate 
function.
    +          // Those eval calls require BoundReferences to work.
    +          BindReferences.bindReference(func, valueAttributes)
    +        case _ =>
    +          // We only need to set inputBufferOffset for aggregate functions 
with mode
    +          // PartialMerge and Final.
    +          func.withNewInputBufferOffset(inputBufferOffset)
    +          inputBufferOffset += func.bufferSchema.length
    +          func
    +      }
    +      // Set mutableBufferOffset for this function. It is important that 
setting
    +      // mutableBufferOffset happens after all potential bindReference 
operations
    +      // because bindReference will create a new instance of the function.
    +      
funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset)
    +      mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
    +      functions(i) = funcWithBoundReferences
    +      i += 1
    +    }
    +    functions
    +  }
    +
    +  // Positions of those non-algebraic aggregate functions in 
allAggregateFunctions.
    +  // For example, we have func1, func2, func3, func4 in 
aggregateFunctions, and
    +  // func2 and func3 are non-algebraic aggregate functions.
    +  // nonAlgebraicAggregateFunctionPositions will be [1, 2].
    +  private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] 
= {
    +    val positions = new ArrayBuffer[Int]()
    +    var i = 0
    +    while (i < allAggregateFunctions.length) {
    +      allAggregateFunctions(i) match {
    +        case agg: AlgebraicAggregate =>
    +        case _ => positions += i
    +      }
    +      i += 1
    +    }
    +    positions.toArray
    +  }
    +
    +  // All AggregateFunctions functions with mode Partial, PartialMerge, or 
Final.
    +  private[this] val nonCompleteAggregateFunctions: 
Array[AggregateFunction2] =
    +    allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
    +
    +  // All non-algebraic aggregate functions with mode Partial, 
PartialMerge, or Final.
    +  private[this] val nonCompleteNonAlgebraicAggregateFunctions: 
Array[AggregateFunction2] =
    +    nonCompleteAggregateFunctions.collect {
    +      case func: AggregateFunction2 if 
!func.isInstanceOf[AlgebraicAggregate] => func
    +    }
    +
    +  // The projection used to initialize buffer values for all 
AlgebraicAggregates.
    +  private[this] val algebraicInitialProjection = {
    +    val initExpressions = allAggregateFunctions.flatMap {
    +      case ae: AlgebraicAggregate => ae.initialValues
    +      case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
    +    }
    +    newMutableProjection(initExpressions, Nil)()
    +  }
    +
    +  // All non-Algebraic AggregateFunctions.
    +  private[this] val allNonAlgebraicAggregateFunctions =
    +    allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions)
    +
    +  
///////////////////////////////////////////////////////////////////////////
    +  // Methods and fields used by sub-classes.
    +  
///////////////////////////////////////////////////////////////////////////
    +
    +  // Initializing functions used to process a row.
    +  protected val processRow: (MutableRow, InternalRow) => Unit = {
    +    val rowToBeProcessed = new JoinedRow
    +    val aggregationBufferSchema = 
allAggregateFunctions.flatMap(_.bufferAttributes)
    +    aggregationMode match {
    +      // Partial-only
    +      case (Some(Partial), None) =>
    +        val updateExpressions = nonCompleteAggregateFunctions.flatMap {
    +          case ae: AlgebraicAggregate => ae.updateExpressions
    +          case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
    +        }
    +        val algebraicUpdateProjection =
    +          newMutableProjection(updateExpressions, aggregationBufferSchema 
++ valueAttributes)()
    +
    +        (currentBuffer: MutableRow, row: InternalRow) => {
    +          algebraicUpdateProjection.target(currentBuffer)
    +          // Process all algebraic aggregate functions.
    +          algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row))
    +          // Process all non-algebraic aggregate functions.
    +          var i = 0
    +          while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
    +            
nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
    +            i += 1
    +          }
    +        }
    +
    +      // PartialMerge-only or Final-only
    +      case (Some(PartialMerge), None) | (Some(Final), None) =>
    +        val inputAggregationBufferSchema = if (initialInputBufferOffset == 
0) {
    +          // If initialInputBufferOffset, the input value does not contain
    +          // grouping keys.
    +          // This part is pretty hacky.
    +          allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq
    +        } else {
    +          groupingKeyAttributes ++ 
allAggregateFunctions.flatMap(_.cloneBufferAttributes)
    +        }
    +        // val inputAggregationBufferSchema =
    +        //  groupingKeyAttributes ++
    +        //    allAggregateFunctions.flatMap(_.cloneBufferAttributes)
    +        val mergeExpressions = nonCompleteAggregateFunctions.flatMap {
    +          case ae: AlgebraicAggregate => ae.mergeExpressions
    +          case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
    +        }
    +        // This projection is used to merge buffer values for all 
AlgebraicAggregates.
    +        val algebraicMergeProjection =
    +          newMutableProjection(
    +            mergeExpressions,
    +            aggregationBufferSchema ++ inputAggregationBufferSchema)()
    +
    +        (currentBuffer: MutableRow, row: InternalRow) => {
    +          // Process all algebraic aggregate functions.
    +          
algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, 
row))
    +          // Process all non-algebraic aggregate functions.
    +          var i = 0
    +          while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
    +            
nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
    +            i += 1
    +          }
    +        }
    +
    +      // Final-Complete
    +      case (Some(Final), Some(Complete)) =>
    +        val completeAggregateFunctions: Array[AggregateFunction2] =
    +          
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
    +        // All non-algebraic aggregate functions with mode Complete.
    +        val completeNonAlgebraicAggregateFunctions: 
Array[AggregateFunction2] =
    +          completeAggregateFunctions.collect {
    +            case func: AggregateFunction2 if 
!func.isInstanceOf[AlgebraicAggregate] => func
    +          }
    +
    +        // The first initialInputBufferOffset values of the input 
aggregation buffer is
    +        // for grouping expressions and distinct columns.
    +        val groupingAttributesAndDistinctColumns = 
valueAttributes.take(initialInputBufferOffset)
    +
    +        val completeOffsetExpressions =
    +          
Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
    +        // We do not touch buffer values of aggregate functions with the 
Final mode.
    +        val finalOffsetExpressions =
    +          
Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
    +
    +        val mergeInputSchema =
    +          aggregationBufferSchema ++
    +            groupingAttributesAndDistinctColumns ++
    +            nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes)
    +        val mergeExpressions =
    +          nonCompleteAggregateFunctions.flatMap {
    +            case ae: AlgebraicAggregate => ae.mergeExpressions
    +            case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
    +          } ++ completeOffsetExpressions
    +        val finalAlgebraicMergeProjection =
    +          newMutableProjection(mergeExpressions, mergeInputSchema)()
    +
    +        val updateExpressions =
    +          finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
    +            case ae: AlgebraicAggregate => ae.updateExpressions
    +            case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
    +          }
    +        val completeAlgebraicUpdateProjection =
    +          newMutableProjection(updateExpressions, aggregationBufferSchema 
++ valueAttributes)()
    +
    +        (currentBuffer: MutableRow, row: InternalRow) => {
    +          val input = rowToBeProcessed(currentBuffer, row)
    +          // For all aggregate functions with mode Complete, update 
buffers.
    +          completeAlgebraicUpdateProjection.target(currentBuffer)(input)
    +          var i = 0
    +          while (i < completeNonAlgebraicAggregateFunctions.length) {
    +            
completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
    +            i += 1
    +          }
    +
    +          // For all aggregate functions with mode Final, merge buffers.
    +          finalAlgebraicMergeProjection.target(currentBuffer)(input)
    +          i = 0
    +          while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
    +            
nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
    +            i += 1
    +          }
    +        }
    +
    +      // Complete-only
    +      case (None, Some(Complete)) =>
    +        val completeAggregateFunctions: Array[AggregateFunction2] =
    +          
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
    +        // All non-algebraic aggregate functions with mode Complete.
    +        val completeNonAlgebraicAggregateFunctions: 
Array[AggregateFunction2] =
    +          completeAggregateFunctions.collect {
    +            case func: AggregateFunction2 if 
!func.isInstanceOf[AlgebraicAggregate] => func
    +          }
    +
    +        val updateExpressions =
    +          completeAggregateFunctions.flatMap {
    +            case ae: AlgebraicAggregate => ae.updateExpressions
    +            case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
    +          }
    +        val completeAlgebraicUpdateProjection =
    +          newMutableProjection(updateExpressions, aggregationBufferSchema 
++ valueAttributes)()
    +
    +        (currentBuffer: MutableRow, row: InternalRow) => {
    +          val input = rowToBeProcessed(currentBuffer, row)
    +          // For all aggregate functions with mode Complete, update 
buffers.
    +          completeAlgebraicUpdateProjection.target(currentBuffer)(input)
    +          var i = 0
    +          while (i < completeNonAlgebraicAggregateFunctions.length) {
    +            
completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
    +            i += 1
    +          }
    +        }
    +
    +      // Grouping only.
    +      case (None, None) => (currentBuffer: MutableRow, row: InternalRow) 
=> {}
    +
    +      case other =>
    +        sys.error(
    --- End diff --
    
    this is an error case -- we should throw IllegalStateException, and make it 
clear that if we hit this path, it's a bug.
    
    Right now it sounds as if this operator just cannot handle a legitimate 
case.



---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to