This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 80a79bb5f6320a97e4051cba6e22c151dc9603cf
Author: sunxia <xingbe...@gmail.com>
AuthorDate: Fri Jan 3 16:43:46 2025 +0800

    [FLINK-36608][table-runtime] Add AdaptiveBroadcastJoinOptimizationStrategy 
to support adaptive broadcast join.
---
 .../generated/optimizer_config_configuration.html  |   2 +-
 .../table/api/config/OptimizerConfigOptions.java   |   2 +-
 .../table/planner/delegation/BatchPlanner.scala    |  18 ++
 .../sql/adaptive/AdaptiveBroadcastJoinITCase.scala | 189 ++++++++++++++++++
 .../runtime/batch/sql/join/JoinITCase.scala        |   4 +
 .../flink/table/planner/utils/TableTestBase.scala  |   6 +-
 .../AdaptiveBroadcastJoinOptimizationStrategy.java | 212 +++++++++++++++++++++
 ...seAdaptiveJoinOperatorOptimizationStrategy.java |  71 +++++++
 .../strategy/PostProcessAdaptiveJoinStrategy.java  | 117 ++++++++++++
 9 files changed, 618 insertions(+), 3 deletions(-)

diff --git 
a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html 
b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html
index 40cf4b735e7..0ab126d602e 100644
--- a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html
+++ b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html
@@ -10,7 +10,7 @@
     <tbody>
         <tr>
             <td><h5>table.optimizer.adaptive-broadcast-join.strategy</h5><br> 
<span class="label label-primary">Batch</span></td>
-            <td style="word-wrap: break-word;">none</td>
+            <td style="word-wrap: break-word;">auto</td>
             <td><p>Enum</p></td>
             <td>Flink will perform broadcast hash join optimization when the 
runtime statistics on one side of a join operator is less than the threshold 
`table.optimizer.join.broadcast-threshold`. The value of this configuration 
option decides when Flink should perform this optimization. AUTO means Flink 
will automatically choose the timing for optimization, RUNTIME_ONLY means 
broadcast hash join optimization is only performed at runtime, and NONE means 
the optimization is only carried  [...]
         </tr>
