This is an automated email from the ASF dual-hosted git repository. junrui pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 95f9a1605321a55fc22bb81fa503d2a9eac9d433 Author: JunRuiLee <[email protected]> AuthorDate: Tue Dec 24 20:00:27 2024 +0800 [FLINK-36067][runtime] Support optimize stream graph based on input info. --- .../runtime/executiongraph/IntermediateResult.java | 25 ++++ .../executiongraph/IntermediateResultInfo.java | 16 ++- .../IntermediateResultPartition.java | 2 +- .../VertexInputInfoComputationUtils.java | 32 ++++-- .../runtime/jobgraph/IntermediateDataSet.java | 12 ++ .../adaptivebatch/AbstractBlockingResultInfo.java | 13 ++- .../adaptivebatch/AdaptiveBatchScheduler.java | 68 ++++++++++- .../AdaptiveExecutionHandlerFactory.java | 4 +- .../adaptivebatch/AllToAllBlockingResultInfo.java | 41 ++++++- .../adaptivebatch/BlockingResultInfo.java | 10 ++ .../DefaultAdaptiveExecutionHandler.java | 48 +++++++- ...faultVertexParallelismAndInputInfosDecider.java | 15 ++- .../adaptivebatch/PointwiseBlockingResultInfo.java | 20 +++- .../flink/runtime/shuffle/PartitionDescriptor.java | 2 +- .../streaming/api/graph/AdaptiveGraphManager.java | 7 +- .../api/graph/DefaultStreamGraphContext.java | 33 +++++- .../VertexInputInfoComputationUtilsTest.java | 94 +++++++++++---- ...AdaptiveExecutionPlanSchedulingContextTest.java | 16 ++- .../AllToAllBlockingResultInfoTest.java | 24 ++-- .../DefaultAdaptiveExecutionHandlerTest.java | 128 ++++++++++++++++++++- ...tVertexParallelismAndInputInfosDeciderTest.java | 97 ++++++++++++---- .../api/graph/DefaultStreamGraphContextTest.java | 6 +- .../scheduling/AdaptiveBatchSchedulerITCase.java | 112 ++++++++++++++++++ 23 files changed, 724 insertions(+), 101 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java index f00539b5307..c010057ef4f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java @@ -63,6 +63,7 @@ public class IntermediateResult { private final int numParallelProducers; private final ExecutionPlanSchedulingContext executionPlanSchedulingContext; + private final boolean singleSubpartitionContainsAllData; private int partitionsAssigned; @@ -102,6 +103,8 @@ public class IntermediateResult { this.shuffleDescriptorCache = new HashMap<>(); this.executionPlanSchedulingContext = checkNotNull(executionPlanSchedulingContext); + + this.singleSubpartitionContainsAllData = intermediateDataSet.isBroadcast(); } public boolean areAllConsumerVerticesCreated() { @@ -199,6 +202,16 @@ public class IntermediateResult { return intermediateDataSet.getDistributionPattern(); } + /** + * Determines whether the associated intermediate data set uses a broadcast distribution + * pattern. + * + * <p>A broadcast distribution pattern indicates that all data produced by this intermediate + * data set should be broadcast to every downstream consumer. + * + * @return true if the intermediate data set is using a broadcast distribution pattern; false + * otherwise. + */ public boolean isBroadcast() { return intermediateDataSet.isBroadcast(); } @@ -207,6 +220,18 @@ public class IntermediateResult { return intermediateDataSet.isForward(); } + /** + * Checks if a single subpartition contains all the produced data. This condition indicate that + * the data was intended to be broadcast to all consumers. If the decision to broadcast was made + * before the data production, this flag would likely be set accordingly. Conversely, if the + * broadcasting decision was made post-production, this flag will be false. + * + * @return true if a single subpartition contains all the data; false otherwise. + */ + public boolean isSingleSubpartitionContainsAllData() { + return singleSubpartitionContainsAllData; + } + public int getConnectionIndex() { return connectionIndex; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java index 26829893b5a..2c52d340afc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java @@ -29,9 +29,21 @@ public interface IntermediateResultInfo { IntermediateDataSetID getResultId(); /** - * Whether it is a broadcast result. + * Checks whether there is a single subpartition that contains all the produced data. * - * @return whether it is a broadcast result + * @return true if one subpartition that contains all the data; false otherwise. + */ + boolean isSingleSubpartitionContainsAllData(); + + /** + * Determines whether the associated intermediate data set uses a broadcast distribution + * pattern. + * + * <p>A broadcast distribution pattern indicates that all data produced by this intermediate + * data set should be broadcast to every downstream consumer. + * + * @return true if the intermediate data set is using a broadcast distribution pattern; false + * otherwise. */ boolean isBroadcast(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java index 19ce7753e3f..6962e6641a7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java @@ -151,7 +151,7 @@ public class IntermediateResultPartition { } private int computeNumberOfSubpartitionsForDynamicGraph() { - if (totalResult.isBroadcast() || totalResult.isForward()) { + if (totalResult.isSingleSubpartitionContainsAllData() || totalResult.isForward()) { // for dynamic graph and broadcast result, and forward result, we only produced one // subpartition, and all the downstream vertices should consume this subpartition. return 1; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java index 680a0bb1634..3c8dfc50e9b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java @@ -84,7 +84,8 @@ public class VertexInputInfoComputationUtils { parallelism, input::getNumSubpartitions, isDynamicGraph, - input.isBroadcast())); + input.isBroadcast(), + input.isSingleSubpartitionContainsAllData())); } } @@ -124,6 +125,7 @@ public class VertexInputInfoComputationUtils { 1, () -> numOfSubpartitionsRetriever.apply(start), isDynamicGraph, + false, false); executionVertexInputInfos.add( new ExecutionVertexInputInfo(index, partitionRange, subpartitionRange)); @@ -145,6 +147,7 @@ public class VertexInputInfoComputationUtils { numConsumers, () -> numOfSubpartitionsRetriever.apply(finalPartitionNum), isDynamicGraph, + false, false); executionVertexInputInfos.add( new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange)); @@ -165,6 +168,7 @@ public class VertexInputInfoComputationUtils { * @param numOfSubpartitionsRetriever a retriever to get the number of subpartitions * @param isDynamicGraph whether is dynamic graph * @param isBroadcast whether the edge is broadcast + * @param isSingleSubpartitionContainsAllData whether single subpartition contains all data * @return the computed {@link JobVertexInputInfo} */ static JobVertexInputInfo computeVertexInputInfoForAllToAll( @@ -172,7 +176,8 @@ public class VertexInputInfoComputationUtils { int targetCount, Function<Integer, Integer> numOfSubpartitionsRetriever, boolean isDynamicGraph, - boolean isBroadcast) { + boolean isBroadcast, + boolean isSingleSubpartitionContainsAllData) { final List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>(); IndexRange partitionRange = new IndexRange(0, sourceCount - 1); for (int i = 0; i < targetCount; ++i) { @@ -182,7 +187,8 @@ public class VertexInputInfoComputationUtils { targetCount, () -> numOfSubpartitionsRetriever.apply(0), isDynamicGraph, - isBroadcast); + isBroadcast, + isSingleSubpartitionContainsAllData); executionVertexInputInfos.add( new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange)); } @@ -199,6 +205,7 @@ public class VertexInputInfoComputationUtils { * @param numOfSubpartitionsSupplier a supplier to get the number of subpartitions * @param isDynamicGraph whether is dynamic graph * @param isBroadcast whether the edge is broadcast + * @param isSingleSubpartitionContainsAllData whether single subpartition contains all data * @return the computed subpartition range */ @VisibleForTesting @@ -207,16 +214,22 @@ public class VertexInputInfoComputationUtils { int numConsumers, Supplier<Integer> numOfSubpartitionsSupplier, boolean isDynamicGraph, - boolean isBroadcast) { + boolean isBroadcast, + boolean isSingleSubpartitionContainsAllData) { int consumerIndex = consumerSubtaskIndex % numConsumers; if (!isDynamicGraph) { return new IndexRange(consumerIndex, consumerIndex); } else { int numSubpartitions = numOfSubpartitionsSupplier.get(); if (isBroadcast) { - // broadcast results have only one subpartition, and be consumed multiple times. - checkArgument(numSubpartitions == 1); - return new IndexRange(0, 0); + if (isSingleSubpartitionContainsAllData) { + // early decided broadcast results have only one subpartition, and be consumed + // multiple times. + checkArgument(numSubpartitions == 1); + return new IndexRange(0, 0); + } else { + return new IndexRange(0, numSubpartitions - 1); + } } else { checkArgument(consumerIndex < numConsumers); checkArgument(numConsumers <= numSubpartitions); @@ -246,6 +259,11 @@ public class VertexInputInfoComputationUtils { return intermediateResult.isBroadcast(); } + @Override + public boolean isSingleSubpartitionContainsAllData() { + return intermediateResult.isSingleSubpartitionContainsAllData(); + } + @Override public boolean isPointwise() { return intermediateResult.getConsumingDistributionPattern() diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java index c5d1187d230..ec73f25e283 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java @@ -134,6 +134,18 @@ public class IntermediateDataSet implements java.io.Serializable { } } + public void updateOutputPattern( + DistributionPattern distributionPattern, boolean isBroadcast, boolean isForward) { + checkState(consumers.isEmpty(), "The output job edges have already been added."); + checkState( + numJobEdgesToCreate == 1, + "Modification is not allowed when the subscribing output is reused."); + + this.distributionPattern = distributionPattern; + this.isBroadcast = isBroadcast; + this.isForward = isForward; + } + public void increaseNumJobEdgesToCreate() { this.numJobEdgesToCreate++; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java index 33147bcdc16..3844c50d84a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java @@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -44,11 +45,14 @@ abstract class AbstractBlockingResultInfo implements BlockingResultInfo { protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex; AbstractBlockingResultInfo( - IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) { + IntermediateDataSetID resultId, + int numOfPartitions, + int numOfSubpartitions, + Map<Integer, long[]> subpartitionBytesByPartitionIndex) { this.resultId = checkNotNull(resultId); this.numOfPartitions = numOfPartitions; this.numOfSubpartitions = numOfSubpartitions; - this.subpartitionBytesByPartitionIndex = new HashMap<>(); + this.subpartitionBytesByPartitionIndex = new HashMap<>(subpartitionBytesByPartitionIndex); } @Override @@ -72,4 +76,9 @@ abstract class AbstractBlockingResultInfo implements BlockingResultInfo { int getNumOfRecordedPartitions() { return subpartitionBytesByPartitionIndex.size(); } + + @Override + public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() { + return Collections.unmodifiableMap(subpartitionBytesByPartitionIndex); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java index a46a210446f..966c6eea998 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java @@ -274,9 +274,14 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph // 4. update json plan getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph())); - // 5. try aggregate subpartition bytes + // 5. In broadcast join optimization, results might be written first with a hash + // method and then read with a broadcast method. Therefore, we need to update the + // result info: + // 1. Update the DistributionPattern to reflect the optimized data distribution. + // 2. Aggregate subpartition bytes when possible for efficiency. for (JobVertex newVertex : newVertices) { for (JobEdge input : newVertex.getInputs()) { + tryUpdateResultInfo(input.getSourceId(), input.getDistributionPattern()); Optional.ofNullable(blockingResultInfos.get(input.getSourceId())) .ifPresent(this::maybeAggregateSubpartitionBytes); } @@ -490,7 +495,8 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph result.getId(), (ignored, resultInfo) -> { if (resultInfo == null) { - resultInfo = createFromIntermediateResult(result); + resultInfo = + createFromIntermediateResult(result, new HashMap<>()); } resultInfo.recordPartitionInfo( partitionId.getPartitionNumber(), partitionBytes); @@ -500,6 +506,16 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph }); } + /** + * Aggregates subpartition bytes if all conditions are met. This method checks whether the + * result info instance is of type {@link AllToAllBlockingResultInfo}, whether all consumer + * vertices are created, and whether all consumer vertices are initialized. If these conditions + * are satisfied, the fine-grained statistic info will not be required by consumer vertices, and + * then we could aggregate the subpartition bytes. + * + * @param resultInfo the BlockingResultInfo instance to potentially aggregate subpartition bytes + * for. + */ private void maybeAggregateSubpartitionBytes(BlockingResultInfo resultInfo) { IntermediateResult intermediateResult = getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId()); @@ -937,7 +953,8 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph } } - private static BlockingResultInfo createFromIntermediateResult(IntermediateResult result) { + private static BlockingResultInfo createFromIntermediateResult( + IntermediateResult result, Map<Integer, long[]> subpartitionBytesByPartitionIndex) { checkArgument(result != null); // Note that for dynamic graph, different partitions in the same result have the same number // of subpartitions. @@ -945,13 +962,15 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph return new PointwiseBlockingResultInfo( result.getId(), result.getNumberOfAssignedPartitions(), - result.getPartitions()[0].getNumberOfSubpartitions()); + result.getPartitions()[0].getNumberOfSubpartitions(), + subpartitionBytesByPartitionIndex); } else { return new AllToAllBlockingResultInfo( result.getId(), result.getNumberOfAssignedPartitions(), result.getPartitions()[0].getNumberOfSubpartitions(), - result.isBroadcast()); + result.isSingleSubpartitionContainsAllData(), + subpartitionBytesByPartitionIndex); } } @@ -965,6 +984,45 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph return speculativeExecutionHandler; } + /** + * Tries to update the result information for a given IntermediateDataSetID according to the + * specified DistributionPattern. This ensures consistency between the distribution pattern and + * the stored result information. + * + * <p>The result information is updated under the following conditions: + * + * <ul> + * <li>If the target pattern is ALL_TO_ALL and the current result info is POINTWISE, a new + * BlockingResultInfo is created and stored. + * <li>If the target pattern is POINTWISE and the current result info is ALL_TO_ALL, a + * conversion is similarly triggered. + * <li>Additionally, for ALL_TO_ALL patterns, the status of broadcast of the result info + * should be updated. + * </ul> + * + * @param id The ID of the intermediate dataset to update. + * @param targetPattern The target distribution pattern to apply. + */ + private void tryUpdateResultInfo(IntermediateDataSetID id, DistributionPattern targetPattern) { + if (blockingResultInfos.containsKey(id)) { + BlockingResultInfo resultInfo = blockingResultInfos.get(id); + IntermediateResult result = getExecutionGraph().getAllIntermediateResults().get(id); + + if ((targetPattern == DistributionPattern.ALL_TO_ALL && resultInfo.isPointwise()) + || (targetPattern == DistributionPattern.POINTWISE + && !resultInfo.isPointwise())) { + + BlockingResultInfo newInfo = + createFromIntermediateResult( + result, resultInfo.getSubpartitionBytesByPartitionIndex()); + + blockingResultInfos.put(id, newInfo); + } else if (resultInfo instanceof AllToAllBlockingResultInfo) { + ((AllToAllBlockingResultInfo) resultInfo).setBroadcast(result.isBroadcast()); + } + } + } + private class DefaultBatchJobRecoveryContext implements BatchJobRecoveryContext { private final FailoverStrategy restartStrategyOnResultConsumable = diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java index b6113012f00..2d7be76c729 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java @@ -21,6 +21,7 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.streaming.api.graph.ExecutionPlan; import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.util.DynamicCodeLoadingException; import java.util.concurrent.Executor; @@ -46,7 +47,8 @@ public class AdaptiveExecutionHandlerFactory { public static AdaptiveExecutionHandler create( ExecutionPlan executionPlan, ClassLoader userClassLoader, - Executor serializationExecutor) { + Executor serializationExecutor) + throws DynamicCodeLoadingException { if (executionPlan instanceof JobGraph) { return new NonAdaptiveExecutionHandler((JobGraph) executionPlan); } else { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java index ed1e945912f..b7b4b8ef1cd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; @@ -26,7 +27,9 @@ import javax.annotation.Nullable; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; @@ -35,29 +38,57 @@ import static org.apache.flink.util.Preconditions.checkState; /** Information of All-To-All result. */ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { - private final boolean isBroadcast; + private final boolean singleSubpartitionContainsAllData; + + private boolean isBroadcast; /** * Aggregated subpartition bytes, which aggregates the subpartition bytes with the same * subpartition index in different partitions. Note that We can aggregate them because they will * be consumed by the same downstream task. */ - @Nullable private List<Long> aggregatedSubpartitionBytes; + @Nullable protected List<Long> aggregatedSubpartitionBytes; + @VisibleForTesting AllToAllBlockingResultInfo( IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions, - boolean isBroadcast) { - super(resultId, numOfPartitions, numOfSubpartitions); + boolean isBroadcast, + boolean singleSubpartitionContainsAllData) { + this( + resultId, + numOfPartitions, + numOfSubpartitions, + singleSubpartitionContainsAllData, + new HashMap<>()); this.isBroadcast = isBroadcast; } + AllToAllBlockingResultInfo( + IntermediateDataSetID resultId, + int numOfPartitions, + int numOfSubpartitions, + boolean singleSubpartitionContainsAllData, + Map<Integer, long[]> subpartitionBytesByPartitionIndex) { + super(resultId, numOfPartitions, numOfSubpartitions, subpartitionBytesByPartitionIndex); + this.singleSubpartitionContainsAllData = singleSubpartitionContainsAllData; + } + @Override public boolean isBroadcast() { return isBroadcast; } + @Override + public boolean isSingleSubpartitionContainsAllData() { + return singleSubpartitionContainsAllData; + } + + void setBroadcast(boolean isBroadcast) { + this.isBroadcast = isBroadcast; + } + @Override public boolean isPointwise() { return false; @@ -83,7 +114,7 @@ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { List<Long> bytes = Optional.ofNullable(aggregatedSubpartitionBytes) .orElse(getAggregatedSubpartitionBytesInternal()); - if (isBroadcast) { + if (singleSubpartitionContainsAllData) { return bytes.get(0); } else { return bytes.stream().reduce(0L, Long::sum); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java index e836d993869..0669417cc95 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java @@ -22,6 +22,8 @@ import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.executiongraph.IntermediateResultInfo; import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import java.util.Map; + /** * The blocking result info, which will be used to calculate the vertex parallelism and input infos. */ @@ -64,4 +66,12 @@ public interface BlockingResultInfo extends IntermediateResultInfo { * @param partitionIndex the intermediate result partition index */ void resetPartitionInfo(int partitionIndex); + + /** + * Gets subpartition bytes by partition index. + * + * @return a map with integer keys representing partition indices and long array values + * representing subpartition bytes. + */ + Map<Integer, long[]> getSubpartitionBytesByPartitionIndex(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java index b365db8d0e0..c0942f65372 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -27,13 +28,17 @@ import org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexFinishedEvent; import org.apache.flink.runtime.jobmaster.event.JobEvent; import org.apache.flink.streaming.api.graph.AdaptiveGraphManager; import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.util.DynamicCodeLoadingException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.Executor; +import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkArgument; @@ -52,10 +57,16 @@ public class DefaultAdaptiveExecutionHandler implements AdaptiveExecutionHandler private final AdaptiveGraphManager adaptiveGraphManager; + private final StreamGraphOptimizer streamGraphOptimizer; + public DefaultAdaptiveExecutionHandler( - ClassLoader userClassloader, StreamGraph streamGraph, Executor serializationExecutor) { + ClassLoader userClassloader, StreamGraph streamGraph, Executor serializationExecutor) + throws DynamicCodeLoadingException { this.adaptiveGraphManager = new AdaptiveGraphManager(userClassloader, streamGraph, serializationExecutor); + + this.streamGraphOptimizer = + new StreamGraphOptimizer(streamGraph.getJobConfiguration(), userClassloader); } @Override @@ -66,6 +77,7 @@ public class DefaultAdaptiveExecutionHandler implements AdaptiveExecutionHandler @Override public void handleJobEvent(JobEvent jobEvent) { try { + tryOptimizeStreamGraph(jobEvent); tryUpdateJobGraph(jobEvent); } catch (Exception e) { log.error("Failed to handle job event {}.", jobEvent, e); @@ -73,6 +85,40 @@ public class DefaultAdaptiveExecutionHandler implements AdaptiveExecutionHandler } } + private void tryOptimizeStreamGraph(JobEvent jobEvent) throws Exception { + if (jobEvent instanceof ExecutionJobVertexFinishedEvent) { + ExecutionJobVertexFinishedEvent event = (ExecutionJobVertexFinishedEvent) jobEvent; + + JobVertexID vertexId = event.getVertexId(); + Map<IntermediateDataSetID, BlockingResultInfo> resultInfo = event.getResultInfo(); + Map<Integer, List<BlockingResultInfo>> resultInfoMap = + resultInfo.entrySet().stream() + .collect( + Collectors.toMap( + entry -> + adaptiveGraphManager.getProducerStreamNodeId( + entry.getKey()), + entry -> + new ArrayList<>( + Collections.singletonList( + entry.getValue())), + (existing, replacement) -> { + existing.addAll(replacement); + return existing; + })); + + OperatorsFinished operatorsFinished = + new OperatorsFinished( + adaptiveGraphManager.getStreamNodeIdsByJobVertexId(vertexId), + resultInfoMap); + + streamGraphOptimizer.onOperatorsFinished( + operatorsFinished, adaptiveGraphManager.getStreamGraphContext()); + } else { + throw new IllegalArgumentException("Unsupported job event " + jobEvent); + } + } + private void tryUpdateJobGraph(JobEvent jobEvent) throws Exception { if (jobEvent instanceof ExecutionJobVertexFinishedEvent) { ExecutionJobVertexFinishedEvent event = (ExecutionJobVertexFinishedEvent) jobEvent; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java index eb78b9cd7a3..bcdaf0b5176 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java @@ -200,7 +200,8 @@ public class DefaultVertexParallelismAndInputInfosDecider } private static boolean areAllInputsBroadcast(List<BlockingResultInfo> consumedResults) { - return consumedResults.stream().allMatch(BlockingResultInfo::isBroadcast); + return consumedResults.stream() + .allMatch(BlockingResultInfo::isSingleSubpartitionContainsAllData); } /** @@ -468,7 +469,15 @@ public class DefaultVertexParallelismAndInputInfosDecider for (int i = 0; i < subpartitionRanges.size(); ++i) { IndexRange subpartitionRange; if (resultInfo.isBroadcast()) { - subpartitionRange = new IndexRange(0, 0); + if (resultInfo.isSingleSubpartitionContainsAllData()) { + subpartitionRange = new IndexRange(0, 0); + } else { + // The partitions of the all-to-all result have the same number of + // subpartitions. So we can use the first partition's subpartition + // number. + subpartitionRange = + new IndexRange(0, resultInfo.getNumSubpartitions(0) - 1); + } } else { subpartitionRange = subpartitionRanges.get(i); } @@ -546,7 +555,7 @@ public class DefaultVertexParallelismAndInputInfosDecider private static List<BlockingResultInfo> getNonBroadcastResultInfos( List<BlockingResultInfo> consumedResults) { return consumedResults.stream() - .filter(resultInfo -> !resultInfo.isBroadcast()) + .filter(resultInfo -> !resultInfo.isSingleSubpartitionContainsAllData()) .collect(Collectors.toList()); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java index ed993af9d81..9a3f3b418aa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java @@ -18,18 +18,31 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import static org.apache.flink.util.Preconditions.checkState; /** Information of Pointwise result. */ public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo { + + @VisibleForTesting PointwiseBlockingResultInfo( IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) { - super(resultId, numOfPartitions, numOfSubpartitions); + this(resultId, numOfPartitions, numOfSubpartitions, new HashMap<>()); + } + + PointwiseBlockingResultInfo( + IntermediateDataSetID resultId, + int numOfPartitions, + int numOfSubpartitions, + Map<Integer, long[]> subpartitionBytesByPartitionIndex) { + super(resultId, numOfPartitions, numOfSubpartitions, subpartitionBytesByPartitionIndex); } @Override @@ -37,6 +50,11 @@ public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo { return false; } + @Override + public boolean isSingleSubpartitionContainsAllData() { + return false; + } + @Override public boolean isPointwise() { return true; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java index c3277932ae2..40296d27cb6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java @@ -151,7 +151,7 @@ public class PartitionDescriptor implements Serializable { result.getResultType(), partition.getNumberOfSubpartitions(), result.getConnectionIndex(), - result.isBroadcast(), + result.isSingleSubpartitionContainsAllData(), result.getConsumingDistributionPattern() == DistributionPattern.ALL_TO_ALL, partition.isNumberOfPartitionConsumersUndefined()); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java index bec248898b1..f683f586d4e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java @@ -107,6 +107,9 @@ public class AdaptiveGraphManager implements AdaptiveGraphGenerator { private final Map<IntermediateDataSetID, List<StreamEdge>> intermediateDataSetIdToOutputEdgesMap; + private final Map<String, IntermediateDataSet> consumerEdgeIdToIntermediateDataSetMap = + new HashMap<>(); + // Records the ids of stream nodes in the StreamNodeForwardGroup. // When stream edge's partitioner is modified to forward, we need get forward groups by source // and target node id. @@ -167,7 +170,8 @@ public class AdaptiveGraphManager implements AdaptiveGraphGenerator { streamGraph, steamNodeIdToForwardGroupMap, frozenNodeToStartNodeMap, - intermediateOutputsCaches); + intermediateOutputsCaches, + consumerEdgeIdToIntermediateDataSetMap); this.jobGraph = createAndInitializeJobGraph(streamGraph, streamGraph.getJobID()); @@ -382,6 +386,7 @@ public class AdaptiveGraphManager implements AdaptiveGraphGenerator { intermediateDataSetIdToOutputEdgesMap .computeIfAbsent(dataSet.getId(), ignored -> new ArrayList<>()) .add(edge); + consumerEdgeIdToIntermediateDataSetMap.put(edge.getEdgeId(), dataSet); // we cache the output here for downstream vertex to create jobEdge. intermediateOutputsCaches .computeIfAbsent(edge.getSourceId(), k -> new HashMap<>()) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java index 07d8631bf92..324f3466e7c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java @@ -19,6 +19,8 @@ package org.apache.flink.streaming.api.graph; import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSet; import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; @@ -36,6 +38,7 @@ import javax.annotation.Nullable; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -69,16 +72,21 @@ public class DefaultStreamGraphContext implements StreamGraphContext { // as they reuse some attributes. private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + private final Map<String, IntermediateDataSet> consumerEdgeIdToIntermediateDataSetMap; + public DefaultStreamGraphContext( StreamGraph streamGraph, Map<Integer, StreamNodeForwardGroup> steamNodeIdToForwardGroupMap, Map<Integer, Integer> frozenNodeToStartNodeMap, - Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches, + Map<String, IntermediateDataSet> consumerEdgeIdToIntermediateDataSetMap) { this.streamGraph = checkNotNull(streamGraph); this.steamNodeIdToForwardGroupMap = checkNotNull(steamNodeIdToForwardGroupMap); this.frozenNodeToStartNodeMap = checkNotNull(frozenNodeToStartNodeMap); this.opIntermediateOutputsCaches = checkNotNull(opIntermediateOutputsCaches); this.immutableStreamGraph = new ImmutableStreamGraph(this.streamGraph); + this.consumerEdgeIdToIntermediateDataSetMap = + checkNotNull(consumerEdgeIdToIntermediateDataSetMap); } @Override @@ -188,9 +196,9 @@ public class DefaultStreamGraphContext implements StreamGraphContext { tryConvertForwardPartitionerAndMergeForwardGroup(targetEdge); } - // The partitioner in NonChainedOutput derived from the consumer edge, so we need to ensure - // that any modifications to the partitioner of consumer edge are synchronized with - // NonChainedOutput. + // The partitioner in NonChainedOutput and IntermediateDataSet derived from the consumer + // edge, so we need to ensure that any modifications to the partitioner of consumer edge are + // synchronized with NonChainedOutput and IntermediateDataSet. Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = opIntermediateOutputsCaches.get(targetEdge.getSourceId()); NonChainedOutput output = @@ -198,6 +206,23 @@ public class DefaultStreamGraphContext implements StreamGraphContext { if (output != null) { output.setPartitioner(targetEdge.getPartitioner()); } + + Optional.ofNullable(consumerEdgeIdToIntermediateDataSetMap.get(targetEdge.getEdgeId())) + .ifPresent( + dataSet -> { + DistributionPattern distributionPattern = + targetEdge.getPartitioner().isPointwise() + ? DistributionPattern.POINTWISE + : DistributionPattern.ALL_TO_ALL; + dataSet.updateOutputPattern( + distributionPattern, + targetEdge.getPartitioner().isBroadcast(), + targetEdge + .getPartitioner() + .getClass() + .equals(ForwardPartitioner.class)); + }); + LOG.info( "The original partitioner of the edge {} is: {} , requested change to: {} , and finally modified to: {}.", targetEdge, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java index e0f4d6e2fad..72de899aaf5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java @@ -19,10 +19,13 @@ package org.apache.flink.runtime.executiongraph; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import static org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForAllToAll; import static org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Test for {@link VertexInputInfoComputationUtils}. */ class VertexInputInfoComputationUtilsTest { @@ -57,34 +60,55 @@ class VertexInputInfoComputationUtilsTest { assertThat(range4).isEqualTo(new IndexRange(4, 5)); } - @Test - void testComputeBroadcastConsumedSubpartitionRange() { - final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 1, true, true); - assertThat(range1).isEqualTo(new IndexRange(0, 0)); - - final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 1, true, true); - assertThat(range2).isEqualTo(new IndexRange(0, 0)); - - final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 1, true, true); - assertThat(range3).isEqualTo(new IndexRange(0, 0)); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testComputeBroadcastConsumedSubpartitionRange(boolean singleSubpartitionContainsAllData) { + int numSubpartitions = singleSubpartitionContainsAllData ? 1 : 3; + final IndexRange range1 = + computeConsumedSubpartitionRange( + 0, 3, numSubpartitions, true, true, singleSubpartitionContainsAllData); + assertThat(range1).isEqualTo(new IndexRange(0, numSubpartitions - 1)); + + final IndexRange range2 = + computeConsumedSubpartitionRange( + 1, 3, numSubpartitions, true, true, singleSubpartitionContainsAllData); + assertThat(range2).isEqualTo(new IndexRange(0, numSubpartitions - 1)); + + final IndexRange range3 = + computeConsumedSubpartitionRange( + 2, 3, numSubpartitions, true, true, singleSubpartitionContainsAllData); + assertThat(range3).isEqualTo(new IndexRange(0, numSubpartitions - 1)); + + if (singleSubpartitionContainsAllData) { + assertThatThrownBy( + () -> + computeConsumedSubpartitionRange( + 2, + 3, + numSubpartitions + 1, + true, + true, + singleSubpartitionContainsAllData)) + .isInstanceOf(IllegalArgumentException.class); + } } @Test void testComputeConsumedSubpartitionRangeForNonDynamicGraph() { - final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, -1, false, false); + final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, -1, false, false, false); assertThat(range1).isEqualTo(new IndexRange(0, 0)); - final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, -1, false, false); + final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, -1, false, false, false); assertThat(range2).isEqualTo(new IndexRange(1, 1)); - final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, -1, false, false); + final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, -1, false, false, false); assertThat(range3).isEqualTo(new IndexRange(2, 2)); } @Test void testComputeVertexInputInfoForAllToAllWithNonDynamicGraph() { final JobVertexInputInfo nonBroadcast = - computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, false); + computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, false, false); assertThat(nonBroadcast.getExecutionVertexInputInfos()) .containsExactlyInAnyOrder( new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), @@ -93,7 +117,7 @@ class VertexInputInfoComputationUtilsTest { 2, new IndexRange(0, 1), new IndexRange(2, 2))); final JobVertexInputInfo broadcast = - computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, true); + computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, true, false); assertThat(broadcast.getExecutionVertexInputInfos()) .containsExactlyInAnyOrder( new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), @@ -102,10 +126,13 @@ class VertexInputInfoComputationUtilsTest { 2, new IndexRange(0, 1), new IndexRange(2, 2))); } - @Test - void testComputeVertexInputInfoForAllToAllWithDynamicGraph() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testComputeVertexInputInfoForAllToAllWithDynamicGraph( + boolean singleSubpartitionContainsAllData) { final JobVertexInputInfo nonBroadcast = - computeVertexInputInfoForAllToAll(2, 3, ignored -> 10, true, false); + computeVertexInputInfoForAllToAll( + 2, 3, ignored -> 10, true, false, singleSubpartitionContainsAllData); assertThat(nonBroadcast.getExecutionVertexInputInfos()) .containsExactlyInAnyOrder( new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 2)), @@ -114,13 +141,30 @@ class VertexInputInfoComputationUtilsTest { 2, new IndexRange(0, 1), new IndexRange(6, 9))); final JobVertexInputInfo broadcast = - computeVertexInputInfoForAllToAll(2, 3, ignored -> 1, true, true); + computeVertexInputInfoForAllToAll( + 2, 3, ignored -> 1, true, true, singleSubpartitionContainsAllData); assertThat(broadcast.getExecutionVertexInputInfos()) .containsExactlyInAnyOrder( new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), new ExecutionVertexInputInfo(1, new IndexRange(0, 1), new IndexRange(0, 0)), new ExecutionVertexInputInfo( 2, new IndexRange(0, 1), new IndexRange(0, 0))); + + if (!singleSubpartitionContainsAllData) { + final JobVertexInputInfo broadcastAndNotSingleSubpartitionContainsAllData = + computeVertexInputInfoForAllToAll( + 2, 3, ignored -> 4, true, true, singleSubpartitionContainsAllData); + assertThat( + broadcastAndNotSingleSubpartitionContainsAllData + .getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo( + 0, new IndexRange(0, 1), new IndexRange(0, 3)), + new ExecutionVertexInputInfo( + 1, new IndexRange(0, 1), new IndexRange(0, 3)), + new ExecutionVertexInputInfo( + 2, new IndexRange(0, 1), new IndexRange(0, 3))); + } } @Test @@ -150,7 +194,7 @@ class VertexInputInfoComputationUtilsTest { private static IndexRange computeConsumedSubpartitionRange( int consumerIndex, int numConsumers, int numSubpartitions) { return computeConsumedSubpartitionRange( - consumerIndex, numConsumers, numSubpartitions, true, false); + consumerIndex, numConsumers, numSubpartitions, true, false, false); } private static IndexRange computeConsumedSubpartitionRange( @@ -158,8 +202,14 @@ class VertexInputInfoComputationUtilsTest { int numConsumers, int numSubpartitions, boolean isDynamicGraph, - boolean isBroadcast) { + boolean isBroadcast, + boolean broadcast) { return VertexInputInfoComputationUtils.computeConsumedSubpartitionRange( - consumerIndex, numConsumers, () -> numSubpartitions, isDynamicGraph, isBroadcast); + consumerIndex, + numConsumers, + () -> numSubpartitions, + isDynamicGraph, + isBroadcast, + broadcast); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.java index 500d65b67ff..b1dc2260ae3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.java @@ -27,6 +27,7 @@ import org.apache.flink.streaming.api.graph.StreamGraph; import org.apache.flink.streaming.api.graph.StreamNode; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.util.DynamicCodeLoadingException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -45,7 +46,7 @@ class AdaptiveExecutionPlanSchedulingContextTest { TestingUtils.defaultExecutorExtension(); @Test - void testGetParallelismAndMaxParallelism() { + void testGetParallelismAndMaxParallelism() throws DynamicCodeLoadingException { int sinkParallelism = 4; int sinkMaxParallelism = 5; @@ -73,7 +74,8 @@ class AdaptiveExecutionPlanSchedulingContextTest { } @Test - void testGetDefaultMaxParallelismWhenParallelismGreaterThanZero() { + void testGetDefaultMaxParallelismWhenParallelismGreaterThanZero() + throws DynamicCodeLoadingException { int sinkParallelism = 4; int sinkMaxParallelism = -1; int defaultMaxParallelism = 100; @@ -94,7 +96,8 @@ class AdaptiveExecutionPlanSchedulingContextTest { } @Test - void testGetDefaultMaxParallelismWhenParallelismLessThanZero() { + void testGetDefaultMaxParallelismWhenParallelismLessThanZero() + throws DynamicCodeLoadingException { int sinkParallelism = -1; int sinkMaxParallelism = -1; int defaultMaxParallelism = 100; @@ -115,7 +118,7 @@ class AdaptiveExecutionPlanSchedulingContextTest { } @Test - public void testGetPendingOperatorCount() { + public void testGetPendingOperatorCount() throws DynamicCodeLoadingException { DefaultAdaptiveExecutionHandler adaptiveExecutionHandler = getDefaultAdaptiveExecutionHandler(); ExecutionPlanSchedulingContext schedulingContext = @@ -131,12 +134,13 @@ class AdaptiveExecutionPlanSchedulingContextTest { assertThat(schedulingContext.getPendingOperatorCount()).isEqualTo(0); } - private static DefaultAdaptiveExecutionHandler getDefaultAdaptiveExecutionHandler() { + private static DefaultAdaptiveExecutionHandler getDefaultAdaptiveExecutionHandler() + throws DynamicCodeLoadingException { return getDefaultAdaptiveExecutionHandler(2, 2); } private static DefaultAdaptiveExecutionHandler getDefaultAdaptiveExecutionHandler( - int sinkParallelism, int sinkMaxParallelism) { + int sinkParallelism, int sinkMaxParallelism) throws DynamicCodeLoadingException { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.fromSequence(0L, 1L).disableChaining().print(); StreamGraph streamGraph = env.getStreamGraph(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java index e298b4a065a..6930940fd7c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java @@ -32,18 +32,19 @@ class AllToAllBlockingResultInfoTest { @Test void testGetNumBytesProducedForNonBroadcast() { - testGetNumBytesProduced(false, 192L); + testGetNumBytesProduced(false, false, 192L); } @Test void testGetNumBytesProducedForBroadcast() { - testGetNumBytesProduced(true, 96L); + testGetNumBytesProduced(true, true, 96L); + testGetNumBytesProduced(true, false, 192L); } @Test void testGetNumBytesProducedWithIndexRange() { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false, false); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L})); @@ -57,7 +58,7 @@ class AllToAllBlockingResultInfoTest { @Test void testGetAggregatedSubpartitionBytes() { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false, false); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L})); @@ -67,8 +68,9 @@ class AllToAllBlockingResultInfoTest { @Test void testGetBytesWithPartialPartitionInfos() { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false, false); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); + resultInfo.aggregateSubpartitionBytes(); assertThatThrownBy(resultInfo::getNumBytesProduced) .isInstanceOf(IllegalStateException.class); @@ -79,7 +81,7 @@ class AllToAllBlockingResultInfoTest { @Test void testRecordPartitionInfoMultiTimes() { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false, false); ResultPartitionBytes partitionBytes1 = new ResultPartitionBytes(new long[] {32L, 64L}); ResultPartitionBytes partitionBytes2 = new ResultPartitionBytes(new long[] {64L, 128L}); @@ -115,9 +117,15 @@ class AllToAllBlockingResultInfoTest { assertThat(resultInfo.getNumOfRecordedPartitions()).isZero(); } - private void testGetNumBytesProduced(boolean isBroadcast, long expectedBytes) { + private void testGetNumBytesProduced( + boolean isBroadcast, boolean singleSubpartitionContainsAllData, long expectedBytes) { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, isBroadcast); + new AllToAllBlockingResultInfo( + new IntermediateDataSetID(), + 2, + 2, + isBroadcast, + singleSubpartitionContainsAllData); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 32L})); resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {64L, 64L})); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java index 6049006d245..3ad14047b0e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java @@ -24,16 +24,27 @@ import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexFinishedEvent; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.graph.StreamGraphContext; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.util.DynamicCodeLoadingException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Random; +import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicInteger; @@ -47,7 +58,7 @@ class DefaultAdaptiveExecutionHandlerTest { TestingUtils.defaultExecutorExtension(); @Test - void testGetJobGraph() { + void testGetJobGraph() throws DynamicCodeLoadingException { JobGraph jobGraph = createAdaptiveExecutionHandler().getJobGraph(); assertThat(jobGraph).isNotNull(); @@ -56,7 +67,7 @@ class DefaultAdaptiveExecutionHandlerTest { } @Test - void testHandleJobEvent() { + void testHandleJobEvent() throws DynamicCodeLoadingException { List<JobVertex> newAddedJobVertices = new ArrayList<>(); AtomicInteger pendingOperators = new AtomicInteger(); @@ -95,7 +106,82 @@ class DefaultAdaptiveExecutionHandlerTest { } @Test - void testGetInitialParallelismAndNotifyJobVertexParallelismDecided() { + void testOptimizeStreamGraph() throws DynamicCodeLoadingException { + StreamGraph streamGraph = createStreamGraph(); + StreamNode source = + streamGraph.getStreamNodes().stream() + .filter(node -> node.getOperatorName().contains("Source")) + .findFirst() + .get(); + StreamNode map = + streamGraph.getStreamNodes().stream() + .filter(node -> node.getOperatorName().contains("Map")) + .findFirst() + .get(); + + assertThat(source.getOutEdges().get(0).getPartitioner()) + .isInstanceOf(ForwardPartitioner.class); + assertThat(map.getOutEdges().get(0).getPartitioner()) + .isInstanceOf(RescalePartitioner.class); + + streamGraph + .getJobConfiguration() + .set( + StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY, + Collections.singletonList( + TestingStreamGraphOptimizerStrategy.class.getName())); + TestingStreamGraphOptimizerStrategy.convertToReBalanceEdgeIds.add( + source.getOutEdges().get(0).getEdgeId()); + TestingStreamGraphOptimizerStrategy.convertToReBalanceEdgeIds.add( + map.getOutEdges().get(0).getEdgeId()); + + DefaultAdaptiveExecutionHandler handler = + createAdaptiveExecutionHandler( + (newVertices, pendingOperatorsCount) -> {}, streamGraph); + + JobGraph jobGraph = handler.getJobGraph(); + JobVertex sourceVertex = jobGraph.getVertices().iterator().next(); + + // notify Source node is finished + ExecutionJobVertexFinishedEvent event1 = + new ExecutionJobVertexFinishedEvent(sourceVertex.getID(), Collections.emptyMap()); + handler.handleJobEvent(event1); + + // verify that the source output edge is not updated because the original edge is forward. + assertThat(sourceVertex.getProducedDataSets().get(0).getConsumers()).hasSize(1); + assertThat( + sourceVertex + .getProducedDataSets() + .get(0) + .getConsumers() + .get(0) + .getShipStrategyName()) + .isEqualToIgnoringCase("forward"); + + // notify Map node is finished + Iterator<JobVertex> jobVertexIterator = jobGraph.getVertices().iterator(); + jobVertexIterator.next(); + JobVertex mapVertex = jobVertexIterator.next(); + + ExecutionJobVertexFinishedEvent event2 = + new ExecutionJobVertexFinishedEvent(mapVertex.getID(), Collections.emptyMap()); + handler.handleJobEvent(event2); + + // verify that the map output edge is updated to reBalance. + assertThat(mapVertex.getProducedDataSets().get(0).getConsumers()).hasSize(1); + assertThat( + mapVertex + .getProducedDataSets() + .get(0) + .getConsumers() + .get(0) + .getShipStrategyName()) + .isEqualToIgnoringCase("rebalance"); + } + + @Test + void testGetInitialParallelismAndNotifyJobVertexParallelismDecided() + throws DynamicCodeLoadingException { StreamGraph streamGraph = createStreamGraph(); DefaultAdaptiveExecutionHandler handler = createAdaptiveExecutionHandler( @@ -123,7 +209,8 @@ class DefaultAdaptiveExecutionHandlerTest { assertThat(handler.getInitialParallelism(map.getID())).isEqualTo(parallelism); } - private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler() { + private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler() + throws DynamicCodeLoadingException { return createAdaptiveExecutionHandler( (newVertices, pendingOperatorsCount) -> {}, createStreamGraph()); } @@ -159,7 +246,8 @@ class DefaultAdaptiveExecutionHandlerTest { * and a given {@link StreamGraph}. */ private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler( - JobGraphUpdateListener listener, StreamGraph streamGraph) { + JobGraphUpdateListener listener, StreamGraph streamGraph) + throws DynamicCodeLoadingException { DefaultAdaptiveExecutionHandler handler = new DefaultAdaptiveExecutionHandler( getClass().getClassLoader(), streamGraph, EXECUTOR_RESOURCE.getExecutor()); @@ -167,4 +255,34 @@ class DefaultAdaptiveExecutionHandlerTest { return handler; } + + public static final class TestingStreamGraphOptimizerStrategy + implements StreamGraphOptimizationStrategy { + + private static final Set<String> convertToReBalanceEdgeIds = new HashSet<>(); + + @Override + public boolean onOperatorsFinished( + OperatorsFinished operatorsFinished, StreamGraphContext context) { + List<Integer> finishedStreamNodeIds = operatorsFinished.getFinishedStreamNodeIds(); + List<StreamEdgeUpdateRequestInfo> requestInfos = new ArrayList<>(); + for (Integer finishedStreamNodeId : finishedStreamNodeIds) { + for (ImmutableStreamEdge outEdge : + context.getStreamGraph() + .getStreamNode(finishedStreamNodeId) + .getOutEdges()) { + if (convertToReBalanceEdgeIds.contains(outEdge.getEdgeId())) { + StreamEdgeUpdateRequestInfo requestInfo = + new StreamEdgeUpdateRequestInfo( + outEdge.getEdgeId(), + outEdge.getSourceId(), + outEdge.getTargetId()); + requestInfo.outputPartitioner(new RebalancePartitioner<>()); + requestInfos.add(requestInfo); + } + } + } + return context.modifyStreamEdge(requestInfos); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java index 23c70f317bd..c1ea24e43aa 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java @@ -32,11 +32,14 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.shaded.guava32.com.google.common.collect.Iterables; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -104,8 +107,9 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { @Test void testDecideParallelismWithMaxSubpartitionLimitation() { - BlockingResultInfo resultInfo1 = new TestingBlockingResultInfo(false, 1L, 1024, 1024); - BlockingResultInfo resultInfo2 = new TestingBlockingResultInfo(false, 1L, 512, 512); + BlockingResultInfo resultInfo1 = + new TestingBlockingResultInfo(false, false, 1L, 1024, 1024); + BlockingResultInfo resultInfo2 = new TestingBlockingResultInfo(false, false, 1L, 512, 512); int parallelism = createDeciderAndDecideParallelism( @@ -206,13 +210,19 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { new IndexRange(8, 9))); } - @Test - void testAllEdgesAllToAllAndOneIsBroadcast() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testAllEdgesAllToAllAndOneIsBroadcast(boolean singleSubpartitionContainsAllData) { AllToAllBlockingResultInfo resultInfo1 = createAllToAllBlockingResultInfo( - new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 17L}); + new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 17L}, false, false); AllToAllBlockingResultInfo resultInfo2 = - createAllToAllBlockingResultInfo(new long[] {10L}, true); + createAllToAllBlockingResultInfo( + singleSubpartitionContainsAllData + ? new long[] {10L} + : new long[] {1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L}, + true, + singleSubpartitionContainsAllData); ParallelismAndInputInfos parallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos( @@ -224,17 +234,30 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { checkAllToAllJobVertexInputInfo( parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()), Arrays.asList(new IndexRange(0, 4), new IndexRange(5, 8), new IndexRange(9, 9))); - checkAllToAllJobVertexInputInfo( - parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()), - Arrays.asList(new IndexRange(0, 0), new IndexRange(0, 0), new IndexRange(0, 0))); + if (singleSubpartitionContainsAllData) { + checkAllToAllJobVertexInputInfo( + parallelismAndInputInfos + .getJobVertexInputInfos() + .get(resultInfo2.getResultId()), + Arrays.asList( + new IndexRange(0, 0), new IndexRange(0, 0), new IndexRange(0, 0))); + } else { + checkAllToAllJobVertexInputInfo( + parallelismAndInputInfos + .getJobVertexInputInfos() + .get(resultInfo2.getResultId()), + Arrays.asList( + new IndexRange(0, 9), new IndexRange(0, 9), new IndexRange(0, 9))); + } } @Test void testAllEdgesBroadcast() { - AllToAllBlockingResultInfo resultInfo1 = - createAllToAllBlockingResultInfo(new long[] {10L}, true); - AllToAllBlockingResultInfo resultInfo2 = - createAllToAllBlockingResultInfo(new long[] {10L}, true); + AllToAllBlockingResultInfo resultInfo1; + AllToAllBlockingResultInfo resultInfo2; + resultInfo1 = createAllToAllBlockingResultInfo(new long[] {10L}, true, false); + resultInfo2 = createAllToAllBlockingResultInfo(new long[] {10L}, true, false); + ParallelismAndInputInfos parallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos( 1, 10, 60L, Arrays.asList(resultInfo1, resultInfo2)); @@ -242,12 +265,15 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { assertThat(parallelismAndInputInfos.getParallelism()).isOne(); assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2); + List<IndexRange> expectedSubpartitionRanges = + Collections.singletonList(new IndexRange(0, 0)); + checkAllToAllJobVertexInputInfo( parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()), - Collections.singletonList(new IndexRange(0, 0))); + expectedSubpartitionRanges); checkAllToAllJobVertexInputInfo( parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()), - Collections.singletonList(new IndexRange(0, 0))); + expectedSubpartitionRanges); } @Test @@ -359,7 +385,8 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { long[] subpartitionBytes = new long[1024]; Arrays.fill(subpartitionBytes, 1L); AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 1024, 1024, false); + new AllToAllBlockingResultInfo( + new IntermediateDataSetID(), 1024, 1024, false, false); for (int i = 0; i < 1024; ++i) { resultInfo.recordPartitionInfo(i, new ResultPartitionBytes(subpartitionBytes)); } @@ -507,11 +534,13 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo( long[] aggregatedSubpartitionBytes) { - return createAllToAllBlockingResultInfo(aggregatedSubpartitionBytes, false); + return createAllToAllBlockingResultInfo(aggregatedSubpartitionBytes, false, false); } private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo( - long[] aggregatedSubpartitionBytes, boolean isBroadcast) { + long[] aggregatedSubpartitionBytes, + boolean isBroadcast, + boolean isSingleSubpartitionContainsAllData) { // For simplicity, we configure only one partition here, so the aggregatedSubpartitionBytes // is equivalent to the subpartition bytes of partition0 AllToAllBlockingResultInfo resultInfo = @@ -519,7 +548,8 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { new IntermediateDataSetID(), 1, aggregatedSubpartitionBytes.length, - isBroadcast); + isBroadcast, + isSingleSubpartitionContainsAllData); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(aggregatedSubpartitionBytes)); return resultInfo; } @@ -552,17 +582,31 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { private static class TestingBlockingResultInfo implements BlockingResultInfo { private final boolean isBroadcast; + private final boolean singleSubpartitionContainsAllData; private final long producedBytes; private final int numPartitions; private final int numSubpartitions; - private TestingBlockingResultInfo(boolean isBroadcast, long producedBytes) { - this(isBroadcast, producedBytes, MAX_PARALLELISM, MAX_PARALLELISM); + private TestingBlockingResultInfo( + boolean isBroadcast, + boolean singleSubpartitionContainsAllData, + long producedBytes) { + this( + isBroadcast, + singleSubpartitionContainsAllData, + producedBytes, + MAX_PARALLELISM, + MAX_PARALLELISM); } private TestingBlockingResultInfo( - boolean isBroadcast, long producedBytes, int numPartitions, int numSubpartitions) { + boolean isBroadcast, + boolean singleSubpartitionContainsAllData, + long producedBytes, + int numPartitions, + int numSubpartitions) { this.isBroadcast = isBroadcast; + this.singleSubpartitionContainsAllData = singleSubpartitionContainsAllData; this.producedBytes = producedBytes; this.numPartitions = numPartitions; this.numSubpartitions = numSubpartitions; @@ -578,6 +622,11 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { return isBroadcast; } + @Override + public boolean isSingleSubpartitionContainsAllData() { + return singleSubpartitionContainsAllData; + } + @Override public boolean isPointwise() { return false; @@ -617,10 +666,10 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { } private static BlockingResultInfo createFromBroadcastResult(long producedBytes) { - return new TestingBlockingResultInfo(true, producedBytes); + return new TestingBlockingResultInfo(true, true, producedBytes); } private static BlockingResultInfo createFromNonBroadcastResult(long producedBytes) { - return new TestingBlockingResultInfo(false, producedBytes); + return new TestingBlockingResultInfo(false, false, producedBytes); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java index 468f6883b73..0be8c246893 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java @@ -52,7 +52,8 @@ class DefaultStreamGraphContextTest { streamGraph, forwardGroupsByEndpointNodeIdCache, frozenNodeToStartNodeMap, - opIntermediateOutputsCaches); + opIntermediateOutputsCaches, + new HashMap<>()); StreamNode sourceNode = streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next()); @@ -136,7 +137,8 @@ class DefaultStreamGraphContextTest { streamGraph, forwardGroupsByEndpointNodeIdCache, frozenNodeToStartNodeMap, - opIntermediateOutputsCaches); + opIntermediateOutputsCaches, + new HashMap<>()); StreamNode sourceNode = streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next()); diff --git a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java index 6654c1eb17e..21b9a7645a9 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java @@ -31,8 +31,18 @@ import org.apache.flink.configuration.MemorySize; import org.apache.flink.configuration.RestOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler; +import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished; +import org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy; import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.graph.StreamGraphContext; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -40,8 +50,10 @@ import org.junit.jupiter.api.Test; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.function.Function; import java.util.stream.Collectors; @@ -130,6 +142,72 @@ class AdaptiveBatchSchedulerITCase { env.execute(); } + @Test + void testAdaptiveOptimizeStreamGraph() throws Exception { + final Configuration configuration = createConfiguration(); + configuration.set( + StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY, + List.of(TestingStreamGraphOptimizerStrategy.class.getName())); + final StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment(configuration); + env.setRuntimeMode(RuntimeExecutionMode.BATCH); + env.disableOperatorChaining(); + env.setParallelism(8); + + SingleOutputStreamOperator<Long> source1 = + env.fromSequence(0, NUMBERS_TO_PRODUCE - 1) + .setParallelism(SOURCE_PARALLELISM_1) + .name("source1"); + SingleOutputStreamOperator<Long> source2 = + env.fromSequence(0, NUMBERS_TO_PRODUCE - 1) + .setParallelism(SOURCE_PARALLELISM_2) + .name("source2"); + + source1.keyBy(i -> i % SOURCE_PARALLELISM_1) + .map(i -> i) + .name("map1") + .rebalance() + .union(source2) + .rebalance() + .map(new NumberCounter()) + .name("map2") + .setParallelism(1); + + StreamGraph streamGraph = env.getStreamGraph(); + StreamNode sourceNode1 = + streamGraph.getStreamNodes().stream() + .filter(node -> node.getOperatorName().contains("source1")) + .findFirst() + .get(); + StreamNode mapNode1 = + streamGraph.getStreamNodes().stream() + .filter(node -> node.getOperatorName().contains("map1")) + .findFirst() + .get(); + + TestingStreamGraphOptimizerStrategy.convertToRescaleEdgeIds.add( + sourceNode1.getOutEdges().get(0).getEdgeId()); + TestingStreamGraphOptimizerStrategy.convertToBroadcastEdgeIds.add( + mapNode1.getOutEdges().get(0).getEdgeId()); + + env.execute(streamGraph); + + Map<Long, Long> numberCountResultMap = + numberCountResults.stream() + .flatMap(map -> map.entrySet().stream()) + .collect( + Collectors.toMap( + Map.Entry::getKey, Map.Entry::getValue, Long::sum)); + + // One part comes from source1, while the other parts come from the broadcast results of + // source2. + Map<Long, Long> expectedResult = + LongStream.range(0, NUMBERS_TO_PRODUCE) + .boxed() + .collect(Collectors.toMap(Function.identity(), i -> 2L)); + assertThat(numberCountResultMap).isEqualTo(expectedResult); + } + private void testSchedulingBase(Boolean useSourceParallelismInference) throws Exception { executeJob(useSourceParallelismInference); @@ -257,4 +335,38 @@ class AdaptiveBatchSchedulerITCase { return expectedParallelism; } } + + public static final class TestingStreamGraphOptimizerStrategy + implements StreamGraphOptimizationStrategy { + + private static final Set<String> convertToBroadcastEdgeIds = new HashSet<>(); + private static final Set<String> convertToRescaleEdgeIds = new HashSet<>(); + + @Override + public boolean onOperatorsFinished( + OperatorsFinished operatorsFinished, StreamGraphContext context) throws Exception { + List<Integer> finishedStreamNodeIds = operatorsFinished.getFinishedStreamNodeIds(); + List<StreamEdgeUpdateRequestInfo> requestInfos = new ArrayList<>(); + for (Integer finishedStreamNodeId : finishedStreamNodeIds) { + for (ImmutableStreamEdge outEdge : + context.getStreamGraph() + .getStreamNode(finishedStreamNodeId) + .getOutEdges()) { + StreamEdgeUpdateRequestInfo requestInfo = + new StreamEdgeUpdateRequestInfo( + outEdge.getEdgeId(), + outEdge.getSourceId(), + outEdge.getTargetId()); + if (convertToBroadcastEdgeIds.contains(outEdge.getEdgeId())) { + requestInfo.outputPartitioner(new BroadcastPartitioner<>()); + requestInfos.add(requestInfo); + } else if (convertToRescaleEdgeIds.contains(outEdge.getEdgeId())) { + requestInfo.outputPartitioner(new RescalePartitioner<>()); + requestInfos.add(requestInfo); + } + } + } + return context.modifyStreamEdge(requestInfos); + } + } }
