KurtYoung commented on a change in pull request #8099: [FLINK-12081][table-planner-blink] Introduce aggregation operator code generator to blink batch URL: https://github.com/apache/flink/pull/8099#discussion_r272000977
########## File path: flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/agg/batch/AggCodeGenHelper.scala ########## @@ -0,0 +1,732 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.agg.batch + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.runtime.util.SingleElementIterator +import org.apache.flink.streaming.api.operators.OneInputStreamOperator +import org.apache.flink.table.`type`.TypeConverters.{createInternalTypeFromTypeInfo, createInternalTypeInfoFromInternalType} +import org.apache.flink.table.`type`.{ArrayType, InternalType, MapType, RowType, StringType} +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForExternalType, genToExternal, genToInternal, newName, primitiveTypeTermForType} +import org.apache.flink.table.codegen.OperatorCodeGenerator.STREAM_RECORD +import org.apache.flink.table.codegen.{CodeGenUtils, CodeGeneratorContext, ExprCodeGenerator, GenerateUtils, GeneratedExpression, OperatorCodeGenerator} +import org.apache.flink.table.dataformat.{BaseRow, GenericRow} +import org.apache.flink.table.expressions.{CallExpression, Expression, ExpressionVisitor, FieldReferenceExpression, ResolvedAggInputReference, ResolvedAggLocalReference, RexNodeConverter, SymbolExpression, TypeLiteralExpression, UnresolvedFieldReferenceExpression, ValueLiteralExpression} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getAggUserDefinedInputTypes, getResultTypeOfAggregateFunction} +import org.apache.flink.table.functions.{AggregateFunction, DeclarativeAggregateFunction, UserDefinedFunction} +import org.apache.flink.table.generated.{GeneratedAggsHandleFunction, GeneratedOperator} +import org.apache.flink.table.runtime.context.ExecutionContextImpl + +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder + +import scala.collection.JavaConverters._ + +/** + * Batch aggregate code generate helper. + */ +object AggCodeGenHelper { + + def getAggBufferNames( + auxGrouping: Array[Int], aggregates: Seq[UserDefinedFunction]): Array[Array[String]] = { + auxGrouping.zipWithIndex.map { + case (_, index) => Array(s"aux_group$index") + } ++ aggregates.zipWithIndex.toArray.map { + case (a: DeclarativeAggregateFunction, index) => + val idx = auxGrouping.length + index + a.aggBufferAttributes.map(attr => s"agg${idx}_${attr.getName}") + case (_: AggregateFunction[_, _], index) => + val idx = auxGrouping.length + index + Array(s"agg$idx") + } + } + + def getAggBufferTypes( + inputType: RowType, auxGrouping: Array[Int], aggregates: Seq[UserDefinedFunction]) + : Array[Array[InternalType]] = { + auxGrouping.map { index => + Array(inputType.getFieldTypes()(index)) + } ++ aggregates.map { + case a: DeclarativeAggregateFunction => a.getAggBufferTypes + case a: AggregateFunction[_, _] => + Array(createInternalTypeFromTypeInfo(getAccumulatorTypeOfAggregateFunction(a))) + }.toArray[Array[InternalType]] + } + + def getUdaggs( + aggregates: Seq[UserDefinedFunction]): Map[AggregateFunction[_, _], String] = { + aggregates + .filter(a => a.isInstanceOf[AggregateFunction[_, _]]) + .map(a => a -> CodeGenUtils.udfFieldName(a)).toMap + .asInstanceOf[Map[AggregateFunction[_, _], String]] + } + + def projectRowType( + rowType: RowType, mapping: Array[Int]): RowType = { + new RowType(mapping.map(rowType.getTypeAt), mapping.map(rowType.getFieldNames()(_))) + } + + /** + * Add agg handler to class member and open it. + */ + private[flink] def addAggsHandler( + aggsHandler: GeneratedAggsHandleFunction, + ctx: CodeGeneratorContext, + aggsHandlerCtx: CodeGeneratorContext): String = { + ctx.addReusableInnerClass(aggsHandler.getClassName, aggsHandler.getCode) + val handler = CodeGenUtils.newName("handler") + ctx.addReusableMember(s"${aggsHandler.getClassName} $handler = null;") + val aggRefers = ctx.addReusableObject(aggsHandlerCtx.references.toArray, "Object[]") + ctx.addReusableOpenStatement( + s""" + |$handler = new ${aggsHandler.getClassName}($aggRefers); + |$handler.open(new ${classOf[ExecutionContextImpl].getCanonicalName}( + | this, getRuntimeContext())); + """.stripMargin) + ctx.addReusableCloseStatement(s"$handler.close();") + handler + } + + private[flink] def projectRowType( + mapping: Array[Int], + inputT: RowType): RowType = + new RowType(mapping.map(inputT.getTypeAt), mapping.map(inputT.getFieldNames()(_))) + + /** + * The generated codes only supports the comparison of the key terms + * in the form of binary row with only one memory segment. + */ + private[flink] def genGroupKeyChangedCheckCode( + currentKeyTerm: String, + lastKeyTerm: String): String = { + s""" + |$currentKeyTerm.getSizeInBytes() != $lastKeyTerm.getSizeInBytes() || + | !(org.apache.flink.table.dataformat.util.BinaryRowUtil.byteArrayEquals( + | $currentKeyTerm.getSegments()[0].getHeapMemory(), + | $lastKeyTerm.getSegments()[0].getHeapMemory(), + | $currentKeyTerm.getSizeInBytes())) + """.stripMargin.trim + } + + def genSortAggCodes( + isMerge: Boolean, + isFinal: Boolean, + ctx: CodeGeneratorContext, + config: TableConfig, + builder: RelBuilder, + grouping: Array[Int], + auxGrouping: Array[Int], + aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], + aggArgs: Array[Array[Int]], + aggregates: Seq[UserDefinedFunction], + udaggs: Map[AggregateFunction[_, _], String], + inputTerm: String, + inputType: RowType, + aggBufferNames: Array[Array[String]], + aggBufferTypes: Array[Array[InternalType]], + outputType: RowType, + forHashAgg: Boolean = false): (String, String, GeneratedExpression) = { + // gen code to apply aggregate functions to grouping elements + val argsMapping = buildAggregateArgsMapping( + isMerge, grouping.length, inputType, auxGrouping, aggArgs, aggBufferTypes) + + val aggBufferExprs = genFlatAggBufferExprs( + isMerge, + ctx, + config, + builder, + auxGrouping, + aggregates, + argsMapping, + aggBufferNames, + aggBufferTypes) + + val initAggBufferCode = genInitFlatAggregateBuffer( + ctx, + config, + builder, + inputType, + inputTerm, + grouping, + auxGrouping, + aggregates, + udaggs, + aggBufferExprs, + forHashAgg) + + val doAggregateCode = genAggregateByFlatAggregateBuffer( + isMerge, + ctx, + config, + builder, + inputType, + inputTerm, + auxGrouping, + aggCallToAggFunction, + aggregates, + udaggs, + argsMapping, + aggBufferNames, + aggBufferTypes, + aggBufferExprs) + + val aggOutputExpr = genSortAggOutputExpr( + isMerge, + isFinal, + ctx, + config, + builder, + grouping, + auxGrouping, + aggregates, + udaggs, + argsMapping, + aggBufferNames, + aggBufferTypes, + aggBufferExprs, + outputType) + + (initAggBufferCode, doAggregateCode, aggOutputExpr) + } + + /** + * Build an arg mapping for reference binding. The mapping will be a 2-dimension array. + * The first dimension represents the aggregate index, the order is same with agg calls in plan. + * The second dimension information represents input count of the aggregate. The meaning will + * be different depends on whether we should do merge. + * + * In non-merge case, aggregate functions will treat inputs as operands. In merge case, the + * input is local aggregation's buffer, we need to merge with our local aggregate buffers. + */ + private[flink] def buildAggregateArgsMapping( + isMerge: Boolean, + aggBufferOffset: Int, + inputType: RowType, + auxGrouping: Array[Int], + aggArgs: Array[Array[Int]], + aggBufferTypes: Array[Array[InternalType]]): Array[Array[(Int, InternalType)]] = { + + val auxGroupingMapping = auxGrouping.indices.map { + i => Array[(Int, InternalType)]((i, aggBufferTypes(i)(0))) + }.toArray + + val aggCallMapping = if (isMerge) { + var offset = aggBufferOffset + auxGrouping.length + aggBufferTypes.slice(auxGrouping.length, aggBufferTypes.length).map { types => + val baseOffset = offset + offset = offset + types.length + types.indices.map(index => (baseOffset + index, types(index))).toArray + } + } else { + aggArgs.map(args => args.map(i => (i, inputType.getTypeAt(i)))) + } + + auxGroupingMapping ++ aggCallMapping + } + + def newLocalReference( + ctx: CodeGeneratorContext, + resultTerm: String, + resultType: InternalType): ResolvedAggLocalReference = { + val nullTerm = resultTerm + "IsNull" + ctx.addReusableMember(s"${primitiveTypeTermForType(resultType)} $resultTerm;") + ctx.addReusableMember(s"boolean $nullTerm;") + new ResolvedAggLocalReference(resultTerm, nullTerm, resultType) + } + + /** + * Resolves the given expression to a resolved Expression. + * + * @param isMerge this is called from merge() method + */ + private case class ResolveReference( + ctx: CodeGeneratorContext, + isMerge: Boolean, + agg: DeclarativeAggregateFunction, + aggIndex: Int, + argsMapping: Array[Array[(Int, InternalType)]], + aggBufferTypes: Array[Array[InternalType]]) extends ExpressionVisitor[Expression] { + + override def visitCall(call: CallExpression): Expression = { + new CallExpression( + call.getFunctionDefinition, + call.getChildren.asScala.map(_.accept(this)).asJava) + } + + override def visitSymbol(symbolExpression: SymbolExpression): Expression = { + symbolExpression + } + + override def visitValueLiteral(valueLiteralExpression: ValueLiteralExpression): Expression = { + valueLiteralExpression + } + + override def visitFieldReference(input: FieldReferenceExpression): Expression = { + input + } + + override def visitTypeLiteral(typeLiteral: TypeLiteralExpression): Expression = { + typeLiteral + } + + private def visitUnresolvedFieldReference( + input: UnresolvedFieldReferenceExpression): Expression = { + agg.aggBufferAttributes.indexOf(input) match { + case -1 => + // We always use UnresolvedFieldReference to represent reference of input field. + // In non-merge case, the input is operand of the aggregate function. But in merge + // case, the input is aggregate buffers which sent by local aggregate. + val localIndex = if (isMerge) { + agg.mergeOperands.indexOf(input) + } else { + agg.operands.indexOf(input) + } + val (inputIndex, inputType) = argsMapping(aggIndex)(localIndex) + new ResolvedAggInputReference(input.getName, inputIndex, inputType) + case localIndex => + val variableName = s"agg${aggIndex}_${input.getName}" + newLocalReference( + ctx, variableName, aggBufferTypes(aggIndex)(localIndex)) + } + } + + override def visit(other: Expression): Expression = { + other match { + case u : UnresolvedFieldReferenceExpression => visitUnresolvedFieldReference(u) + case _ => other + } + } + } + + /** + * Declare all aggregate buffer variables, store these variables in class members + */ + private[flink] def genFlatAggBufferExprs( + isMerge: Boolean, + ctx: CodeGeneratorContext, + config: TableConfig, + builder: RelBuilder, + auxGrouping: Array[Int], + aggregates: Seq[UserDefinedFunction], + argsMapping: Array[Array[(Int, InternalType)]], + aggBufferNames: Array[Array[String]], + aggBufferTypes: Array[Array[InternalType]]): Seq[GeneratedExpression] = { + val exprCodegen = new ExprCodeGenerator(ctx, false) + val converter = new RexNodeConverter(builder) + + val accessAuxGroupingExprs = auxGrouping.indices.map { + idx => newLocalReference(ctx, aggBufferNames(idx)(0), aggBufferTypes(idx)(0)) + }.map(_.accept(converter)).map(exprCodegen.generateExpression) + + val aggCallExprs = aggregates.zipWithIndex.flatMap { + case (agg: DeclarativeAggregateFunction, aggIndex: Int) => + val idx = auxGrouping.length + aggIndex + agg.aggBufferAttributes.map(_.accept( + ResolveReference(ctx, isMerge, agg, idx, argsMapping, aggBufferTypes))) + case (_: AggregateFunction[_, _], aggIndex: Int) => + val idx = auxGrouping.length + aggIndex + val variableName = aggBufferNames(idx)(0) + Some(newLocalReference(ctx, variableName, aggBufferTypes(idx)(0))) + }.map(_.accept(converter)).map(exprCodegen.generateExpression) + + accessAuxGroupingExprs ++ aggCallExprs + } + + /** + * Generate codes which will init the aggregate buffer. + */ + private[flink] def genInitFlatAggregateBuffer( + ctx: CodeGeneratorContext, + config: TableConfig, + builder: RelBuilder, + inputType: RowType, + inputTerm: String, + grouping: Array[Int], + auxGrouping: Array[Int], + aggregates: Seq[UserDefinedFunction], + udaggs: Map[AggregateFunction[_, _], String], + aggBufferExprs: Seq[GeneratedExpression], + forHashAgg: Boolean = false): String = { + val exprCodegen = new ExprCodeGenerator(ctx, false) + .bindInput(inputType, inputTerm = inputTerm, inputFieldMapping = Some(auxGrouping)) + + val initAuxGroupingExprs = { + if (forHashAgg) { + // access fallbackInput + auxGrouping.indices.map(idx => idx + grouping.length).toArray + } else { + // access input + auxGrouping + } + }.map { idx => + GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, idx) + } + + val initAggCallBufferExprs = aggregates.flatMap { + case (agg: DeclarativeAggregateFunction) => + agg.initialValuesExpressions + case (agg: AggregateFunction[_, _]) => + Some(agg) + }.map { + case (expr: Expression) => expr.accept(new RexNodeConverter(builder)) + case t@_ => t + }.map { + case (rex: RexNode) => exprCodegen.generateExpression(rex) + case (agg: AggregateFunction[_, _]) => + val resultTerm = s"${udaggs(agg)}.createAccumulator()" + val nullTerm = "false" + val resultType = getAccumulatorTypeOfAggregateFunction(agg) + GeneratedExpression( + genToInternal(ctx, resultType, resultTerm), + nullTerm, + "", + createInternalTypeFromTypeInfo(resultType)) + } + + val initAggBufferExprs = initAuxGroupingExprs ++ initAggCallBufferExprs + require(aggBufferExprs.length == initAggBufferExprs.length) + + aggBufferExprs.zip(initAggBufferExprs).map { + case (aggBufVar, initExpr) => + val resultCode = aggBufVar.resultType match { + case _: StringType | _: RowType | _: ArrayType | _: MapType => + val serializer = createInternalTypeInfoFromInternalType(aggBufVar.resultType) + .createSerializer(new ExecutionConfig) + val term = ctx.addReusableObject( + serializer, "serializer", serializer.getClass.getCanonicalName) + s"$term.copy(${initExpr.resultTerm})" + case _ => initExpr.resultTerm + } + s""" + |${initExpr.code} + |${aggBufVar.nullTerm} = ${initExpr.nullTerm}; + |${aggBufVar.resultTerm} = $resultCode; + """.stripMargin.trim + } mkString "\n" + } + + private[flink] def genAggregateByFlatAggregateBuffer( + isMerge: Boolean, + ctx: CodeGeneratorContext, + config: TableConfig, + builder: RelBuilder, + inputType: RowType, + inputTerm: String, + auxGrouping: Array[Int], + aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], + aggregates: Seq[UserDefinedFunction], + udaggs: Map[AggregateFunction[_, _], String], + argsMapping: Array[Array[(Int, InternalType)]], + aggBufferNames: Array[Array[String]], + aggBufferTypes: Array[Array[InternalType]], + aggBufferExprs: Seq[GeneratedExpression]): String = { + if (isMerge) { + genMergeFlatAggregateBuffer( + ctx, + config, + builder, + inputTerm, + inputType, + auxGrouping, + aggregates, + udaggs, + argsMapping, + aggBufferNames, + aggBufferTypes, + aggBufferExprs) + } else { + genAccumulateFlatAggregateBuffer( + ctx, + config, + builder, + inputTerm, + inputType, + auxGrouping, + aggCallToAggFunction, + udaggs, + argsMapping, + aggBufferNames, + aggBufferTypes, + aggBufferExprs) + } + } + + def genSortAggOutputExpr( + isMerge: Boolean, + isFinal: Boolean, + ctx: CodeGeneratorContext, + config: TableConfig, + builder: RelBuilder, + grouping: Array[Int], + auxGrouping: Array[Int], + aggregates: Seq[UserDefinedFunction], + udaggs: Map[AggregateFunction[_, _], String], + argsMapping: Array[Array[(Int, InternalType)]], + aggBufferNames: Array[Array[String]], + aggBufferTypes: Array[Array[InternalType]], + aggBufferExprs: Seq[GeneratedExpression], + outputType: RowType): GeneratedExpression = { + val valueRow = CodeGenUtils.newName("valueRow") + val resultCodegen = new ExprCodeGenerator(ctx, false) + if (isFinal) { + val getValueExprs = genGetValueFromFlatAggregateBuffer( + isMerge, + ctx, + config, + builder, + auxGrouping, + aggregates, + udaggs, + argsMapping, + aggBufferNames, + aggBufferTypes, + outputType) + val valueRowType = new RowType(getValueExprs.map(_.resultType): _*) + resultCodegen.generateResultExpression( + getValueExprs, valueRowType, classOf[GenericRow], valueRow) + } else { + val valueRowType = new RowType(aggBufferExprs.map(_.resultType): _*) + resultCodegen.generateResultExpression( + aggBufferExprs, valueRowType, classOf[GenericRow], valueRow) + } + } + + /** + * Generate expressions which will get final aggregate value from aggregate buffers. + */ + private[flink] def genGetValueFromFlatAggregateBuffer( + isMerge: Boolean, + ctx: CodeGeneratorContext, + config: TableConfig, + builder: RelBuilder, + auxGrouping: Array[Int], + aggregates: Seq[UserDefinedFunction], + udaggs: Map[AggregateFunction[_, _], String], + argsMapping: Array[Array[(Int, InternalType)]], + aggBufferNames: Array[Array[String]], + aggBufferTypes: Array[Array[InternalType]], + outputType: RowType): Seq[GeneratedExpression] = { + + val exprCodegen = new ExprCodeGenerator(ctx, false) + + val auxGroupingExprs = auxGrouping.indices.map { idx => + val resultTerm = aggBufferNames(idx)(0) + val nullTerm = s"${resultTerm}IsNull" + GeneratedExpression(resultTerm, nullTerm, "", aggBufferTypes(idx)(0)) + } + + val aggExprs = aggregates.zipWithIndex.map { + case (agg: DeclarativeAggregateFunction, aggIndex) => + val idx = auxGrouping.length + aggIndex + agg.getValueExpression.accept(ResolveReference( + ctx, isMerge, agg, idx, argsMapping, aggBufferTypes)) + case (agg: AggregateFunction[_, _], aggIndex) => + val idx = auxGrouping.length + aggIndex + (agg, idx) + }.map { + case (expr: Expression) => expr.accept(new RexNodeConverter(builder)) + case t@_ => t + }.map { + case (rex: RexNode) => exprCodegen.generateExpression(rex) + case (agg: AggregateFunction[_, _], aggIndex: Int) => + val resultType = getResultTypeOfAggregateFunction(agg) + val accType = getAccumulatorTypeOfAggregateFunction(agg) + val resultTerm = genToInternal(ctx, resultType, + s"${udaggs(agg)}.getValue(${genToExternal(ctx, accType, aggBufferNames(aggIndex)(0))})") + val nullTerm = s"${aggBufferNames(aggIndex)(0)}IsNull" + GeneratedExpression(resultTerm, nullTerm, "", createInternalTypeFromTypeInfo(resultType)) + } + + auxGroupingExprs ++ aggExprs + } + + /** + * Generate codes which will read input and merge the aggregate buffers. + */ + private[flink] def genMergeFlatAggregateBuffer( + ctx: CodeGeneratorContext, + config: TableConfig, + builder: RelBuilder, + inputTerm: String, + inputType: RowType, + auxGrouping: Array[Int], + aggregates: Seq[UserDefinedFunction], + udaggs: Map[AggregateFunction[_, _], String], + argsMapping: Array[Array[(Int, InternalType)]], + aggBufferNames: Array[Array[String]], + aggBufferTypes: Array[Array[InternalType]], + aggBufferExprs: Seq[GeneratedExpression]): String = { + Review comment: delete blank line ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
