This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e3873d  [SPARK-37564][SQL] Add code-gen for sort aggregate without 
grouping keys
5e3873d is described below

commit 5e3873d6adca0b93efd4f30a5efc80603d0c6c5e
Author: Cheng Su <[email protected]>
AuthorDate: Fri Dec 17 20:18:37 2021 +0800

    [SPARK-37564][SQL] Add code-gen for sort aggregate without grouping keys
    
    ### What changes were proposed in this pull request?
    
    We can support `SortAggregateExec` code-gen without grouping keys. When 
there's no grouping key, sort aggregate should share same execution logic with 
hash aggregate.
    
    At a high level, this PR does
    
    * `spark.sql.codegen.aggregate.sortAggregate.enabled`: a new config 
(user-facing) is introduced to allow users to disable sort aggregate in case 
anything going wrong, and it's enabled by default.
    *  `spark.sql.aggregate.forceApplySortAggregate`: an internal config 
(non-user-facing) is introduced to allow developers test sort aggregate, to 
improve unit test coverage of it. We already had a similar config for shuffled 
hash join.
    * `AggregateCodegenSupport.scala`: The base class to have implementation of 
code-gen without grouping keys and other code-gen related boilerplate code. The 
subclass is required to implement `doProduceWithKeys()` and 
`doConsumeWithKeys()` to handle aggregate with grouping keys. The 
implementation in this class is copied literarily from `HashAggregateExec`.
    * `HashAggregateExec.scala`: only keeps its original `doProduceWithKeys()` 
and `doConsumeWithKeys()`.
    * `SortAggregateExec.scala`: extends `AggregateCodegenSupport.scala` to 
support code-gen without grouping keys.
    
    The implementation of `SortAggregateExec` code-gen with grouping keys will 
be added in follow-up PR.
    
    ### Why are the changes needed?
    
    To enable code-gen for sort aggregate, so we can have better query 
performance when sort aggregate is used.
    
    ### Does this PR introduce _any_ user-facing change?
    
    The config of `spark.sql.codegen.aggregate.sortAggregate.enabled` to 
enable/disable sort aggregate code-gen.
    
    ### How was this patch tested?
    
    Added unit test in `WholeStageCodegenSuite.scala`.
    
    Closes #34826 from c21/sort-agg-codegen.
    
    Authored-by: Cheng Su <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |   8 +
 .../sql/execution/WholeStageCodegenExec.scala      |   5 +-
 .../spark/sql/execution/aggregate/AggUtils.scala   |  17 +-
 .../aggregate/AggregateCodegenSupport.scala        | 340 +++++++++++++++++++++
 .../execution/aggregate/HashAggregateExec.scala    | 281 +----------------
 .../execution/aggregate/SortAggregateExec.scala    |  30 +-
 .../sql/execution/WholeStageCodegenSuite.scala     |  17 +-
 7 files changed, 410 insertions(+), 288 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6c6fb40..868340b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1781,6 +1781,14 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val ENABLE_SORT_AGGREGATE_CODEGEN =
