wuchong commented on a change in pull request #8001: [FLINK-11949][table-planner-blink] Introduce DeclarativeAggregateFunction and AggsHandlerCodeGenerator to blink planner URL: https://github.com/apache/flink/pull/8001#discussion_r266791171
########## File path: flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/agg/DistinctAggCodeGen.scala ########## @@ -0,0 +1,918 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.agg + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo +import org.apache.flink.table.`type`.{InternalType, RowType, TypeConverters} +import org.apache.flink.table.api.TableException +import org.apache.flink.table.api.dataview.MapView +import org.apache.flink.table.codegen.CodeGenUtils.{BASE_ROW, newName, _} +import org.apache.flink.table.codegen.GenerateUtils.{generateFieldAccess, generateInputAccess} +import org.apache.flink.table.codegen.GeneratedExpression._ +import org.apache.flink.table.codegen.agg.AggsHandlerCodeGenerator._ +import org.apache.flink.table.codegen.{CodeGenUtils, CodeGeneratorContext, ExprCodeGenerator, GenerateUtils, GeneratedExpression} +import org.apache.flink.table.dataformat.GenericRow +import org.apache.flink.table.expressions.{Expression, RexNodeGenExpressionVisitor} +import org.apache.flink.table.plan.util.DistinctInfo +import org.apache.flink.util.Preconditions +import org.apache.flink.util.Preconditions.checkArgument + +import org.apache.calcite.tools.RelBuilder + +import java.lang.{Long => JLong} + +/** + * It is for code generate distinct aggregate. The distinct aggregate buffer is a MapView which + * is used to store the unique keys and the frequency of appearance. When a key is been seen the + * first time, we will trigger the inner aggregate function's accumulate() function. + * + * @param ctx the code gen context + * @param distinctInfo the distinct information + * @param distinctIndex the index of this distinct in all distincts + * @param innerAggCodeGens the code generator of inner aggregate + * @param filterExpressions filter argument access expression, none if no filter + * @param mergedAccOffset the mergedAcc may come from local aggregate, + * this is the first buffer offset in the row + * @param aggBufferOffset the offset in the buffers of this aggregate + * @param aggBufferSize the total size of aggregate buffers + * @param hasNamespace whether the accumulators state has namespace + * @param mergedAccOnHeap whether the merged accumulator is on heap, otherwise is on state + * @param consumeRetraction whether the distinct consumes retraction + * @param inputFieldCopy copy input field element if true (only mutable type will be copied) + * @param relBuilder the rel builder to translate expressions to calcite rex nodes + */ +class DistinctAggCodeGen( + ctx: CodeGeneratorContext, + distinctInfo: DistinctInfo, + distinctIndex: Int, + innerAggCodeGens: Array[AggCodeGen], + filterExpressions: Array[Option[Expression]], + mergedAccOffset: Int, + aggBufferOffset: Int, + aggBufferSize: Int, + hasNamespace: Boolean, + needMerge: Boolean, + mergedAccOnHeap: Boolean, + consumeRetraction: Boolean, + inputFieldCopy: Boolean, + relBuilder: RelBuilder) extends AggCodeGen { + + val MAP_VIEW: String = className[MapView[_, _]] + val MAP_ENTRY: String = className[java.util.Map.Entry[_, _]] + val ITERABLE: String = className[java.lang.Iterable[_]] + + val aggCount: Int = innerAggCodeGens.length + val externalAccType: TypeInformation[_] = distinctInfo.accType + val internalAccType: InternalType = createInternalTypeFromTypeInfo(externalAccType) + val keyType: TypeInformation[_] = distinctInfo.keyType + val internalKeyType: InternalType = createInternalTypeFromTypeInfo(keyType) + val keyTypeTerm: String = keyType.getTypeClass.getCanonicalName + val distinctAccTerm: String = s"distinct_view_$distinctIndex" + val distinctBackupAccTerm: String = s"distinct_backup_view_$distinctIndex" + + val isValueChangedTerm: String = s"is_distinct_value_changed_$distinctIndex" + val isValueEmptyTerm: String = s"is_distinct_value_empty_$distinctIndex" + val valueGenerator: DistinctValueGenerator = createDistinctValueGenerator() + private val rexNodeGen = new RexNodeGenExpressionVisitor(relBuilder) + + addReusableDistinctAccumulator() + + /** + * Add the distinct accumulator to the member variable and open close methods. + */ + private def addReusableDistinctAccumulator(): Unit = { + // sanity check + if (distinctInfo.excludeAcc) { + // it only works in incremental mode when the distinct acc is excluded + // the distinct mapview must works on state mode when incremental mode + Preconditions.checkState(distinctInfo.dataViewSpec.nonEmpty) + } + + val enableBackupDataView = needMerge && !mergedAccOnHeap + + // add state mapview to member field + addReusableStateDataViews( + ctx, + distinctInfo.dataViewSpec.toArray, + hasNamespace, + enableBackupDataView) + + + // add distinctAccTerm to member field + ctx.addReusableMember(s"private $MAP_VIEW $distinctAccTerm;") + if (enableBackupDataView) { + ctx.addReusableMember(s"private $MAP_VIEW $distinctBackupAccTerm;") + } + + // when dataview works on state, assign the stateDataView to accTerm in open method + distinctInfo.dataViewSpec match { + case Some(spec) => + val dataviewTerm = createDataViewTerm(spec) + ctx.addReusableOpenStatement(s"$distinctAccTerm = $dataviewTerm;") + if (enableBackupDataView) { + val dataviewBackupTerm = createDataViewBackupTerm(spec) + ctx.addReusableOpenStatement(s"$distinctBackupAccTerm = $dataviewBackupTerm;") + } + case None => // do nothing + } + } + + override def createAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = { + if (distinctInfo.excludeAcc) { + // when the distinct acc is excluded, no need to create distinct accumulator + Seq() + } else { + val accTerm = newName("distinct_acc") + val code = s"$MAP_VIEW $accTerm = new $MAP_VIEW();" + Seq(GeneratedExpression(accTerm, NEVER_NULL, code, internalAccType)) + } + } + + override def setAccumulator(generator: ExprCodeGenerator): String = { + generateAccumulatorAccess( + ctx, + generator.input1Type, + generator.input1Term, + aggBufferOffset, + useStateDataView = true, + useBackupDataView = false) + // return empty because the access code is set in ctx's ReusableInputUnboxingExprs + "" + } + + override def resetAccumulator(generator: ExprCodeGenerator): String = { + if (distinctInfo.excludeAcc) { + "" + } else { + s"$distinctAccTerm.clear();" + } + } + + override def getAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = { + if (distinctInfo.excludeAcc) { + // when the distinct acc is excluded, the accumulator result shouldn't include distinct acc + Seq() + } else { + Seq(GeneratedExpression( + distinctAccTerm, + NEVER_NULL, + NO_CODE, + internalAccType)) + } + } + + override def accumulate(generator: ExprCodeGenerator): String = { + val keyExpr = generateKeyExpression(ctx, generator) + val key = keyExpr.resultTerm + val accumulateCode = innerAggCodeGens.map(_.accumulate(generator)) + val valueTerm = newName("value") + val valueTypeTerm = valueGenerator.valueTypeTerm + val filterResults = filterExpressions.map { + case None => None + case Some(f) => Some(generator.generateExpression(f.accept(rexNodeGen)).resultTerm) + } + + val head = + s""" + |${keyExpr.code} + |$valueTypeTerm $valueTerm = ($valueTypeTerm) $distinctAccTerm.get($key); + |if ($valueTerm == null) { + | $valueTerm = ${valueGenerator.initialValue}; + |} + """.stripMargin + + val body = if (consumeRetraction) { + // input contains retraction, due to local/global, the value might be empty, and need remove + s""" + |$head + |boolean $isValueEmptyTerm = true; + |${valueGenerator.foreachAccumulate(valueTerm, accumulateCode, filterResults)} + |if ($isValueEmptyTerm) { + | $distinctAccTerm.remove($key); + |} else { + | $distinctAccTerm.put($key, $valueTerm); + |} + """.stripMargin + } else { + // input contains only append messages, update value only when value changed + s""" + |$head + |boolean $isValueChangedTerm = false; + |${valueGenerator.foreachAccumulate(valueTerm, accumulateCode, filterResults)} + |if ($isValueChangedTerm) { + | $distinctAccTerm.put($key, $valueTerm); + |} + """.stripMargin + } + + if (filterResults.exists(_.isDefined)) { + val condition = filterResults.flatten.mkString(" || ") + s""" + |if ($condition) { + | $body + |} + """.stripMargin + } else { + body + } + } + + override def retract(generator: ExprCodeGenerator): String = { + if (!consumeRetraction) { + throw new TableException("This should never happen, please file a issue.") + } + val keyExpr = generateKeyExpression(ctx, generator) + val key = keyExpr.resultTerm + val retractCodes = innerAggCodeGens.map(_.retract(generator)) + val valueTerm = newName("value") + val valueTypeTerm = valueGenerator.valueTypeTerm + val filterResults = filterExpressions.map { + case None => None + case Some(f) => Some(generator.generateExpression(f.accept(rexNodeGen)).resultTerm) + } + + val head = + s""" + |${keyExpr.code} + |$valueTypeTerm $valueTerm = ($valueTypeTerm) $distinctAccTerm.get($key); + |if ($valueTerm == null) { + | $valueTerm = ${valueGenerator.initialValue}; + |} + """.stripMargin + + val body = + s""" + |$head + |boolean $isValueEmptyTerm = true; + |${valueGenerator.foreachRetract(valueTerm, retractCodes, filterResults)} + |if ($isValueEmptyTerm) { + | $distinctAccTerm.remove($key); + |} else { + | $distinctAccTerm.put($key, $valueTerm); + |} + """.stripMargin + + if (filterResults.exists(_.isDefined)) { + val condition = filterResults.flatten.mkString(" || ") + s""" + |if ($condition) { + | $body + |} + """.stripMargin + } else { + body + } + } + + override def merge(generator: ExprCodeGenerator): String = { + // generate other MapView acc field + val otherAccExpr = generateAccumulatorAccess( + ctx, + generator.input1Type, + generator.input1Term, + mergedAccOffset + aggBufferOffset, + useStateDataView = !mergedAccOnHeap, + useBackupDataView = true) + + val keyTerm = newName(DISTINCT_KEY_TERM) + val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL) + .bindInput(internalKeyType, inputTerm = keyTerm) + val accumulateCodes = innerAggCodeGens.map(_.accumulate(exprGenerator)) + val retractCodes = if (consumeRetraction) { + innerAggCodeGens.map(_.retract(exprGenerator)) + } else { + innerAggCodeGens.map(_ => + "throw new RuntimeException(\"This distinct aggregate do not consume" + + " retractions, " + + "but received retract message, which should never happen.\");") + } + + val otherAccTerm = otherAccExpr.resultTerm + val otherEntries = newName("otherEntries") + val valueTypeTerm = valueGenerator.valueTypeTerm + val thisValue = "thisValue" + val otherValue = "otherValue" + + s""" + |$ITERABLE<$MAP_ENTRY> $otherEntries = ($ITERABLE<$MAP_ENTRY>) $otherAccTerm.entries(); + |if ($otherEntries != null) { + | for ($MAP_ENTRY entry: $otherEntries) { + | $keyTypeTerm $keyTerm = ($keyTypeTerm) entry.getKey(); + | ${ctx.reuseInputUnboxingCode(keyTerm)} + | $valueTypeTerm $otherValue = ($valueTypeTerm) entry.getValue(); + | $valueTypeTerm $thisValue = ($valueTypeTerm) $distinctAccTerm.get($keyTerm); + | if ($thisValue == null) { + | $thisValue = ${valueGenerator.initialValue}; + | } + | boolean $isValueChangedTerm = false; + | boolean $isValueEmptyTerm = false; + | ${valueGenerator.foreachMerge(thisValue, otherValue, accumulateCodes, retractCodes)} + | if ($isValueEmptyTerm) { + | $distinctAccTerm.remove($keyTerm); + | } else if ($isValueChangedTerm) { // value is not empty and is changed, do update + | $distinctAccTerm.put($keyTerm, $thisValue); + | } + | } // end foreach + |} // end otherEntries != null + """.stripMargin + } + + override def getValue(generator: ExprCodeGenerator): GeneratedExpression = { + throw new TableException( + "Distinct shouldn't return result value, this is a bug, please file a issue.") + } + + override def checkNeededMethods( + needAccumulate: Boolean, + needRetract: Boolean, + needMerge: Boolean, + needReset: Boolean): Unit = { + if (needMerge) { + // see merge method for more information + innerAggCodeGens + .foreach(_.checkNeededMethods(needAccumulate = true, needRetract = consumeRetraction)) + } else { + innerAggCodeGens + .foreach(_.checkNeededMethods(needAccumulate, needRetract, needMerge, needReset)) + } + } + + private def generateKeyExpression( + ctx: CodeGeneratorContext, + generator: ExprCodeGenerator): GeneratedExpression = { + var fieldExprs = distinctInfo.argIndexes.map(generateInputAccess( Review comment: var -> val ? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
