This is an automated email from the ASF dual-hosted git repository. zhuzh pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 80a79bb5f6320a97e4051cba6e22c151dc9603cf Author: sunxia <xingbe...@gmail.com> AuthorDate: Fri Jan 3 16:43:46 2025 +0800 [FLINK-36608][table-runtime] Add AdaptiveBroadcastJoinOptimizationStrategy to support adaptive broadcast join. --- .../generated/optimizer_config_configuration.html | 2 +- .../table/api/config/OptimizerConfigOptions.java | 2 +- .../table/planner/delegation/BatchPlanner.scala | 18 ++ .../sql/adaptive/AdaptiveBroadcastJoinITCase.scala | 189 ++++++++++++++++++ .../runtime/batch/sql/join/JoinITCase.scala | 4 + .../flink/table/planner/utils/TableTestBase.scala | 6 +- .../AdaptiveBroadcastJoinOptimizationStrategy.java | 212 +++++++++++++++++++++ ...seAdaptiveJoinOperatorOptimizationStrategy.java | 71 +++++++ .../strategy/PostProcessAdaptiveJoinStrategy.java | 117 ++++++++++++ 9 files changed, 618 insertions(+), 3 deletions(-) diff --git a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html index 40cf4b735e7..0ab126d602e 100644 --- a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html +++ b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html @@ -10,7 +10,7 @@ <tbody> <tr> <td><h5>table.optimizer.adaptive-broadcast-join.strategy</h5><br> <span class="label label-primary">Batch</span></td> - <td style="word-wrap: break-word;">none</td> + <td style="word-wrap: break-word;">auto</td> <td><p>Enum</p></td> <td>Flink will perform broadcast hash join optimization when the runtime statistics on one side of a join operator is less than the threshold `table.optimizer.join.broadcast-threshold`. The value of this configuration option decides when Flink should perform this optimization. AUTO means Flink will automatically choose the timing for optimization, RUNTIME_ONLY means broadcast hash join optimization is only performed at runtime, and NONE means the optimization is only carried [...] </tr> diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java index 00478b487d4..ccda21813f2 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java @@ -166,7 +166,7 @@ public class OptimizerConfigOptions { TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY = key("table.optimizer.adaptive-broadcast-join.strategy") .enumType(AdaptiveBroadcastJoinStrategy.class) - .defaultValue(AdaptiveBroadcastJoinStrategy.NONE) + .defaultValue(AdaptiveBroadcastJoinStrategy.AUTO) .withDescription( "Flink will perform broadcast hash join optimization when the runtime " + "statistics on one side of a join operator is less than the " diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala index 6e8febadec5..6686de333ee 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.planner.delegation import org.apache.flink.api.common.RuntimeExecutionMode import org.apache.flink.api.dag.Transformation import org.apache.flink.configuration.ExecutionOptions +import org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy import org.apache.flink.table.api._ import org.apache.flink.table.api.config.OptimizerConfigOptions import org.apache.flink.table.catalog.{CatalogManager, FunctionCatalog} @@ -34,6 +35,7 @@ import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodePlanDumper import org.apache.flink.table.planner.plan.optimize.{BatchCommonSubGraphBasedOptimizer, Optimizer} import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment +import org.apache.flink.table.runtime.strategy.{AdaptiveBroadcastJoinOptimizationStrategy, PostProcessAdaptiveJoinStrategy} import org.apache.calcite.plan.{ConventionTraitDef, RelTrait, RelTraitDef} import org.apache.calcite.rel.RelCollationTraitDef @@ -100,6 +102,22 @@ class BatchPlanner( transformations ++ planner.extraTransformations } + override def afterTranslation(): Unit = { + super.afterTranslation() + val configuration = getTableConfig + val optimizationStrategies = new util.ArrayList[String]() + if ( + configuration.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY) + != OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE + ) { + optimizationStrategies.add(classOf[AdaptiveBroadcastJoinOptimizationStrategy].getName) + optimizationStrategies.add(classOf[PostProcessAdaptiveJoinStrategy].getName) + } + configuration.set( + StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY, + optimizationStrategies) + } + override def explain( operations: util.List[Operation], format: ExplainFormat, diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/adaptive/AdaptiveBroadcastJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/adaptive/AdaptiveBroadcastJoinITCase.scala new file mode 100644 index 00000000000..47c95fbc3ff --- /dev/null +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/adaptive/AdaptiveBroadcastJoinITCase.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.runtime.batch.sql.adaptive + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRING_TYPE_INFO} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} +import org.apache.flink.table.planner.runtime.utils.BatchTestBase +import org.apache.flink.types.Row + +import org.junit.jupiter.api.{BeforeEach, Test} + +import scala.collection.JavaConversions._ +import scala.util.Random + +/** IT cases for adaptive broadcast join. */ +class AdaptiveBroadcastJoinITCase extends BatchTestBase { + + @BeforeEach + override def before(): Unit = { + super.before() + + registerCollection( + "T", + AdaptiveBroadcastJoinITCase.dataT, + AdaptiveBroadcastJoinITCase.rowType, + "a, b, c, d", + AdaptiveBroadcastJoinITCase.nullables) + registerCollection( + "T1", + AdaptiveBroadcastJoinITCase.dataT1, + AdaptiveBroadcastJoinITCase.rowType, + "a1, b1, c1, d1", + AdaptiveBroadcastJoinITCase.nullables) + registerCollection( + "T2", + AdaptiveBroadcastJoinITCase.dataT2, + AdaptiveBroadcastJoinITCase.rowType, + "a2, b2, c2, d2", + AdaptiveBroadcastJoinITCase.nullables) + registerCollection( + "T3", + AdaptiveBroadcastJoinITCase.dataT3, + AdaptiveBroadcastJoinITCase.rowType, + "a3, b3, c3, d3", + AdaptiveBroadcastJoinITCase.nullables) + } + + @Test + def testWithShuffleHashJoin(): Unit = { + tEnv.getConfig + .set(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin") + testSimpleJoin() + } + + @Test + def testWithShuffleMergeJoin(): Unit = { + tEnv.getConfig + .set(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,ShuffleHashJoin") + testSimpleJoin() + } + + @Test + def testWithBroadcastJoin(): Unit = { + tEnv.getConfig.set( + ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, + "SortMergeJoin,NestedLoopJoin") + tEnv.getConfig.set( + OptimizerConfigOptions.TABLE_OPTIMIZER_BROADCAST_JOIN_THRESHOLD, + Long.box(Long.MaxValue)) + testSimpleJoin() + } + + @Test + def testShuffleJoinWithForwardForConsecutiveHash(): Unit = { + tEnv.getConfig.set( + OptimizerConfigOptions.TABLE_OPTIMIZER_MULTIPLE_INPUT_ENABLED, + Boolean.box(false)) + val sql = + """ + |WITH + | r AS (SELECT * FROM T1, T2, T3 WHERE a1 = a2 and a1 = a3) + |SELECT sum(b1) FROM r group by a1 + |""".stripMargin + checkResult(sql) + } + + @Test + def testJoinWithUnionInput(): Unit = { + val sql = + """ + |SELECT * FROM + | (SELECT a FROM (SELECT a1 as a FROM T1) UNION ALL (SELECT a2 as a FROM T2)) Y + | LEFT JOIN T ON T.a = Y.a + |""".stripMargin + checkResult(sql) + } + + @Test + def testJoinWithMultipleInput(): Unit = { + val sql = + """ + |SELECT * FROM + | (SELECT a FROM T1 JOIN T ON a = a1) t1 + | INNER JOIN + | (SELECT d2 FROM T JOIN T2 ON d2 = a) t2 + |ON t1.a = t2.d2 + |""".stripMargin + checkResult(sql) + } + + def testSimpleJoin(): Unit = { + // inner join + val sql1 = "SELECT * FROM T1, T2 WHERE a1 = a2" + checkResult(sql1) + + // left join + val sql2 = "SELECT * FROM T1 LEFT JOIN T2 on a1 = a2" + checkResult(sql2) + + // right join + val sql3 = "SELECT * FROM T1 RIGHT JOIN T2 on a1 = a2" + checkResult(sql3) + + // semi join + val sql4 = "SELECT * FROM T1 WHERE a1 IN (SELECT a2 FROM T2)" + checkResult(sql4) + + // anti join + val sql5 = "SELECT * FROM T1 WHERE a1 NOT IN (SELECT a2 FROM T2 where a2 = a1)" + checkResult(sql5) + } + + def checkResult(sql: String): Unit = { + tEnv.getConfig + .set( + OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY, + OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE) + val expected = executeQuery(sql) + tEnv.getConfig + .set( + OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY, + OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.AUTO) + checkResult(sql, expected) + tEnv.getConfig + .set( + OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY, + OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.RUNTIME_ONLY) + checkResult(sql, expected) + } +} + +object AdaptiveBroadcastJoinITCase { + + def generateRandomData(): Seq[Row] = { + val data = new java.util.ArrayList[Row]() + val numRows = Random.nextInt(30) + lazy val strs = Seq("adaptive", "join", "itcase") + for (x <- 0 until numRows) { + data.add( + BatchTestBase.row(x.toLong, Random.nextLong(), strs(Random.nextInt(3)), Random.nextLong())) + } + data + } + + lazy val rowType = + new RowTypeInfo(LONG_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO, LONG_TYPE_INFO) + lazy val nullables: Array[Boolean] = Array(true, true, true, true) + + lazy val dataT: Seq[Row] = generateRandomData() + lazy val dataT1: Seq[Row] = generateRandomData() + lazy val dataT2: Seq[Row] = generateRandomData() + lazy val dataT3: Seq[Row] = generateRandomData() +} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala index 2997ee062e0..794cb558c74 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala @@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeutils.TypeComparator import org.apache.flink.api.dag.Transformation import org.apache.flink.api.java.typeutils.GenericTypeInfo import org.apache.flink.streaming.api.transformations.{LegacySinkTransformation, OneInputTransformation, TwoInputTransformation} +import org.apache.flink.table.api.config.OptimizerConfigOptions import org.apache.flink.table.api.internal.{StatementSetImpl, TableEnvironmentInternal} import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.expressions.utils.FuncWithOpen @@ -203,6 +204,9 @@ class JoinITCase extends BatchTestBase { @TestTemplate def testLongHashJoinGenerator(): Unit = { if (expectedJoinType == HashJoin) { + tEnv.getConfig.set( + OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY, + OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE) val sink = (new CollectRowTableSink).configure(Array("c"), Array(Types.STRING)) tEnv.asInstanceOf[TableEnvironmentInternal].registerTableSinkInternal("outputTable", sink) val stmtSet = tEnv.createStatementSet() diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala index 06391a710de..4da7813a81e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala @@ -36,7 +36,8 @@ import org.apache.flink.streaming.api.legacy.io.CollectionInputFormat import org.apache.flink.table.api._ import org.apache.flink.table.api.bridge.java.{StreamTableEnvironment => JavaStreamTableEnv} import org.apache.flink.table.api.bridge.scala.{StreamTableEnvironment => ScalaStreamTableEnv} -import org.apache.flink.table.api.config.ExecutionConfigOptions +import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} +import org.apache.flink.table.api.config.OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE import org.apache.flink.table.api.internal.{StatementSetImpl, TableEnvironmentImpl, TableEnvironmentInternal, TableImpl} import org.apache.flink.table.api.typeutils.CaseClassTypeInfo import org.apache.flink.table.catalog._ @@ -1329,6 +1330,9 @@ abstract class TableTestUtil( tableEnv.getConfig.set( BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_ENABLED, Boolean.box(false)) + tableEnv.getConfig.set( + OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_BROADCAST_JOIN_STRATEGY, + NONE) private val env: StreamExecutionEnvironment = getPlanner.getExecEnv diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveBroadcastJoinOptimizationStrategy.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveBroadcastJoinOptimizationStrategy.java new file mode 100644 index 00000000000..1c33ed9fb60 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveBroadcastJoinOptimizationStrategy.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.strategy; + +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo; +import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished; +import org.apache.flink.streaming.api.graph.StreamGraphContext; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.runtime.operators.join.FlinkJoinType; +import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** The stream graph optimization strategy of adaptive broadcast join. */ +public class AdaptiveBroadcastJoinOptimizationStrategy + extends BaseAdaptiveJoinOperatorOptimizationStrategy { + private static final Logger LOG = + LoggerFactory.getLogger(AdaptiveBroadcastJoinOptimizationStrategy.class); + + private Long broadcastThreshold; + + private Map<Integer, Map<Integer, Long>> aggregatedInputBytesByTypeNumberAndNodeId; + + @Override + public void initialize(StreamGraphContext context) { + ReadableConfig config = context.getStreamGraph().getConfiguration(); + broadcastThreshold = + config.get(OptimizerConfigOptions.TABLE_OPTIMIZER_BROADCAST_JOIN_THRESHOLD); + aggregatedInputBytesByTypeNumberAndNodeId = new HashMap<>(); + } + + @Override + public boolean onOperatorsFinished( + OperatorsFinished operatorsFinished, StreamGraphContext context) { + visitDownstreamAdaptiveJoinNode(operatorsFinished, context); + + return true; + } + + @Override + protected void tryOptimizeAdaptiveJoin( + OperatorsFinished operatorsFinished, + StreamGraphContext context, + ImmutableStreamNode adaptiveJoinNode, + List<ImmutableStreamEdge> upstreamStreamEdges, + AdaptiveJoin adaptiveJoin) { + for (ImmutableStreamEdge upstreamEdge : upstreamStreamEdges) { + IntermediateDataSetID relatedDataSetId = + context.getConsumedIntermediateDataSetId(upstreamEdge.getEdgeId()); + long producedBytes = + operatorsFinished.getResultInfoMap().get(upstreamEdge.getSourceId()).stream() + .filter( + blockingResultInfo -> + relatedDataSetId.equals( + blockingResultInfo.getResultId())) + .mapToLong(BlockingResultInfo::getNumBytesProduced) + .sum(); + aggregatedInputBytesByTypeNumber( + adaptiveJoinNode, upstreamEdge.getTypeNumber(), producedBytes); + } + + // If all upstream nodes have finished, we attempt to optimize the AdaptiveJoin node. + if (context.areAllUpstreamNodesFinished(adaptiveJoinNode)) { + Long leftInputSize = + aggregatedInputBytesByTypeNumberAndNodeId.get(adaptiveJoinNode.getId()).get(1); + checkState( + leftInputSize != null, + "Left input bytes of adaptive join [%s] is unknown, which is unexpected.", + adaptiveJoinNode.getId()); + Long rightInputSize = + aggregatedInputBytesByTypeNumberAndNodeId.get(adaptiveJoinNode.getId()).get(2); + checkState( + rightInputSize != null, + "Right input bytes of adaptive join [%s] is unknown, which is unexpected.", + adaptiveJoinNode.getId()); + + boolean leftSizeSmallerThanThreshold = leftInputSize <= broadcastThreshold; + boolean rightSizeSmallerThanThreshold = rightInputSize <= broadcastThreshold; + boolean leftSmallerThanRight = leftInputSize < rightInputSize; + FlinkJoinType joinType = adaptiveJoin.getJoinType(); + boolean canBeBroadcast; + boolean leftIsBuild; + switch (joinType) { + case RIGHT: + // For a right outer join, if the left side can be broadcast, then the left side + // is + // always the build side; otherwise, the smaller side is the build side. + canBeBroadcast = leftSizeSmallerThanThreshold; + leftIsBuild = true; + break; + case INNER: + canBeBroadcast = leftSizeSmallerThanThreshold || rightSizeSmallerThanThreshold; + leftIsBuild = leftSmallerThanRight; + break; + case LEFT: + case SEMI: + case ANTI: + // For left outer / semi / anti join, if the right side can be broadcast, then + // the + // right side is always the build side; otherwise, the smaller side is the build + // side. + canBeBroadcast = rightSizeSmallerThanThreshold; + leftIsBuild = false; + break; + case FULL: + default: + throw new RuntimeException(String.format("Unexpected join type %s.", joinType)); + } + + boolean isBroadcast = false; + if (canBeBroadcast) { + isBroadcast = + tryModifyStreamEdgesForBroadcastJoin( + adaptiveJoinNode.getInEdges(), context, leftIsBuild); + + if (isBroadcast) { + LOG.info( + "The {} input data size of the join node [{}] is small enough, " + + "adaptively convert it to a broadcast hash join. Broadcast " + + "threshold bytes: {}, left input bytes: {}, right input bytes: {}.", + leftIsBuild ? "left" : "right", + adaptiveJoinNode.getId(), + broadcastThreshold, + leftInputSize, + rightInputSize); + } + } + adaptiveJoin.markAsBroadcastJoin( + isBroadcast, isBroadcast ? leftIsBuild : leftSmallerThanRight); + + aggregatedInputBytesByTypeNumberAndNodeId.remove(adaptiveJoinNode.getId()); + } + } + + private void aggregatedInputBytesByTypeNumber( + ImmutableStreamNode adaptiveJoinNode, int typeNumber, long producedBytes) { + Integer streamNodeId = adaptiveJoinNode.getId(); + + aggregatedInputBytesByTypeNumberAndNodeId + .computeIfAbsent(streamNodeId, k -> new HashMap<>()) + .merge(typeNumber, producedBytes, Long::sum); + } + + private List<ImmutableStreamEdge> filterEdges( + List<ImmutableStreamEdge> inEdges, int typeNumber) { + return inEdges.stream() + .filter(e -> e.getTypeNumber() == typeNumber) + .collect(Collectors.toList()); + } + + private List<StreamEdgeUpdateRequestInfo> generateStreamEdgeUpdateRequestInfos( + List<ImmutableStreamEdge> modifiedEdges, StreamPartitioner<?> outputPartitioner) { + List<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new ArrayList<>(); + for (ImmutableStreamEdge streamEdge : modifiedEdges) { + StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo = + new StreamEdgeUpdateRequestInfo( + streamEdge.getEdgeId(), + streamEdge.getSourceId(), + streamEdge.getTargetId()) + .withOutputPartitioner(outputPartitioner); + streamEdgeUpdateRequestInfos.add(streamEdgeUpdateRequestInfo); + } + + return streamEdgeUpdateRequestInfos; + } + + private boolean tryModifyStreamEdgesForBroadcastJoin( + List<ImmutableStreamEdge> inEdges, StreamGraphContext context, boolean leftIsBuild) { + List<StreamEdgeUpdateRequestInfo> modifiedBuildSideEdges = + generateStreamEdgeUpdateRequestInfos( + filterEdges(inEdges, leftIsBuild ? 1 : 2), new BroadcastPartitioner<>()); + List<StreamEdgeUpdateRequestInfo> modifiedProbeSideEdges = + generateStreamEdgeUpdateRequestInfos( + filterEdges(inEdges, leftIsBuild ? 2 : 1), new ForwardPartitioner<>()); + modifiedBuildSideEdges.addAll(modifiedProbeSideEdges); + + return context.modifyStreamEdge(modifiedBuildSideEdges); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/BaseAdaptiveJoinOperatorOptimizationStrategy.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/BaseAdaptiveJoinOperatorOptimizationStrategy.java new file mode 100644 index 00000000000..847ce5935dd --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/BaseAdaptiveJoinOperatorOptimizationStrategy.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.strategy; + +import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished; +import org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy; +import org.apache.flink.streaming.api.graph.StreamGraphContext; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode; +import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** The base stream graph optimization strategy class for adaptive join operator. */ +public abstract class BaseAdaptiveJoinOperatorOptimizationStrategy + implements StreamGraphOptimizationStrategy { + + protected void visitDownstreamAdaptiveJoinNode( + OperatorsFinished operatorsFinished, StreamGraphContext context) { + ImmutableStreamGraph streamGraph = context.getStreamGraph(); + List<Integer> finishedStreamNodeIds = operatorsFinished.getFinishedStreamNodeIds(); + Map<ImmutableStreamNode, List<ImmutableStreamEdge>> joinNodesWithInEdges = new HashMap<>(); + for (Integer finishedStreamNodeId : finishedStreamNodeIds) { + for (ImmutableStreamEdge streamEdge : + streamGraph.getStreamNode(finishedStreamNodeId).getOutEdges()) { + ImmutableStreamNode downstreamNode = + streamGraph.getStreamNode(streamEdge.getTargetId()); + if (downstreamNode.getOperatorFactory() instanceof AdaptiveJoin) { + joinNodesWithInEdges + .computeIfAbsent(downstreamNode, k -> new ArrayList<>()) + .add(streamEdge); + } + } + } + for (ImmutableStreamNode joinNode : joinNodesWithInEdges.keySet()) { + tryOptimizeAdaptiveJoin( + operatorsFinished, + context, + joinNode, + joinNodesWithInEdges.get(joinNode), + (AdaptiveJoin) joinNode.getOperatorFactory()); + } + } + + abstract void tryOptimizeAdaptiveJoin( + OperatorsFinished operatorsFinished, + StreamGraphContext context, + ImmutableStreamNode adaptiveJoinNode, + List<ImmutableStreamEdge> upstreamStreamEdges, + AdaptiveJoin adaptiveJoin); +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/PostProcessAdaptiveJoinStrategy.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/PostProcessAdaptiveJoinStrategy.java new file mode 100644 index 00000000000..6e898bc9718 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/PostProcessAdaptiveJoinStrategy.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.strategy; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished; +import org.apache.flink.streaming.api.graph.StreamGraphContext; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.api.graph.util.StreamNodeUpdateRequestInfo; +import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.List; + +/** + * The post-processing phase of adaptive join optimization, which must be placed at the end of all + * adaptive join optimization strategies. This is necessary because certain operations, like + * 'reorder inputs', can influence how adaptive broadcast join or skewed join determine the left and + * right sides. + */ +public class PostProcessAdaptiveJoinStrategy extends BaseAdaptiveJoinOperatorOptimizationStrategy { + + @Override + public boolean onOperatorsFinished( + OperatorsFinished operatorsFinished, StreamGraphContext context) { + visitDownstreamAdaptiveJoinNode(operatorsFinished, context); + + return true; + } + + @Override + protected void tryOptimizeAdaptiveJoin( + OperatorsFinished operatorsFinished, + StreamGraphContext context, + ImmutableStreamNode adaptiveJoinNode, + List<ImmutableStreamEdge> upstreamStreamEdges, + AdaptiveJoin adaptiveJoin) { + if (context.areAllUpstreamNodesFinished(adaptiveJoinNode)) { + // For hash join, reorder the join node inputs so the build side is read first. + if (adaptiveJoin.shouldReorderInputs()) { + if (!context.modifyStreamEdge( + generateStreamEdgeUpdateRequestInfosForInputsReordered( + adaptiveJoinNode)) + || !context.modifyStreamNode( + generateStreamNodeUpdateRequestInfosForInputsReordered( + adaptiveJoinNode))) { + throw new RuntimeException( + "Unexpected error occurs while reordering the inputs " + + "of the adaptive join node, potentially leading to data inaccuracies. " + + "Exceptions will be thrown."); + } + } + + // Generate OperatorFactory for adaptive join operator after inputs are reordered. + ReadableConfig config = context.getStreamGraph().getConfiguration(); + ClassLoader userClassLoader = context.getStreamGraph().getUserClassLoader(); + adaptiveJoin.genOperatorFactory(userClassLoader, config); + } + } + + private static List<StreamEdgeUpdateRequestInfo> + generateStreamEdgeUpdateRequestInfosForInputsReordered( + ImmutableStreamNode adaptiveJoinNode) { + List<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new ArrayList<>(); + for (ImmutableStreamEdge inEdge : adaptiveJoinNode.getInEdges()) { + StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo = + new StreamEdgeUpdateRequestInfo( + inEdge.getEdgeId(), inEdge.getSourceId(), inEdge.getTargetId()); + streamEdgeUpdateRequestInfo.withTypeNumber(inEdge.getTypeNumber() == 1 ? 2 : 1); + streamEdgeUpdateRequestInfos.add(streamEdgeUpdateRequestInfo); + } + return streamEdgeUpdateRequestInfos; + } + + private List<StreamNodeUpdateRequestInfo> + generateStreamNodeUpdateRequestInfosForInputsReordered( + ImmutableStreamNode modifiedNode) { + List<StreamNodeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new ArrayList<>(); + + TypeSerializer<?>[] typeSerializers = modifiedNode.getTypeSerializersIn(); + Preconditions.checkState( + typeSerializers.length == 2, + String.format( + "Adaptive join currently only supports two " + + "inputs, but the join node [%s] has received %s inputs.", + modifiedNode.getId(), typeSerializers.length)); + TypeSerializer<?>[] swappedTypeSerializers = new TypeSerializer<?>[2]; + swappedTypeSerializers[0] = typeSerializers[1]; + swappedTypeSerializers[1] = typeSerializers[0]; + StreamNodeUpdateRequestInfo requestInfo = + new StreamNodeUpdateRequestInfo(modifiedNode.getId()) + .withTypeSerializersIn(swappedTypeSerializers); + streamEdgeUpdateRequestInfos.add(requestInfo); + + return streamEdgeUpdateRequestInfos; + } +}