This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5e3873d [SPARK-37564][SQL] Add code-gen for sort aggregate without
grouping keys
5e3873d is described below
commit 5e3873d6adca0b93efd4f30a5efc80603d0c6c5e
Author: Cheng Su <[email protected]>
AuthorDate: Fri Dec 17 20:18:37 2021 +0800
[SPARK-37564][SQL] Add code-gen for sort aggregate without grouping keys
### What changes were proposed in this pull request?
We can support `SortAggregateExec` code-gen without grouping keys. When
there's no grouping key, sort aggregate should share same execution logic with
hash aggregate.
At a high level, this PR does
* `spark.sql.codegen.aggregate.sortAggregate.enabled`: a new config
(user-facing) is introduced to allow users to disable sort aggregate in case
anything going wrong, and it's enabled by default.
* `spark.sql.aggregate.forceApplySortAggregate`: an internal config
(non-user-facing) is introduced to allow developers test sort aggregate, to
improve unit test coverage of it. We already had a similar config for shuffled
hash join.
* `AggregateCodegenSupport.scala`: The base class to have implementation of
code-gen without grouping keys and other code-gen related boilerplate code. The
subclass is required to implement `doProduceWithKeys()` and
`doConsumeWithKeys()` to handle aggregate with grouping keys. The
implementation in this class is copied literarily from `HashAggregateExec`.
* `HashAggregateExec.scala`: only keeps its original `doProduceWithKeys()`
and `doConsumeWithKeys()`.
* `SortAggregateExec.scala`: extends `AggregateCodegenSupport.scala` to
support code-gen without grouping keys.
The implementation of `SortAggregateExec` code-gen with grouping keys will
be added in follow-up PR.
### Why are the changes needed?
To enable code-gen for sort aggregate, so we can have better query
performance when sort aggregate is used.
### Does this PR introduce _any_ user-facing change?
The config of `spark.sql.codegen.aggregate.sortAggregate.enabled` to
enable/disable sort aggregate code-gen.
### How was this patch tested?
Added unit test in `WholeStageCodegenSuite.scala`.
Closes #34826 from c21/sort-agg-codegen.
Authored-by: Cheng Su <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/sql/internal/SQLConf.scala | 8 +
.../sql/execution/WholeStageCodegenExec.scala | 5 +-
.../spark/sql/execution/aggregate/AggUtils.scala | 17 +-
.../aggregate/AggregateCodegenSupport.scala | 340 +++++++++++++++++++++
.../execution/aggregate/HashAggregateExec.scala | 281 +----------------
.../execution/aggregate/SortAggregateExec.scala | 30 +-
.../sql/execution/WholeStageCodegenSuite.scala | 17 +-
7 files changed, 410 insertions(+), 288 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6c6fb40..868340b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1781,6 +1781,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val ENABLE_SORT_AGGREGATE_CODEGEN =
+ buildConf("spark.sql.codegen.aggregate.sortAggregate.enabled")
+ .internal()
+ .doc("When true, enable code-gen for sort aggregate.")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(true)
+
val ENABLE_FULL_OUTER_SHUFFLED_HASH_JOIN_CODEGEN =
buildConf("spark.sql.codegen.join.fullOuterShuffledHashJoin.enabled")
.internal()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index e80ad89..dde976c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -47,7 +47,8 @@ trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
- case _: HashAggregateExec => "agg"
+ case _: HashAggregateExec => "hashAgg"
+ case _: SortAggregateExec => "sortAgg"
case _: BroadcastHashJoinExec => "bhj"
case _: ShuffledHashJoinExec => "shj"
case _: SortMergeJoinExec => "smj"
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 0f239b4..32db622 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
/**
* Utility functions used by the query planner to convert our plan to new
aggregation code path.
@@ -53,7 +55,9 @@ object AggUtils {
child: SparkPlan): SparkPlan = {
val useHash = HashAggregateExec.supportsAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
- if (useHash) {
+ val forceSortAggregate = forceApplySortAggregate(child.conf)
+
+ if (useHash && !forceSortAggregate) {
HashAggregateExec(
requiredChildDistributionExpressions =
requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
@@ -66,7 +70,7 @@ object AggUtils {
val objectHashEnabled = child.conf.useObjectHashAggregation
val useObjectHash =
ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
- if (objectHashEnabled && useObjectHash) {
+ if (objectHashEnabled && useObjectHash && !forceSortAggregate) {
ObjectHashAggregateExec(
requiredChildDistributionExpressions =
requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
@@ -529,4 +533,13 @@ object AggUtils {
case None => partialAggregate
}
}
+
+ /**
+ * Returns whether a sort aggregate should be force applied.
+ * The config key is hard-coded because it's testing only and should not be
exposed.
+ */
+ private def forceApplySortAggregate(conf: SQLConf): Boolean = {
+ Utils.isTesting &&
+ conf.getConfString("spark.sql.test.forceApplySortAggregate", "false") ==
"true"
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala
new file mode 100644
index 0000000..6304363
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, ExpressionEquals, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
+import org.apache.spark.sql.execution.{BlockingOperatorWithCodegen,
CodegenSupport, GeneratePredicateHelper}
+import org.apache.spark.util.Utils
+
+/**
+ * An interface for those aggregate physical operators that support codegen.
+ */
+trait AggregateCodegenSupport
+ extends BaseAggregateExec
+ with BlockingOperatorWithCodegen
+ with GeneratePredicateHelper {
+
+ /**
+ * All the modes of aggregate expressions.
+ */
+ protected val modes: Seq[AggregateMode] =
aggregateExpressions.map(_.mode).distinct
+
+ /**
+ * The variables are used as aggregation buffers and each aggregate function
has one or more
+ * ExprCode to initialize its buffer slots. Only used for aggregation
without keys.
+ */
+ private var bufVars: Seq[Seq[ExprCode]] = _
+
+ /**
+ * The generated code for `doProduce` call when aggregate has grouping keys.
+ */
+ protected def doProduceWithKeys(ctx: CodegenContext): String
+
+ /**
+ * The generated code for `doConsume` call when aggregate has grouping keys.
+ */
+ protected def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]):
String
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ if (groupingExpressions.isEmpty) {
+ doProduceWithoutKeys(ctx)
+ } else {
+ doProduceWithKeys(ctx)
+ }
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row:
ExprCode): String = {
+ if (groupingExpressions.isEmpty) {
+ doConsumeWithoutKeys(ctx, input)
+ } else {
+ doConsumeWithKeys(ctx, input)
+ }
+ }
+
+ override def supportCodegen: Boolean = {
+ val isMutableAggBuffer = aggregateBufferAttributes.forall(a =>
UnsafeRow.isMutable(a.dataType))
+ // ImperativeAggregate are not supported right now
+ isMutableAggBuffer &&
+
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].inputRDDs()
+ }
+
+ override def usedInputs: AttributeSet = inputSet
+
+ /**
+ * The generated code for `doProduce` call when aggregate does not have
grouping keys.
+ */
+ private def doProduceWithoutKeys(ctx: CodegenContext): String = {
+ val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
+ // The generated function doesn't have input row in the code context.
+ ctx.INPUT_ROW = null
+
+ // generate variables for aggregation buffer
+ val functions =
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+ val initExpr = functions.map(f => f.initialValues)
+ bufVars = initExpr.map { exprs =>
+ exprs.map { e =>
+ val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN,
"bufIsNull")
+ val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType),
"bufValue")
+ // The initial expression should not access any column
+ val ev = e.genCode(ctx)
+ val initVars =
+ code"""
+ |$isNull = ${ev.isNull};
+ |$value = ${ev.value};
+ """.stripMargin
+ ExprCode(
+ ev.code + initVars,
+ JavaCode.isNullGlobal(isNull),
+ JavaCode.global(value, e.dataType))
+ }
+ }
+ val flatBufVars = bufVars.flatten
+ val initBufVar = evaluateVariables(flatBufVars)
+
+ // generate variables for output
+ val (resultVars, genResult) = if (modes.contains(Final) ||
modes.contains(Complete)) {
+ // evaluate aggregate results
+ ctx.currentVars = flatBufVars
+ val aggResults = bindReferences(
+ functions.map(_.evaluateExpression),
+ aggregateBufferAttributes).map(_.genCode(ctx))
+ val evaluateAggResults = evaluateVariables(aggResults)
+ // evaluate result expressions
+ ctx.currentVars = aggResults
+ val resultVars = bindReferences(resultExpressions,
aggregateAttributes).map(_.genCode(ctx))
+ (resultVars,
+ s"""
+ |$evaluateAggResults
+ |${evaluateVariables(resultVars)}
+ """.stripMargin)
+ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+ // output the aggregate buffer directly
+ (flatBufVars, "")
+ } else {
+ // no aggregate function, the result should be literals
+ val resultVars = resultExpressions.map(_.genCode(ctx))
+ (resultVars, evaluateVariables(resultVars))
+ }
+
+ val doAgg = ctx.freshName("doAggregateWithoutKey")
+ val doAggFuncName = ctx.addNewFunction(doAgg,
+ s"""
+ |private void $doAgg() throws java.io.IOException {
+ | // initialize aggregation buffer
+ | $initBufVar
+ |
+ | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ |}
+ """.stripMargin)
+
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ val aggTime = metricTerm(ctx, "aggTime")
+ val beforeAgg = ctx.freshName("beforeAgg")
+ s"""
+ |while (!$initAgg) {
+ | $initAgg = true;
+ | long $beforeAgg = System.nanoTime();
+ | $doAggFuncName();
+ | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS);
+ |
+ | // output the result
+ | ${genResult.trim}
+ |
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars).trim}
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * The generated code for `doConsume` call when aggregate does not have
grouping keys.
+ */
+ private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ // only have DeclarativeAggregate
+ val functions =
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+ val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++
inputAttributes
+ // To individually generate code for each aggregate function, an element
in `updateExprs` holds
+ // all the expressions for the buffer of an aggregation function.
+ val updateExprs = aggregateExpressions.map { e =>
+ e.mode match {
+ case Partial | Complete =>
+
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+ case PartialMerge | Final =>
+
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+ }
+ }
+ ctx.currentVars = bufVars.flatten ++ input
+ val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
+ bindReferences(updateExprsForOneFunc, inputAttrs)
+ }
+ val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
+ val effectiveCodes =
ctx.evaluateSubExprEliminationState(subExprs.states.values)
+ val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
+ ctx.withSubExprEliminationExprs(subExprs.states) {
+ boundUpdateExprsForOneFunc.map(_.genCode(ctx))
+ }
+ }
+
+ val aggNames = functions.map(_.prettyName)
+ val aggCodeBlocks = bufferEvals.zipWithIndex.map { case
(bufferEvalsForOneFunc, i) =>
+ val bufVarsForOneFunc = bufVars(i)
+ // All the update code for aggregation buffers should be placed in the
end
+ // of each aggregation function code.
+ val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case
(ev, bufVar) =>
+ s"""
+ |${bufVar.isNull} = ${ev.isNull};
+ |${bufVar.value} = ${ev.value};
+ """.stripMargin
+ }
+ code"""
+ |${ctx.registerComment(s"do aggregate for ${aggNames(i)}")}
+ |${ctx.registerComment("evaluate aggregate function")}
+ |${evaluateVariables(bufferEvalsForOneFunc)}
+ |${ctx.registerComment("update aggregation buffers")}
+ |${updates.mkString("\n").trim}
+ """.stripMargin
+ }
+
+ val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
+ ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks,
subExprs)
+ s"""
+ |// do aggregate
+ |// common sub-expressions
+ |$effectiveCodes
+ |// evaluate aggregate functions and update aggregation buffers
+ |$codeToEvalAggFuncs
+ """.stripMargin
+ }
+
+ /**
+ * The generated code to evaluate aggregate functions.
+ */
+ protected def generateEvalCodeForAggFuncs(
+ ctx: CodegenContext,
+ input: Seq[ExprCode],
+ inputAttrs: Seq[Attribute],
+ boundUpdateExprs: Seq[Seq[Expression]],
+ aggNames: Seq[String],
+ aggCodeBlocks: Seq[Block],
+ subExprs: SubExprCodes): String = {
+ val aggCodes = if (conf.codegenSplitAggregateFunc &&
+ aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
+ val maybeSplitCodes = splitAggregateExpressions(
+ ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
+
+ maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
+ } else {
+ aggCodeBlocks.map(_.code)
+ }
+
+ aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map {
+ case (aggCode, (Partial | Complete, Some(condition))) =>
+ // Note: wrap in "do { } while(false);", so the generated checks can
jump out
+ // with "continue;"
+ s"""
+ |do {
+ | ${generatePredicateCode(ctx, condition, inputAttrs, input)}
+ | $aggCode
+ |} while(false);
+ """.stripMargin
+ case (aggCode, _) =>
+ aggCode
+ }.mkString("\n")
+ }
+
+ /**
+ * Splits aggregate code into small functions because the most of JVM
implementations
+ * can not compile too long functions. Returns None if we are not able to
split the given code.
+ *
+ * Note: The difference from `CodeGenerator.splitExpressions` is that we
define an individual
+ * function for each aggregation function (e.g., SUM and AVG). For example,
in a query
+ * `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
+ * for `SUM(a)` and `AVG(a)`.
+ */
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggNames: Seq[String],
+ aggBufferUpdatingExprs: Seq[Seq[Expression]],
+ aggCodeBlocks: Seq[Block],
+ subExprs: Map[ExpressionEquals, SubExprEliminationState]):
Option[Seq[String]] = {
+ val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
+ s.eval.value :: s.eval.isNull :: Nil
+ }
+ if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
+ // `SimpleExprValue`s cannot be used as an input variable for split
functions, so
+ // we give up splitting functions if it exists in `subExprs`.
+ None
+ } else {
+ val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
+ val inputVarsForOneFunc = aggExprsForOneFunc.map(
+ CodeGenerator.getLocalInputVariableValues(ctx, _,
subExprs)._1).reduce(_ ++ _).toSeq
+ val paramLength =
CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
+
+ // Checks if a parameter length for the `aggExprsForOneFunc` does not
go over the JVM limit
+ if (CodeGenerator.isValidParamLength(paramLength)) {
+ Some(inputVarsForOneFunc)
+ } else {
+ None
+ }
+ }
+
+ // Checks if all the aggregate code can be split into pieces.
+ // If the parameter length of at lease one `aggExprsForOneFunc` goes
over the limit,
+ // we totally give up splitting aggregate code.
+ if (inputVars.forall(_.isDefined)) {
+ val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
+ val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}")
+ val argList = args.map { v =>
+ s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}"
+ }.mkString(", ")
+ val doAggFuncName = ctx.addNewFunction(doAggFunc,
+ s"""
+ |private void $doAggFunc($argList) throws java.io.IOException {
+ | ${aggCodeBlocks(i)}
+ |}
+ """.stripMargin)
+
+ val inputVariables = args.map(_.variableName).mkString(", ")
+ s"$doAggFuncName($inputVariables);"
+ }
+ Some(splitCodes)
+ } else {
+ val errMsg = "Failed to split aggregate code into small functions
because the parameter " +
+ "length of at least one split function went over the JVM limit: " +
+ CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
+ if (Utils.isTesting) {
+ throw new IllegalStateException(errMsg)
+ } else {
+ logInfo(errMsg)
+ None
+ }
+ }
+ }
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 8545154..d4a4502 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -51,9 +51,7 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends BaseAggregateExec
- with BlockingOperatorWithCodegen
- with GeneratePredicateHelper {
+ extends AggregateCodegenSupport {
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
@@ -130,279 +128,6 @@ case class HashAggregateExec(
}
}
- // all the mode of aggregate expressions
- private val modes = aggregateExpressions.map(_.mode).distinct
-
- override def usedInputs: AttributeSet = inputSet
-
- override def supportCodegen: Boolean = {
- // ImperativeAggregate are not supported right now
-
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
- }
-
- override def inputRDDs(): Seq[RDD[InternalRow]] = {
- child.asInstanceOf[CodegenSupport].inputRDDs()
- }
-
- protected override def doProduce(ctx: CodegenContext): String = {
- if (groupingExpressions.isEmpty) {
- doProduceWithoutKeys(ctx)
- } else {
- doProduceWithKeys(ctx)
- }
- }
-
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row:
ExprCode): String = {
- if (groupingExpressions.isEmpty) {
- doConsumeWithoutKeys(ctx, input)
- } else {
- doConsumeWithKeys(ctx, input)
- }
- }
-
- // The variables are used as aggregation buffers and each aggregate function
has one or more
- // ExprCode to initialize its buffer slots. Only used for aggregation
without keys.
- private var bufVars: Seq[Seq[ExprCode]] = _
-
- private def doProduceWithoutKeys(ctx: CodegenContext): String = {
- val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
- // The generated function doesn't have input row in the code context.
- ctx.INPUT_ROW = null
-
- // generate variables for aggregation buffer
- val functions =
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- val initExpr = functions.map(f => f.initialValues)
- bufVars = initExpr.map { exprs =>
- exprs.map { e =>
- val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN,
"bufIsNull")
- val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType),
"bufValue")
- // The initial expression should not access any column
- val ev = e.genCode(ctx)
- val initVars = code"""
- |$isNull = ${ev.isNull};
- |$value = ${ev.value};
- """.stripMargin
- ExprCode(
- ev.code + initVars,
- JavaCode.isNullGlobal(isNull),
- JavaCode.global(value, e.dataType))
- }
- }
- val flatBufVars = bufVars.flatten
- val initBufVar = evaluateVariables(flatBufVars)
-
- // generate variables for output
- val (resultVars, genResult) = if (modes.contains(Final) ||
modes.contains(Complete)) {
- // evaluate aggregate results
- ctx.currentVars = flatBufVars
- val aggResults = bindReferences(
- functions.map(_.evaluateExpression),
- aggregateBufferAttributes).map(_.genCode(ctx))
- val evaluateAggResults = evaluateVariables(aggResults)
- // evaluate result expressions
- ctx.currentVars = aggResults
- val resultVars = bindReferences(resultExpressions,
aggregateAttributes).map(_.genCode(ctx))
- (resultVars, s"""
- |$evaluateAggResults
- |${evaluateVariables(resultVars)}
- """.stripMargin)
- } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
- // output the aggregate buffer directly
- (flatBufVars, "")
- } else {
- // no aggregate function, the result should be literals
- val resultVars = resultExpressions.map(_.genCode(ctx))
- (resultVars, evaluateVariables(resultVars))
- }
-
- val doAgg = ctx.freshName("doAggregateWithoutKey")
- val doAggFuncName = ctx.addNewFunction(doAgg,
- s"""
- |private void $doAgg() throws java.io.IOException {
- | // initialize aggregation buffer
- | $initBufVar
- |
- | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
- |}
- """.stripMargin)
-
- val numOutput = metricTerm(ctx, "numOutputRows")
- val aggTime = metricTerm(ctx, "aggTime")
- val beforeAgg = ctx.freshName("beforeAgg")
- s"""
- |while (!$initAgg) {
- | $initAgg = true;
- | long $beforeAgg = System.nanoTime();
- | $doAggFuncName();
- | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS);
- |
- | // output the result
- | ${genResult.trim}
- |
- | $numOutput.add(1);
- | ${consume(ctx, resultVars).trim}
- |}
- """.stripMargin
- }
-
- // Splits aggregate code into small functions because the most of JVM
implementations
- // can not compile too long functions. Returns None if we are not able to
split the given code.
- //
- // Note: The difference from `CodeGenerator.splitExpressions` is that we
define an individual
- // function for each aggregation function (e.g., SUM and AVG). For example,
in a query
- // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
- // for `SUM(a)` and `AVG(a)`.
- private def splitAggregateExpressions(
- ctx: CodegenContext,
- aggNames: Seq[String],
- aggBufferUpdatingExprs: Seq[Seq[Expression]],
- aggCodeBlocks: Seq[Block],
- subExprs: Map[ExpressionEquals, SubExprEliminationState]):
Option[Seq[String]] = {
- val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
- s.eval.value :: s.eval.isNull :: Nil
- }
- if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
- // `SimpleExprValue`s cannot be used as an input variable for split
functions, so
- // we give up splitting functions if it exists in `subExprs`.
- None
- } else {
- val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
- val inputVarsForOneFunc = aggExprsForOneFunc.map(
- CodeGenerator.getLocalInputVariableValues(ctx, _,
subExprs)._1).reduce(_ ++ _).toSeq
- val paramLength =
CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
-
- // Checks if a parameter length for the `aggExprsForOneFunc` does not
go over the JVM limit
- if (CodeGenerator.isValidParamLength(paramLength)) {
- Some(inputVarsForOneFunc)
- } else {
- None
- }
- }
-
- // Checks if all the aggregate code can be split into pieces.
- // If the parameter length of at lease one `aggExprsForOneFunc` goes
over the limit,
- // we totally give up splitting aggregate code.
- if (inputVars.forall(_.isDefined)) {
- val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
- val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}")
- val argList = args.map { v =>
- s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}"
- }.mkString(", ")
- val doAggFuncName = ctx.addNewFunction(doAggFunc,
- s"""
- |private void $doAggFunc($argList) throws java.io.IOException {
- | ${aggCodeBlocks(i)}
- |}
- """.stripMargin)
-
- val inputVariables = args.map(_.variableName).mkString(", ")
- s"$doAggFuncName($inputVariables);"
- }
- Some(splitCodes)
- } else {
- val errMsg = "Failed to split aggregate code into small functions
because the parameter " +
- "length of at least one split function went over the JVM limit: " +
- CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
- if (Utils.isTesting) {
- throw new IllegalStateException(errMsg)
- } else {
- logInfo(errMsg)
- None
- }
- }
- }
- }
-
- private def generateEvalCodeForAggFuncs(
- ctx: CodegenContext,
- input: Seq[ExprCode],
- inputAttrs: Seq[Attribute],
- boundUpdateExprs: Seq[Seq[Expression]],
- aggNames: Seq[String],
- aggCodeBlocks: Seq[Block],
- subExprs: SubExprCodes): String = {
- val aggCodes = if (conf.codegenSplitAggregateFunc &&
- aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
- val maybeSplitCodes = splitAggregateExpressions(
- ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
-
- maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
- } else {
- aggCodeBlocks.map(_.code)
- }
-
- aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map {
- case (aggCode, (Partial | Complete, Some(condition))) =>
- // Note: wrap in "do { } while(false);", so the generated checks can
jump out
- // with "continue;"
- s"""
- |do {
- | ${generatePredicateCode(ctx, condition, inputAttrs, input)}
- | $aggCode
- |} while(false);
- """.stripMargin
- case (aggCode, _) =>
- aggCode
- }.mkString("\n")
- }
-
- private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
- // only have DeclarativeAggregate
- val functions =
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++
inputAttributes
- // To individually generate code for each aggregate function, an element
in `updateExprs` holds
- // all the expressions for the buffer of an aggregation function.
- val updateExprs = aggregateExpressions.map { e =>
- e.mode match {
- case Partial | Complete =>
-
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
- case PartialMerge | Final =>
-
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
- }
- }
- ctx.currentVars = bufVars.flatten ++ input
- val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
- bindReferences(updateExprsForOneFunc, inputAttrs)
- }
- val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
- val effectiveCodes =
ctx.evaluateSubExprEliminationState(subExprs.states.values)
- val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
- ctx.withSubExprEliminationExprs(subExprs.states) {
- boundUpdateExprsForOneFunc.map(_.genCode(ctx))
- }
- }
-
- val aggNames = functions.map(_.prettyName)
- val aggCodeBlocks = bufferEvals.zipWithIndex.map { case
(bufferEvalsForOneFunc, i) =>
- val bufVarsForOneFunc = bufVars(i)
- // All the update code for aggregation buffers should be placed in the
end
- // of each aggregation function code.
- val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case
(ev, bufVar) =>
- s"""
- |${bufVar.isNull} = ${ev.isNull};
- |${bufVar.value} = ${ev.value};
- """.stripMargin
- }
- code"""
- |${ctx.registerComment(s"do aggregate for ${aggNames(i)}")}
- |${ctx.registerComment("evaluate aggregate function")}
- |${evaluateVariables(bufferEvalsForOneFunc)}
- |${ctx.registerComment("update aggregation buffers")}
- |${updates.mkString("\n").trim}
- """.stripMargin
- }
-
- val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
- ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks,
subExprs)
- s"""
- |// do aggregate
- |// common sub-expressions
- |$effectiveCodes
- |// evaluate aggregate functions and update aggregation buffers
- |$codeToEvalAggFuncs
- """.stripMargin
- }
-
private val groupingAttributes = groupingExpressions.map(_.toAttribute)
private val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
@@ -692,7 +417,7 @@ case class HashAggregateExec(
}
}
- private def doProduceWithKeys(ctx: CodegenContext): String = {
+ protected override def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
if (conf.enableTwoLevelAggMap) {
enableTwoLevelHashMap(ctx)
@@ -891,7 +616,7 @@ case class HashAggregateExec(
""".stripMargin
}
- private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ protected override def doConsumeWithKeys(ctx: CodegenContext, input:
Seq[ExprCode]): String = {
// create grouping key
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, bindReferences[Expression](groupingExpressions, child.output))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 4fb0f44..f5462d2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -17,13 +17,17 @@
package org.apache.spark.sql.execution.aggregate
+import java.util.concurrent.TimeUnit.NANOSECONDS
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
/**
* Sort-based aggregate operator.
@@ -36,11 +40,12 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends BaseAggregateExec
+ extends AggregateCodegenSupport
with AliasAwareOutputOrdering {
override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output
rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output
rows"),
+ "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in
aggregation build"))
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
@@ -52,11 +57,14 @@ case class SortAggregateExec(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
+ val aggTime = longMetric("aggTime")
+
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
+ val beforeAgg = System.nanoTime()
// Because the constructor of an aggregation iterator will read at least
the first row,
// we need to get the value of iter.hasNext first.
val hasInput = iter.hasNext
- if (!hasInput && groupingExpressions.nonEmpty) {
+ val res = if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator[UnsafeRow]()
@@ -82,9 +90,25 @@ case class SortAggregateExec(
outputIter
}
}
+ aggTime += NANOSECONDS.toMillis(System.nanoTime() - beforeAgg)
+ res
}
}
+ override def supportCodegen: Boolean = {
+ // TODO(SPARK-32750): Support sort aggregate code-gen with grouping keys
+ super.supportCodegen &&
conf.getConf(SQLConf.ENABLE_SORT_AGGREGATE_CODEGEN) &&
+ groupingExpressions.isEmpty
+ }
+
+ protected def doProduceWithKeys(ctx: CodegenContext): String = {
+ throw new UnsupportedOperationException("SortAggregate code-gen does not
support grouping keys")
+ }
+
+ protected def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ throw new UnsupportedOperationException("SortAggregate code-gen does not
support grouping keys")
+ }
+
override def simpleString(maxFields: Int): String = toString(verbose =
false, maxFields)
override def verboseString(maxFields: Int): String = toString(verbose =
true, maxFields)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 55ca1e8..7332d49 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats,
CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
@@ -41,7 +41,7 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
assert(df.collect() === Array(Row(2)))
}
- test("Aggregate should be included in WholeStageCodegen") {
+ test("HashAggregate should be included in WholeStageCodegen") {
val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
@@ -50,6 +50,17 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
assert(df.collect() === Array(Row(9, 4.5)))
}
+ test("SortAggregate should be included in WholeStageCodegen") {
+ val df = spark.range(10).agg(max(col("id")), avg(col("id")))
+ withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
+ val plan = df.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegenExec] &&
+
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]).isDefined)
+ assert(df.collect() === Array(Row(9, 4.5)))
+ }
+ }
+
testWithWholeStageCodegenOnAndOff("GenerateExec should be" +
" included in WholeStageCodegen") { codegenEnabled =>
import testImplicits._
@@ -129,7 +140,7 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
Map("hair" -> "black", "eye" -> "brown"), "eye", "brown")))
}
- test("Aggregate with grouping keys should be included in WholeStageCodegen")
{
+ test("HashAggregate with grouping keys should be included in
WholeStageCodegen") {
val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") *
2)
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]