maropu commented on a change in pull request #20965: [SPARK-21870][SQL] Split
aggregation code into small functions
URL: https://github.com/apache/spark/pull/20965#discussion_r319702479
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -824,59 +936,158 @@ case class HashAggregateExec(
// generating input columns, we use `currentVars`.
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++
input
+ val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName)
+ // Computes start offsets for each aggregation function code
+ // in the underlying buffer row.
+ val bufferStartOffsets = {
+ val offsets = mutable.ArrayBuffer[Int]()
+ var curOffset = 0
+ updateExprs.foreach { exprsForOneFunc =>
+ offsets += curOffset
+ curOffset += exprsForOneFunc.length
+ }
+ offsets.toArray
+ }
+
val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
- val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
- val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+ val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
+ bindReferences(updateExprsForOneFunc, inputAttr)
+ }
+ val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
- val unsafeRowBufferEvals =
ctx.withSubExprEliminationExprs(subExprs.states) {
- boundUpdateExpr.map(_.genCode(ctx))
+ val unsafeRowBufferEvals = boundUpdateExprs.map {
boundUpdateExprsForOneFunc =>
+ ctx.withSubExprEliminationExprs(subExprs.states) {
Review comment:
Yea, right. The current pr actually handles the case, e.g.,;
```
scala> sql("SELECT SUM(a + b), AVG(a + b) FROM VALUES((1, 1)) t(a,
b)").debugCodegen
/* 109 */ private void agg_doConsume_0(InternalRow localtablescan_row_0,
int agg_expr_0_0, int agg_expr_1_0) throws java.io.IOException {
/* 110 */ // do aggregate
/* 111 */ // common sub-expressions
/* 112 */ int agg_value_5 = -1;
/* 113 */
/* 114 */ agg_value_5 = agg_expr_0_0 + agg_expr_1_0;
/* 115 */ boolean agg_isNull_4 = false;
/* 116 */ long agg_value_4 = -1L;
/* 117 */ if (!false) {
/* 118 */ agg_value_4 = (long) agg_value_5;
/* 119 */ }
/* 120 */ // evaluate aggregate functions and update aggregation buffers
/* 121 */ agg_doAggregate_sum_0(agg_value_4, agg_isNull_4);
/* 122 */ agg_doAggregate_avg_0(agg_value_4, agg_isNull_4);
/* 123 */
/* 124 */ }
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]