lsyldliu commented on code in PR #22734: URL: https://github.com/apache/flink/pull/22734#discussion_r1242113622
########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/fusion/OpFusionCodegenSpec.java: ########## @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.fusion; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.ExprCodeGenerator; +import org.apache.flink.table.planner.codegen.GeneratedExpression; + +import java.util.List; +import java.util.Set; + +/** An interface for those physical operators that support operator fusion codegen. */ +@Internal +public interface OpFusionCodegenSpec { + + /** + * Initializes the operator spec. Sets access to the context. This method must be called before + * doProduce and doConsume related methods. + */ + void setup(OpFusionContext opFusionContext); + + /** Prefix used in the current operator's variable names. */ + String variablePrefix(); Review Comment: Intuitively, it's fine to use class names. But I feel like using class names would result in long variable names, and it makes sense to provide this interface. ########## flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala: ########## @@ -449,14 +449,17 @@ object GenerateUtils { * whether the input is nullable * @param deepCopy * whether to copy the accessed field (usually needed when buffered) + * @param fusionCodegen + * if fusion codegen enabled, don't need to hide generated code Review Comment: Currently, regarding the case of single operator codegen, we materialize the code for each input field to take values in advance, thus avoiding multiple evaluate them inside the operator. However, in the multi-operator fusion codegen case, the logic of our computation is reversed, and one of the objectives we want to achieve is to compute as lazy as possible, i.e., delay materialization. So we need to get the code that access the value of each field in advance, and then passes it to the downstream operator, only materializing it where it is needed, instead of calculating it in advance. For example, for the Join operator, suppose there are 5 fields at the probe side and the build side together, and then pass to the downstream project operator, the project cuts out 3 of the fields and keeps only two, the three fields that are cut out, in fact, we do not need to take the value at all, which is one of the goals of OFCG, to eliminate invalid computation. ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java: ########## @@ -293,4 +299,61 @@ private long getLargeManagedMemory(FlinkJoinType joinType, ExecNodeConfig config // large one return Math.max(hashJoinManagedMemory, sortMergeJoinManagedMemory); } + + @Override + public boolean supportFusionCodegen() { + RowType leftType = (RowType) getInputEdges().get(0).getOutputType(); + LogicalType[] keyFieldTypes = + IntStream.of(joinSpec.getLeftKeys()) + .mapToObj(leftType::getTypeAt) + .toArray(LogicalType[]::new); + RowType keyType = RowType.of(keyFieldTypes); + FlinkJoinType joinType = joinSpec.getJoinType(); + HashJoinType hashJoinType = + HashJoinType.of( + leftIsBuild, + joinType.isLeftOuter(), + joinType.isRightOuter(), + joinType == FlinkJoinType.SEMI, + joinType == FlinkJoinType.ANTI); + // TODO decimal and multiKeys support and all HashJoinType support. + return LongHashJoinGenerator.support(hashJoinType, keyType, joinSpec.getFilterNulls()); Review Comment: Currently, we only support codegen when the join key is a single key and is a long type, otherwise, it is all hardcode, so the first version we currently follow this behavior. For multi-key scenarios, we need another PR to complete them, which are two separate work. ########## flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/HashJoinFusionCodegenSpec.scala: ########## @@ -0,0 +1,542 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.plan.fusion.spec + +import org.apache.flink.table.data.RowData +import org.apache.flink.table.data.binary.BinaryRowData +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression, GenerateUtils} +import org.apache.flink.table.planner.codegen.CodeGenUtils.{fieldIndices, newName, newNames, primitiveDefaultValue, primitiveTypeTermForType, BINARY_ROW, ROW_DATA} +import org.apache.flink.table.planner.codegen.LongHashJoinGenerator.{genGetLongKey, genProjection} +import org.apache.flink.table.planner.plan.fusion.{OpFusionCodegenSpecBase, OpFusionCodegenSpecGenerator, OpFusionContext} +import org.apache.flink.table.planner.plan.nodes.exec.spec.JoinSpec +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.{toJava, toScala} +import org.apache.flink.table.runtime.hashtable.LongHybridHashTable +import org.apache.flink.table.runtime.operators.join.{FlinkJoinType, HashJoinType} +import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer +import org.apache.flink.table.runtime.util.RowIterator +import org.apache.flink.table.types.logical.{LogicalType, RowType} + +import java.util + +/** Base operator fusion codegen spec for HashJoin. */ +class HashJoinFusionCodegenSpec( + operatorCtx: CodeGeneratorContext, + isBroadcast: Boolean, + leftIsBuild: Boolean, + joinSpec: JoinSpec, + estimatedLeftAvgRowSize: Int, + estimatedRightAvgRowSize: Int, + estimatedLeftRowCount: Long, + estimatedRightRowCount: Long, + compressionEnabled: Boolean, + compressionBlockSize: Int) + extends OpFusionCodegenSpecBase(operatorCtx) { + + private lazy val joinType: FlinkJoinType = joinSpec.getJoinType + private lazy val hashJoinType: HashJoinType = HashJoinType.of( + leftIsBuild, + joinType.isLeftOuter, + joinType.isRightOuter, + joinType == FlinkJoinType.SEMI, + joinType == FlinkJoinType.ANTI) + private lazy val (buildKeys, probeKeys) = if (leftIsBuild) { + (joinSpec.getLeftKeys, joinSpec.getRightKeys) + } else { + (joinSpec.getRightKeys, joinSpec.getLeftKeys) + } + private lazy val (buildRowSize, buildRowCount) = if (leftIsBuild) { + (estimatedLeftAvgRowSize, estimatedLeftRowCount) + } else { + (estimatedRightAvgRowSize, estimatedRightRowCount) + } + private lazy val buildInputId = if (leftIsBuild) { + 1 + } else { + 2 + } + + private lazy val Seq(buildToBinaryRow, probeToBinaryRow) = + newNames("buildToBinaryRow", "probeToBinaryRow") + + private lazy val hashTableTerm: String = newName("hashTable") + + private var buildInput: OpFusionCodegenSpecGenerator = _ + private var probeInput: OpFusionCodegenSpecGenerator = _ + private var buildType: RowType = _ + private var probeType: RowType = _ + private var keyType: RowType = _ + + override def setup(opFusionContext: OpFusionContext): Unit = { + super.setup(opFusionContext) + val inputs = toScala(fusionContext.getInputs) + assert(inputs.size == 2) + if (leftIsBuild) { + buildInput = inputs.head + probeInput = inputs(1) + } else { + buildInput = inputs(1) + probeInput = inputs.head + } + + buildType = buildInput.getOutputType + probeType = probeInput.getOutputType + if (leftIsBuild) { + keyType = RowType.of(joinSpec.getLeftKeys.map(idx => buildType.getTypeAt(idx)): _*) + } else { + keyType = RowType.of(joinSpec.getLeftKeys.map(idx => probeType.getTypeAt(idx)): _*) + } + } + + override def variablePrefix: String = if (isBroadcast) { "bhj" } + else { "shj" } + + override protected def doProcessProduce(fusionCtx: CodeGeneratorContext): Unit = { + // call build side first, then call probe side + buildInput.processProduce(fusionCtx) + probeInput.processProduce(fusionCtx) + } + + override protected def doEndInputProduce(fusionCtx: CodeGeneratorContext): Unit = { + // call build side first, then call probe side + buildInput.endInputProduce(fusionCtx) + probeInput.endInputProduce(fusionCtx) + } + + override def doProcessConsume( + inputId: Int, + inputVars: util.List[GeneratedExpression], + row: GeneratedExpression): String = { + // only probe side will call the consumeProcess method to consume the output record + if (inputId == buildInputId) { + codegenBuild(toScala(inputVars), row) + } else { + codegenProbe(inputVars) + } + } + + private def codegenBuild( + inputVars: Seq[GeneratedExpression], + row: GeneratedExpression): String = { + // initialize hash table related code + if (isBroadcast) { + codegenHashTable(false) + } else { + // TODO Shuffled HashJoin support build side spill to disk + codegenHashTable(true) + } + + val (nullCheckBuildCode, nullCheckBuildTerm) = { + genAnyNullsInKeys(buildKeys, inputVars) + } + s""" + |$nullCheckBuildCode + |if (!$nullCheckBuildTerm) { + | ${row.getCode} + | $hashTableTerm.putBuildRow(($BINARY_ROW) ${row.resultTerm}); + |} + """.stripMargin + } + + private def codegenProbe(inputVars: util.List[GeneratedExpression]): String = { + hashJoinType match { + case HashJoinType.INNER => + codegenInnerProbe(inputVars) + case HashJoinType.PROBE_OUTER => codegenProbeOuterProbe(inputVars) + case HashJoinType.SEMI => codegenSemiProbe(inputVars) + case HashJoinType.ANTI => codegenAntiProbe(inputVars) + case _ => + throw new UnsupportedOperationException( + s"Operator fusion codegen doesn't support $hashJoinType now.") + } + } + + private def codegenInnerProbe(inputVars: util.List[GeneratedExpression]): String = { + val (keyEv, anyNull) = genStreamSideJoinKey(probeKeys, inputVars) + val keyCode = keyEv.getCode + val (matched, checkCondition, buildLocalVars, buildVars) = getJoinCondition(buildType) + val resultVars = if (leftIsBuild) { + buildVars ++ toScala(inputVars) + } else { + toScala(inputVars) ++ buildVars + } + val buildIterTerm = newName("buildIter") + s""" + |// generate join key for probe side + |$keyCode + |// find matches from hash table + |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ? + | null : $hashTableTerm.get(${keyEv.resultTerm}); + |if ($buildIterTerm != null ) { + | $buildLocalVars + | while ($buildIterTerm.advanceNext()) { + | $ROW_DATA $matched = $buildIterTerm.getRow(); + | $checkCondition { + | ${fusionContext.processConsume(toJava(resultVars))} + | } + | } + |} + """.stripMargin + } + + private def codegenProbeOuterProbe(inputVars: util.List[GeneratedExpression]): String = { + val (keyEv, anyNull) = genStreamSideJoinKey(probeKeys, inputVars) + val keyCode = keyEv.getCode + val matched = newName("buildRow") + // start new local variable + getOperatorCtx.startNewLocalVariableStatement(matched) + val buildVars = genInputVars(matched, buildType) + + // filter the output via condition + val conditionPassed = newName("conditionPassed") + val checkCondition = if (joinSpec.getNonEquiCondition.isPresent) { + // here need bind the buildRow before generate build condition + if (buildInputId == 1) { + getExprCodeGenerator.bindInput(buildType, matched) + } else { + getExprCodeGenerator.bindSecondInput(buildType, matched) + } + // TODO evaluate the variables from probe and build side that used by condition in advance + // generate the expr code + val expr = getExprCodeGenerator.generateExpression(joinSpec.getNonEquiCondition.get) + s""" + |boolean $conditionPassed = true; + |if ($matched != null) { + | ${expr.getCode} + | $conditionPassed = !${expr.nullTerm} && ${expr.resultTerm}; + |} + """.stripMargin + } else { + s"final boolean $conditionPassed = true;" + } + + // generate the final result vars that need to consider the null for outer join + val buildResultVars = genProbeOuterBuildVars(matched, buildVars) + val resultVars = if (leftIsBuild) { + buildResultVars ++ toScala(inputVars) + } else { + toScala(inputVars) ++ buildResultVars + } + val buildIterTerm = newName("buildIter") + val found = newName("found") + val hasNext = newName("hasNext") + s""" + |// generate join key for probe side + |$keyCode + | + |boolean $found = false; + |boolean $hasNext = false; + |// find matches from hash table + |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ? + | null : $hashTableTerm.get(${keyEv.resultTerm}); + |${getOperatorCtx.reuseLocalVariableCode(matched)} + |while (($buildIterTerm != null && ($hasNext = $buildIterTerm.advanceNext())) || !$found) { + | $ROW_DATA $matched = $buildIterTerm != null && $hasNext ? $buildIterTerm.getRow() : null; + | ${checkCondition.trim} + | if ($conditionPassed) { + | $found = true; + | ${fusionContext.processConsume(toJava(resultVars))} + | } + |} + """.stripMargin + } + + private def codegenSemiProbe(inputVars: util.List[GeneratedExpression]): String = { + val (keyEv, anyNull) = genStreamSideJoinKey(probeKeys, inputVars) + val keyCode = keyEv.getCode + val (matched, checkCondition, buildLocalVars, _) = getJoinCondition(buildType) + + val buildIterTerm = newName("buildIter") + s""" + |// generate join key for probe side + |$keyCode + |// find matches from hash table + |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ? + | null : $hashTableTerm.get(${keyEv.resultTerm}); + |if ($buildIterTerm != null ) { + | $buildLocalVars + | while ($buildIterTerm.advanceNext()) { + | $ROW_DATA $matched = $buildIterTerm.getRow(); + | $checkCondition { + | ${fusionContext.processConsume(inputVars)} + | break; + | } + | } + |} + """.stripMargin + } + + private def codegenAntiProbe(inputVars: util.List[GeneratedExpression]): String = { + val (keyEv, anyNull) = genStreamSideJoinKey(probeKeys, inputVars) + val keyCode = keyEv.getCode + val (matched, checkCondition, buildLocalVars, _) = getJoinCondition(buildType) + + val buildIterTerm = newName("buildIter") + val found = newName("found") + + s""" + |// generate join key for probe side + |$keyCode + |boolean $found = false; + |// find matches from hash table + |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ? + | null : $hashTableTerm.get(${keyEv.resultTerm}); + |if ($buildIterTerm != null ) { + | $buildLocalVars + | while ($buildIterTerm.advanceNext()) { + | $ROW_DATA $matched = $buildIterTerm.getRow(); + | $checkCondition { + | $found = true; + | break; + | } + | } + |} + | + |if (!$found) { + | ${fusionContext.processConsume(inputVars)} + |} + """.stripMargin + } + + override def doEndInputConsume(inputId: Int): String = { + // If the hash table spill to disk during runtime, the probe endInput also need to + // consumeProcess to consume the spilled record + if (inputId == buildInputId) { + s""" + |LOG.info("Finish build phase."); + |$hashTableTerm.endBuild(); + """.stripMargin + } else { + fusionContext.endInputConsume() + } + } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key has + * any null in it or not. + */ + protected def genStreamSideJoinKey( + probeKeyMapping: Array[Int], + inputVars: util.List[GeneratedExpression]): (GeneratedExpression, String) = { + // current only support one join key which is long type + if (probeKeyMapping.length == 1) { + // generate the join key as Long + val ev = inputVars.get(probeKeyMapping(0)) + (ev, ev.nullTerm) + } else { + // generate the join key as BinaryRowData + throw new UnsupportedOperationException( + s"Operator fusion codegen doesn't support multiple join keys now.") + } + } + + protected def genAnyNullsInKeys( + keyMapping: Array[Int], + input: Seq[GeneratedExpression]): (String, String) = { + val builder = new StringBuilder + val codeBuilder = new StringBuilder + val anyNullTerm = newName("anyNull") + + keyMapping.foreach( + key => { + codeBuilder.append(input(key).getCode + "\n") + builder.append(s"$anyNullTerm |= ${input(key).nullTerm};") + }) + ( + s""" + |boolean $anyNullTerm = false; + |$codeBuilder + |$builder + """.stripMargin, + anyNullTerm) + } + + protected def getJoinCondition( + buildType: RowType): (String, String, String, Seq[GeneratedExpression]) = { + val buildRow = newName("buildRow") + // here need bind the buildRow before generate build condition + if (buildInputId == 1) { + getExprCodeGenerator.bindInput(buildType, buildRow) + } else { + getExprCodeGenerator.bindSecondInput(buildType, buildRow) + } + + getOperatorCtx.startNewLocalVariableStatement(buildRow) + val buildVars = genInputVars(buildRow, buildType) + val checkCondition = if (joinSpec.getNonEquiCondition.isPresent) { + // bind the build row name again + val expr = getExprCodeGenerator.generateExpression(joinSpec.getNonEquiCondition.get) + val skipRow = s"${expr.nullTerm} || !${expr.resultTerm}" + s""" + |// generate join condition + |${expr.getCode} + |if (!($skipRow)) + """.stripMargin + } else { + "" + } + val buildLocalVars = + if (hashJoinType.leftSemiOrAnti() && !joinSpec.getNonEquiCondition.isPresent) { + "" + } else { + getOperatorCtx.reuseLocalVariableCode(buildRow) + } + + (buildRow, checkCondition, buildLocalVars, buildVars) + } + + /** Generates build side expr for outer join. */ + protected def genProbeOuterBuildVars( + buildRow: String, + buildVars: Seq[GeneratedExpression]): Seq[GeneratedExpression] = { + buildVars.zipWithIndex.map { + case (expr, i) => + val fieldType = buildType.getTypeAt(i) + val resultTypeTerm = primitiveTypeTermForType(fieldType) + val defaultValue = primitiveDefaultValue(fieldType) + val Seq(fieldTerm, nullTerm) = + getOperatorCtx.addReusableLocalVariables((resultTypeTerm, "field"), ("boolean", "isNull")) + val code = s""" + |$nullTerm = true; + |$fieldTerm = $defaultValue; + |if ($buildRow != null) { + | ${expr.getCode} + | $nullTerm = ${expr.nullTerm}; + | $fieldTerm = ${expr.resultTerm}; + |} + """.stripMargin + GeneratedExpression(fieldTerm, nullTerm, code, fieldType) + } + } + + /** Generates the input row variables expr. */ + def genInputVars(inputRowTerm: String, inputType: RowType): Seq[GeneratedExpression] = { + val indices = fieldIndices(inputType) + val buildExprs = indices + .map( + index => GenerateUtils.generateFieldAccess(getOperatorCtx, inputType, inputRowTerm, index)) + .toSeq + indices.foreach( + index => + getOperatorCtx + .addReusableInputUnboxingExprs(inputRowTerm, index, buildExprs(index))) + + buildExprs + } + + override def usedInputVars(inputId: Int): util.Set[Integer] = { + if (inputId == buildInputId) { + val set: util.Set[Integer] = new util.HashSet[Integer]() + buildKeys.toStream.map(key => set.add(key)) + set + } else { + super.usedInputVars(inputId) + } + } + + override def getInputRowDataClass(inputId: Int): Class[_ <: RowData] = { + if (inputId == buildInputId) { + // To build side, we wrap it BinaryRowData + classOf[BinaryRowData] + } else { + super.getInputRowDataClass(inputId) + } + } + + private def codegenHashTable(spillEnabled: Boolean): Unit = { Review Comment: Yes. Since there is a relatively large difference between OFCG and single operator codegen in terms of doing code generation, the first step is not very good for us to do code reuse. Also due to 1.18 codefreeze, there is not much time to consider code reuse, but I believe this thing in subsequent versions we are going to do gradually, after all, the cost of maintaining two sets of code is huge. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
