This is an automated email from the ASF dual-hosted git repository. wanglijie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1d433fb4b72d74b1aa0cfa8ac7a9f7fed22bce32 Author: Lijie Wang <[email protected]> AuthorDate: Tue Oct 25 17:44:37 2022 +0800 [FLINK-29665][runtime] Support flexible subpartion range division This closes #21162 --- .../TaskDeploymentDescriptorFactory.java | 73 ++---- .../executiongraph/DefaultExecutionGraph.java | 34 ++- .../executiongraph/EdgeManagerBuildUtil.java | 233 ++++++++--------- .../runtime/executiongraph/ExecutionGraph.java | 16 +- .../runtime/executiongraph/ExecutionJobVertex.java | 2 +- .../runtime/executiongraph/ExecutionVertex.java | 8 + .../IntermediateResultInfo.java} | 36 +-- .../IntermediateResultPartition.java | 6 +- .../InternalExecutionGraphAccessor.java | 10 + .../VertexInputInfoComputationUtils.java | 275 +++++++++++++++++++++ .../executiongraph/VertexInputInfoStore.java | 78 ++++++ .../SsgNetworkMemoryCalculationUtils.java | 8 +- .../adaptivebatch/AllToAllBlockingResultInfo.java | 10 + .../adaptivebatch/BlockingResultInfo.java | 25 +- .../adaptivebatch/PointwiseBlockingResultInfo.java | 10 + .../TaskDeploymentDescriptorFactoryTest.java | 86 ------- .../DefaultExecutionGraphConstructionTest.java | 5 +- .../executiongraph/EdgeManagerBuildUtilTest.java | 166 ++++++++++++- .../executiongraph/PointwisePatternTest.java | 2 +- .../VertexInputInfoComputationUtilsTest.java | 165 +++++++++++++ .../runtime/scheduler/adaptive/ExecutingTest.java | 8 + .../adaptive/StateTrackingMockExecutionGraph.java | 6 +- .../DefaultVertexParallelismDeciderTest.java | 10 + 23 files changed, 917 insertions(+), 355 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java index b55b8e7b71e..26897a9aff2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java @@ -29,6 +29,7 @@ import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo; import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; @@ -60,7 +61,7 @@ import java.util.Map; import java.util.Optional; import java.util.function.Function; -import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; /** * Factory of {@link TaskDeploymentDescriptor} to deploy {@link @@ -78,6 +79,8 @@ public class TaskDeploymentDescriptorFactory { private final BlobWriter blobWriter; private final Map<IntermediateDataSetID, ShuffleDescriptor[]> consumedClusterPartitionShuffleDescriptors; + private final Function<IntermediateDataSetID, ExecutionVertexInputInfo> + executionVertexInputInfoRetriever; private TaskDeploymentDescriptorFactory( ExecutionAttemptID executionId, @@ -90,7 +93,9 @@ public class TaskDeploymentDescriptorFactory { resultPartitionRetriever, BlobWriter blobWriter, Map<IntermediateDataSetID, ShuffleDescriptor[]> - consumedClusterPartitionShuffleDescriptors) { + consumedClusterPartitionShuffleDescriptors, + Function<IntermediateDataSetID, ExecutionVertexInputInfo> + executionVertexInputInfoRetriever) { this.executionId = executionId; this.serializedJobInformation = serializedJobInformation; this.taskInfo = taskInfo; @@ -101,6 +106,7 @@ public class TaskDeploymentDescriptorFactory { this.blobWriter = blobWriter; this.consumedClusterPartitionShuffleDescriptors = consumedClusterPartitionShuffleDescriptors; + this.executionVertexInputInfoRetriever = checkNotNull(executionVertexInputInfoRetriever); } public TaskDeploymentDescriptor createDeploymentDescriptor( @@ -128,24 +134,22 @@ public class TaskDeploymentDescriptorFactory { // If the produced partition has multiple consumers registered, we // need to request the one matching our sub task index. // TODO Refactor after removing the consumers from the intermediate result partitions - IntermediateResultPartition resultPartition = - resultPartitionRetriever.apply(consumedPartitionGroup.getFirst()); - IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult(); - IndexRange consumedSubpartitionRange = - computeConsumedSubpartitionRange( - consumedPartitionGroup.getNumConsumers(), - resultPartition, - executionId.getSubtaskIndex()); + IntermediateResult consumedIntermediateResult = + resultPartitionRetriever + .apply(consumedPartitionGroup.getFirst()) + .getIntermediateResult(); IntermediateDataSetID resultId = consumedIntermediateResult.getId(); ResultPartitionType partitionType = consumedIntermediateResult.getResultType(); + IndexRange subpartitionRange = + executionVertexInputInfoRetriever.apply(resultId).getSubpartitionIndexRange(); inputGates.add( new InputGateDeploymentDescriptor( resultId, partitionType, - consumedSubpartitionRange, + subpartitionRange, getConsumedPartitionShuffleDescriptors( consumedIntermediateResult, consumedPartitionGroup))); } @@ -166,50 +170,6 @@ public class TaskDeploymentDescriptorFactory { return inputGates; } - public static IndexRange computeConsumedSubpartitionRange( - int numConsumers, - IntermediateResultPartition resultPartition, - int consumerSubtaskIndex) { - int consumerIndex = consumerSubtaskIndex % numConsumers; - IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult(); - int numSubpartitions = resultPartition.getNumberOfSubpartitions(); - return computeConsumedSubpartitionRange( - consumerIndex, - numConsumers, - numSubpartitions, - consumedIntermediateResult.getProducer().getGraph().isDynamic(), - consumedIntermediateResult.isBroadcast()); - } - - @VisibleForTesting - static IndexRange computeConsumedSubpartitionRange( - int consumerIndex, - int numConsumers, - int numSubpartitions, - boolean isDynamicGraph, - boolean isBroadcast) { - - if (!isDynamicGraph) { - checkArgument(numConsumers == numSubpartitions); - return new IndexRange(consumerIndex, consumerIndex); - } else { - if (isBroadcast) { - // broadcast result should have only one subpartition, and be consumed multiple - // times. - checkArgument(numSubpartitions == 1); - return new IndexRange(0, 0); - } else { - checkArgument(consumerIndex < numConsumers); - checkArgument(numConsumers <= numSubpartitions); - - int start = consumerIndex * numSubpartitions / numConsumers; - int nextStart = (consumerIndex + 1) * numSubpartitions / numConsumers; - - return new IndexRange(start, nextStart - 1); - } - } - } - private MaybeOffloaded<ShuffleDescriptor[]> getConsumedPartitionShuffleDescriptors( IntermediateResult intermediateResult, ConsumedPartitionGroup consumedPartitionGroup) throws IOException { @@ -285,7 +245,8 @@ public class TaskDeploymentDescriptorFactory { executionVertex.getAllConsumedPartitionGroups(), internalExecutionGraphAccessor::getResultPartitionOrThrow, internalExecutionGraphAccessor.getBlobWriter(), - clusterPartitionShuffleDescriptors); + clusterPartitionShuffleDescriptors, + executionVertex::getExecutionVertexInputInfo); } private static Map<IntermediateDataSetID, ShuffleDescriptor[]> diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java index 187bf88f35b..2f992dc5faf 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java @@ -289,6 +289,7 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG private final Map<IntermediateResultPartitionID, IntermediateResultPartition> resultPartitionsById; + private final VertexInputInfoStore vertexInputInfoStore; private final boolean isDynamic; private final ExecutionJobVertex.Factory executionJobVertexFactory; @@ -386,6 +387,7 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG this.edgeManager = new EdgeManager(); this.executionVerticesById = new HashMap<>(); this.resultPartitionsById = new HashMap<>(); + this.vertexInputInfoStore = new VertexInputInfoStore(); this.isDynamic = isDynamic; @@ -828,17 +830,7 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG } @Override - public void attachJobGraph(List<JobVertex> topologicallySorted) throws JobException { - if (isDynamic) { - attachJobGraph(topologicallySorted, Collections.emptyList()); - } else { - attachJobGraph(topologicallySorted, topologicallySorted); - } - } - - private void attachJobGraph( - List<JobVertex> verticesToAttach, List<JobVertex> verticesToInitialize) - throws JobException { + public void attachJobGraph(List<JobVertex> verticesToAttach) throws JobException { assertRunningInJobMasterMainThread(); @@ -850,7 +842,9 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG intermediateResults.size()); attachJobVertices(verticesToAttach); - initializeJobVertices(verticesToInitialize); + if (!isDynamic) { + initializeJobVertices(verticesToAttach); + } // the topology assigning should happen before notifying new vertices to failoverStrategy executionTopology = DefaultExecutionTopology.fromExecutionGraph(this); @@ -898,10 +892,18 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG } @Override - public void initializeJobVertex(ExecutionJobVertex ejv, long createTimestamp) + public void initializeJobVertex( + ExecutionJobVertex ejv, + long createTimestamp, + Map<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos) throws JobException { checkNotNull(ejv); + checkNotNull(jobVertexInputInfos); + + jobVertexInputInfos.forEach( + (resultId, info) -> + this.vertexInputInfoStore.put(ejv.getJobVertexId(), resultId, info)); ejv.initialize( executionHistorySizeLimit, @@ -1694,4 +1696,10 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG public MarkPartitionFinishedStrategy getMarkPartitionFinishedStrategy() { return markPartitionFinishedStrategy; } + + @Override + public JobVertexInputInfo getJobVertexInputInfo( + JobVertexID jobVertexId, IntermediateDataSetID resultId) { + return vertexInputInfoStore.get(jobVertexId, resultId); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java index cb2918f4089..a3613eba27c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.executiongraph; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; @@ -26,10 +27,14 @@ import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkState; + /** Utilities for building {@link EdgeManager}. */ public class EdgeManagerBuildUtil { @@ -39,20 +44,21 @@ public class EdgeManagerBuildUtil { * * @param vertex the downstream consumer {@link ExecutionJobVertex} * @param intermediateResult the upstream consumed {@link IntermediateResult} - * @param distributionPattern the {@link DistributionPattern} of the edge that connects the - * upstream {@link IntermediateResult} and the downstream {@link IntermediateResult} */ static void connectVertexToResult( - ExecutionJobVertex vertex, - IntermediateResult intermediateResult, - DistributionPattern distributionPattern) { + ExecutionJobVertex vertex, IntermediateResult intermediateResult) { + final DistributionPattern distributionPattern = + intermediateResult.getConsumingDistributionPattern(); + final JobVertexInputInfo jobVertexInputInfo = + vertex.getGraph() + .getJobVertexInputInfo(vertex.getJobVertexId(), intermediateResult.getId()); switch (distributionPattern) { case POINTWISE: - connectPointwise(vertex.getTaskVertices(), intermediateResult); + connectPointwise(vertex, intermediateResult, jobVertexInputInfo); break; case ALL_TO_ALL: - connectAllToAll(vertex.getTaskVertices(), intermediateResult); + connectAllToAll(vertex, intermediateResult, jobVertexInputInfo); break; default: throw new IllegalArgumentException("Unrecognized distribution pattern."); @@ -82,155 +88,120 @@ public class EdgeManagerBuildUtil { } private static void connectAllToAll( - ExecutionVertex[] taskVertices, IntermediateResult intermediateResult) { + ExecutionJobVertex jobVertex, + IntermediateResult result, + JobVertexInputInfo jobVertexInputInfo) { + // check the vertex input info is legal + jobVertexInputInfo + .getExecutionVertexInputInfos() + .forEach( + executionVertexInputInfo -> { + IndexRange partitionRange = + executionVertexInputInfo.getPartitionIndexRange(); + checkArgument(partitionRange.getStartIndex() == 0); + checkArgument( + partitionRange.getEndIndex() + == (result.getNumberOfAssignedPartitions() - 1)); + }); + + connectInternal( + Arrays.asList(jobVertex.getTaskVertices()), + Arrays.asList(result.getPartitions()), + result.getResultType(), + jobVertex.getGraph().getEdgeManager()); + } + + private static void connectPointwise( + ExecutionJobVertex jobVertex, + IntermediateResult result, + JobVertexInputInfo jobVertexInputInfo) { + + Map<IndexRange, List<Integer>> consumersByPartition = new LinkedHashMap<>(); + + for (ExecutionVertexInputInfo executionVertexInputInfo : + jobVertexInputInfo.getExecutionVertexInputInfos()) { + int consumerIndex = executionVertexInputInfo.getSubtaskIndex(); + IndexRange range = executionVertexInputInfo.getPartitionIndexRange(); + consumersByPartition.compute( + range, + (ignore, consumers) -> { + if (consumers == null) { + consumers = new ArrayList<>(); + } + consumers.add(consumerIndex); + return consumers; + }); + } + + consumersByPartition.forEach( + (range, subtasks) -> { + List<ExecutionVertex> taskVertices = new ArrayList<>(); + List<IntermediateResultPartition> partitions = new ArrayList<>(); + for (int index : subtasks) { + taskVertices.add(jobVertex.getTaskVertices()[index]); + } + for (int i = range.getStartIndex(); i <= range.getEndIndex(); ++i) { + partitions.add(result.getPartitions()[i]); + } + connectInternal( + taskVertices, + partitions, + result.getResultType(), + jobVertex.getGraph().getEdgeManager()); + }); + } + + /** Connect all execution vertices to all partitions. */ + private static void connectInternal( + List<ExecutionVertex> taskVertices, + List<IntermediateResultPartition> partitions, + ResultPartitionType resultPartitionType, + EdgeManager edgeManager) { + checkState(!taskVertices.isEmpty()); + checkState(!partitions.isEmpty()); - List<IntermediateResultPartitionID> consumedPartitions = - Arrays.stream(intermediateResult.getPartitions()) - .map(IntermediateResultPartition::getPartitionId) - .collect(Collectors.toList()); ConsumedPartitionGroup consumedPartitionGroup = createAndRegisterConsumedPartitionGroupToEdgeManager( - taskVertices.length, consumedPartitions, intermediateResult); + taskVertices.size(), partitions, resultPartitionType, edgeManager); for (ExecutionVertex ev : taskVertices) { ev.addConsumedPartitionGroup(consumedPartitionGroup); } List<ExecutionVertexID> consumerVertices = - Arrays.stream(taskVertices) - .map(ExecutionVertex::getID) - .collect(Collectors.toList()); + taskVertices.stream().map(ExecutionVertex::getID).collect(Collectors.toList()); ConsumerVertexGroup consumerVertexGroup = - ConsumerVertexGroup.fromMultipleVertices( - consumerVertices, intermediateResult.getResultType()); - for (IntermediateResultPartition partition : intermediateResult.getPartitions()) { + ConsumerVertexGroup.fromMultipleVertices(consumerVertices, resultPartitionType); + for (IntermediateResultPartition partition : partitions) { partition.addConsumers(consumerVertexGroup); } } - private static void connectPointwise( - ExecutionVertex[] taskVertices, IntermediateResult intermediateResult) { - - final int sourceCount = intermediateResult.getPartitions().length; - final int targetCount = taskVertices.length; - - if (sourceCount == targetCount) { - for (int i = 0; i < sourceCount; i++) { - ExecutionVertex executionVertex = taskVertices[i]; - IntermediateResultPartition partition = intermediateResult.getPartitions()[i]; - - ConsumerVertexGroup consumerVertexGroup = - ConsumerVertexGroup.fromSingleVertex( - executionVertex.getID(), intermediateResult.getResultType()); - partition.addConsumers(consumerVertexGroup); - - ConsumedPartitionGroup consumedPartitionGroup = - createAndRegisterConsumedPartitionGroupToEdgeManager( - consumerVertexGroup.size(), - partition.getPartitionId(), - intermediateResult); - executionVertex.addConsumedPartitionGroup(consumedPartitionGroup); - } - } else if (sourceCount > targetCount) { - for (int index = 0; index < targetCount; index++) { - - ExecutionVertex executionVertex = taskVertices[index]; - ConsumerVertexGroup consumerVertexGroup = - ConsumerVertexGroup.fromSingleVertex( - executionVertex.getID(), intermediateResult.getResultType()); - - int start = index * sourceCount / targetCount; - int end = (index + 1) * sourceCount / targetCount; - - List<IntermediateResultPartitionID> consumedPartitions = - new ArrayList<>(end - start); - - for (int i = start; i < end; i++) { - IntermediateResultPartition partition = intermediateResult.getPartitions()[i]; - partition.addConsumers(consumerVertexGroup); - - consumedPartitions.add(partition.getPartitionId()); - } - - ConsumedPartitionGroup consumedPartitionGroup = - createAndRegisterConsumedPartitionGroupToEdgeManager( - consumerVertexGroup.size(), consumedPartitions, intermediateResult); - executionVertex.addConsumedPartitionGroup(consumedPartitionGroup); - } - } else { - for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) { - int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount; - int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount; - - IntermediateResultPartition partition = - intermediateResult.getPartitions()[partitionNum]; - ConsumedPartitionGroup consumedPartitionGroup = - createAndRegisterConsumedPartitionGroupToEdgeManager( - end - start, partition.getPartitionId(), intermediateResult); - - List<ExecutionVertexID> consumers = new ArrayList<>(end - start); - - for (int i = start; i < end; i++) { - ExecutionVertex executionVertex = taskVertices[i]; - executionVertex.addConsumedPartitionGroup(consumedPartitionGroup); - - consumers.add(executionVertex.getID()); - } - - ConsumerVertexGroup consumerVertexGroup = - ConsumerVertexGroup.fromMultipleVertices( - consumers, intermediateResult.getResultType()); - partition.addConsumers(consumerVertexGroup); - } - } - } - private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager( int numConsumers, - IntermediateResultPartitionID consumedPartitionId, - IntermediateResult intermediateResult) { - ConsumedPartitionGroup consumedPartitionGroup = - ConsumedPartitionGroup.fromSinglePartition( - numConsumers, consumedPartitionId, intermediateResult.getResultType()); - finishAllDataProducedPartitions( - intermediateResult, - Collections.singletonList(consumedPartitionId), - consumedPartitionGroup); - registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult); - return consumedPartitionGroup; - } - - private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager( - int numConsumers, - List<IntermediateResultPartitionID> consumedPartitions, - IntermediateResult intermediateResult) { + List<IntermediateResultPartition> partitions, + ResultPartitionType resultPartitionType, + EdgeManager edgeManager) { + List<IntermediateResultPartitionID> partitionIds = + partitions.stream() + .map(IntermediateResultPartition::getPartitionId) + .collect(Collectors.toList()); ConsumedPartitionGroup consumedPartitionGroup = ConsumedPartitionGroup.fromMultiplePartitions( - numConsumers, consumedPartitions, intermediateResult.getResultType()); - finishAllDataProducedPartitions( - intermediateResult, consumedPartitions, consumedPartitionGroup); - registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult); + numConsumers, partitionIds, resultPartitionType); + finishAllDataProducedPartitions(partitions, consumedPartitionGroup); + edgeManager.registerConsumedPartitionGroup(consumedPartitionGroup); return consumedPartitionGroup; } private static void finishAllDataProducedPartitions( - IntermediateResult intermediateResult, - List<IntermediateResultPartitionID> consumedPartitionIds, + List<IntermediateResultPartition> partitions, ConsumedPartitionGroup consumedPartitionGroup) { - for (IntermediateResultPartitionID consumedPartitionId : consumedPartitionIds) { + for (IntermediateResultPartition partition : partitions) { // this is for dynamic graph as consumedPartitionGroup has not been created when the // partition becomes finished. - if (intermediateResult.getPartitionById(consumedPartitionId).hasDataAllProduced()) { + if (partition.hasDataAllProduced()) { consumedPartitionGroup.partitionFinished(); } } } - - private static void registerConsumedPartitionGroupToEdgeManager( - ConsumedPartitionGroup consumedPartitionGroup, IntermediateResult intermediateResult) { - intermediateResult - .getProducer() - .getGraph() - .getEdgeManager() - .registerConsumedPartitionGroup(consumedPartitionGroup); - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java index 77802620b1d..2f90291e75c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java @@ -203,6 +203,15 @@ public interface ExecutionGraph extends AccessExecutionGraph { @Nonnull ComponentMainThreadExecutor getJobMasterMainThreadExecutor(); + default void initializeJobVertex(ExecutionJobVertex ejv, long createTimestamp) + throws JobException { + initializeJobVertex( + ejv, + createTimestamp, + VertexInputInfoComputationUtils.computeVertexInputInfos( + ejv, getAllIntermediateResults()::get)); + } + /** * Initialize the given execution job vertex, mainly includes creating execution vertices * according to the parallelism, and connecting to the predecessors. @@ -210,8 +219,13 @@ public interface ExecutionGraph extends AccessExecutionGraph { * @param ejv The execution job vertex that needs to be initialized. * @param createTimestamp The timestamp for creating execution vertices, used to initialize the * first Execution with. + * @param jobVertexInputInfos The input infos of this job vertex. */ - void initializeJobVertex(ExecutionJobVertex ejv, long createTimestamp) throws JobException; + void initializeJobVertex( + ExecutionJobVertex ejv, + long createTimestamp, + Map<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos) + throws JobException; /** * Notify that some job vertices have been newly initialized, execution graph will try to update diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java index 1a24eda0c67..a914fd1630a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java @@ -487,7 +487,7 @@ public class ExecutionJobVertex this.inputs.add(ires); - EdgeManagerBuildUtil.connectVertexToResult(this, ires, edge.getDistributionPattern()); + EdgeManagerBuildUtil.connectVertexToResult(this, ires); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java index 06bc7e6b9f4..bd3a9794db0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java @@ -28,6 +28,7 @@ import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.clusterframework.types.ResourceProfile; import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobmaster.LogicalSlot; @@ -161,6 +162,13 @@ public class ExecutionVertex timeout); } + public ExecutionVertexInputInfo getExecutionVertexInputInfo(IntermediateDataSetID resultId) { + return getExecutionGraphAccessor() + .getJobVertexInputInfo(getJobvertexId(), resultId) + .getExecutionVertexInputInfos() + .get(subTaskIndex); + } + public Execution getPartitionProducer() { return currentExecution; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java similarity index 50% copy from flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java copy to flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java index 0fd14e4439e..26829893b5a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java @@ -16,16 +16,11 @@ * limitations under the License. */ -package org.apache.flink.runtime.scheduler.adaptivebatch; +package org.apache.flink.runtime.executiongraph; -import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; -/** - * The blocking result info, which will be used to calculate the vertex parallelism and input infos. - */ -public interface BlockingResultInfo { - +public interface IntermediateResultInfo { /** * Get the intermediate result id. * @@ -48,30 +43,17 @@ public interface BlockingResultInfo { boolean isPointwise(); /** - * Return the num of bytes produced(numBytesProduced) by the producer. - * - * <p>The difference between numBytesProduced and numBytesOut : numBytesProduced represents the - * number of bytes actually produced, and numBytesOut represents the number of bytes sent to - * downstream tasks. In unicast scenarios, these two values should be equal. In broadcast - * scenarios, numBytesOut should be (N * numBytesProduced), where N refers to the number of - * subpartitions. - * - * @return the num of bytes produced by the producer - */ - long getNumBytesProduced(); - - /** - * Record the information of the result partition. + * Get number of partitions for this result. * - * @param partitionIndex the intermediate result partition index - * @param partitionBytes the {@link ResultPartitionBytes} of the partition + * @return the number of partitions in this result */ - void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes); + int getNumPartitions(); /** - * Reset the information of the result partition. + * Get number of subpartitions for the given partition. * - * @param partitionIndex the intermediate result partition index + * @param partitionIndex the partition index + * @return the number of subpartitions of the partition */ - void resetPartitionInfo(int partitionIndex); + int getNumSubpartitions(int partitionIndex); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java index fd5d91bbd3f..0d69cfeed96 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java @@ -33,7 +33,7 @@ import static org.apache.flink.util.Preconditions.checkState; public class IntermediateResultPartition { - private static final int UNKNOWN = -1; + static final int NUM_SUBPARTITIONS_UNKNOWN = -1; private final IntermediateResult totalResult; @@ -44,7 +44,7 @@ public class IntermediateResultPartition { private final EdgeManager edgeManager; /** Number of subpartitions. Initialized lazily and will not change once set. */ - private int numberOfSubpartitions = UNKNOWN; + private int numberOfSubpartitions = NUM_SUBPARTITIONS_UNKNOWN; /** Whether this partition has produced all data. */ private boolean dataAllProduced = false; @@ -114,7 +114,7 @@ public class IntermediateResultPartition { } public int getNumberOfSubpartitions() { - if (numberOfSubpartitions == UNKNOWN) { + if (numberOfSubpartitions == NUM_SUBPARTITIONS_UNKNOWN) { numberOfSubpartitions = computeNumberOfSubpartitions(); checkState( numberOfSubpartitions > 0, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/InternalExecutionGraphAccessor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/InternalExecutionGraphAccessor.java index 5a814fb4211..bd2c7d70e04 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/InternalExecutionGraphAccessor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/InternalExecutionGraphAccessor.java @@ -122,4 +122,14 @@ public interface InternalExecutionGraphAccessor { IntermediateDataSetID intermediateResultPartition); MarkPartitionFinishedStrategy getMarkPartitionFinishedStrategy(); + + /** + * Get the input info of a certain input of a certain job vertex. + * + * @param jobVertexId the job vertex id + * @param resultId the input(intermediate result) id + * @return the input info + */ + JobVertexInputInfo getJobVertexInputInfo( + JobVertexID jobVertexId, IntermediateDataSetID resultId); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java new file mode 100644 index 00000000000..04d3175d10f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.runtime.JobException; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobEdge; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Util to compute {@link JobVertexInputInfo}s for execution job vertex. */ +public class VertexInputInfoComputationUtils { + + public static Map<IntermediateDataSetID, JobVertexInputInfo> computeVertexInputInfos( + ExecutionJobVertex ejv, + Function<IntermediateDataSetID, IntermediateResult> intermediateResultRetriever) + throws JobException { + checkState(ejv.isParallelismDecided()); + final List<IntermediateResultInfo> intermediateResultInfos = new ArrayList<>(); + for (JobEdge edge : ejv.getJobVertex().getInputs()) { + IntermediateResult ires = intermediateResultRetriever.apply(edge.getSourceId()); + if (ires == null) { + throw new JobException( + "Cannot connect this job graph to the previous graph. No previous intermediate result found for ID " + + edge.getSourceId()); + } + intermediateResultInfos.add(new IntermediateResultWrapper(ires)); + } + return computeVertexInputInfos( + ejv.getParallelism(), intermediateResultInfos, ejv.getGraph().isDynamic()); + } + + public static Map<IntermediateDataSetID, JobVertexInputInfo> computeVertexInputInfos( + int parallelism, List<IntermediateResultInfo> inputs, boolean isDynamicGraph) { + + checkArgument(parallelism > 0); + final Map<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos = + new LinkedHashMap<>(); + + for (IntermediateResultInfo input : inputs) { + int sourceParallelism = input.getNumPartitions(); + + if (input.isPointwise()) { + jobVertexInputInfos.putIfAbsent( + input.getResultId(), + computeVertexInputInfoForPointwise( + sourceParallelism, + parallelism, + input::getNumSubpartitions, + isDynamicGraph)); + } else { + jobVertexInputInfos.putIfAbsent( + input.getResultId(), + computeVertexInputInfoForAllToAll( + sourceParallelism, + parallelism, + input::getNumSubpartitions, + isDynamicGraph, + input.isBroadcast())); + } + } + + return jobVertexInputInfos; + } + + /** + * Compute the {@link JobVertexInputInfo} for a {@link DistributionPattern#POINTWISE} edge. This + * computation algorithm will evenly distribute subpartitions to downstream subtasks according + * to the number of subpartitions. Different downstream subtasks consume roughly the same number + * of subpartitions. + * + * @param sourceCount the parallelism of upstream + * @param targetCount the parallelism of downstream + * @param numOfSubpartitionsRetriever a retriever to get the number of subpartitions + * @param isDynamicGraph whether is dynamic graph + * @return the computed {@link JobVertexInputInfo} + */ + @VisibleForTesting + static JobVertexInputInfo computeVertexInputInfoForPointwise( + int sourceCount, + int targetCount, + Function<Integer, Integer> numOfSubpartitionsRetriever, + boolean isDynamicGraph) { + + final List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>(); + + if (sourceCount >= targetCount) { + for (int index = 0; index < targetCount; index++) { + + int start = index * sourceCount / targetCount; + int end = (index + 1) * sourceCount / targetCount; + + IndexRange partitionRange = new IndexRange(start, end - 1); + IndexRange subpartitionRange = + computeConsumedSubpartitionRange( + index, + 1, + () -> numOfSubpartitionsRetriever.apply(start), + isDynamicGraph, + false); + executionVertexInputInfos.add( + new ExecutionVertexInputInfo(index, partitionRange, subpartitionRange)); + } + } else { + for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) { + + int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount; + int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount; + int numConsumers = end - start; + + IndexRange partitionRange = new IndexRange(partitionNum, partitionNum); + // Variable used in lambda expression should be final or effectively final + final int finalPartitionNum = partitionNum; + for (int i = start; i < end; i++) { + IndexRange subpartitionRange = + computeConsumedSubpartitionRange( + i, + numConsumers, + () -> numOfSubpartitionsRetriever.apply(finalPartitionNum), + isDynamicGraph, + false); + executionVertexInputInfos.add( + new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange)); + } + } + } + return new JobVertexInputInfo(executionVertexInputInfos); + } + + /** + * Compute the {@link JobVertexInputInfo} for a {@link DistributionPattern#ALL_TO_ALL} edge. + * This computation algorithm will evenly distribute subpartitions to downstream subtasks + * according to the number of subpartitions. Different downstream subtasks consume roughly the + * same number of subpartitions. + * + * @param sourceCount the parallelism of upstream + * @param targetCount the parallelism of downstream + * @param numOfSubpartitionsRetriever a retriever to get the number of subpartitions + * @param isDynamicGraph whether is dynamic graph + * @param isBroadcast whether the edge is broadcast + * @return the computed {@link JobVertexInputInfo} + */ + @VisibleForTesting + static JobVertexInputInfo computeVertexInputInfoForAllToAll( + int sourceCount, + int targetCount, + Function<Integer, Integer> numOfSubpartitionsRetriever, + boolean isDynamicGraph, + boolean isBroadcast) { + final List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>(); + IndexRange partitionRange = new IndexRange(0, sourceCount - 1); + for (int i = 0; i < targetCount; ++i) { + IndexRange subpartitionRange = + computeConsumedSubpartitionRange( + i, + targetCount, + () -> numOfSubpartitionsRetriever.apply(0), + isDynamicGraph, + isBroadcast); + executionVertexInputInfos.add( + new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange)); + } + return new JobVertexInputInfo(executionVertexInputInfos); + } + + /** + * Compute the consumed subpartition range for a subtask. This computation algorithm will evenly + * distribute subpartitions to downstream subtasks according to the number of subpartitions. + * Different downstream subtasks consume roughly the same number of subpartitions. + * + * @param consumerSubtaskIndex the subtask index + * @param numConsumers the total number of consumers + * @param numOfSubpartitionsSupplier a supplier to get the number of subpartitions + * @param isDynamicGraph whether is dynamic graph + * @param isBroadcast whether the edge is broadcast + * @return the computed subpartition range + */ + @VisibleForTesting + static IndexRange computeConsumedSubpartitionRange( + int consumerSubtaskIndex, + int numConsumers, + Supplier<Integer> numOfSubpartitionsSupplier, + boolean isDynamicGraph, + boolean isBroadcast) { + int consumerIndex = consumerSubtaskIndex % numConsumers; + if (!isDynamicGraph) { + return new IndexRange(consumerIndex, consumerIndex); + } else { + int numSubpartitions = numOfSubpartitionsSupplier.get(); + if (isBroadcast) { + // broadcast results have only one subpartition, and be consumed multiple times. + checkArgument(numSubpartitions == 1); + return new IndexRange(0, 0); + } else { + checkArgument(consumerIndex < numConsumers); + checkArgument(numConsumers <= numSubpartitions); + + int start = consumerIndex * numSubpartitions / numConsumers; + int nextStart = (consumerIndex + 1) * numSubpartitions / numConsumers; + + return new IndexRange(start, nextStart - 1); + } + } + } + + private static class IntermediateResultWrapper implements IntermediateResultInfo { + private final IntermediateResult intermediateResult; + + IntermediateResultWrapper(IntermediateResult intermediateResult) { + this.intermediateResult = checkNotNull(intermediateResult); + } + + @Override + public IntermediateDataSetID getResultId() { + return intermediateResult.getId(); + } + + @Override + public boolean isBroadcast() { + return intermediateResult.isBroadcast(); + } + + @Override + public boolean isPointwise() { + return intermediateResult.getConsumingDistributionPattern() + == DistributionPattern.POINTWISE; + } + + @Override + public int getNumPartitions() { + return intermediateResult.getNumberOfAssignedPartitions(); + } + + @Override + public int getNumSubpartitions(int partitionIndex) { + // Note that this method should only be called for dynamic graph.This method is used to + // compute which sub-partitions a consumer vertex should consume, however, for + // non-dynamic graph it is not needed, and the number of sub-partitions is not decided + // at this stage, due to the execution edge are not created. + checkState( + intermediateResult.getProducer().getGraph().isDynamic(), + "This method should only be called for dynamic graph."); + return intermediateResult.getPartitions()[partitionIndex].getNumberOfSubpartitions(); + } + } + + /** Private default constructor to avoid being instantiated. */ + private VertexInputInfoComputationUtils() {} +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoStore.java new file mode 100644 index 00000000000..21638f2166d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoStore.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph; + +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobVertexID; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A store contains all the {@link JobVertexInputInfo}s. Note that if a vertex has multiple job + * edges connecting to the same intermediate result, their {@link DistributionPattern} must be the + * same and therefore the {@link JobVertexInputInfo} will be the same. + */ +public class VertexInputInfoStore { + + private final Map<JobVertexID, Map<IntermediateDataSetID, JobVertexInputInfo>> + jobVertexInputInfos = new HashMap<>(); + + /** + * Put a {@link JobVertexInputInfo}. + * + * @param jobVertexId the job vertex id + * @param resultId the intermediate result id + * @param info the {@link JobVertexInputInfo} to put + */ + public void put( + JobVertexID jobVertexId, IntermediateDataSetID resultId, JobVertexInputInfo info) { + checkNotNull(jobVertexId); + checkNotNull(resultId); + checkNotNull(info); + + jobVertexInputInfos.compute( + jobVertexId, + (ignored, inputInfos) -> { + if (inputInfos == null) { + inputInfos = new HashMap<>(); + } + + inputInfos.putIfAbsent(resultId, info); + return inputInfos; + }); + } + + /** + * Get a {@link JobVertexInputInfo}. + * + * @param jobVertexId the job vertex id + * @param resultId the intermediate result id + * @return the {@link JobVertexInputInfo} identified by the job vertex id and intermediate + * result id + */ + public JobVertexInputInfo get(JobVertexID jobVertexId, IntermediateDataSetID resultId) { + checkNotNull(jobVertexId); + checkNotNull(resultId); + return checkNotNull(jobVertexInputInfos.get(jobVertexId).get(resultId)); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java index ecd723a4246..bca093c7219 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java @@ -21,7 +21,6 @@ package org.apache.flink.runtime.scheduler; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.configuration.MemorySize; import org.apache.flink.runtime.clusterframework.types.ResourceProfile; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory; import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; @@ -184,10 +183,9 @@ public class SsgNetworkMemoryCalculationUtils { IntermediateResultPartition resultPartition = ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst())); IndexRange subpartitionIndexRange = - TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange( - partitionGroup.getNumConsumers(), - resultPartition, - vertex.getParallelSubtaskIndex()); + vertex.getExecutionVertexInputInfo( + resultPartition.getIntermediateResult().getId()) + .getSubpartitionIndexRange(); tmp.merge( partitionGroup.getIntermediateDataSetID(), diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java index a4a06efc57a..bae270a608a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java @@ -61,6 +61,16 @@ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { return false; } + @Override + public int getNumPartitions() { + return numOfPartitions; + } + + @Override + public int getNumSubpartitions(int partitionIndex) { + return numOfSubpartitions; + } + @Override public long getNumBytesProduced() { checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready"); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java index 0fd14e4439e..2eeb2a90a20 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java @@ -18,34 +18,13 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; +import org.apache.flink.runtime.executiongraph.IntermediateResultInfo; import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; /** * The blocking result info, which will be used to calculate the vertex parallelism and input infos. */ -public interface BlockingResultInfo { - - /** - * Get the intermediate result id. - * - * @return the intermediate result id - */ - IntermediateDataSetID getResultId(); - - /** - * Whether it is a broadcast result. - * - * @return whether it is a broadcast result - */ - boolean isBroadcast(); - - /** - * Whether it is a pointwise result. - * - * @return whether it is a pointwise result - */ - boolean isPointwise(); +public interface BlockingResultInfo extends IntermediateResultInfo { /** * Return the num of bytes produced(numBytesProduced) by the producer. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java index 287b180df00..225b653064b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java @@ -41,6 +41,16 @@ public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo { return true; } + @Override + public int getNumPartitions() { + return numOfPartitions; + } + + @Override + public int getNumSubpartitions(int partitionIndex) { + return numOfSubpartitions; + } + @Override public long getNumBytesProduced() { checkState( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java index 30938b355da..7ea1f77be79 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java @@ -30,7 +30,6 @@ import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; -import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; @@ -57,8 +56,6 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.ScheduledExecutorService; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertEquals; /** Tests for {@link TaskDeploymentDescriptorFactory}. */ @@ -188,87 +185,4 @@ public class TaskDeploymentDescriptorFactoryTest extends TestLogger { return compressedSerializedValue.deserializeValue(ClassLoader.getSystemClassLoader()); } } - - @Test - public void testComputeConsumedSubpartitionRange3to1() { - final IndexRange range = computeConsumedSubpartitionRange(0, 1, 3); - assertThat(range.getStartIndex(), is(0)); - assertThat(range.getEndIndex(), is(2)); - } - - @Test - public void testComputeConsumedSubpartitionRange3to2() { - final IndexRange range1 = computeConsumedSubpartitionRange(0, 2, 3); - assertThat(range1.getStartIndex(), is(0)); - assertThat(range1.getEndIndex(), is(0)); - - final IndexRange range2 = computeConsumedSubpartitionRange(1, 2, 3); - assertThat(range2.getStartIndex(), is(1)); - assertThat(range2.getEndIndex(), is(2)); - } - - @Test - public void testComputeConsumedSubpartitionRange6to4() { - final IndexRange range1 = computeConsumedSubpartitionRange(0, 4, 6); - assertThat(range1.getStartIndex(), is(0)); - assertThat(range1.getEndIndex(), is(0)); - - final IndexRange range2 = computeConsumedSubpartitionRange(1, 4, 6); - assertThat(range2.getStartIndex(), is(1)); - assertThat(range2.getEndIndex(), is(2)); - - final IndexRange range3 = computeConsumedSubpartitionRange(2, 4, 6); - assertThat(range3.getStartIndex(), is(3)); - assertThat(range3.getEndIndex(), is(3)); - - final IndexRange range4 = computeConsumedSubpartitionRange(3, 4, 6); - assertThat(range4.getStartIndex(), is(4)); - assertThat(range4.getEndIndex(), is(5)); - } - - @Test - public void testComputeBroadcastConsumedSubpartitionRange() { - final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 1, true, true); - assertThat(range1.getStartIndex(), is(0)); - assertThat(range1.getEndIndex(), is(0)); - - final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 1, true, true); - assertThat(range2.getStartIndex(), is(0)); - assertThat(range2.getEndIndex(), is(0)); - - final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 1, true, true); - assertThat(range3.getStartIndex(), is(0)); - assertThat(range3.getEndIndex(), is(0)); - } - - @Test - public void testComputeConsumedSubpartitionRangeForNonDynamicGraph() { - final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 3, false, false); - assertThat(range1.getStartIndex(), is(0)); - assertThat(range1.getEndIndex(), is(0)); - - final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 3, false, false); - assertThat(range2.getStartIndex(), is(1)); - assertThat(range2.getEndIndex(), is(1)); - - final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 3, false, false); - assertThat(range3.getStartIndex(), is(2)); - assertThat(range3.getEndIndex(), is(2)); - } - - private static IndexRange computeConsumedSubpartitionRange( - int consumerIndex, int numConsumers, int numSubpartitions) { - return computeConsumedSubpartitionRange( - consumerIndex, numConsumers, numSubpartitions, true, false); - } - - private static IndexRange computeConsumedSubpartitionRange( - int consumerIndex, - int numConsumers, - int numSubpartitions, - boolean isDynamicGraph, - boolean isBroadcast) { - return TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange( - consumerIndex, numConsumers, numSubpartitions, isDynamicGraph, isBroadcast); - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java index 47fce0af13b..2443a4f2b42 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java @@ -55,8 +55,9 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * This class contains test concerning the correct conversion from {@link JobGraph} to {@link - * ExecutionGraph} objects. It also tests that {@link EdgeManagerBuildUtil#connectVertexToResult} - * builds {@link DistributionPattern#ALL_TO_ALL} connections correctly. + * ExecutionGraph} objects. It also tests that {@link + * VertexInputInfoComputationUtils#computeVertexInputInfoForAllToAll} builds {@link + * DistributionPattern#ALL_TO_ALL} connections correctly. */ class DefaultExecutionGraphConstructionTest { @RegisterExtension diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java index 28b44595a26..41422bb48b1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java @@ -22,7 +22,6 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; -import org.apache.flink.runtime.scheduler.SchedulerBase; import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; @@ -33,9 +32,13 @@ import org.junit.jupiter.api.extension.RegisterExtension; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; import java.util.List; +import java.util.Objects; import java.util.concurrent.ScheduledExecutorService; +import static org.apache.flink.runtime.executiongraph.IntermediateResultPartitionTest.computeVertexParallelismStoreConsideringDynamicGraph; import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL; import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE; import static org.assertj.core.api.Assertions.assertThat; @@ -68,6 +71,142 @@ class EdgeManagerBuildUtilTest { testGetMaxNumEdgesToTarget(23, 17, ALL_TO_ALL); } + @Test + void testConnectAllToAll() throws Exception { + int upstream = 3; + int downstream = 2; + + // use dynamic graph to specify the vertex input info + ExecutionGraph eg = setupExecutionGraph(upstream, downstream, POINTWISE, true); + + List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>(); + for (int i = 0; i < downstream; i++) { + executionVertexInputInfos.add( + new ExecutionVertexInputInfo( + i, + new IndexRange(0, upstream - 1), + // the subpartition range will not be used in edge manager, so set (0, + // 0) + new IndexRange(0, 0))); + } + final JobVertexInputInfo jobVertexInputInfo = + new JobVertexInputInfo(executionVertexInputInfos); + + final Iterator<ExecutionJobVertex> vertexIterator = + eg.getVerticesTopologically().iterator(); + final ExecutionJobVertex producer = vertexIterator.next(); + final ExecutionJobVertex consumer = vertexIterator.next(); + + // initialize producer and consumer + eg.initializeJobVertex(producer, 1L, Collections.emptyMap()); + eg.initializeJobVertex( + consumer, + 1L, + Collections.singletonMap( + producer.getProducedDataSets()[0].getId(), jobVertexInputInfo)); + + IntermediateResult result = + Objects.requireNonNull(eg.getJobVertex(producer.getJobVertexId())) + .getProducedDataSets()[0]; + IntermediateResultPartition partition1 = result.getPartitions()[0]; + IntermediateResultPartition partition2 = result.getPartitions()[1]; + IntermediateResultPartition partition3 = result.getPartitions()[2]; + + ExecutionVertex vertex1 = consumer.getTaskVertices()[0]; + ExecutionVertex vertex2 = consumer.getTaskVertices()[1]; + + // check consumers of the partitions + assertThat(partition1.getConsumerVertexGroups().get(0)) + .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID()); + assertThat(partition1.getConsumerVertexGroups().get(0)) + .isEqualTo(partition1.getConsumerVertexGroups().get(0)); + assertThat(partition3.getConsumerVertexGroups().get(0)) + .isEqualTo(partition1.getConsumerVertexGroups().get(0)); + + // check inputs of the execution vertices + assertThat(vertex1.getConsumedPartitionGroup(0)) + .containsExactlyInAnyOrder( + partition1.getPartitionId(), + partition2.getPartitionId(), + partition3.getPartitionId()); + assertThat(vertex2.getConsumedPartitionGroup(0)) + .isEqualTo(vertex1.getConsumedPartitionGroup(0)); + } + + @Test + void testConnectPointwise() throws Exception { + int upstream = 4; + int downstream = 4; + + // use dynamic graph to specify the vertex input info + ExecutionGraph eg = setupExecutionGraph(upstream, downstream, POINTWISE, true); + + // set partition ranges + List<IndexRange> partitionRanges = + Arrays.asList( + new IndexRange(0, 0), + new IndexRange(0, 0), + new IndexRange(1, 2), + new IndexRange(3, 3)); + List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>(); + for (int i = 0; i < downstream; i++) { + executionVertexInputInfos.add( + new ExecutionVertexInputInfo( + // the subpartition range will not be used in edge manager, so set (0, + // 0) + i, partitionRanges.get(i), new IndexRange(0, 0))); + } + final JobVertexInputInfo jobVertexInputInfo = + new JobVertexInputInfo(executionVertexInputInfos); + + final Iterator<ExecutionJobVertex> vertexIterator = + eg.getVerticesTopologically().iterator(); + final ExecutionJobVertex producer = vertexIterator.next(); + final ExecutionJobVertex consumer = vertexIterator.next(); + + // initialize producer and consumer + eg.initializeJobVertex(producer, 1L, Collections.emptyMap()); + eg.initializeJobVertex( + consumer, + 1L, + Collections.singletonMap( + producer.getProducedDataSets()[0].getId(), jobVertexInputInfo)); + + IntermediateResult result = + Objects.requireNonNull(eg.getJobVertex(producer.getJobVertexId())) + .getProducedDataSets()[0]; + IntermediateResultPartition partition1 = result.getPartitions()[0]; + IntermediateResultPartition partition2 = result.getPartitions()[1]; + IntermediateResultPartition partition3 = result.getPartitions()[2]; + IntermediateResultPartition partition4 = result.getPartitions()[3]; + + ExecutionVertex vertex1 = consumer.getTaskVertices()[0]; + ExecutionVertex vertex2 = consumer.getTaskVertices()[1]; + ExecutionVertex vertex3 = consumer.getTaskVertices()[2]; + ExecutionVertex vertex4 = consumer.getTaskVertices()[3]; + + // check consumers of the partitions + assertThat(partition1.getConsumerVertexGroups().get(0)) + .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID()); + assertThat(partition2.getConsumerVertexGroups().get(0)) + .containsExactlyInAnyOrder(vertex3.getID()); + assertThat(partition3.getConsumerVertexGroups().get(0)) + .isEqualTo(partition2.getConsumerVertexGroups().get(0)); + assertThat(partition4.getConsumerVertexGroups().get(0)) + .containsExactlyInAnyOrder(vertex4.getID()); + + // check inputs of the execution vertices + assertThat(vertex1.getConsumedPartitionGroup(0)) + .containsExactlyInAnyOrder(partition1.getPartitionId()); + assertThat(vertex2.getConsumedPartitionGroup(0)) + .isEqualTo(vertex1.getConsumedPartitionGroup(0)); + assertThat(vertex3.getConsumedPartitionGroup(0)) + .containsExactlyInAnyOrder( + partition2.getPartitionId(), partition3.getPartitionId()); + assertThat(vertex4.getConsumedPartitionGroup(0)) + .containsExactlyInAnyOrder(partition4.getPartitionId()); + } + private void testGetMaxNumEdgesToTarget( int upstream, int downstream, DistributionPattern pattern) throws Exception { @@ -110,6 +249,16 @@ class EdgeManagerBuildUtilTest { private Pair<ExecutionJobVertex, ExecutionJobVertex> setupExecutionGraph( int upstream, int downstream, DistributionPattern pattern) throws Exception { + Iterator<ExecutionJobVertex> jobVertices = + setupExecutionGraph(upstream, downstream, pattern, false) + .getVerticesTopologically() + .iterator(); + return Pair.of(jobVertices.next(), jobVertices.next()); + } + + private ExecutionGraph setupExecutionGraph( + int upstream, int downstream, DistributionPattern pattern, boolean isDynamicGraph) + throws Exception { JobVertex v1 = new JobVertex("vertex1"); JobVertex v2 = new JobVertex("vertex2"); @@ -123,12 +272,19 @@ class EdgeManagerBuildUtilTest { List<JobVertex> ordered = new ArrayList<>(Arrays.asList(v1, v2)); - ExecutionGraph eg = + TestingDefaultExecutionGraphBuilder builder = TestingDefaultExecutionGraphBuilder.newBuilder() .setVertexParallelismStore( - SchedulerBase.computeVertexParallelismStore(ordered)) - .build(EXECUTOR_RESOURCE.getExecutor()); + computeVertexParallelismStoreConsideringDynamicGraph( + ordered, isDynamicGraph, 128)); + ExecutionGraph eg; + if (isDynamicGraph) { + eg = builder.buildDynamicGraph(EXECUTOR_RESOURCE.getExecutor()); + } else { + eg = builder.build(EXECUTOR_RESOURCE.getExecutor()); + } + eg.attachJobGraph(ordered); - return Pair.of(eg.getAllVertices().get(v1.getID()), eg.getAllVertices().get(v2.getID())); + return eg; } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java index d471cac83f5..c1ad5384c0a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java @@ -43,7 +43,7 @@ import static org.junit.Assert.fail; /** * Tests for building {@link DistributionPattern#POINTWISE} connections in {@link - * EdgeManagerBuildUtil#connectVertexToResult}. + * VertexInputInfoComputationUtils#computeVertexInputInfoForPointwise}. */ public class PointwisePatternTest { @ClassRule diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java new file mode 100644 index 00000000000..e0f4d6e2fad --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph; + +import org.junit.jupiter.api.Test; + +import static org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForAllToAll; +import static org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise; +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link VertexInputInfoComputationUtils}. */ +class VertexInputInfoComputationUtilsTest { + + @Test + void testComputeConsumedSubpartitionRange3to1() { + final IndexRange range = computeConsumedSubpartitionRange(0, 1, 3); + assertThat(range).isEqualTo(new IndexRange(0, 2)); + } + + @Test + void testComputeConsumedSubpartitionRange3to2() { + final IndexRange range1 = computeConsumedSubpartitionRange(0, 2, 3); + assertThat(range1).isEqualTo(new IndexRange(0, 0)); + + final IndexRange range2 = computeConsumedSubpartitionRange(1, 2, 3); + assertThat(range2).isEqualTo(new IndexRange(1, 2)); + } + + @Test + void testComputeConsumedSubpartitionRange6to4() { + final IndexRange range1 = computeConsumedSubpartitionRange(0, 4, 6); + assertThat(range1).isEqualTo(new IndexRange(0, 0)); + + final IndexRange range2 = computeConsumedSubpartitionRange(1, 4, 6); + assertThat(range2).isEqualTo(new IndexRange(1, 2)); + + final IndexRange range3 = computeConsumedSubpartitionRange(2, 4, 6); + assertThat(range3).isEqualTo(new IndexRange(3, 3)); + + final IndexRange range4 = computeConsumedSubpartitionRange(3, 4, 6); + assertThat(range4).isEqualTo(new IndexRange(4, 5)); + } + + @Test + void testComputeBroadcastConsumedSubpartitionRange() { + final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 1, true, true); + assertThat(range1).isEqualTo(new IndexRange(0, 0)); + + final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 1, true, true); + assertThat(range2).isEqualTo(new IndexRange(0, 0)); + + final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 1, true, true); + assertThat(range3).isEqualTo(new IndexRange(0, 0)); + } + + @Test + void testComputeConsumedSubpartitionRangeForNonDynamicGraph() { + final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, -1, false, false); + assertThat(range1).isEqualTo(new IndexRange(0, 0)); + + final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, -1, false, false); + assertThat(range2).isEqualTo(new IndexRange(1, 1)); + + final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, -1, false, false); + assertThat(range3).isEqualTo(new IndexRange(2, 2)); + } + + @Test + void testComputeVertexInputInfoForAllToAllWithNonDynamicGraph() { + final JobVertexInputInfo nonBroadcast = + computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, false); + assertThat(nonBroadcast.getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), + new ExecutionVertexInputInfo(1, new IndexRange(0, 1), new IndexRange(1, 1)), + new ExecutionVertexInputInfo( + 2, new IndexRange(0, 1), new IndexRange(2, 2))); + + final JobVertexInputInfo broadcast = + computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, true); + assertThat(broadcast.getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), + new ExecutionVertexInputInfo(1, new IndexRange(0, 1), new IndexRange(1, 1)), + new ExecutionVertexInputInfo( + 2, new IndexRange(0, 1), new IndexRange(2, 2))); + } + + @Test + void testComputeVertexInputInfoForAllToAllWithDynamicGraph() { + final JobVertexInputInfo nonBroadcast = + computeVertexInputInfoForAllToAll(2, 3, ignored -> 10, true, false); + assertThat(nonBroadcast.getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 2)), + new ExecutionVertexInputInfo(1, new IndexRange(0, 1), new IndexRange(3, 5)), + new ExecutionVertexInputInfo( + 2, new IndexRange(0, 1), new IndexRange(6, 9))); + + final JobVertexInputInfo broadcast = + computeVertexInputInfoForAllToAll(2, 3, ignored -> 1, true, true); + assertThat(broadcast.getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), + new ExecutionVertexInputInfo(1, new IndexRange(0, 1), new IndexRange(0, 0)), + new ExecutionVertexInputInfo( + 2, new IndexRange(0, 1), new IndexRange(0, 0))); + } + + @Test + void testComputeVertexInputInfoForPointwiseWithNonDynamicGraph() { + final JobVertexInputInfo jobVertexInputInfo = + computeVertexInputInfoForPointwise(2, 3, ignored -> 3, false); + assertThat(jobVertexInputInfo.getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo(0, new IndexRange(0, 0), new IndexRange(0, 0)), + new ExecutionVertexInputInfo(1, new IndexRange(0, 0), new IndexRange(1, 1)), + new ExecutionVertexInputInfo( + 2, new IndexRange(1, 1), new IndexRange(0, 0))); + } + + @Test + void testComputeVertexInputInfoForPointwiseWithDynamicGraph() { + final JobVertexInputInfo jobVertexInputInfo = + computeVertexInputInfoForPointwise(2, 3, ignored -> 4, true); + assertThat(jobVertexInputInfo.getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo(0, new IndexRange(0, 0), new IndexRange(0, 1)), + new ExecutionVertexInputInfo(1, new IndexRange(0, 0), new IndexRange(2, 3)), + new ExecutionVertexInputInfo( + 2, new IndexRange(1, 1), new IndexRange(0, 3))); + } + + private static IndexRange computeConsumedSubpartitionRange( + int consumerIndex, int numConsumers, int numSubpartitions) { + return computeConsumedSubpartitionRange( + consumerIndex, numConsumers, numSubpartitions, true, false); + } + + private static IndexRange computeConsumedSubpartitionRange( + int consumerIndex, + int numConsumers, + int numSubpartitions, + boolean isDynamicGraph, + boolean isBroadcast) { + return VertexInputInfoComputationUtils.computeConsumedSubpartitionRange( + consumerIndex, numConsumers, () -> numSubpartitions, isDynamicGraph, isBroadcast); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java index bc366f8d478..9281d910485 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java @@ -48,6 +48,7 @@ import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor; import org.apache.flink.runtime.executiongraph.JobInformation; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; import org.apache.flink.runtime.executiongraph.MarkPartitionFinishedStrategy; import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition; import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder; @@ -1037,5 +1038,12 @@ public class ExecutingTest extends TestLogger { throw new UnsupportedOperationException( "This method is not supported by the MockInternalExecutionGraphAccessor."); } + + @Override + public JobVertexInputInfo getJobVertexInputInfo( + JobVertexID jobVertexId, IntermediateDataSetID resultId) { + throw new UnsupportedOperationException( + "This method is not supported by the MockInternalExecutionGraphAccessor."); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java index 54a65ccddcd..639c846cbec 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java @@ -44,6 +44,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.executiongraph.JobStatusListener; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition; import org.apache.flink.runtime.executiongraph.failover.flip1.ResultPartitionAvailabilityChecker; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; @@ -370,7 +371,10 @@ class StateTrackingMockExecutionGraph implements ExecutionGraph { } @Override - public void initializeJobVertex(ExecutionJobVertex ejv, long createTimestamp) + public void initializeJobVertex( + ExecutionJobVertex ejv, + long createTimestamp, + Map<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos) throws JobException { throw new UnsupportedOperationException(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java index 0f588bcc5e6..1cb8f3d05b3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java @@ -170,6 +170,16 @@ class DefaultVertexParallelismDeciderTest { return false; } + @Override + public int getNumPartitions() { + return 0; + } + + @Override + public int getNumSubpartitions(int partitionIndex) { + return 0; + } + @Override public long getNumBytesProduced() { return producedBytes;