+    buildConf("spark.sql.codegen.aggregate.sortAggregate.enabled")
+      .internal()
+      .doc("When true, enable code-gen for sort aggregate.")
+      .version("3.3.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val ENABLE_FULL_OUTER_SHUFFLED_HASH_JOIN_CODEGEN =
     buildConf("spark.sql.codegen.join.fullOuterShuffledHashJoin.enabled")
       .internal()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index e80ad89..dde976c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
SortAggregateExec}
 import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -47,7 +47,8 @@ trait CodegenSupport extends SparkPlan {
 
   /** Prefix used in the current operator's variable names. */
   private def variablePrefix: String = this match {
-    case _: HashAggregateExec => "agg"
+    case _: HashAggregateExec => "hashAgg"
+    case _: SortAggregateExec => "sortAgg"
     case _: BroadcastHashJoinExec => "bhj"
     case _: ShuffledHashJoinExec => "shj"
     case _: SortMergeJoinExec => "smj"
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 0f239b4..32db622 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
 
 /**
  * Utility functions used by the query planner to convert our plan to new 
aggregation code path.
@@ -53,7 +55,9 @@ object AggUtils {
       child: SparkPlan): SparkPlan = {
     val useHash = HashAggregateExec.supportsAggregate(
       aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
-    if (useHash) {
+    val forceSortAggregate = forceApplySortAggregate(child.conf)
+
+    if (useHash && !forceSortAggregate) {
       HashAggregateExec(
         requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
         groupingExpressions = groupingExpressions,
@@ -66,7 +70,7 @@ object AggUtils {
       val objectHashEnabled = child.conf.useObjectHashAggregation
       val useObjectHash = 
ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
 
-      if (objectHashEnabled && useObjectHash) {
+      if (objectHashEnabled && useObjectHash && !forceSortAggregate) {
         ObjectHashAggregateExec(
           requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
           groupingExpressions = groupingExpressions,
@@ -529,4 +533,13 @@ object AggUtils {
       case None => partialAggregate
     }
   }
+
+  /**
+   * Returns whether a sort aggregate should be force applied.
+   * The config key is hard-coded because it's testing only and should not be 
exposed.
+   */
+  private def forceApplySortAggregate(conf: SQLConf): Boolean = {
+    Utils.isTesting &&
+      conf.getConfString("spark.sql.test.forceApplySortAggregate", "false") == 
"true"
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala
new file mode 100644
index 0000000..6304363
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, ExpressionEquals, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
+import org.apache.spark.sql.execution.{BlockingOperatorWithCodegen, 
CodegenSupport, GeneratePredicateHelper}
+import org.apache.spark.util.Utils
+
+/**
+ * An interface for those aggregate physical operators that support codegen.
+ */
+trait AggregateCodegenSupport
+  extends BaseAggregateExec
+  with BlockingOperatorWithCodegen
+  with GeneratePredicateHelper {
+
+  /**
+   * All the modes of aggregate expressions.
+   */
+  protected val modes: Seq[AggregateMode] = 
aggregateExpressions.map(_.mode).distinct
+
+  /**
+   * The variables are used as aggregation buffers and each aggregate function 
has one or more
+   * ExprCode to initialize its buffer slots. Only used for aggregation 
without keys.
+   */
+  private var bufVars: Seq[Seq[ExprCode]] = _
+
+  /**
+   * The generated code for `doProduce` call when aggregate has grouping keys.
+   */
+  protected def doProduceWithKeys(ctx: CodegenContext): String
+
+  /**
+   * The generated code for `doConsume` call when aggregate has grouping keys.
+   */
+  protected def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    if (groupingExpressions.isEmpty) {
+      doProduceWithoutKeys(ctx)
+    } else {
+      doProduceWithKeys(ctx)
+    }
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
+    if (groupingExpressions.isEmpty) {
+      doConsumeWithoutKeys(ctx, input)
+    } else {
+      doConsumeWithKeys(ctx, input)
+    }
+  }
+
+  override def supportCodegen: Boolean = {
+    val isMutableAggBuffer = aggregateBufferAttributes.forall(a => 
UnsafeRow.isMutable(a.dataType))
+    // ImperativeAggregate are not supported right now
+    isMutableAggBuffer &&
+      
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
+  }
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].inputRDDs()
+  }
+
+  override def usedInputs: AttributeSet = inputSet
+
+  /**
+   * The generated code for `doProduce` call when aggregate does not have 
grouping keys.
+   */
+  private def doProduceWithoutKeys(ctx: CodegenContext): String = {
+    val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
+    // The generated function doesn't have input row in the code context.
+    ctx.INPUT_ROW = null
+
+    // generate variables for aggregation buffer
+    val functions = 
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+    val initExpr = functions.map(f => f.initialValues)
+    bufVars = initExpr.map { exprs =>
+      exprs.map { e =>
+        val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, 
"bufIsNull")
+        val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), 
"bufValue")
+        // The initial expression should not access any column
+        val ev = e.genCode(ctx)
+        val initVars =
+          code"""
+                |$isNull = ${ev.isNull};
+                |$value = ${ev.value};
+              """.stripMargin
+        ExprCode(
+          ev.code + initVars,
+          JavaCode.isNullGlobal(isNull),
+          JavaCode.global(value, e.dataType))
+      }
+    }
+    val flatBufVars = bufVars.flatten
+    val initBufVar = evaluateVariables(flatBufVars)
+
+    // generate variables for output
+    val (resultVars, genResult) = if (modes.contains(Final) || 
modes.contains(Complete)) {
+      // evaluate aggregate results
+      ctx.currentVars = flatBufVars
+      val aggResults = bindReferences(
+        functions.map(_.evaluateExpression),
+        aggregateBufferAttributes).map(_.genCode(ctx))
+      val evaluateAggResults = evaluateVariables(aggResults)
+      // evaluate result expressions
+      ctx.currentVars = aggResults
+      val resultVars = bindReferences(resultExpressions, 
aggregateAttributes).map(_.genCode(ctx))
+      (resultVars,
+        s"""
+           |$evaluateAggResults
+           |${evaluateVariables(resultVars)}
+         """.stripMargin)
+    } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+      // output the aggregate buffer directly
+      (flatBufVars, "")
+    } else {
+      // no aggregate function, the result should be literals
+      val resultVars = resultExpressions.map(_.genCode(ctx))
+      (resultVars, evaluateVariables(resultVars))
+    }
+
+    val doAgg = ctx.freshName("doAggregateWithoutKey")
+    val doAggFuncName = ctx.addNewFunction(doAgg,
+      s"""
+         |private void $doAgg() throws java.io.IOException {
+         |  // initialize aggregation buffer
+         |  $initBufVar
+         |
+         |  ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+         |}
+       """.stripMargin)
+
+    val numOutput = metricTerm(ctx, "numOutputRows")
+    val aggTime = metricTerm(ctx, "aggTime")
+    val beforeAgg = ctx.freshName("beforeAgg")
+    s"""
+       |while (!$initAgg) {
+       |  $initAgg = true;
+       |  long $beforeAgg = System.nanoTime();
+       |  $doAggFuncName();
+       |  $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS);
+       |
+       |  // output the result
+       |  ${genResult.trim}
+       |
+       |  $numOutput.add(1);
+       |  ${consume(ctx, resultVars).trim}
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * The generated code for `doConsume` call when aggregate does not have 
grouping keys.
+   */
+  private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
+    // only have DeclarativeAggregate
+    val functions = 
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+    val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ 
inputAttributes
+    // To individually generate code for each aggregate function, an element 
in `updateExprs` holds
+    // all the expressions for the buffer of an aggregation function.
+    val updateExprs = aggregateExpressions.map { e =>
+      e.mode match {
+        case Partial | Complete =>
+          
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+        case PartialMerge | Final =>
+          
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+      }
+    }
+    ctx.currentVars = bufVars.flatten ++ input
+    val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
+      bindReferences(updateExprsForOneFunc, inputAttrs)
+    }
+    val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
+    val effectiveCodes = 
ctx.evaluateSubExprEliminationState(subExprs.states.values)
+    val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
+      ctx.withSubExprEliminationExprs(subExprs.states) {
+        boundUpdateExprsForOneFunc.map(_.genCode(ctx))
+      }
+    }
+
+    val aggNames = functions.map(_.prettyName)
+    val aggCodeBlocks = bufferEvals.zipWithIndex.map { case 
(bufferEvalsForOneFunc, i) =>
+      val bufVarsForOneFunc = bufVars(i)
+      // All the update code for aggregation buffers should be placed in the 
end
+      // of each aggregation function code.
+      val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case 
(ev, bufVar) =>
+        s"""
+           |${bufVar.isNull} = ${ev.isNull};
+           |${bufVar.value} = ${ev.value};
+         """.stripMargin
+      }
+      code"""
+            |${ctx.registerComment(s"do aggregate for ${aggNames(i)}")}
+            |${ctx.registerComment("evaluate aggregate function")}
+            |${evaluateVariables(bufferEvalsForOneFunc)}
+            |${ctx.registerComment("update aggregation buffers")}
+            |${updates.mkString("\n").trim}
+       """.stripMargin
+    }
+
+    val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
+      ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, 
subExprs)
+    s"""
+       |// do aggregate
+       |// common sub-expressions
+       |$effectiveCodes
+       |// evaluate aggregate functions and update aggregation buffers
+       |$codeToEvalAggFuncs
+     """.stripMargin
+  }
+
+  /**
+   * The generated code to evaluate aggregate functions.
+   */
+  protected def generateEvalCodeForAggFuncs(
+      ctx: CodegenContext,
+      input: Seq[ExprCode],
+      inputAttrs: Seq[Attribute],
+      boundUpdateExprs: Seq[Seq[Expression]],
+      aggNames: Seq[String],
+      aggCodeBlocks: Seq[Block],
+      subExprs: SubExprCodes): String = {
+    val aggCodes = if (conf.codegenSplitAggregateFunc &&
+      aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
+      val maybeSplitCodes = splitAggregateExpressions(
+        ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
+
+      maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
+    } else {
+      aggCodeBlocks.map(_.code)
+    }
+
+    aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map {
+      case (aggCode, (Partial | Complete, Some(condition))) =>
+        // Note: wrap in "do { } while(false);", so the generated checks can 
jump out
+        // with "continue;"
+        s"""
+           |do {
+           |  ${generatePredicateCode(ctx, condition, inputAttrs, input)}
+           |  $aggCode
+           |} while(false);
+         """.stripMargin
+      case (aggCode, _) =>
+        aggCode
+    }.mkString("\n")
+  }
+
+  /**
+   * Splits aggregate code into small functions because the most of JVM 
implementations
+   * can not compile too long functions. Returns None if we are not able to 
split the given code.
+   *
+   * Note: The difference from `CodeGenerator.splitExpressions` is that we 
define an individual
+   * function for each aggregation function (e.g., SUM and AVG). For example, 
in a query
+   * `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
+   * for `SUM(a)` and `AVG(a)`.
+   */
+  private def splitAggregateExpressions(
+      ctx: CodegenContext,
+      aggNames: Seq[String],
+      aggBufferUpdatingExprs: Seq[Seq[Expression]],
+      aggCodeBlocks: Seq[Block],
+      subExprs: Map[ExpressionEquals, SubExprEliminationState]): 
Option[Seq[String]] = {
+    val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
+      s.eval.value :: s.eval.isNull :: Nil
+    }
+    if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
+      // `SimpleExprValue`s cannot be used as an input variable for split 
functions, so
+      // we give up splitting functions if it exists in `subExprs`.
+      None
+    } else {
+      val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
+        val inputVarsForOneFunc = aggExprsForOneFunc.map(
+          CodeGenerator.getLocalInputVariableValues(ctx, _, 
subExprs)._1).reduce(_ ++ _).toSeq
+        val paramLength = 
CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
+
+        // Checks if a parameter length for the `aggExprsForOneFunc` does not 
go over the JVM limit
+        if (CodeGenerator.isValidParamLength(paramLength)) {
+          Some(inputVarsForOneFunc)
+        } else {
+          None
+        }
+      }
+
+      // Checks if all the aggregate code can be split into pieces.
+      // If the parameter length of at lease one `aggExprsForOneFunc` goes 
over the limit,
+      // we totally give up splitting aggregate code.
+      if (inputVars.forall(_.isDefined)) {
+        val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
+          val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}")
+          val argList = args.map { v =>
+            s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}"
+          }.mkString(", ")
+          val doAggFuncName = ctx.addNewFunction(doAggFunc,
+            s"""
+               |private void $doAggFunc($argList) throws java.io.IOException {
+               |  ${aggCodeBlocks(i)}
+               |}
+             """.stripMargin)
+
+          val inputVariables = args.map(_.variableName).mkString(", ")
+          s"$doAggFuncName($inputVariables);"
+        }
+        Some(splitCodes)
+      } else {
+        val errMsg = "Failed to split aggregate code into small functions 
because the parameter " +
+          "length of at least one split function went over the JVM limit: " +
+          CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
+        if (Utils.isTesting) {
+          throw new IllegalStateException(errMsg)
+        } else {
+          logInfo(errMsg)
+          None
+        }
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 8545154..d4a4502 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -51,9 +51,7 @@ case class HashAggregateExec(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends BaseAggregateExec
-  with BlockingOperatorWithCodegen
-  with GeneratePredicateHelper {
+  extends AggregateCodegenSupport {
 
   require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
 
@@ -130,279 +128,6 @@ case class HashAggregateExec(
     }
   }
 
-  // all the mode of aggregate expressions
-  private val modes = aggregateExpressions.map(_.mode).distinct
-
-  override def usedInputs: AttributeSet = inputSet
-
-  override def supportCodegen: Boolean = {
-    // ImperativeAggregate are not supported right now
-    
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
-  }
-
-  override def inputRDDs(): Seq[RDD[InternalRow]] = {
-    child.asInstanceOf[CodegenSupport].inputRDDs()
-  }
-
-  protected override def doProduce(ctx: CodegenContext): String = {
-    if (groupingExpressions.isEmpty) {
-      doProduceWithoutKeys(ctx)
-    } else {
-      doProduceWithKeys(ctx)
-    }
-  }
-
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    if (groupingExpressions.isEmpty) {
-      doConsumeWithoutKeys(ctx, input)
-    } else {
-      doConsumeWithKeys(ctx, input)
-    }
-  }
-
-  // The variables are used as aggregation buffers and each aggregate function 
has one or more
-  // ExprCode to initialize its buffer slots. Only used for aggregation 
without keys.
-  private var bufVars: Seq[Seq[ExprCode]] = _
-
-  private def doProduceWithoutKeys(ctx: CodegenContext): String = {
-    val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
-    // The generated function doesn't have input row in the code context.
-    ctx.INPUT_ROW = null
-
-    // generate variables for aggregation buffer
-    val functions = 
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
-    val initExpr = functions.map(f => f.initialValues)
-    bufVars = initExpr.map { exprs =>
-      exprs.map { e =>
-        val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, 
"bufIsNull")
-        val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), 
"bufValue")
-        // The initial expression should not access any column
-        val ev = e.genCode(ctx)
-        val initVars = code"""
-           |$isNull = ${ev.isNull};
-           |$value = ${ev.value};
-         """.stripMargin
-        ExprCode(
-          ev.code + initVars,
-          JavaCode.isNullGlobal(isNull),
-          JavaCode.global(value, e.dataType))
-      }
-    }
-    val flatBufVars = bufVars.flatten
-    val initBufVar = evaluateVariables(flatBufVars)
-
-    // generate variables for output
-    val (resultVars, genResult) = if (modes.contains(Final) || 
modes.contains(Complete)) {
-      // evaluate aggregate results
-      ctx.currentVars = flatBufVars
-      val aggResults = bindReferences(
-        functions.map(_.evaluateExpression),
-        aggregateBufferAttributes).map(_.genCode(ctx))
-      val evaluateAggResults = evaluateVariables(aggResults)
-      // evaluate result expressions
-      ctx.currentVars = aggResults
-      val resultVars = bindReferences(resultExpressions, 
aggregateAttributes).map(_.genCode(ctx))
-      (resultVars, s"""
-        |$evaluateAggResults
-        |${evaluateVariables(resultVars)}
-       """.stripMargin)
-    } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
-      // output the aggregate buffer directly
-      (flatBufVars, "")
-    } else {
-      // no aggregate function, the result should be literals
-      val resultVars = resultExpressions.map(_.genCode(ctx))
-      (resultVars, evaluateVariables(resultVars))
-    }
-
-    val doAgg = ctx.freshName("doAggregateWithoutKey")
-    val doAggFuncName = ctx.addNewFunction(doAgg,
-      s"""
-         |private void $doAgg() throws java.io.IOException {
-         |  // initialize aggregation buffer
-         |  $initBufVar
-         |
-         |  ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
-         |}
-       """.stripMargin)
-
-    val numOutput = metricTerm(ctx, "numOutputRows")
-    val aggTime = metricTerm(ctx, "aggTime")
-    val beforeAgg = ctx.freshName("beforeAgg")
-    s"""
-       |while (!$initAgg) {
-       |  $initAgg = true;
-       |  long $beforeAgg = System.nanoTime();
-       |  $doAggFuncName();
-       |  $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS);
-       |
-       |  // output the result
-       |  ${genResult.trim}
-       |
-       |  $numOutput.add(1);
-       |  ${consume(ctx, resultVars).trim}
-       |}
-     """.stripMargin
-  }
-
-  // Splits aggregate code into small functions because the most of JVM 
implementations
-  // can not compile too long functions. Returns None if we are not able to 
split the given code.
-  //
-  // Note: The difference from `CodeGenerator.splitExpressions` is that we 
define an individual
-  // function for each aggregation function (e.g., SUM and AVG). For example, 
in a query
-  // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
-  // for `SUM(a)` and `AVG(a)`.
-  private def splitAggregateExpressions(
-      ctx: CodegenContext,
-      aggNames: Seq[String],
-      aggBufferUpdatingExprs: Seq[Seq[Expression]],
-      aggCodeBlocks: Seq[Block],
-      subExprs: Map[ExpressionEquals, SubExprEliminationState]): 
Option[Seq[String]] = {
-    val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
-      s.eval.value :: s.eval.isNull :: Nil
-    }
-    if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
-      // `SimpleExprValue`s cannot be used as an input variable for split 
functions, so
-      // we give up splitting functions if it exists in `subExprs`.
-      None
-    } else {
-      val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
-        val inputVarsForOneFunc = aggExprsForOneFunc.map(
-          CodeGenerator.getLocalInputVariableValues(ctx, _, 
subExprs)._1).reduce(_ ++ _).toSeq
-        val paramLength = 
CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
-
-        // Checks if a parameter length for the `aggExprsForOneFunc` does not 
go over the JVM limit
-        if (CodeGenerator.isValidParamLength(paramLength)) {
-          Some(inputVarsForOneFunc)
-        } else {
-          None
-        }
-      }
-
-      // Checks if all the aggregate code can be split into pieces.
-      // If the parameter length of at lease one `aggExprsForOneFunc` goes 
over the limit,
-      // we totally give up splitting aggregate code.
-      if (inputVars.forall(_.isDefined)) {
-        val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
-          val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}")
-          val argList = args.map { v =>
-            s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}"
-          }.mkString(", ")
-          val doAggFuncName = ctx.addNewFunction(doAggFunc,
-            s"""
-               |private void $doAggFunc($argList) throws java.io.IOException {
-               |  ${aggCodeBlocks(i)}
-               |}
-             """.stripMargin)
-
-          val inputVariables = args.map(_.variableName).mkString(", ")
-          s"$doAggFuncName($inputVariables);"
-        }
-        Some(splitCodes)
-      } else {
-        val errMsg = "Failed to split aggregate code into small functions 
because the parameter " +
-          "length of at least one split function went over the JVM limit: " +
-          CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
-        if (Utils.isTesting) {
-          throw new IllegalStateException(errMsg)
-        } else {
-          logInfo(errMsg)
-          None
-        }
-      }
-    }
-  }
-
-  private def generateEvalCodeForAggFuncs(
-      ctx: CodegenContext,
-      input: Seq[ExprCode],
-      inputAttrs: Seq[Attribute],
-      boundUpdateExprs: Seq[Seq[Expression]],
-      aggNames: Seq[String],
-      aggCodeBlocks: Seq[Block],
-      subExprs: SubExprCodes): String = {
-    val aggCodes = if (conf.codegenSplitAggregateFunc &&
-      aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
-      val maybeSplitCodes = splitAggregateExpressions(
-        ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
-
-      maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
-    } else {
-      aggCodeBlocks.map(_.code)
-    }
-
-    aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map {
-      case (aggCode, (Partial | Complete, Some(condition))) =>
-        // Note: wrap in "do { } while(false);", so the generated checks can 
jump out
-        // with "continue;"
-        s"""
-           |do {
-           |  ${generatePredicateCode(ctx, condition, inputAttrs, input)}
-           |  $aggCode
-           |} while(false);
-         """.stripMargin
-      case (aggCode, _) =>
-        aggCode
-    }.mkString("\n")
-  }
-
-  private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
-    // only have DeclarativeAggregate
-    val functions = 
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
-    val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ 
inputAttributes
-    // To individually generate code for each aggregate function, an element 
in `updateExprs` holds
-    // all the expressions for the buffer of an aggregation function.
-    val updateExprs = aggregateExpressions.map { e =>
-      e.mode match {
-        case Partial | Complete =>
-          
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
-        case PartialMerge | Final =>
-          
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
-      }
-    }
-    ctx.currentVars = bufVars.flatten ++ input
-    val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
-      bindReferences(updateExprsForOneFunc, inputAttrs)
-    }
-    val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
-    val effectiveCodes = 
ctx.evaluateSubExprEliminationState(subExprs.states.values)
-    val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
-      ctx.withSubExprEliminationExprs(subExprs.states) {
-        boundUpdateExprsForOneFunc.map(_.genCode(ctx))
-      }
-    }
-
-    val aggNames = functions.map(_.prettyName)
-    val aggCodeBlocks = bufferEvals.zipWithIndex.map { case 
(bufferEvalsForOneFunc, i) =>
-      val bufVarsForOneFunc = bufVars(i)
-      // All the update code for aggregation buffers should be placed in the 
end
-      // of each aggregation function code.
-      val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case 
(ev, bufVar) =>
-        s"""
-           |${bufVar.isNull} = ${ev.isNull};
-           |${bufVar.value} = ${ev.value};
-         """.stripMargin
-      }
-      code"""
-         |${ctx.registerComment(s"do aggregate for ${aggNames(i)}")}
-         |${ctx.registerComment("evaluate aggregate function")}
-         |${evaluateVariables(bufferEvalsForOneFunc)}
-         |${ctx.registerComment("update aggregation buffers")}
-         |${updates.mkString("\n").trim}
-       """.stripMargin
-    }
-
-    val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
-      ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, 
subExprs)
-    s"""
-       |// do aggregate
-       |// common sub-expressions
-       |$effectiveCodes
-       |// evaluate aggregate functions and update aggregation buffers
-       |$codeToEvalAggFuncs
-     """.stripMargin
-  }
-
   private val groupingAttributes = groupingExpressions.map(_.toAttribute)
   private val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
   private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
@@ -692,7 +417,7 @@ case class HashAggregateExec(
     }
   }
 
-  private def doProduceWithKeys(ctx: CodegenContext): String = {
+  protected override def doProduceWithKeys(ctx: CodegenContext): String = {
     val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
     if (conf.enableTwoLevelAggMap) {
       enableTwoLevelHashMap(ctx)
@@ -891,7 +616,7 @@ case class HashAggregateExec(
      """.stripMargin
   }
 
