rednaxelafx commented on a change in pull request #20965: [SPARK-21870][SQL]
Split aggregation code into small functions
URL: https://github.com/apache/spark/pull/20965#discussion_r316427809
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -267,29 +302,81 @@ case class HashAggregateExec(
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
- ctx.currentVars = bufVars ++ input
- val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
- val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
- val effectiveCodes = subExprs.codes.mkString("\n")
- val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
- boundUpdateExpr.map(_.genCode(ctx))
- }
- // aggregate buffer should be updated atomic
- val updates = aggVals.zipWithIndex.map { case (ev, i) =>
+
+ if (!conf.codegenSplitAggregateFunc) {
+ ctx.currentVars = bufVars ++ input
+ val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_,
inputAttrs))
+ val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+ val effectiveCodes = subExprs.codes.mkString("\n")
+ val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
+ boundUpdateExpr.map(_.genCode(ctx))
+ }
+ // aggregate buffer should be updated atomic
+ val updates = aggVals.zipWithIndex.map { case (ev, i) =>
+ s"""
+ | ${bufVars(i).isNull} = ${ev.isNull};
+ | ${bufVars(i).value} = ${ev.value};
+ """.stripMargin
+ }
+ s"""
+ | // do aggregate
+ | // common sub-expressions
+ | $effectiveCodes
+ | // evaluate aggregate function
+ | ${evaluateVariables(aggVals)}
+ | // update aggregation buffer
+ | ${updates.mkString("\n").trim}
+ """.stripMargin
+ } else {
+ // We need to copy the aggregation buffer to local variables first
because each aggregate
Review comment:
Could you please help elaborate why copying the aggregate buffer values to
local variables is required here?
In the current data layout design, `DeclarativeAggregate`s never share their
buffer across aggregate expressions; so it never made any sense to me why in
the old code we have to update all aggregate buffer slots at the end after all
update expressions have been evaluated.
The flow of:
`Load aggregate buffer -> evaluate update expressions -> store back to
aggregate buffer`
should be strictly confined to each aggregate expression.
The 3-step flow logically looks like this:
```
// step 1: load from aggregate buffer
val localAggBufSlot1_isNull = aggBufSlot1_isNull;
val localAggBufSlot1_value = aggBufSlot1_value;
// step 2: evaluate and materialize update/merge expressions
val isNull2 = false;
val value2 = (localAggBufSlot1_isNull ? 0 : localAggBufSlot1_value) + input1;
// step 3: write back to aggregate buffer
aggBufSlot1_isNull = isNull2;
aggBufSlot1_value = value2;
```
As you can see, evaluating the update/merge expressions won't have any side
effects on the aggregate buffers -- they're evaluated to local variables first,
and then stored back to aggregate buffers. So even if I simplify that to:
```
// NO step 1: load from aggregate buffer
// val localAggBufSlot1_isNull = aggBufSlot1_isNull;
// val localAggBufSlot1_value = aggBufSlot1_value;
// step 2: evaluate and materialize update/merge expressions, directly
loading from aggregate buffers
val isNull2 = false;
val value2 = (aggBufSlot1_isNull ? 0 : aggBufSlot1_value) + input1;
// step 3: write back to aggregate buffer
aggBufSlot1_isNull = isNull2;
aggBufSlot1_value = value2;
```
It'll be just as safe.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]