diff --git 
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java
 
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java
index 00478b487d4..ccda21813f2 100644
--- 
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java
+++ 
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java
@@ -166,7 +166,7 @@ public class OptimizerConfigOptions {
             TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY =
                     key("table.optimizer.adaptive-broadcast-join.strategy")
                             .enumType(AdaptiveBroadcastJoinStrategy.class)
-                            .defaultValue(AdaptiveBroadcastJoinStrategy.NONE)
+                            .defaultValue(AdaptiveBroadcastJoinStrategy.AUTO)
                             .withDescription(
                                     "Flink will perform broadcast hash join 
optimization when the runtime "
                                             + "statistics on one side of a 
join operator is less than the "
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
index 6e8febadec5..6686de333ee 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.delegation
 import org.apache.flink.api.common.RuntimeExecutionMode
 import org.apache.flink.api.dag.Transformation
 import org.apache.flink.configuration.ExecutionOptions
+import 
org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.config.OptimizerConfigOptions
 import org.apache.flink.table.catalog.{CatalogManager, FunctionCatalog}
@@ -34,6 +35,7 @@ import 
org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodePlanDumper
 import 
org.apache.flink.table.planner.plan.optimize.{BatchCommonSubGraphBasedOptimizer,
 Optimizer}
 import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil
 import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment
+import 
org.apache.flink.table.runtime.strategy.{AdaptiveBroadcastJoinOptimizationStrategy,
 PostProcessAdaptiveJoinStrategy}
 
 import org.apache.calcite.plan.{ConventionTraitDef, RelTrait, RelTraitDef}
 import org.apache.calcite.rel.RelCollationTraitDef
@@ -100,6 +102,22 @@ class BatchPlanner(
     transformations ++ planner.extraTransformations
   }
 
+  override def afterTranslation(): Unit = {
+    super.afterTranslation()
+    val configuration = getTableConfig
+    val optimizationStrategies = new util.ArrayList[String]()
+    if (
+      
configuration.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY)
+        != OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE
+    ) {
+      
optimizationStrategies.add(classOf[AdaptiveBroadcastJoinOptimizationStrategy].getName)
+      
optimizationStrategies.add(classOf[PostProcessAdaptiveJoinStrategy].getName)
+    }
+    configuration.set(
+      StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY,
+      optimizationStrategies)
+  }
+
   override def explain(
       operations: util.List[Operation],
       format: ExplainFormat,
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/adaptive/AdaptiveBroadcastJoinITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/adaptive/AdaptiveBroadcastJoinITCase.scala
new file mode 100644
index 00000000000..47c95fbc3ff
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/adaptive/AdaptiveBroadcastJoinITCase.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.runtime.batch.sql.adaptive
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, 
STRING_TYPE_INFO}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.table.api.config.{ExecutionConfigOptions, 
OptimizerConfigOptions}
+import org.apache.flink.table.planner.runtime.utils.BatchTestBase
+import org.apache.flink.types.Row
+
+import org.junit.jupiter.api.{BeforeEach, Test}
+
+import scala.collection.JavaConversions._
+import scala.util.Random
+
+/** IT cases for adaptive broadcast join. */
+class AdaptiveBroadcastJoinITCase extends BatchTestBase {
+
+  @BeforeEach
+  override def before(): Unit = {
+    super.before()
+
+    registerCollection(
+      "T",
+      AdaptiveBroadcastJoinITCase.dataT,
+      AdaptiveBroadcastJoinITCase.rowType,
+      "a, b, c, d",
+      AdaptiveBroadcastJoinITCase.nullables)
+    registerCollection(
+      "T1",
+      AdaptiveBroadcastJoinITCase.dataT1,
+      AdaptiveBroadcastJoinITCase.rowType,
+      "a1, b1, c1, d1",
+      AdaptiveBroadcastJoinITCase.nullables)
+    registerCollection(
+      "T2",
+      AdaptiveBroadcastJoinITCase.dataT2,
+      AdaptiveBroadcastJoinITCase.rowType,
+      "a2, b2, c2, d2",
+      AdaptiveBroadcastJoinITCase.nullables)
+    registerCollection(
+      "T3",
+      AdaptiveBroadcastJoinITCase.dataT3,
+      AdaptiveBroadcastJoinITCase.rowType,
+      "a3, b3, c3, d3",
+      AdaptiveBroadcastJoinITCase.nullables)
+  }
+
+  @Test
+  def testWithShuffleHashJoin(): Unit = {
+    tEnv.getConfig
+      .set(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, 
"NestedLoopJoin,SortMergeJoin")
+    testSimpleJoin()
+  }
+
+  @Test
+  def testWithShuffleMergeJoin(): Unit = {
+    tEnv.getConfig
+      .set(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, 
"NestedLoopJoin,ShuffleHashJoin")
+    testSimpleJoin()
+  }
+
+  @Test
+  def testWithBroadcastJoin(): Unit = {
+    tEnv.getConfig.set(
+      ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS,
+      "SortMergeJoin,NestedLoopJoin")
+    tEnv.getConfig.set(
+      OptimizerConfigOptions.TABLE_OPTIMIZER_BROADCAST_JOIN_THRESHOLD,
+      Long.box(Long.MaxValue))
+    testSimpleJoin()
+  }
+
+  @Test
+  def testShuffleJoinWithForwardForConsecutiveHash(): Unit = {
+    tEnv.getConfig.set(
+      OptimizerConfigOptions.TABLE_OPTIMIZER_MULTIPLE_INPUT_ENABLED,
+      Boolean.box(false))
+    val sql =
+      """
+        |WITH
+        |  r AS (SELECT * FROM T1, T2, T3 WHERE a1 = a2 and a1 = a3)
+        |SELECT sum(b1) FROM r group by a1
+        |""".stripMargin
+    checkResult(sql)
+  }
+
+  @Test
+  def testJoinWithUnionInput(): Unit = {
+    val sql =
+      """
+        |SELECT * FROM
+        |  (SELECT a FROM (SELECT a1 as a FROM T1) UNION ALL (SELECT a2 as a 
FROM T2)) Y
+        |  LEFT JOIN T ON T.a = Y.a
+        |""".stripMargin
+    checkResult(sql)
+  }
+
+  @Test
+  def testJoinWithMultipleInput(): Unit = {
+    val sql =
+      """
+        |SELECT * FROM
+        |  (SELECT a FROM T1 JOIN T ON a = a1) t1
+        |  INNER JOIN
+        |  (SELECT d2 FROM T JOIN T2 ON d2 = a) t2
+        |ON t1.a = t2.d2
+        |""".stripMargin
+    checkResult(sql)
+  }
+
+  def testSimpleJoin(): Unit = {
+    // inner join
+    val sql1 = "SELECT * FROM T1, T2 WHERE a1 = a2"
+    checkResult(sql1)
+
+    // left join
+    val sql2 = "SELECT * FROM T1 LEFT JOIN T2 on a1 = a2"
+    checkResult(sql2)
+
+    // right join
+    val sql3 = "SELECT * FROM T1 RIGHT JOIN T2 on a1 = a2"
+    checkResult(sql3)
+
+    // semi join
+    val sql4 = "SELECT * FROM T1 WHERE a1 IN (SELECT a2 FROM T2)"
+    checkResult(sql4)
+
+    // anti join
+    val sql5 = "SELECT * FROM T1 WHERE a1 NOT IN (SELECT a2 FROM T2 where a2 = 
a1)"
+    checkResult(sql5)
+  }
+
+  def checkResult(sql: String): Unit = {
+    tEnv.getConfig
+      .set(
+        
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY,
+        OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE)
+    val expected = executeQuery(sql)
+    tEnv.getConfig
+      .set(
+        
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY,
+        OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.AUTO)
+    checkResult(sql, expected)
+    tEnv.getConfig
+      .set(
+        
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY,
+        OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.RUNTIME_ONLY)
+    checkResult(sql, expected)
+  }
+}
+
+object AdaptiveBroadcastJoinITCase {
+
+  def generateRandomData(): Seq[Row] = {
+    val data = new java.util.ArrayList[Row]()
+    val numRows = Random.nextInt(30)
+    lazy val strs = Seq("adaptive", "join", "itcase")
+    for (x <- 0 until numRows) {
+      data.add(
+        BatchTestBase.row(x.toLong, Random.nextLong(), 
strs(Random.nextInt(3)), Random.nextLong()))
+    }
+    data
+  }
+
+  lazy val rowType =
+    new RowTypeInfo(LONG_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO, 
LONG_TYPE_INFO)
+  lazy val nullables: Array[Boolean] = Array(true, true, true, true)
+
+  lazy val dataT: Seq[Row] = generateRandomData()
+  lazy val dataT1: Seq[Row] = generateRandomData()
+  lazy val dataT2: Seq[Row] = generateRandomData()
+  lazy val dataT3: Seq[Row] = generateRandomData()
+}
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala
index 2997ee062e0..794cb558c74 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeutils.TypeComparator
 import org.apache.flink.api.dag.Transformation
 import org.apache.flink.api.java.typeutils.GenericTypeInfo
 import 
