lsyldliu commented on code in PR #22734:
URL: https://github.com/apache/flink/pull/22734#discussion_r1242113622


##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/fusion/OpFusionCodegenSpec.java:
##########
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.fusion;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.ExprCodeGenerator;
+import org.apache.flink.table.planner.codegen.GeneratedExpression;
+
+import java.util.List;
+import java.util.Set;
+
+/** An interface for those physical operators that support operator fusion 
codegen. */
+@Internal
+public interface OpFusionCodegenSpec {
+
+    /**
+     * Initializes the operator spec. Sets access to the context. This method 
must be called before
+     * doProduce and doConsume related methods.
+     */
+    void setup(OpFusionContext opFusionContext);
+
+    /** Prefix used in the current operator's variable names. */
+    String variablePrefix();

Review Comment:
   Intuitively, it's fine to use class names. But I feel like using class names 
would result in long variable names, and it makes sense to provide this 
interface.



##########
flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala:
##########
@@ -449,14 +449,17 @@ object GenerateUtils {
    *   whether the input is nullable
    * @param deepCopy
    *   whether to copy the accessed field (usually needed when buffered)
+   * @param fusionCodegen
+   *   if fusion codegen enabled, don't need to hide generated code

Review Comment:
   Currently, regarding the case of single operator codegen, we materialize the 
code for each input field to take values in advance, thus avoiding multiple 
evaluate them inside the operator. 
   However, in the multi-operator fusion codegen case, the logic of our 
computation is reversed, and one of the objectives we want to achieve is to 
compute as lazy as possible, i.e., delay materialization. So we need to get the 
code that access the value of each field in advance, and then passes it to the 
downstream operator,  only materializing it where it is needed, instead of 
calculating it in advance. For example, for the Join operator, suppose there 
are 5 fields at the probe side and the build side together, and then pass to 
the downstream project operator, the project cuts out 3 of the fields and keeps 
only two, the three fields that are cut out, in fact, we do not need to take 
the value at all, which is one of the goals of OFCG, to eliminate invalid 
computation.



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java:
##########
@@ -293,4 +299,61 @@ private long getLargeManagedMemory(FlinkJoinType joinType, 
ExecNodeConfig config
         // large one
         return Math.max(hashJoinManagedMemory, sortMergeJoinManagedMemory);
     }
+
+    @Override
+    public boolean supportFusionCodegen() {
+        RowType leftType = (RowType) getInputEdges().get(0).getOutputType();
+        LogicalType[] keyFieldTypes =
+                IntStream.of(joinSpec.getLeftKeys())
+                        .mapToObj(leftType::getTypeAt)
+                        .toArray(LogicalType[]::new);
+        RowType keyType = RowType.of(keyFieldTypes);
+        FlinkJoinType joinType = joinSpec.getJoinType();
+        HashJoinType hashJoinType =
+                HashJoinType.of(
+                        leftIsBuild,
+                        joinType.isLeftOuter(),
+                        joinType.isRightOuter(),
+                        joinType == FlinkJoinType.SEMI,
+                        joinType == FlinkJoinType.ANTI);
+        // TODO decimal and multiKeys support and all HashJoinType support.
+        return LongHashJoinGenerator.support(hashJoinType, keyType, 
joinSpec.getFilterNulls());

Review Comment:
   Currently, we only support codegen when the join key is a single key and is 
a long type, otherwise, it is all hardcode, so the first version we currently 
follow this behavior. For multi-key scenarios, we need another PR to complete 
them, which are two separate work.



##########
flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/HashJoinFusionCodegenSpec.scala:
##########
@@ -0,0 +1,542 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.plan.fusion.spec
+
+import org.apache.flink.table.data.RowData
+import org.apache.flink.table.data.binary.BinaryRowData
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, 
GeneratedExpression, GenerateUtils}
+import org.apache.flink.table.planner.codegen.CodeGenUtils.{fieldIndices, 
newName, newNames, primitiveDefaultValue, primitiveTypeTermForType, BINARY_ROW, 
ROW_DATA}
+import 
org.apache.flink.table.planner.codegen.LongHashJoinGenerator.{genGetLongKey, 
genProjection}
+import org.apache.flink.table.planner.plan.fusion.{OpFusionCodegenSpecBase, 
OpFusionCodegenSpecGenerator, OpFusionContext}
+import org.apache.flink.table.planner.plan.nodes.exec.spec.JoinSpec
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.{toJava, 
toScala}
+import org.apache.flink.table.runtime.hashtable.LongHybridHashTable
+import org.apache.flink.table.runtime.operators.join.{FlinkJoinType, 
HashJoinType}
+import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer
+import org.apache.flink.table.runtime.util.RowIterator
+import org.apache.flink.table.types.logical.{LogicalType, RowType}
+
+import java.util
+
+/** Base operator fusion codegen spec for HashJoin. */
+class HashJoinFusionCodegenSpec(
+    operatorCtx: CodeGeneratorContext,
+    isBroadcast: Boolean,
+    leftIsBuild: Boolean,
+    joinSpec: JoinSpec,
+    estimatedLeftAvgRowSize: Int,
+    estimatedRightAvgRowSize: Int,
+    estimatedLeftRowCount: Long,
+    estimatedRightRowCount: Long,
+    compressionEnabled: Boolean,
+    compressionBlockSize: Int)
+  extends OpFusionCodegenSpecBase(operatorCtx) {
+
+  private lazy val joinType: FlinkJoinType = joinSpec.getJoinType
+  private lazy val hashJoinType: HashJoinType = HashJoinType.of(
+    leftIsBuild,
+    joinType.isLeftOuter,
+    joinType.isRightOuter,
+    joinType == FlinkJoinType.SEMI,
+    joinType == FlinkJoinType.ANTI)
+  private lazy val (buildKeys, probeKeys) = if (leftIsBuild) {
+    (joinSpec.getLeftKeys, joinSpec.getRightKeys)
+  } else {
+    (joinSpec.getRightKeys, joinSpec.getLeftKeys)
+  }
+  private lazy val (buildRowSize, buildRowCount) = if (leftIsBuild) {
+    (estimatedLeftAvgRowSize, estimatedLeftRowCount)
+  } else {
+    (estimatedRightAvgRowSize, estimatedRightRowCount)
+  }
+  private lazy val buildInputId = if (leftIsBuild) {
+    1
+  } else {
+    2
+  }
+
+  private lazy val Seq(buildToBinaryRow, probeToBinaryRow) =
+    newNames("buildToBinaryRow", "probeToBinaryRow")
+
+  private lazy val hashTableTerm: String = newName("hashTable")
+
+  private var buildInput: OpFusionCodegenSpecGenerator = _
+  private var probeInput: OpFusionCodegenSpecGenerator = _
+  private var buildType: RowType = _
+  private var probeType: RowType = _
+  private var keyType: RowType = _
+
+  override def setup(opFusionContext: OpFusionContext): Unit = {
+    super.setup(opFusionContext)
+    val inputs = toScala(fusionContext.getInputs)
+    assert(inputs.size == 2)
+    if (leftIsBuild) {
+      buildInput = inputs.head
+      probeInput = inputs(1)
+    } else {
+      buildInput = inputs(1)
+      probeInput = inputs.head
+    }
+
+    buildType = buildInput.getOutputType
+    probeType = probeInput.getOutputType
+    if (leftIsBuild) {
+      keyType = RowType.of(joinSpec.getLeftKeys.map(idx => 
buildType.getTypeAt(idx)): _*)
+    } else {
+      keyType = RowType.of(joinSpec.getLeftKeys.map(idx => 
probeType.getTypeAt(idx)): _*)
+    }
+  }
+
+  override def variablePrefix: String = if (isBroadcast) { "bhj" }
+  else { "shj" }
+
+  override protected def doProcessProduce(fusionCtx: CodeGeneratorContext): 
Unit = {
+    // call build side first, then call probe side
+    buildInput.processProduce(fusionCtx)
+    probeInput.processProduce(fusionCtx)
+  }
+
+  override protected def doEndInputProduce(fusionCtx: CodeGeneratorContext): 
Unit = {
+    // call build side first, then call probe side
+    buildInput.endInputProduce(fusionCtx)
+    probeInput.endInputProduce(fusionCtx)
+  }
+
+  override def doProcessConsume(
+      inputId: Int,
+      inputVars: util.List[GeneratedExpression],
+      row: GeneratedExpression): String = {
+    // only probe side will call the consumeProcess method to consume the 
output record
+    if (inputId == buildInputId) {
+      codegenBuild(toScala(inputVars), row)
+    } else {
+      codegenProbe(inputVars)
+    }
+  }
+
+  private def codegenBuild(
+      inputVars: Seq[GeneratedExpression],
+      row: GeneratedExpression): String = {
+    // initialize hash table related code
+    if (isBroadcast) {
+      codegenHashTable(false)
+    } else {
+      // TODO Shuffled HashJoin support build side spill to disk
+      codegenHashTable(true)
+    }
+
+    val (nullCheckBuildCode, nullCheckBuildTerm) = {
+      genAnyNullsInKeys(buildKeys, inputVars)
+    }
+    s"""
+       |$nullCheckBuildCode
+       |if (!$nullCheckBuildTerm) {
+       |  ${row.getCode}
+       |  $hashTableTerm.putBuildRow(($BINARY_ROW) ${row.resultTerm});
+       |}
+       """.stripMargin
+  }
+
+  private def codegenProbe(inputVars: util.List[GeneratedExpression]): String 
= {
+    hashJoinType match {
+      case HashJoinType.INNER =>
+        codegenInnerProbe(inputVars)
+      case HashJoinType.PROBE_OUTER => codegenProbeOuterProbe(inputVars)
+      case HashJoinType.SEMI => codegenSemiProbe(inputVars)
+      case HashJoinType.ANTI => codegenAntiProbe(inputVars)
+      case _ =>
+        throw new UnsupportedOperationException(
+          s"Operator fusion codegen doesn't support $hashJoinType now.")
+    }
+  }
+
+  private def codegenInnerProbe(inputVars: util.List[GeneratedExpression]): 
String = {
+    val (keyEv, anyNull) = genStreamSideJoinKey(probeKeys, inputVars)
+    val keyCode = keyEv.getCode
+    val (matched, checkCondition, buildLocalVars, buildVars) = 
getJoinCondition(buildType)
+    val resultVars = if (leftIsBuild) {
+      buildVars ++ toScala(inputVars)
+    } else {
+      toScala(inputVars) ++ buildVars
+    }
+    val buildIterTerm = newName("buildIter")
+    s"""
+       |// generate join key for probe side
+       |$keyCode
+       |// find matches from hash table
+       |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ?
+       |  null : $hashTableTerm.get(${keyEv.resultTerm});
+       |if ($buildIterTerm != null ) {
+       |  $buildLocalVars
+       |  while ($buildIterTerm.advanceNext()) {
+       |    $ROW_DATA $matched = $buildIterTerm.getRow();
+       |    $checkCondition {
+       |      ${fusionContext.processConsume(toJava(resultVars))}
+       |    }
+       |  }
+       |}
+           """.stripMargin
+  }
+
+  private def codegenProbeOuterProbe(inputVars: 
util.List[GeneratedExpression]): String = {
+    val (keyEv, anyNull) = genStreamSideJoinKey(probeKeys, inputVars)
+    val keyCode = keyEv.getCode
+    val matched = newName("buildRow")
+    // start new local variable
+    getOperatorCtx.startNewLocalVariableStatement(matched)
+    val buildVars = genInputVars(matched, buildType)
+
+    // filter the output via condition
+    val conditionPassed = newName("conditionPassed")
+    val checkCondition = if (joinSpec.getNonEquiCondition.isPresent) {
+      // here need bind the buildRow before generate build condition
+      if (buildInputId == 1) {
+        getExprCodeGenerator.bindInput(buildType, matched)
+      } else {
+        getExprCodeGenerator.bindSecondInput(buildType, matched)
+      }
+      // TODO evaluate the variables from probe and build side that used by 
condition in advance
+      // generate the expr code
+      val expr = 
getExprCodeGenerator.generateExpression(joinSpec.getNonEquiCondition.get)
+      s"""
+         |boolean $conditionPassed = true;
+         |if ($matched != null) {
+         |  ${expr.getCode}
+         |  $conditionPassed = !${expr.nullTerm} && ${expr.resultTerm};
+         |}
+       """.stripMargin
+    } else {
+      s"final boolean $conditionPassed = true;"
+    }
+
+    // generate the final result vars that need to consider the null for outer 
join
+    val buildResultVars = genProbeOuterBuildVars(matched, buildVars)
+    val resultVars = if (leftIsBuild) {
+      buildResultVars ++ toScala(inputVars)
+    } else {
+      toScala(inputVars) ++ buildResultVars
+    }
+    val buildIterTerm = newName("buildIter")
+    val found = newName("found")
+    val hasNext = newName("hasNext")
+    s"""
+       |// generate join key for probe side
+       |$keyCode
+       |
+       |boolean $found = false;
+       |boolean $hasNext = false;
+       |// find matches from hash table
+       |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ?
+       |  null : $hashTableTerm.get(${keyEv.resultTerm});
+       |${getOperatorCtx.reuseLocalVariableCode(matched)}
+       |while (($buildIterTerm != null && ($hasNext = 
$buildIterTerm.advanceNext())) || !$found) {
+       |  $ROW_DATA $matched = $buildIterTerm != null && $hasNext ? 
$buildIterTerm.getRow() : null;
+       |  ${checkCondition.trim}
+       |  if ($conditionPassed) {
+       |    $found = true;
+       |    ${fusionContext.processConsume(toJava(resultVars))}
+       |  }
+       |}
+           """.stripMargin
+  }
+
+  private def codegenSemiProbe(inputVars: util.List[GeneratedExpression]): 
String = {
+    val (keyEv, anyNull) = genStreamSideJoinKey(probeKeys, inputVars)
+    val keyCode = keyEv.getCode
+    val (matched, checkCondition, buildLocalVars, _) = 
getJoinCondition(buildType)
+
+    val buildIterTerm = newName("buildIter")
+    s"""
+       |// generate join key for probe side
+       |$keyCode
+       |// find matches from hash table
+       |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ?
+       |  null : $hashTableTerm.get(${keyEv.resultTerm});
+       |if ($buildIterTerm != null ) {
+       |  $buildLocalVars
+       |  while ($buildIterTerm.advanceNext()) {
+       |    $ROW_DATA $matched = $buildIterTerm.getRow();
+       |    $checkCondition {
+       |      ${fusionContext.processConsume(inputVars)}
+       |      break;
+       |    }
+       |  }
+       |}
+           """.stripMargin
+  }
+
+  private def codegenAntiProbe(inputVars: util.List[GeneratedExpression]): 
String = {
+    val (keyEv, anyNull) = genStreamSideJoinKey(probeKeys, inputVars)
+    val keyCode = keyEv.getCode
+    val (matched, checkCondition, buildLocalVars, _) = 
getJoinCondition(buildType)
+
+    val buildIterTerm = newName("buildIter")
+    val found = newName("found")
+
+    s"""
+       |// generate join key for probe side
+       |$keyCode
+       |boolean $found = false;
+       |// find matches from hash table
+       |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ?
+       |  null : $hashTableTerm.get(${keyEv.resultTerm});
+       |if ($buildIterTerm != null ) {
+       |  $buildLocalVars
+       |  while ($buildIterTerm.advanceNext()) {
+       |    $ROW_DATA $matched = $buildIterTerm.getRow();
+       |    $checkCondition {
+       |      $found = true;
+       |      break;
+       |    }
+       |  }
+       |}
+       |
+       |if (!$found) {
+       |  ${fusionContext.processConsume(inputVars)}
+       |}
+           """.stripMargin
+  }
+
+  override def doEndInputConsume(inputId: Int): String = {
+    // If the hash table spill to disk during runtime, the probe endInput also 
need to
+    // consumeProcess to consume the spilled record
+    if (inputId == buildInputId) {
+      s"""
+         |LOG.info("Finish build phase.");
+         |$hashTableTerm.endBuild();
+       """.stripMargin
+    } else {
+      fusionContext.endInputConsume()
+    }
+  }
+
+  /**
+   * Returns the code for generating join key for stream side, and expression 
of whether the key has
+   * any null in it or not.
+   */
+  protected def genStreamSideJoinKey(
+      probeKeyMapping: Array[Int],
+      inputVars: util.List[GeneratedExpression]): (GeneratedExpression, 
String) = {
+    // current only support one join key which is long type
+    if (probeKeyMapping.length == 1) {
+      // generate the join key as Long
+      val ev = inputVars.get(probeKeyMapping(0))
+      (ev, ev.nullTerm)
+    } else {
+      // generate the join key as BinaryRowData
+      throw new UnsupportedOperationException(
+        s"Operator fusion codegen doesn't support multiple join keys now.")
+    }
+  }
+
+  protected def genAnyNullsInKeys(
+      keyMapping: Array[Int],
+      input: Seq[GeneratedExpression]): (String, String) = {
+    val builder = new StringBuilder
+    val codeBuilder = new StringBuilder
+    val anyNullTerm = newName("anyNull")
+
+    keyMapping.foreach(
+      key => {
+        codeBuilder.append(input(key).getCode + "\n")
+        builder.append(s"$anyNullTerm |= ${input(key).nullTerm};")
+      })
+    (
+      s"""
+         |boolean $anyNullTerm = false;
+         |$codeBuilder
+         |$builder
+     """.stripMargin,
+      anyNullTerm)
+  }
+
+  protected def getJoinCondition(
+      buildType: RowType): (String, String, String, Seq[GeneratedExpression]) 
= {
+    val buildRow = newName("buildRow")
+    // here need bind the buildRow before generate build condition
+    if (buildInputId == 1) {
+      getExprCodeGenerator.bindInput(buildType, buildRow)
+    } else {
+      getExprCodeGenerator.bindSecondInput(buildType, buildRow)
+    }
+
+    getOperatorCtx.startNewLocalVariableStatement(buildRow)
+    val buildVars = genInputVars(buildRow, buildType)
+    val checkCondition = if (joinSpec.getNonEquiCondition.isPresent) {
+      // bind the build row name again
+      val expr = 
getExprCodeGenerator.generateExpression(joinSpec.getNonEquiCondition.get)
+      val skipRow = s"${expr.nullTerm} || !${expr.resultTerm}"
+      s"""
+         |// generate join condition
+         |${expr.getCode}
+         |if (!($skipRow))
+       """.stripMargin
+    } else {
+      ""
+    }
+    val buildLocalVars =
+      if (hashJoinType.leftSemiOrAnti() && 
!joinSpec.getNonEquiCondition.isPresent) {
+        ""
+      } else {
+        getOperatorCtx.reuseLocalVariableCode(buildRow)
+      }
+
+    (buildRow, checkCondition, buildLocalVars, buildVars)
+  }
+
+  /** Generates build side expr for outer join. */
+  protected def genProbeOuterBuildVars(
+      buildRow: String,
+      buildVars: Seq[GeneratedExpression]): Seq[GeneratedExpression] = {
+    buildVars.zipWithIndex.map {
+      case (expr, i) =>
+        val fieldType = buildType.getTypeAt(i)
+        val resultTypeTerm = primitiveTypeTermForType(fieldType)
+        val defaultValue = primitiveDefaultValue(fieldType)
+        val Seq(fieldTerm, nullTerm) =
+          getOperatorCtx.addReusableLocalVariables((resultTypeTerm, "field"), 
("boolean", "isNull"))
+        val code = s"""
+                      |$nullTerm = true;
+                      |$fieldTerm = $defaultValue;
+                      |if ($buildRow != null) {
+                      |  ${expr.getCode}
+                      |  $nullTerm = ${expr.nullTerm};
+                      |  $fieldTerm = ${expr.resultTerm};
+                      |}
+          """.stripMargin
+        GeneratedExpression(fieldTerm, nullTerm, code, fieldType)
+    }
+  }
+
+  /** Generates the input row variables expr. */
+  def genInputVars(inputRowTerm: String, inputType: RowType): 
Seq[GeneratedExpression] = {
+    val indices = fieldIndices(inputType)
+    val buildExprs = indices
+      .map(
+        index => GenerateUtils.generateFieldAccess(getOperatorCtx, inputType, 
inputRowTerm, index))
+      .toSeq
+    indices.foreach(
+      index =>
+        getOperatorCtx
+          .addReusableInputUnboxingExprs(inputRowTerm, index, 
buildExprs(index)))
+
+    buildExprs
+  }
+
+  override def usedInputVars(inputId: Int): util.Set[Integer] = {
+    if (inputId == buildInputId) {
+      val set: util.Set[Integer] = new util.HashSet[Integer]()
+      buildKeys.toStream.map(key => set.add(key))
+      set
+    } else {
+      super.usedInputVars(inputId)
+    }
+  }
+
+  override def getInputRowDataClass(inputId: Int): Class[_ <: RowData] = {
+    if (inputId == buildInputId) {
+      // To build side, we wrap it BinaryRowData
+      classOf[BinaryRowData]
+    } else {
+      super.getInputRowDataClass(inputId)
+    }
+  }
+
+  private def codegenHashTable(spillEnabled: Boolean): Unit = {

Review Comment:
   Yes. Since there is a relatively large difference between OFCG and single 
operator codegen in terms of doing code generation, the first step is not very 
good for us to do code reuse. Also due to 1.18 codefreeze, there is not much 
time to consider code reuse, but I believe this thing in subsequent versions we 
are going to do gradually, after all, the cost of maintaining two sets of code 
is huge.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to