-  private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
+  protected override def doConsumeWithKeys(ctx: CodegenContext, input: 
Seq[ExprCode]): String = {
     // create grouping key
     val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
       ctx, bindReferences[Expression](groupingExpressions, child.output))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 4fb0f44..f5462d2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -17,13 +17,17 @@
 
 package org.apache.spark.sql.execution.aggregate
 
+import java.util.concurrent.TimeUnit.NANOSECONDS
+
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Sort-based aggregate operator.
@@ -36,11 +40,12 @@ case class SortAggregateExec(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends BaseAggregateExec
+  extends AggregateCodegenSupport
   with AliasAwareOutputOrdering {
 
   override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
+    "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in 
aggregation build"))
 
   override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
     groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
@@ -52,11 +57,14 @@ case class SortAggregateExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val aggTime = longMetric("aggTime")
+
     child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
+      val beforeAgg = System.nanoTime()
       // Because the constructor of an aggregation iterator will read at least 
the first row,
       // we need to get the value of iter.hasNext first.
       val hasInput = iter.hasNext
-      if (!hasInput && groupingExpressions.nonEmpty) {
+      val res = if (!hasInput && groupingExpressions.nonEmpty) {
         // This is a grouped aggregate and the input iterator is empty,
         // so return an empty iterator.
         Iterator[UnsafeRow]()
@@ -82,9 +90,25 @@ case class SortAggregateExec(
           outputIter
         }
       }
+      aggTime += NANOSECONDS.toMillis(System.nanoTime() - beforeAgg)
+      res
     }
   }
 
+  override def supportCodegen: Boolean = {
+    // TODO(SPARK-32750): Support sort aggregate code-gen with grouping keys
+    super.supportCodegen && 
conf.getConf(SQLConf.ENABLE_SORT_AGGREGATE_CODEGEN) &&
+      groupingExpressions.isEmpty
+  }
+
+  protected def doProduceWithKeys(ctx: CodegenContext): String = {
+    throw new UnsupportedOperationException("SortAggregate code-gen does not 
support grouping keys")
+  }
+
+  protected def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
+    throw new UnsupportedOperationException("SortAggregate code-gen does not 
support grouping keys")
+  }
+
   override def simpleString(maxFields: Int): String = toString(verbose = 
false, maxFields)
 
   override def verboseString(maxFields: Int): String = toString(verbose = 
true, maxFields)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 55ca1e8..7332d49 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
 import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, 
CodeAndComment, CodeGenerator}
 import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
SortAggregateExec}
 import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
 import org.apache.spark.sql.functions._