org.apache.flink.streaming.api.transformations.{LegacySinkTransformation, 
OneInputTransformation, TwoInputTransformation}
+import org.apache.flink.table.api.config.OptimizerConfigOptions
 import org.apache.flink.table.api.internal.{StatementSetImpl, 
TableEnvironmentInternal}
 import org.apache.flink.table.planner.delegation.PlannerBase
 import org.apache.flink.table.planner.expressions.utils.FuncWithOpen
@@ -203,6 +204,9 @@ class JoinITCase extends BatchTestBase {
   @TestTemplate
   def testLongHashJoinGenerator(): Unit = {
     if (expectedJoinType == HashJoin) {
+      tEnv.getConfig.set(
+        
OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY,
+        OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE)
       val sink = (new CollectRowTableSink).configure(Array("c"), 
Array(Types.STRING))
       
tEnv.asInstanceOf[TableEnvironmentInternal].registerTableSinkInternal("outputTable",
 sink)
       val stmtSet = tEnv.createStatementSet()
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
index 06391a710de..4da7813a81e 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
@@ -36,7 +36,8 @@ import 
org.apache.flink.streaming.api.legacy.io.CollectionInputFormat
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.bridge.java.{StreamTableEnvironment => 
JavaStreamTableEnv}
 import org.apache.flink.table.api.bridge.scala.{StreamTableEnvironment => 
ScalaStreamTableEnv}
-import org.apache.flink.table.api.config.ExecutionConfigOptions
+import org.apache.flink.table.api.config.{ExecutionConfigOptions, 
OptimizerConfigOptions}
+import 
org.apache.flink.table.api.config.OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE
 import org.apache.flink.table.api.internal.{StatementSetImpl, 
TableEnvironmentImpl, TableEnvironmentInternal, TableImpl}
 import org.apache.flink.table.api.typeutils.CaseClassTypeInfo
 import org.apache.flink.table.catalog._
@@ -1329,6 +1330,9 @@ abstract class TableTestUtil(
   tableEnv.getConfig.set(
     BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_ENABLED,
     Boolean.box(false))
+  tableEnv.getConfig.set(
+    OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY,
+    NONE)
 
   private val env: StreamExecutionEnvironment = getPlanner.getExecEnv
 
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveBroadcastJoinOptimizationStrategy.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveBroadcastJoinOptimizationStrategy.java
new file mode 100644
index 00000000000..1c33ed9fb60
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveBroadcastJoinOptimizationStrategy.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.strategy;
+
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
+import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
+import org.apache.flink.streaming.api.graph.StreamGraphContext;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
+import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** The stream graph optimization strategy of adaptive broadcast join. */
+public class AdaptiveBroadcastJoinOptimizationStrategy
+        extends BaseAdaptiveJoinOperatorOptimizationStrategy {
+    private static final Logger LOG =
+            
LoggerFactory.getLogger(AdaptiveBroadcastJoinOptimizationStrategy.class);
+
+    private Long broadcastThreshold;
+
+    private Map<Integer, Map<Integer, Long>> 
aggregatedInputBytesByTypeNumberAndNodeId;
+
+    @Override
+    public void initialize(StreamGraphContext context) {
+        ReadableConfig config = context.getStreamGraph().getConfiguration();
+        broadcastThreshold =
+                
config.get(OptimizerConfigOptions.TABLE_OPTIMIZER_BROADCAST_JOIN_THRESHOLD);
+        aggregatedInputBytesByTypeNumberAndNodeId = new HashMap<>();
+    }
+
+    @Override
+    public boolean onOperatorsFinished(
+            OperatorsFinished operatorsFinished, StreamGraphContext context) {
+        visitDownstreamAdaptiveJoinNode(operatorsFinished, context);
+
+        return true;
+    }
+
+    @Override
+    protected void tryOptimizeAdaptiveJoin(
+            OperatorsFinished operatorsFinished,
+            StreamGraphContext context,
+            ImmutableStreamNode adaptiveJoinNode,
+            List<ImmutableStreamEdge> upstreamStreamEdges,
+            AdaptiveJoin adaptiveJoin) {
+        for (ImmutableStreamEdge upstreamEdge : upstreamStreamEdges) {
+            IntermediateDataSetID relatedDataSetId =
+                    
context.getConsumedIntermediateDataSetId(upstreamEdge.getEdgeId());
+            long producedBytes =
+                    
operatorsFinished.getResultInfoMap().get(upstreamEdge.getSourceId()).stream()
+                            .filter(
+                                    blockingResultInfo ->
+                                            relatedDataSetId.equals(
+                                                    
blockingResultInfo.getResultId()))
+                            .mapToLong(BlockingResultInfo::getNumBytesProduced)
+                            .sum();
+            aggregatedInputBytesByTypeNumber(
+                    adaptiveJoinNode, upstreamEdge.getTypeNumber(), 
producedBytes);
+        }
+
+        // If all upstream nodes have finished, we attempt to optimize the 
AdaptiveJoin node.
+        if (context.areAllUpstreamNodesFinished(adaptiveJoinNode)) {
+            Long leftInputSize =
+                    
aggregatedInputBytesByTypeNumberAndNodeId.get(adaptiveJoinNode.getId()).get(1);
+            checkState(
+                    leftInputSize != null,
+                    "Left input bytes of adaptive join [%s] is unknown, which 
is unexpected.",
+                    adaptiveJoinNode.getId());
+            Long rightInputSize =
+                    
aggregatedInputBytesByTypeNumberAndNodeId.get(adaptiveJoinNode.getId()).get(2);
+            checkState(
+                    rightInputSize != null,
+                    "Right input bytes of adaptive join [%s] is unknown, which 
is unexpected.",
+                    adaptiveJoinNode.getId());
+
+            boolean leftSizeSmallerThanThreshold = leftInputSize <= 
broadcastThreshold;
+            boolean rightSizeSmallerThanThreshold = rightInputSize <= 
broadcastThreshold;
+            boolean leftSmallerThanRight = leftInputSize < rightInputSize;
+            FlinkJoinType joinType = adaptiveJoin.getJoinType();
+            boolean canBeBroadcast;
+            boolean leftIsBuild;
+            switch (joinType) {
+                case RIGHT:
+                    // For a right outer join, if the left side can be 
broadcast, then the left side
+                    // is
+                    // always the build side; otherwise, the smaller side is 
the build side.
+                    canBeBroadcast = leftSizeSmallerThanThreshold;
+                    leftIsBuild = true;
+                    break;
+                case INNER:
+                    canBeBroadcast = leftSizeSmallerThanThreshold || 
rightSizeSmallerThanThreshold;
+                    leftIsBuild = leftSmallerThanRight;
+                    break;
+                case LEFT:
+                case SEMI:
+                case ANTI:
+                    // For left outer / semi / anti join, if the right side 
can be broadcast, then
+                    // the
+                    // right side is always the build side; otherwise, the 
smaller side is the build
+                    // side.
+                    canBeBroadcast = rightSizeSmallerThanThreshold;
+                    leftIsBuild = false;
+                    break;
+                case FULL:
+                default:
+                    throw new RuntimeException(String.format("Unexpected join 
type %s.", joinType));
+            }
+
+            boolean isBroadcast = false;
+            if (canBeBroadcast) {
+                isBroadcast =
+                        tryModifyStreamEdgesForBroadcastJoin(
+                                adaptiveJoinNode.getInEdges(), context, 
leftIsBuild);
+
+                if (isBroadcast) {
+                    LOG.info(
+                            "The {} input data size of the join node [{}] is 
small enough, "
+                                    + "adaptively convert it to a broadcast 
hash join. Broadcast "
+                                    + "threshold bytes: {}, left input bytes: 
{}, right input bytes: {}.",
+                            leftIsBuild ? "left" : "right",
+                            adaptiveJoinNode.getId(),
+                            broadcastThreshold,
+                            leftInputSize,
+                            rightInputSize);
+                }
+            }
+            adaptiveJoin.markAsBroadcastJoin(
+                    isBroadcast, isBroadcast ? leftIsBuild : 
leftSmallerThanRight);
+
+            
aggregatedInputBytesByTypeNumberAndNodeId.remove(adaptiveJoinNode.getId());
+        }
+    }
+
+    private void aggregatedInputBytesByTypeNumber(
+            ImmutableStreamNode adaptiveJoinNode, int typeNumber, long 
producedBytes) {
+        Integer streamNodeId = adaptiveJoinNode.getId();
+
+        aggregatedInputBytesByTypeNumberAndNodeId
+                .computeIfAbsent(streamNodeId, k -> new HashMap<>())
+                .merge(typeNumber, producedBytes, Long::sum);
+    }
+
+    private List<ImmutableStreamEdge> filterEdges(
+            List<ImmutableStreamEdge> inEdges, int typeNumber) {
+        return inEdges.stream()
+                .filter(e -> e.getTypeNumber() == typeNumber)
+                .collect(Collectors.toList());
+    }
+
+    private List<StreamEdgeUpdateRequestInfo> 
generateStreamEdgeUpdateRequestInfos(
+            List<ImmutableStreamEdge> modifiedEdges, StreamPartitioner<?> 
outputPartitioner) {
+        List<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new 
ArrayList<>();
+        for (ImmutableStreamEdge streamEdge : modifiedEdges) {
+            StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo =
+                    new StreamEdgeUpdateRequestInfo(
+                                    streamEdge.getEdgeId(),
+                                    streamEdge.getSourceId(),
+                                    streamEdge.getTargetId())
+                            .withOutputPartitioner(outputPartitioner);
+            streamEdgeUpdateRequestInfos.add(streamEdgeUpdateRequestInfo);
+        }
+
+        return streamEdgeUpdateRequestInfos;
+    }
+
+    private boolean tryModifyStreamEdgesForBroadcastJoin(
+            List<ImmutableStreamEdge> inEdges, StreamGraphContext context, 
boolean leftIsBuild) {
+        List<StreamEdgeUpdateRequestInfo> modifiedBuildSideEdges =
+                generateStreamEdgeUpdateRequestInfos(
+                        filterEdges(inEdges, leftIsBuild ? 1 : 2), new 
BroadcastPartitioner<>());
+        List<StreamEdgeUpdateRequestInfo> modifiedProbeSideEdges =
+                generateStreamEdgeUpdateRequestInfos(
+                        filterEdges(inEdges, leftIsBuild ? 2 : 1), new 
ForwardPartitioner<>());
+        modifiedBuildSideEdges.addAll(modifiedProbeSideEdges);
+
+        return context.modifyStreamEdge(modifiedBuildSideEdges);
+    }
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/BaseAdaptiveJoinOperatorOptimizationStrategy.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/BaseAdaptiveJoinOperatorOptimizationStrategy.java
new file mode 100644
index 00000000000..847ce5935dd
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/BaseAdaptiveJoinOperatorOptimizationStrategy.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.strategy;
+
+import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
+import 
org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy;
+import org.apache.flink.streaming.api.graph.StreamGraphContext;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
+import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** The base stream graph optimization strategy class for adaptive join 
operator. */
+public abstract class BaseAdaptiveJoinOperatorOptimizationStrategy
+        implements StreamGraphOptimizationStrategy {
+
+    protected void visitDownstreamAdaptiveJoinNode(
+            OperatorsFinished operatorsFinished, StreamGraphContext context) {
+        ImmutableStreamGraph streamGraph = context.getStreamGraph();
+        List<Integer> finishedStreamNodeIds = 
operatorsFinished.getFinishedStreamNodeIds();
+        Map<ImmutableStreamNode, List<ImmutableStreamEdge>> 
joinNodesWithInEdges = new HashMap<>();
+        for (Integer finishedStreamNodeId : finishedStreamNodeIds) {
+            for (ImmutableStreamEdge streamEdge :
+                    
streamGraph.getStreamNode(finishedStreamNodeId).getOutEdges()) {
+                ImmutableStreamNode downstreamNode =
+                        streamGraph.getStreamNode(streamEdge.getTargetId());
+                if (downstreamNode.getOperatorFactory() instanceof 
AdaptiveJoin) {
+                    joinNodesWithInEdges
+                            .computeIfAbsent(downstreamNode, k -> new 
ArrayList<>())
+                            .add(streamEdge);
+                }
+            }
+        }
+        for (ImmutableStreamNode joinNode : joinNodesWithInEdges.keySet()) {
+            tryOptimizeAdaptiveJoin(
+                    operatorsFinished,
+                    context,
+                    joinNode,
+                    joinNodesWithInEdges.get(joinNode),
+                    (AdaptiveJoin) joinNode.getOperatorFactory());
+        }
+    }
+
+    abstract void tryOptimizeAdaptiveJoin(
+            OperatorsFinished operatorsFinished,
+            StreamGraphContext context,
+            ImmutableStreamNode adaptiveJoinNode,
+            List<ImmutableStreamEdge> upstreamStreamEdges,
+            AdaptiveJoin adaptiveJoin);
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/PostProcessAdaptiveJoinStrategy.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/PostProcessAdaptiveJoinStrategy.java
new file mode 100644
index 00000000000..6e898bc9718
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/PostProcessAdaptiveJoinStrategy.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.strategy;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
+import org.apache.flink.streaming.api.graph.StreamGraphContext;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
+import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.api.graph.util.StreamNodeUpdateRequestInfo;
+import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * The post-processing phase of adaptive join optimization, which must be 
placed at the end of all
+ * adaptive join optimization strategies. This is necessary because certain 
operations, like
+ * 'reorder inputs', can influence how adaptive broadcast join or skewed join 
determine the left and
+ * right sides.
+ */
+public class PostProcessAdaptiveJoinStrategy extends 
BaseAdaptiveJoinOperatorOptimizationStrategy {
+
+    @Override
+    public boolean onOperatorsFinished(
+            OperatorsFinished operatorsFinished, StreamGraphContext context) {
+        visitDownstreamAdaptiveJoinNode(operatorsFinished, context);
+
+        return true;
+    }
+
+    @Override
+    protected void tryOptimizeAdaptiveJoin(
+            OperatorsFinished operatorsFinished,
+            StreamGraphContext context,
+            ImmutableStreamNode adaptiveJoinNode,
+            List<ImmutableStreamEdge> upstreamStreamEdges,
+            AdaptiveJoin adaptiveJoin) {
+        if (context.areAllUpstreamNodesFinished(adaptiveJoinNode)) {
+            // For hash join, reorder the join node inputs so the build side 
is read first.
+            if (adaptiveJoin.shouldReorderInputs()) {
+                if (!context.modifyStreamEdge(
+                                
generateStreamEdgeUpdateRequestInfosForInputsReordered(
+                                        adaptiveJoinNode))
+                        || !context.modifyStreamNode(
+                                
generateStreamNodeUpdateRequestInfosForInputsReordered(
+                                        adaptiveJoinNode))) {
+                    throw new RuntimeException(
+                            "Unexpected error occurs while reordering the 
inputs "
+                                    + "of the adaptive join node, potentially 
leading to data inaccuracies. "
+                                    + "Exceptions will be thrown.");
+                }
+            }
+
+            // Generate OperatorFactory for adaptive join operator after 
inputs are reordered.
+            ReadableConfig config = 
context.getStreamGraph().getConfiguration();
+            ClassLoader userClassLoader = 
context.getStreamGraph().getUserClassLoader();
+            adaptiveJoin.genOperatorFactory(userClassLoader, config);
+        }
+    }
+
+    private static List<StreamEdgeUpdateRequestInfo>
+            generateStreamEdgeUpdateRequestInfosForInputsReordered(
+                    ImmutableStreamNode adaptiveJoinNode) {
+        List<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new 
ArrayList<>();
+        for (ImmutableStreamEdge inEdge : adaptiveJoinNode.getInEdges()) {
+            StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo =
+                    new StreamEdgeUpdateRequestInfo(
+                            inEdge.getEdgeId(), inEdge.getSourceId(), 
inEdge.getTargetId());
+            streamEdgeUpdateRequestInfo.withTypeNumber(inEdge.getTypeNumber() 
== 1 ? 2 : 1);
+            streamEdgeUpdateRequestInfos.add(streamEdgeUpdateRequestInfo);
+        }
+        return streamEdgeUpdateRequestInfos;
+    }
+
+    private List<StreamNodeUpdateRequestInfo>
+            generateStreamNodeUpdateRequestInfosForInputsReordered(
+                    ImmutableStreamNode modifiedNode) {
+        List<StreamNodeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new 
ArrayList<>();
+
+        TypeSerializer<?>[] typeSerializers = 
modifiedNode.getTypeSerializersIn();
+        Preconditions.checkState(
+                typeSerializers.length == 2,
+                String.format(
+                        "Adaptive join currently only supports two "
+                                + "inputs, but the join node [%s] has received 
%s inputs.",
+                        modifiedNode.getId(), typeSerializers.length));
+        TypeSerializer<?>[] swappedTypeSerializers = new TypeSerializer<?>[2];
+        swappedTypeSerializers[0] = typeSerializers[1];
+        swappedTypeSerializers[1] = typeSerializers[0];
+        StreamNodeUpdateRequestInfo requestInfo =
+                new StreamNodeUpdateRequestInfo(modifiedNode.getId())
+                        .withTypeSerializersIn(swappedTypeSerializers);
+        streamEdgeUpdateRequestInfos.add(requestInfo);
+
+        return streamEdgeUpdateRequestInfos;
+    }
+}


Reply via email to