Github user nongli commented on a diff in the pull request:
https://github.com/apache/spark/pull/10855#discussion_r50918676
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
---
@@ -137,61 +157,297 @@ case class TungstenAggregate(
bufVars = initExpr.map { e =>
val isNull = ctx.freshName("bufIsNull")
val value = ctx.freshName("bufValue")
+ ctx.addMutableState("boolean", isNull, "")
+ ctx.addMutableState(ctx.javaType(e.dataType), value, "")
// The initial expression should not access any column
val ev = e.gen(ctx)
- val initVars = s"""
- | boolean $isNull = ${ev.isNull};
- | ${ctx.javaType(e.dataType)} $value = ${ev.value};
- """.stripMargin
+ val initVars =
+ s"""
+ $isNull = ${ev.isNull};
+ $value = ${ev.value};
+ """
ExprCode(ev.code + initVars, isNull, value)
}
- val (rdd, childSource) =
child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val source =
+ // generate variables for output
+ val (resultVars, genResult) = if (modes.contains(Final)
|modes.contains(Complete)) {
+ // evaluate aggregate results
+ ctx.currentVars = bufVars
+ val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
+ val aggResults = functions.map(_.evaluateExpression).map { e =>
+ BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+ }
+ // evaluate result expressions
+ ctx.currentVars = aggResults
+ val resultVars = resultExpressions.map { e =>
+ BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
+ }
+ (resultVars, s"""
+ ${aggResults.map(_.code).mkString("\n")}
+ ${resultVars.map(_.code).mkString("\n")}
+ """)
+ } else {
+ // output the aggregate buffer directly
+ (bufVars, "")
+ }
+
+ val doAgg = ctx.freshName("doAgg")
+ ctx.addNewFunction(doAgg,
s"""
- | if (!$initAgg) {
- | $initAgg = true;
- |
- | // initialize aggregation buffer
- | ${bufVars.map(_.code).mkString("\n")}
- |
- | $childSource
- |
- | // output the result
- | ${consume(ctx, bufVars)}
- | }
- """.stripMargin
-
- (rdd, source)
+ private void $doAgg() throws java.io.IOException {
+ // initialize aggregation buffer
+ ${bufVars.map(_.code).mkString("\n")}
+
+ ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ }
+ """)
+
+ s"""
+ if (!$initAgg) {
+ $initAgg = true;
+ $doAgg();
+
+ // output the result
+ $genResult
+
+ ${consume(ctx, resultVars)}
+ }
+ """
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input:
Seq[ExprCode]): String = {
+ private def doConsumeWithoutKeys(
+ ctx: CodegenContext,
+ child: SparkPlan,
+ input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions =
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- // the mode could be only Partial or PartialMerge
- val updateExpr = if (modes.contains(Partial)) {
- functions.flatMap(_.updateExpressions)
+ val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++
child.output
+ val updateExpr = aggregateExpressions.flatMap { e =>
+ e.mode match {
+ case Partial | Complete =>
+
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+ case PartialMerge | Final =>
+
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+ }
+ }
+ ctx.currentVars = bufVars ++ input
+ // TODO: support subexpression elimination
+ val updates = updateExpr.zipWithIndex.map { case (e, i) =>
+ val ev = BindReferences.bindReference[Expression](e,
inputAttrs).gen(ctx)
+ s"""
+ ${ev.code}
+ ${bufVars(i).isNull} = ${ev.isNull};
+ ${bufVars(i).value} = ${ev.value};
+ """
+ }
+
+ s"""
+ // do aggregate and update aggregation buffer
+ ${updates.mkString("")}
+ """
+ }
+
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+ val declFunctions = aggregateExpressions.map(_.aggregateFunction)
+ .filter(_.isInstanceOf[DeclarativeAggregate])
+ .map(_.asInstanceOf[DeclarativeAggregate])
+ val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
+ val bufferSchema = StructType.fromAttributes(bufferAttributes)
+
+ // The name for HashMap
+ var hashMapTerm: String = _
+
+ def createHashMap(): UnsafeFixedWidthAggregationMap = {
+ // create initialized aggregate buffer
+ val initExpr = declFunctions.flatMap(f => f.initialValues)
+ val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)
+
+ // create hashMap
+ new UnsafeFixedWidthAggregationMap(
+ initialBuffer,
+ bufferSchema,
+ groupingKeySchema,
+ TaskContext.get().taskMemoryManager(),
+ 1024 * 16, // initial capacity
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ false // disable tracking of performance metrics
+ )
+ }
+
+ def createUnsafeJoiner(): UnsafeRowJoiner = {
+ GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+ }
+
+ private def doProduceWithKeys(ctx: CodegenContext): String = {
+ val initAgg = ctx.freshName("initAgg")
+ ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+ // create hashMap
+ val thisPlan = ctx.addReferenceObj("tungstenAggregate", this)
+ hashMapTerm = ctx.freshName("hashMap")
+ val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
+ ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm =
$thisPlan.createHashMap();")
+
+ // Create a name for iterator from HashMap
+ val iterTerm = ctx.freshName("mapIter")
+ ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName,
iterTerm, "")
+
+ // generate code for output
+ val keyTerm = ctx.freshName("aggKey")
+ val bufferTerm = ctx.freshName("aggBuffer")
+ val outputCode = if (modes.contains(Final) |modes.contains(Complete)) {
+ // generate output using resultExpressions
+ ctx.currentVars = null
+ ctx.INPUT_ROW = keyTerm
+ val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
+ BoundReference(i, e.dataType, e.nullable).gen(ctx)
+ }
+ ctx.INPUT_ROW = bufferTerm
+ val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
+ BoundReference(i, e.dataType, e.nullable).gen(ctx)
+ }
+ // evaluate the aggregation result
+ ctx.currentVars = bufferVars
+ val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
+ BindReferences.bindReference(e, bufferAttributes).gen(ctx)
+ }
+ // generate the final result
+ ctx.currentVars = keyVars ++ aggResults
+ val inputAttrs = groupingAttributes ++ aggregateAttributes
+ val resultVars = resultExpressions.map { e =>
+ BindReferences.bindReference(e, inputAttrs).gen(ctx)
+ }
+ s"""
+ ${keyVars.map(_.code).mkString("\n")}
+ ${bufferVars.map(_.code).mkString("\n")}
+ ${aggResults.map(_.code).mkString("\n")}
+ ${resultVars.map(_.code).mkString("\n")}
+
+ ${consume(ctx, resultVars)}
+ """
+
+ } else if (modes.contains(Partial) |modes.contains(PartialMerge)) {
+ // This should be the last operator in a stage, we should output
UnsafeRow directly
+ val joinerTerm = ctx.freshName("unsafeRowJoiner")
+ ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
+ s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
+ val resultRow = ctx.freshName("resultRow")
+ s"""
+ UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
+ ${consume(ctx, null, resultRow)}
+ """
+
} else {
- functions.flatMap(_.mergeExpressions)
+ // only grouping key
+ ctx.INPUT_ROW = keyTerm
+ ctx.currentVars = null
+ val eval = resultExpressions.map{ e =>
+ BindReferences.bindReference(e, groupingAttributes).gen(ctx)
+ }
+ s"""
+ ${eval.map(_.code).mkString("\n")}
+ ${consume(ctx, eval)}
+ """
+ }
+
+ val doAgg = ctx.freshName("doAgg")
+ ctx.addNewFunction(doAgg,
+ s"""
+ private void $doAgg() throws java.io.IOException {
+ ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+
+ $iterTerm = $hashMapTerm.iterator();
+ }
+ """)
+
+ s"""
+ if (!$initAgg) {
+ $initAgg = true;
+ $doAgg();
+ }
+
+ // output the result
+ while ($iterTerm.next()) {
+ UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
--- End diff --
Do we need this row at all? This is the start of a new pipeline right? (Not
for this patch)
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]