@@ -41,7 +41,7 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSparkSession
     assert(df.collect() === Array(Row(2)))
   }
 
-  test("Aggregate should be included in WholeStageCodegen") {
+  test("HashAggregate should be included in WholeStageCodegen") {
     val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
     val plan = df.queryExecution.executedPlan
     assert(plan.find(p =>
@@ -50,6 +50,17 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSparkSession
     assert(df.collect() === Array(Row(9, 4.5)))
   }
 
+  test("SortAggregate should be included in WholeStageCodegen") {
+    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
+    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
+      val plan = df.queryExecution.executedPlan
+      assert(plan.find(p =>
+        p.isInstanceOf[WholeStageCodegenExec] &&
+          
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]).isDefined)
+      assert(df.collect() === Array(Row(9, 4.5)))
+    }
+  }
+
   testWithWholeStageCodegenOnAndOff("GenerateExec should be" +
     " included in WholeStageCodegen") { codegenEnabled =>
     import testImplicits._
@@ -129,7 +140,7 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSparkSession
           Map("hair" -> "black", "eye" -> "brown"), "eye", "brown")))
   }
 
-  test("Aggregate with grouping keys should be included in WholeStageCodegen") 
{
+  test("HashAggregate with grouping keys should be included in 
WholeStageCodegen") {
     val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 
2)
     val plan = df.queryExecution.executedPlan
     assert(plan.find(p =>

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to