This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit df021354008135292c7c04785396cb0f0b0870ac
Author: noorall <863485...@qq.com>
AuthorDate: Thu Dec 26 12:11:07 2024 +0800

    [FLINK-36629][table-planner] Introduce the 
AdaptiveSkewedJoinOptimizationStrategy
---
 .../AdaptiveSkewedJoinOptimizationStrategy.java    | 311 +++++++++++++++++++++
 1 file changed, 311 insertions(+)

diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveSkewedJoinOptimizationStrategy.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveSkewedJoinOptimizationStrategy.java
new file mode 100644
index 00000000000..e9c628e20bd
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/strategy/AdaptiveSkewedJoinOptimizationStrategy.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.strategy;
+
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import 
org.apache.flink.runtime.scheduler.adaptivebatch.AllToAllBlockingResultInfo;
+import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
+import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
+import org.apache.flink.streaming.api.graph.StreamGraphContext;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
+import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import 
org.apache.flink.streaming.runtime.partitioner.ForwardForConsecutiveHashPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static 
org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold;
+import static 
org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.median;
+import static 
org.apache.flink.table.runtime.strategy.AdaptiveJoinOptimizationUtils.filterEdges;
+import static 
org.apache.flink.table.runtime.strategy.AdaptiveJoinOptimizationUtils.isBroadcastJoin;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** The stream graph optimization strategy of adaptive skewed join. */
+public class AdaptiveSkewedJoinOptimizationStrategy
+        extends BaseAdaptiveJoinOperatorOptimizationStrategy {
+    private static final Logger LOG =
+            
LoggerFactory.getLogger(AdaptiveSkewedJoinOptimizationStrategy.class);
+
+    private static final int LEFT_INPUT_TYPE_NUMBER = 1;
+    private static final int RIGHT_INPUT_TYPE_NUMBER = 2;
+
+    private Map<Integer, Map<Integer, long[]>> 
aggregatedProducedBytesByTypeNumberAndNodeId;
+
+    private OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy
+            adaptiveSkewedJoinOptimizationStrategy;
+    private long skewedThresholdInBytes;
+    private double skewedFactor;
+
+    @Override
+    public void initialize(StreamGraphContext context) {
+        ReadableConfig config = context.getStreamGraph().getConfiguration();
+        aggregatedProducedBytesByTypeNumberAndNodeId = new HashMap<>();
+        adaptiveSkewedJoinOptimizationStrategy =
+                config.get(
+                        OptimizerConfigOptions
+                                
.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_STRATEGY);
+        skewedFactor =
+                config.get(
+                        OptimizerConfigOptions
+                                
.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_SKEWED_FACTOR);
+        skewedThresholdInBytes =
+                config.get(
+                                OptimizerConfigOptions
+                                        
.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_SKEWED_THRESHOLD)
+                        .getBytes();
+    }
+
+    @Override
+    public boolean onOperatorsFinished(
+            OperatorsFinished operatorsFinished, StreamGraphContext context) 
throws Exception {
+        visitDownstreamAdaptiveJoinNode(operatorsFinished, context);
+
+        return true;
+    }
+
+    @Override
+    void tryOptimizeAdaptiveJoin(
+            OperatorsFinished operatorsFinished,
+            StreamGraphContext context,
+            ImmutableStreamNode adaptiveJoinNode,
+            List<ImmutableStreamEdge> upstreamStreamEdges,
+            AdaptiveJoin adaptiveJoin) {
+        if (!canPerformOptimization(adaptiveJoinNode)) {
+            return;
+        }
+        for (ImmutableStreamEdge edge : upstreamStreamEdges) {
+            BlockingResultInfo resultInfo = 
getBlockingResultInfo(operatorsFinished, context, edge);
+            checkState(resultInfo instanceof AllToAllBlockingResultInfo);
+            aggregatedProducedBytesByTypeNumber(
+                    adaptiveJoinNode,
+                    edge.getTypeNumber(),
+                    ((AllToAllBlockingResultInfo) 
resultInfo).getAggregatedSubpartitionBytes());
+        }
+        if (context.areAllUpstreamNodesFinished(adaptiveJoinNode)) {
+            applyAdaptiveSkewedJoinOptimization(
+                    context, adaptiveJoinNode, adaptiveJoin.getJoinType());
+            freeNodeStatistic(adaptiveJoinNode.getId());
+        }
+    }
+
+    private boolean canPerformOptimization(ImmutableStreamNode 
adaptiveJoinNode) {
+        // For broadcast joins, especially those generated by
+        // AdaptiveBroadcastJoinOptimizationStrategy, skip perform 
optimization to
+        // avoid unexpected problems.
+        if (isBroadcastJoin(adaptiveJoinNode)) {
+            return false;
+        }
+        if (adaptiveSkewedJoinOptimizationStrategy
+                == 
OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.AUTO) {
+            return !existExactForwardOutEdge(adaptiveJoinNode.getOutEdges())
+                    && 
!existForwardForConsecutiveHashOutEdge(adaptiveJoinNode.getOutEdges());
+        } else if (adaptiveSkewedJoinOptimizationStrategy
+                == 
OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.FORCED) {
+            return !existExactForwardOutEdge(adaptiveJoinNode.getOutEdges());
+        } else {
+            return false;
+        }
+    }
+
+    private static BlockingResultInfo getBlockingResultInfo(
+            OperatorsFinished operatorsFinished,
+            StreamGraphContext context,
+            ImmutableStreamEdge edge) {
+        List<BlockingResultInfo> resultInfos =
+                operatorsFinished.getResultInfoMap().get(edge.getSourceId());
+        IntermediateDataSetID intermediateDataSetId =
+                context.getConsumedIntermediateDataSetId(edge.getEdgeId());
+        for (BlockingResultInfo result : resultInfos) {
+            if (result.getResultId().equals(intermediateDataSetId)) {
+                return result;
+            }
+        }
+        throw new IllegalStateException(
+                "No matching BlockingResultInfo found for edge ID: " + 
edge.getEdgeId());
+    }
+
+    private void aggregatedProducedBytesByTypeNumber(
+            ImmutableStreamNode adaptiveJoinNode, int typeNumber, List<Long> 
subpartitionBytes) {
+        Integer streamNodeId = adaptiveJoinNode.getId();
+        long[] aggregatedSubpartitionBytes =
+                aggregatedProducedBytesByTypeNumberAndNodeId
+                        .computeIfAbsent(streamNodeId, k -> new HashMap<>())
+                        .computeIfAbsent(
+                                typeNumber, (ignore) -> new 
long[subpartitionBytes.size()]);
+        checkState(subpartitionBytes.size() == 
aggregatedSubpartitionBytes.length);
+        for (int i = 0; i < subpartitionBytes.size(); i++) {
+            aggregatedSubpartitionBytes[i] += subpartitionBytes.get(i);
+        }
+    }
+
+    private void applyAdaptiveSkewedJoinOptimization(
+            StreamGraphContext context,
+            ImmutableStreamNode adaptiveJoinNode,
+            FlinkJoinType joinType) {
+        long[] leftInputSize =
+                aggregatedProducedBytesByTypeNumberAndNodeId
+                        .get(adaptiveJoinNode.getId())
+                        .get(LEFT_INPUT_TYPE_NUMBER);
+        checkState(
+                leftInputSize != null,
+                "Left input bytes of adaptive join [%s] is unknown, which is 
unexpected.",
+                adaptiveJoinNode.getId());
+        long[] rightInputSize =
+                aggregatedProducedBytesByTypeNumberAndNodeId
+                        .get(adaptiveJoinNode.getId())
+                        .get(RIGHT_INPUT_TYPE_NUMBER);
+        checkState(
+                rightInputSize != null,
+                "Right input bytes of adaptive join [%s] is unknown, which is 
unexpected.",
+                adaptiveJoinNode.getId());
+
+        long leftSkewedThreshold =
+                computeSkewThreshold(median(leftInputSize), skewedFactor, 
skewedThresholdInBytes);
+        long rightSkewedThreshold =
+                computeSkewThreshold(median(rightInputSize), skewedFactor, 
skewedThresholdInBytes);
+
+        boolean isLeftOptimizable = false;
+        boolean isRightOptimizable = false;
+        switch (joinType) {
+            case RIGHT:
+                isRightOptimizable = true;
+                break;
+            case INNER:
+                isLeftOptimizable = true;
+                isRightOptimizable = true;
+                break;
+            case LEFT:
+            case SEMI:
+            case ANTI:
+                isLeftOptimizable = true;
+                break;
+            case FULL:
+            default:
+                throw new IllegalStateException(
+                        String.format("Unexpected join type %s.", joinType));
+        }
+
+        isLeftOptimizable =
+                isLeftOptimizable
+                        & existBytesLargerThanThreshold(leftInputSize, 
leftSkewedThreshold);
+        isRightOptimizable =
+                isRightOptimizable
+                        & existBytesLargerThanThreshold(rightInputSize, 
rightSkewedThreshold);
+
+        if (isLeftOptimizable) {
+            boolean isModificationSucceed =
+                    tryModifyInputAndOutputEdges(context, adaptiveJoinNode, 
LEFT_INPUT_TYPE_NUMBER);
+            LOG.info(
+                    "Apply skewed join optimization {} for left input of node 
{}.",
+                    isModificationSucceed ? "succeeded" : "failed",
+                    adaptiveJoinNode.getId());
+        }
+        if (isRightOptimizable) {
+            boolean isModificationSucceed =
+                    tryModifyInputAndOutputEdges(
+                            context, adaptiveJoinNode, 
RIGHT_INPUT_TYPE_NUMBER);
+            LOG.info(
+                    "Apply skewed join optimization {} for right input of node 
{}.",
+                    isModificationSucceed ? "succeeded" : "failed",
+                    adaptiveJoinNode.getId());
+        }
+    }
+
+    private static boolean tryModifyInputAndOutputEdges(
+            StreamGraphContext context, ImmutableStreamNode adaptiveJoinNode, 
int typeNumber) {
+        List<StreamEdgeUpdateRequestInfo> modifiedRequests = new ArrayList<>();
+        // Modify the IntraInputKeyCorrelation of all input edges with the 
specified typeNumber to
+        // false.
+        modifiedRequests.addAll(
+                generateCorrelationModificationRequestInfos(
+                        filterEdges(adaptiveJoinNode.getInEdges(), 
typeNumber)));
+        // Modify ForwardForConsecutiveHashPartitioner of the output edges to 
HashPartitioner
+        modifiedRequests.addAll(
+                generateForwardPartitionerModificationRequestInfos(
+                        adaptiveJoinNode.getOutEdges(), context));
+        return context.modifyStreamEdge(modifiedRequests);
+    }
+
+    private static List<StreamEdgeUpdateRequestInfo> 
generateCorrelationModificationRequestInfos(
+            List<ImmutableStreamEdge> streamEdges) {
+        List<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new 
ArrayList<>();
+        for (ImmutableStreamEdge edge : streamEdges) {
+            streamEdgeUpdateRequestInfos.add(
+                    new StreamEdgeUpdateRequestInfo(
+                                    edge.getEdgeId(), edge.getSourceId(), 
edge.getTargetId())
+                            .withIntraInputKeyCorrelated(false));
+        }
+        return streamEdgeUpdateRequestInfos;
+    }
+
+    private static List<StreamEdgeUpdateRequestInfo>
+            generateForwardPartitionerModificationRequestInfos(
+                    List<ImmutableStreamEdge> streamEdges, StreamGraphContext 
context) {
+        List<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new 
ArrayList<>();
+        for (ImmutableStreamEdge edge : streamEdges) {
+            if (edge.isForwardForConsecutiveHashEdge()) {
+                StreamPartitioner<?> partitioner =
+                        checkNotNull(
+                                context.getOutputPartitioner(
+                                        edge.getEdgeId(), edge.getSourceId(), 
edge.getTargetId()));
+                StreamPartitioner<?> newPartitioner =
+                        ((ForwardForConsecutiveHashPartitioner<?>) partitioner)
+                                .getHashPartitioner();
+                streamEdgeUpdateRequestInfos.add(
+                        new StreamEdgeUpdateRequestInfo(
+                                        edge.getEdgeId(), edge.getSourceId(), 
edge.getTargetId())
+                                .withOutputPartitioner(newPartitioner));
+            }
+        }
+        return streamEdgeUpdateRequestInfos;
+    }
+
+    private void freeNodeStatistic(Integer nodeId) {
+        aggregatedProducedBytesByTypeNumberAndNodeId.remove(nodeId);
+    }
+
+    private static boolean existBytesLargerThanThreshold(long[] inputBytes, 
long threshold) {
+        for (long byteSize : inputBytes) {
+            if (byteSize > threshold) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private static boolean existExactForwardOutEdge(List<ImmutableStreamEdge> 
edges) {
+        return 
edges.stream().anyMatch(ImmutableStreamEdge::isExactForwardEdge);
+    }
+
+    private static boolean 
existForwardForConsecutiveHashOutEdge(List<ImmutableStreamEdge> edges) {
+        return 
edges.stream().anyMatch(ImmutableStreamEdge::isForwardForConsecutiveHashEdge);
+    }
+}

Reply via email to