Github user nongli commented on a diff in the pull request:
https://github.com/apache/spark/pull/10855#discussion_r50918156
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
---
@@ -137,61 +157,297 @@ case class TungstenAggregate(
bufVars = initExpr.map { e =>
val isNull = ctx.freshName("bufIsNull")
val value = ctx.freshName("bufValue")
+ ctx.addMutableState("boolean", isNull, "")
+ ctx.addMutableState(ctx.javaType(e.dataType), value, "")
// The initial expression should not access any column
val ev = e.gen(ctx)
- val initVars = s"""
- | boolean $isNull = ${ev.isNull};
- | ${ctx.javaType(e.dataType)} $value = ${ev.value};
- """.stripMargin
+ val initVars =
+ s"""
+ $isNull = ${ev.isNull};
+ $value = ${ev.value};
+ """
ExprCode(ev.code + initVars, isNull, value)
}
- val (rdd, childSource) =
child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val source =
+ // generate variables for output
+ val (resultVars, genResult) = if (modes.contains(Final)
|modes.contains(Complete)) {
+ // evaluate aggregate results
+ ctx.currentVars = bufVars
+ val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
+ val aggResults = functions.map(_.evaluateExpression).map { e =>
+ BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+ }
+ // evaluate result expressions
+ ctx.currentVars = aggResults
+ val resultVars = resultExpressions.map { e =>
+ BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
+ }
+ (resultVars, s"""
+ ${aggResults.map(_.code).mkString("\n")}
+ ${resultVars.map(_.code).mkString("\n")}
+ """)
+ } else {
+ // output the aggregate buffer directly
+ (bufVars, "")
+ }
+
+ val doAgg = ctx.freshName("doAgg")
+ ctx.addNewFunction(doAgg,
s"""
- | if (!$initAgg) {
- | $initAgg = true;
- |
- | // initialize aggregation buffer
- | ${bufVars.map(_.code).mkString("\n")}
- |
- | $childSource
- |
- | // output the result
- | ${consume(ctx, bufVars)}
- | }
- """.stripMargin
-
- (rdd, source)
+ private void $doAgg() throws java.io.IOException {
+ // initialize aggregation buffer
+ ${bufVars.map(_.code).mkString("\n")}
+
+ ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ }
+ """)
+
+ s"""
+ if (!$initAgg) {
+ $initAgg = true;
+ $doAgg();
+
+ // output the result
+ $genResult
+
+ ${consume(ctx, resultVars)}
+ }
+ """
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input:
Seq[ExprCode]): String = {
+ private def doConsumeWithoutKeys(
+ ctx: CodegenContext,
+ child: SparkPlan,
+ input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions =
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- // the mode could be only Partial or PartialMerge
- val updateExpr = if (modes.contains(Partial)) {
- functions.flatMap(_.updateExpressions)
+ val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++
child.output
+ val updateExpr = aggregateExpressions.flatMap { e =>
+ e.mode match {
+ case Partial | Complete =>
+
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+ case PartialMerge | Final =>
+
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+ }
+ }
+ ctx.currentVars = bufVars ++ input
+ // TODO: support subexpression elimination
+ val updates = updateExpr.zipWithIndex.map { case (e, i) =>
+ val ev = BindReferences.bindReference[Expression](e,
inputAttrs).gen(ctx)
+ s"""
+ ${ev.code}
+ ${bufVars(i).isNull} = ${ev.isNull};
+ ${bufVars(i).value} = ${ev.value};
+ """
+ }
+
+ s"""
+ // do aggregate and update aggregation buffer
+ ${updates.mkString("")}
+ """
+ }
+
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+ val declFunctions = aggregateExpressions.map(_.aggregateFunction)
+ .filter(_.isInstanceOf[DeclarativeAggregate])
+ .map(_.asInstanceOf[DeclarativeAggregate])
+ val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
+ val bufferSchema = StructType.fromAttributes(bufferAttributes)
+
+ // The name for HashMap
+ var hashMapTerm: String = _
+
+ def createHashMap(): UnsafeFixedWidthAggregationMap = {
+ // create initialized aggregate buffer
+ val initExpr = declFunctions.flatMap(f => f.initialValues)
+ val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)
+
+ // create hashMap
+ new UnsafeFixedWidthAggregationMap(
+ initialBuffer,
+ bufferSchema,
+ groupingKeySchema,
+ TaskContext.get().taskMemoryManager(),
+ 1024 * 16, // initial capacity
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ false // disable tracking of performance metrics
+ )
+ }
+
+ def createUnsafeJoiner(): UnsafeRowJoiner = {
--- End diff --
can you consistently use "private def"
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]