This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch release-1.18
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.18 by this push:
     new c78f440533e [FLINK-32905][table-runtime] Fix the bug of broadcast hash 
join doesn't support spill to disk when enable operator fusion codegn
c78f440533e is described below

commit c78f440533efe68687ab1e6e04d3407e40303de7
Author: Ron <[email protected]>
AuthorDate: Thu Aug 24 19:22:17 2023 +0800

    [FLINK-32905][table-runtime] Fix the bug of broadcast hash join doesn't 
support spill to disk when enable operator fusion codegn
    
    This closes #23274
---
 .../plan/fusion/spec/CalcFusionCodegenSpec.scala   |  1 +
 .../fusion/spec/HashJoinFusionCodegenSpec.scala    | 81 ++++++++--------------
 .../batch/sql/OperatorFusionCodegenITCase.scala    |  9 +++
 .../runtime/hashtable/BaseHybridHashTable.java     | 10 +--
 .../table/runtime/hashtable/BinaryHashTable.java   |  9 +--
 .../runtime/hashtable/LongHybridHashTable.java     | 35 +---------
 6 files changed, 41 insertions(+), 104 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/CalcFusionCodegenSpec.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/CalcFusionCodegenSpec.scala
index 00fbe0697c5..1007fb5bcf0 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/CalcFusionCodegenSpec.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/CalcFusionCodegenSpec.scala
@@ -64,6 +64,7 @@ class CalcFusionCodegenSpec(
     } else if (condition.isEmpty) { // only projection
       val projectionExprs = 
projection.map(getExprCodeGenerator.generateExpression)
       s"""
+         |${opCodegenCtx.reuseLocalVariableCode()}
          |${evaluateRequiredVariables(toScala(inputVars), 
projectionUsedColumns)}
          |${fusionContext.processConsume(toJava(projectionExprs))}
          |""".stripMargin
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/HashJoinFusionCodegenSpec.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/HashJoinFusionCodegenSpec.scala
index c2fa0733eb7..618460c5c76 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/HashJoinFusionCodegenSpec.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/HashJoinFusionCodegenSpec.scala
@@ -25,7 +25,7 @@ import 
org.apache.flink.table.planner.codegen.LongHashJoinGenerator.{genGetLongK
 import org.apache.flink.table.planner.plan.fusion.{OpFusionCodegenSpecBase, 
OpFusionContext}
 import 
org.apache.flink.table.planner.plan.fusion.FusionCodegenUtil.{constructDoConsumeCode,
 constructDoConsumeFunction, evaluateRequiredVariables, extractRefInputFields}
 import org.apache.flink.table.planner.plan.nodes.exec.spec.JoinSpec
-import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.{toJava, 
toScala}
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
 import org.apache.flink.table.runtime.hashtable.LongHybridHashTable
 import org.apache.flink.table.runtime.operators.join.{FlinkJoinType, 
HashJoinType}
 import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer
@@ -137,11 +137,7 @@ class HashJoinFusionCodegenSpec(
       inputVars: Seq[GeneratedExpression],
       row: GeneratedExpression): String = {
     // initialize hash table related code
-    if (isBroadcast) {
-      codegenHashTable(false)
-    } else {
-      codegenHashTable(true)
-    }
+    codegenHashTable()
 
     val (nullCheckBuildCode, nullCheckBuildTerm) = {
       genAnyNullsInKeys(buildKeys, inputVars)
@@ -166,7 +162,13 @@ class HashJoinFusionCodegenSpec(
        |// find matches from hash table
        |${classOf[RowIterator[_]].getCanonicalName} $buildIterTerm = $anyNull ?
        |  null : $hashTableTerm.get(${keyEv.resultTerm});
-       |${wrapProbeWithSpilledCode(anyNull, buildIterTerm, row, processCode)}
+       |
+       |if(!$anyNull && $buildIterTerm == null) {
+       |   ${row.code}
+       |   $hashTableTerm.insertIntoProbeBuffer(${row.resultTerm});
+       |} else {
+       |  $processCode
+       |}
            """.stripMargin
   }
 
@@ -179,15 +181,11 @@ class HashJoinFusionCodegenSpec(
          |$hashTableTerm.endBuild();
        """.stripMargin
     } else {
-      if (isBroadcast) {
-        fusionContext.endInputConsume()
-      } else {
-        s"""
-           |// Process the spilled partitions first
-           |$codegenEndInputCode
-           |${fusionContext.endInputConsume()}
-           |""".stripMargin
-      }
+      s"""
+         |// Process the spilled partitions first
+         |$codegenEndInputCode
+         |${fusionContext.endInputConsume()}
+         |""".stripMargin
     }
   }
 
@@ -214,6 +212,10 @@ class HashJoinFusionCodegenSpec(
        |  $processCode
        |}
        |LOG.info("Finish rebuild phase.");
+       |
+       |if(!$hashTableTerm.getPartitionsPendingForSMJ().isEmpty()) {
+       |  throw new UnsupportedOperationException("Currently doesn't support 
fallback to sort merge join for hash join when 
'table.exec.operator-fusion-codegen.enabled' is true.");
+       |}
        |""".stripMargin
   }
 
@@ -511,45 +513,19 @@ class HashJoinFusionCodegenSpec(
     classOf[BinaryRowData]
   }
 
-  private def wrapProbeWithSpilledCode(
-      anyNull: String,
-      buildIterTerm: String,
-      row: GeneratedExpression,
-      processCode: String): String = {
-    // Broadcast HashJoin doesn't support spill to disk.
-    if (isBroadcast) {
-      processCode
-    } else {
-      // If join key is not null and $buildIterTerm is null indicate the build 
partition
-      // corresponding to the probe row has spilled to disk, so also spill it 
to disk.
-      s"""
-         |if(!$anyNull && $buildIterTerm == null) {
-         |   ${row.code}
-         |   $hashTableTerm.insertIntoProbeBuffer(${row.resultTerm});
-         |} else {
-         |  $processCode
-         |}
-         |""".stripMargin
-    }
-  }
-
   private def codegenConsumeCode(resultVars: Seq[GeneratedExpression]): String 
= {
-    if (isBroadcast) {
-      fusionContext.processConsume(toJava(resultVars))
-    } else {
-      // Here need to cache to avoid generating the consume code multiple time
-      if (consumeFunctionName == null) {
-        consumeFunctionName = constructDoConsumeFunction(
-          variablePrefix,
-          opCodegenCtx,
-          fusionContext,
-          fusionContext.getOutputType)
-      }
-      constructDoConsumeCode(consumeFunctionName, resultVars)
+    // Here need to cache to avoid generating the consume code multiple time
+    if (consumeFunctionName == null) {
+      consumeFunctionName = constructDoConsumeFunction(
+        variablePrefix,
+        opCodegenCtx,
+        fusionContext,
+        fusionContext.getOutputType)
     }
+    constructDoConsumeCode(consumeFunctionName, resultVars)
   }
 
-  private def codegenHashTable(spillEnabled: Boolean): Unit = {
+  private def codegenHashTable(): Unit = {
     val buildSer = new BinaryRowDataSerializer(buildType.getFieldCount)
     val buildSerTerm = opCodegenCtx.addReusableObject(buildSer, "buildSer")
     val probeSer = new BinaryRowDataSerializer(probeType.getFieldCount)
@@ -591,8 +567,7 @@ class HashJoinFusionCodegenSpec(
          |      memorySize,
          |      getContainingTask().getEnvironment().getIOManager(),
          |      $buildRowSize,
-         |      ${buildRowCount}L / 
getRuntimeContext().getNumberOfParallelSubtasks(),
-         |      $spillEnabled);
+         |      ${buildRowCount}L / 
getRuntimeContext().getNumberOfParallelSubtasks());
          |  }
          |
          |  @Override
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/OperatorFusionCodegenITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/OperatorFusionCodegenITCase.scala
index 9d909804ef6..1211bb025b0 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/OperatorFusionCodegenITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/OperatorFusionCodegenITCase.scala
@@ -156,6 +156,15 @@ class OperatorFusionCodegenITCase extends BatchTestBase {
     )
   }
 
+  @TestTemplate
+  def testHashJoinWithOnlyProjection(): Unit = {
+    checkOpFusionCodegenResult("""
+                                 |SELECT * FROM (SELECT a, nx + ny AS nt FROM x
+                                 |  JOIN y ON x.a = y.ny) t
+                                 |JOIN z ON t.a = z.nz WHERE t.nt -10 > z.nz
+                                 |""".stripMargin)
+  }
+
   @TestTemplate
   def testHashJoinWithDeadlockCausedByExchangeInAncestor(): Unit = {
     checkOpFusionCodegenResult(
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
index 32b32cad0ad..cd71bd9142b 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
@@ -99,12 +99,6 @@ public abstract class BaseHybridHashTable implements 
MemorySegmentPool {
     /** Try to make the buildSide rows distinct. */
     public final boolean tryDistinctBuildRow;
 
-    /**
-     * In operator fusion codegen case, we don't support spill to disk for 
broadcast hashjoin, so
-     * this flag is introduced.
-     */
-    protected final boolean spillEnabled;
-
     /**
      * The recursion depth of the partition that is currently processed. The 
initial table has a
      * recursion depth of 0. Partitions spilled from a table that is built for 
a partition with
@@ -147,8 +141,7 @@ public abstract class BaseHybridHashTable implements 
MemorySegmentPool {
             IOManager ioManager,
             int avgRecordLen,
             long buildRowCount,
-            boolean tryDistinctBuildRow,
-            boolean spillEnabled) {
+            boolean tryDistinctBuildRow) {
         this.compressionEnabled = compressionEnabled;
         this.compressionCodecFactory =
                 this.compressionEnabled
@@ -174,7 +167,6 @@ public abstract class BaseHybridHashTable implements 
MemorySegmentPool {
 
         this.segmentSizeBits = MathUtils.log2strict(segmentSize);
         this.segmentSizeMask = segmentSize - 1;
-        this.spillEnabled = spillEnabled;
 
         // open builds the initial table by consuming the build-side input
         this.currentRecursionDepth = 0;
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
index 101e3290e5a..b21a00f803b 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
@@ -41,7 +41,6 @@ import org.apache.flink.table.runtime.util.RowIterator;
 import org.apache.flink.util.MathUtils;
 
 import java.io.IOException;
-import java.io.UnsupportedEncodingException;
 import java.util.ArrayList;
 import java.util.List;
 
@@ -170,8 +169,7 @@ public class BinaryHashTable extends BaseHybridHashTable {
                 ioManager,
                 avgRecordLen,
                 buildRowCount,
-                !type.buildLeftSemiOrAnti() && tryDistinctBuildRow,
-                true);
+                !type.buildLeftSemiOrAnti() && tryDistinctBuildRow);
         // assign the members
         this.originBuildSideSerializer = buildSideSerializer;
         this.binaryBuildSideSerializer =
@@ -670,11 +668,6 @@ public class BinaryHashTable extends BaseHybridHashTable {
      */
     @Override
     protected int spillPartition() throws IOException {
-        if (!spillEnabled) {
-            throw new UnsupportedEncodingException(
-                    "Currently doesn't support spill to disk for grace hash 
join "
-                            + "when broadcast hash join strategy is chosen and 
operator fusion codegen is enabled simultaneously.");
-        }
         // find the largest partition
         int largestNumBlocks = 0;
         int largestPartNum = -1;
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
index cb0e81b296b..ff6b3047a70 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
@@ -37,7 +37,6 @@ import javax.annotation.Nullable;
 
 import java.io.EOFException;
 import java.io.IOException;
-import java.io.UnsupportedEncodingException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
@@ -84,32 +83,6 @@ public abstract class LongHybridHashTable extends 
BaseHybridHashTable {
             IOManager ioManager,
             int avgRecordLen,
             long buildRowCount) {
-        this(
-                owner,
-                compressionEnabled,
-                compressionBlockSize,
-                buildSideSerializer,
-                probeSideSerializer,
-                memManager,
-                reservedMemorySize,
-                ioManager,
-                avgRecordLen,
-                buildRowCount,
-                true);
-    }
-
-    public LongHybridHashTable(
-            Object owner,
-            boolean compressionEnabled,
-            int compressionBlockSize,
-            BinaryRowDataSerializer buildSideSerializer,
-            BinaryRowDataSerializer probeSideSerializer,
-            MemoryManager memManager,
-            long reservedMemorySize,
-            IOManager ioManager,
-            int avgRecordLen,
-            long buildRowCount,
-            boolean spillEnabled) {
         super(
                 owner,
                 compressionEnabled,
@@ -119,8 +92,7 @@ public abstract class LongHybridHashTable extends 
BaseHybridHashTable {
                 ioManager,
                 avgRecordLen,
                 buildRowCount,
-                false,
-                spillEnabled);
+                false);
         this.buildSideSerializer = buildSideSerializer;
         this.probeSideSerializer = probeSideSerializer;
 
@@ -621,11 +593,6 @@ public abstract class LongHybridHashTable extends 
BaseHybridHashTable {
 
     @Override
     public int spillPartition() throws IOException {
-        if (!spillEnabled) {
-            throw new UnsupportedEncodingException(
-                    "Currently doesn't support spill to disk for grace hash 
join "
-                            + "when broadcast hash join strategy is chosen and 
operator fusion codegen is enabled simultaneously.");
-        }
         // find the largest partition
         int largestNumBlocks = 0;
         int largestPartNum = -1;

Reply via email to