This is an automated email from the ASF dual-hosted git repository.

wanglijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1d433fb4b72d74b1aa0cfa8ac7a9f7fed22bce32
Author: Lijie Wang <[email protected]>
AuthorDate: Tue Oct 25 17:44:37 2022 +0800

    [FLINK-29665][runtime] Support flexible subpartion range division
    
    This closes #21162
---
 .../TaskDeploymentDescriptorFactory.java           |  73 ++----
 .../executiongraph/DefaultExecutionGraph.java      |  34 ++-
 .../executiongraph/EdgeManagerBuildUtil.java       | 233 ++++++++---------
 .../runtime/executiongraph/ExecutionGraph.java     |  16 +-
 .../runtime/executiongraph/ExecutionJobVertex.java |   2 +-
 .../runtime/executiongraph/ExecutionVertex.java    |   8 +
 .../IntermediateResultInfo.java}                   |  36 +--
 .../IntermediateResultPartition.java               |   6 +-
 .../InternalExecutionGraphAccessor.java            |  10 +
 .../VertexInputInfoComputationUtils.java           | 275 +++++++++++++++++++++
 .../executiongraph/VertexInputInfoStore.java       |  78 ++++++
 .../SsgNetworkMemoryCalculationUtils.java          |   8 +-
 .../adaptivebatch/AllToAllBlockingResultInfo.java  |  10 +
 .../adaptivebatch/BlockingResultInfo.java          |  25 +-
 .../adaptivebatch/PointwiseBlockingResultInfo.java |  10 +
 .../TaskDeploymentDescriptorFactoryTest.java       |  86 -------
 .../DefaultExecutionGraphConstructionTest.java     |   5 +-
 .../executiongraph/EdgeManagerBuildUtilTest.java   | 166 ++++++++++++-
 .../executiongraph/PointwisePatternTest.java       |   2 +-
 .../VertexInputInfoComputationUtilsTest.java       | 165 +++++++++++++
 .../runtime/scheduler/adaptive/ExecutingTest.java  |   8 +
 .../adaptive/StateTrackingMockExecutionGraph.java  |   6 +-
 .../DefaultVertexParallelismDeciderTest.java       |  10 +
 23 files changed, 917 insertions(+), 355 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
index b55b8e7b71e..26897a9aff2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
 import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
 import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
@@ -60,7 +61,7 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.function.Function;
 
-import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
  * Factory of {@link TaskDeploymentDescriptor} to deploy {@link
@@ -78,6 +79,8 @@ public class TaskDeploymentDescriptorFactory {
     private final BlobWriter blobWriter;
     private final Map<IntermediateDataSetID, ShuffleDescriptor[]>
             consumedClusterPartitionShuffleDescriptors;
+    private final Function<IntermediateDataSetID, ExecutionVertexInputInfo>
+            executionVertexInputInfoRetriever;
 
     private TaskDeploymentDescriptorFactory(
             ExecutionAttemptID executionId,
@@ -90,7 +93,9 @@ public class TaskDeploymentDescriptorFactory {
                     resultPartitionRetriever,
             BlobWriter blobWriter,
             Map<IntermediateDataSetID, ShuffleDescriptor[]>
-                    consumedClusterPartitionShuffleDescriptors) {
+                    consumedClusterPartitionShuffleDescriptors,
+            Function<IntermediateDataSetID, ExecutionVertexInputInfo>
+                    executionVertexInputInfoRetriever) {
         this.executionId = executionId;
         this.serializedJobInformation = serializedJobInformation;
         this.taskInfo = taskInfo;
@@ -101,6 +106,7 @@ public class TaskDeploymentDescriptorFactory {
         this.blobWriter = blobWriter;
         this.consumedClusterPartitionShuffleDescriptors =
                 consumedClusterPartitionShuffleDescriptors;
+        this.executionVertexInputInfoRetriever = 
checkNotNull(executionVertexInputInfoRetriever);
     }
 
     public TaskDeploymentDescriptor createDeploymentDescriptor(
@@ -128,24 +134,22 @@ public class TaskDeploymentDescriptorFactory {
             // If the produced partition has multiple consumers registered, we
             // need to request the one matching our sub task index.
             // TODO Refactor after removing the consumers from the 
intermediate result partitions
-            IntermediateResultPartition resultPartition =
-                    
resultPartitionRetriever.apply(consumedPartitionGroup.getFirst());
 
-            IntermediateResult consumedIntermediateResult = 
resultPartition.getIntermediateResult();
-            IndexRange consumedSubpartitionRange =
-                    computeConsumedSubpartitionRange(
-                            consumedPartitionGroup.getNumConsumers(),
-                            resultPartition,
-                            executionId.getSubtaskIndex());
+            IntermediateResult consumedIntermediateResult =
+                    resultPartitionRetriever
+                            .apply(consumedPartitionGroup.getFirst())
+                            .getIntermediateResult();
 
             IntermediateDataSetID resultId = 
consumedIntermediateResult.getId();
             ResultPartitionType partitionType = 
consumedIntermediateResult.getResultType();
+            IndexRange subpartitionRange =
+                    
executionVertexInputInfoRetriever.apply(resultId).getSubpartitionIndexRange();
 
             inputGates.add(
                     new InputGateDeploymentDescriptor(
                             resultId,
                             partitionType,
-                            consumedSubpartitionRange,
+                            subpartitionRange,
                             getConsumedPartitionShuffleDescriptors(
                                     consumedIntermediateResult, 
consumedPartitionGroup)));
         }
@@ -166,50 +170,6 @@ public class TaskDeploymentDescriptorFactory {
         return inputGates;
     }
 
-    public static IndexRange computeConsumedSubpartitionRange(
-            int numConsumers,
-            IntermediateResultPartition resultPartition,
-            int consumerSubtaskIndex) {
-        int consumerIndex = consumerSubtaskIndex % numConsumers;
-        IntermediateResult consumedIntermediateResult = 
resultPartition.getIntermediateResult();
-        int numSubpartitions = resultPartition.getNumberOfSubpartitions();
-        return computeConsumedSubpartitionRange(
-                consumerIndex,
-                numConsumers,
-                numSubpartitions,
-                
consumedIntermediateResult.getProducer().getGraph().isDynamic(),
-                consumedIntermediateResult.isBroadcast());
-    }
-
-    @VisibleForTesting
-    static IndexRange computeConsumedSubpartitionRange(
-            int consumerIndex,
-            int numConsumers,
-            int numSubpartitions,
-            boolean isDynamicGraph,
-            boolean isBroadcast) {
-
-        if (!isDynamicGraph) {
-            checkArgument(numConsumers == numSubpartitions);
-            return new IndexRange(consumerIndex, consumerIndex);
-        } else {
-            if (isBroadcast) {
-                // broadcast result should have only one subpartition, and be 
consumed multiple
-                // times.
-                checkArgument(numSubpartitions == 1);
-                return new IndexRange(0, 0);
-            } else {
-                checkArgument(consumerIndex < numConsumers);
-                checkArgument(numConsumers <= numSubpartitions);
-
-                int start = consumerIndex * numSubpartitions / numConsumers;
-                int nextStart = (consumerIndex + 1) * numSubpartitions / 
numConsumers;
-
-                return new IndexRange(start, nextStart - 1);
-            }
-        }
-    }
-
     private MaybeOffloaded<ShuffleDescriptor[]> 
getConsumedPartitionShuffleDescriptors(
             IntermediateResult intermediateResult, ConsumedPartitionGroup 
consumedPartitionGroup)
             throws IOException {
@@ -285,7 +245,8 @@ public class TaskDeploymentDescriptorFactory {
                 executionVertex.getAllConsumedPartitionGroups(),
                 internalExecutionGraphAccessor::getResultPartitionOrThrow,
                 internalExecutionGraphAccessor.getBlobWriter(),
-                clusterPartitionShuffleDescriptors);
+                clusterPartitionShuffleDescriptors,
+                executionVertex::getExecutionVertexInputInfo);
     }
 
     private static Map<IntermediateDataSetID, ShuffleDescriptor[]>
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
index 187bf88f35b..2f992dc5faf 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
@@ -289,6 +289,7 @@ public class DefaultExecutionGraph implements 
ExecutionGraph, InternalExecutionG
     private final Map<IntermediateResultPartitionID, 
IntermediateResultPartition>
             resultPartitionsById;
 
+    private final VertexInputInfoStore vertexInputInfoStore;
     private final boolean isDynamic;
 
     private final ExecutionJobVertex.Factory executionJobVertexFactory;
@@ -386,6 +387,7 @@ public class DefaultExecutionGraph implements 
ExecutionGraph, InternalExecutionG
         this.edgeManager = new EdgeManager();
         this.executionVerticesById = new HashMap<>();
         this.resultPartitionsById = new HashMap<>();
+        this.vertexInputInfoStore = new VertexInputInfoStore();
 
         this.isDynamic = isDynamic;
 
@@ -828,17 +830,7 @@ public class DefaultExecutionGraph implements 
ExecutionGraph, InternalExecutionG
     }
 
     @Override
-    public void attachJobGraph(List<JobVertex> topologicallySorted) throws 
JobException {
-        if (isDynamic) {
-            attachJobGraph(topologicallySorted, Collections.emptyList());
-        } else {
-            attachJobGraph(topologicallySorted, topologicallySorted);
-        }
-    }
-
-    private void attachJobGraph(
-            List<JobVertex> verticesToAttach, List<JobVertex> 
verticesToInitialize)
-            throws JobException {
+    public void attachJobGraph(List<JobVertex> verticesToAttach) throws 
JobException {
 
         assertRunningInJobMasterMainThread();
 
@@ -850,7 +842,9 @@ public class DefaultExecutionGraph implements 
ExecutionGraph, InternalExecutionG
                 intermediateResults.size());
 
         attachJobVertices(verticesToAttach);
-        initializeJobVertices(verticesToInitialize);
+        if (!isDynamic) {
+            initializeJobVertices(verticesToAttach);
+        }
 
         // the topology assigning should happen before notifying new vertices 
to failoverStrategy
         executionTopology = DefaultExecutionTopology.fromExecutionGraph(this);
@@ -898,10 +892,18 @@ public class DefaultExecutionGraph implements 
ExecutionGraph, InternalExecutionG
     }
 
     @Override
-    public void initializeJobVertex(ExecutionJobVertex ejv, long 
createTimestamp)
+    public void initializeJobVertex(
+            ExecutionJobVertex ejv,
+            long createTimestamp,
+            Map<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos)
             throws JobException {
 
         checkNotNull(ejv);
+        checkNotNull(jobVertexInputInfos);
+
+        jobVertexInputInfos.forEach(
+                (resultId, info) ->
+                        this.vertexInputInfoStore.put(ejv.getJobVertexId(), 
resultId, info));
 
         ejv.initialize(
                 executionHistorySizeLimit,
@@ -1694,4 +1696,10 @@ public class DefaultExecutionGraph implements 
ExecutionGraph, InternalExecutionG
     public MarkPartitionFinishedStrategy getMarkPartitionFinishedStrategy() {
         return markPartitionFinishedStrategy;
     }
+
+    @Override
+    public JobVertexInputInfo getJobVertexInputInfo(
+            JobVertexID jobVertexId, IntermediateDataSetID resultId) {
+        return vertexInputInfoStore.get(jobVertexId, resultId);
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
index cb2918f4089..a3613eba27c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.executiongraph;
 
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
@@ -26,10 +27,14 @@ import 
org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Collectors;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
+
 /** Utilities for building {@link EdgeManager}. */
 public class EdgeManagerBuildUtil {
 
@@ -39,20 +44,21 @@ public class EdgeManagerBuildUtil {
      *
      * @param vertex the downstream consumer {@link ExecutionJobVertex}
      * @param intermediateResult the upstream consumed {@link 
IntermediateResult}
-     * @param distributionPattern the {@link DistributionPattern} of the edge 
that connects the
-     *     upstream {@link IntermediateResult} and the downstream {@link 
IntermediateResult}
      */
     static void connectVertexToResult(
-            ExecutionJobVertex vertex,
-            IntermediateResult intermediateResult,
-            DistributionPattern distributionPattern) {
+            ExecutionJobVertex vertex, IntermediateResult intermediateResult) {
+        final DistributionPattern distributionPattern =
+                intermediateResult.getConsumingDistributionPattern();
+        final JobVertexInputInfo jobVertexInputInfo =
+                vertex.getGraph()
+                        .getJobVertexInputInfo(vertex.getJobVertexId(), 
intermediateResult.getId());
 
         switch (distributionPattern) {
             case POINTWISE:
-                connectPointwise(vertex.getTaskVertices(), intermediateResult);
+                connectPointwise(vertex, intermediateResult, 
jobVertexInputInfo);
                 break;
             case ALL_TO_ALL:
-                connectAllToAll(vertex.getTaskVertices(), intermediateResult);
+                connectAllToAll(vertex, intermediateResult, 
jobVertexInputInfo);
                 break;
             default:
                 throw new IllegalArgumentException("Unrecognized distribution 
pattern.");
@@ -82,155 +88,120 @@ public class EdgeManagerBuildUtil {
     }
 
     private static void connectAllToAll(
-            ExecutionVertex[] taskVertices, IntermediateResult 
intermediateResult) {
+            ExecutionJobVertex jobVertex,
+            IntermediateResult result,
+            JobVertexInputInfo jobVertexInputInfo) {
+        // check the vertex input info is legal
+        jobVertexInputInfo
+                .getExecutionVertexInputInfos()
+                .forEach(
+                        executionVertexInputInfo -> {
+                            IndexRange partitionRange =
+                                    
executionVertexInputInfo.getPartitionIndexRange();
+                            checkArgument(partitionRange.getStartIndex() == 0);
+                            checkArgument(
+                                    partitionRange.getEndIndex()
+                                            == 
(result.getNumberOfAssignedPartitions() - 1));
+                        });
+
+        connectInternal(
+                Arrays.asList(jobVertex.getTaskVertices()),
+                Arrays.asList(result.getPartitions()),
+                result.getResultType(),
+                jobVertex.getGraph().getEdgeManager());
+    }
+
+    private static void connectPointwise(
+            ExecutionJobVertex jobVertex,
+            IntermediateResult result,
+            JobVertexInputInfo jobVertexInputInfo) {
+
+        Map<IndexRange, List<Integer>> consumersByPartition = new 
LinkedHashMap<>();
+
+        for (ExecutionVertexInputInfo executionVertexInputInfo :
+                jobVertexInputInfo.getExecutionVertexInputInfos()) {
+            int consumerIndex = executionVertexInputInfo.getSubtaskIndex();
+            IndexRange range = 
executionVertexInputInfo.getPartitionIndexRange();
+            consumersByPartition.compute(
+                    range,
+                    (ignore, consumers) -> {
+                        if (consumers == null) {
+                            consumers = new ArrayList<>();
+                        }
+                        consumers.add(consumerIndex);
+                        return consumers;
+                    });
+        }
+
+        consumersByPartition.forEach(
+                (range, subtasks) -> {
+                    List<ExecutionVertex> taskVertices = new ArrayList<>();
+                    List<IntermediateResultPartition> partitions = new 
ArrayList<>();
+                    for (int index : subtasks) {
+                        taskVertices.add(jobVertex.getTaskVertices()[index]);
+                    }
+                    for (int i = range.getStartIndex(); i <= 
range.getEndIndex(); ++i) {
+                        partitions.add(result.getPartitions()[i]);
+                    }
+                    connectInternal(
+                            taskVertices,
+                            partitions,
+                            result.getResultType(),
+                            jobVertex.getGraph().getEdgeManager());
+                });
+    }
+
+    /** Connect all execution vertices to all partitions. */
+    private static void connectInternal(
+            List<ExecutionVertex> taskVertices,
+            List<IntermediateResultPartition> partitions,
+            ResultPartitionType resultPartitionType,
+            EdgeManager edgeManager) {
+        checkState(!taskVertices.isEmpty());
+        checkState(!partitions.isEmpty());
 
-        List<IntermediateResultPartitionID> consumedPartitions =
-                Arrays.stream(intermediateResult.getPartitions())
-                        .map(IntermediateResultPartition::getPartitionId)
-                        .collect(Collectors.toList());
         ConsumedPartitionGroup consumedPartitionGroup =
                 createAndRegisterConsumedPartitionGroupToEdgeManager(
-                        taskVertices.length, consumedPartitions, 
intermediateResult);
+                        taskVertices.size(), partitions, resultPartitionType, 
edgeManager);
         for (ExecutionVertex ev : taskVertices) {
             ev.addConsumedPartitionGroup(consumedPartitionGroup);
         }
 
         List<ExecutionVertexID> consumerVertices =
-                Arrays.stream(taskVertices)
-                        .map(ExecutionVertex::getID)
-                        .collect(Collectors.toList());
+                
taskVertices.stream().map(ExecutionVertex::getID).collect(Collectors.toList());
         ConsumerVertexGroup consumerVertexGroup =
-                ConsumerVertexGroup.fromMultipleVertices(
-                        consumerVertices, intermediateResult.getResultType());
-        for (IntermediateResultPartition partition : 
intermediateResult.getPartitions()) {
+                ConsumerVertexGroup.fromMultipleVertices(consumerVertices, 
resultPartitionType);
+        for (IntermediateResultPartition partition : partitions) {
             partition.addConsumers(consumerVertexGroup);
         }
     }
 
-    private static void connectPointwise(
-            ExecutionVertex[] taskVertices, IntermediateResult 
intermediateResult) {
-
-        final int sourceCount = intermediateResult.getPartitions().length;
-        final int targetCount = taskVertices.length;
-
-        if (sourceCount == targetCount) {
-            for (int i = 0; i < sourceCount; i++) {
-                ExecutionVertex executionVertex = taskVertices[i];
-                IntermediateResultPartition partition = 
intermediateResult.getPartitions()[i];
-
-                ConsumerVertexGroup consumerVertexGroup =
-                        ConsumerVertexGroup.fromSingleVertex(
-                                executionVertex.getID(), 
intermediateResult.getResultType());
-                partition.addConsumers(consumerVertexGroup);
-
-                ConsumedPartitionGroup consumedPartitionGroup =
-                        createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                consumerVertexGroup.size(),
-                                partition.getPartitionId(),
-                                intermediateResult);
-                
executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
-            }
-        } else if (sourceCount > targetCount) {
-            for (int index = 0; index < targetCount; index++) {
-
-                ExecutionVertex executionVertex = taskVertices[index];
-                ConsumerVertexGroup consumerVertexGroup =
-                        ConsumerVertexGroup.fromSingleVertex(
-                                executionVertex.getID(), 
intermediateResult.getResultType());
-
-                int start = index * sourceCount / targetCount;
-                int end = (index + 1) * sourceCount / targetCount;
-
-                List<IntermediateResultPartitionID> consumedPartitions =
-                        new ArrayList<>(end - start);
-
-                for (int i = start; i < end; i++) {
-                    IntermediateResultPartition partition = 
intermediateResult.getPartitions()[i];
-                    partition.addConsumers(consumerVertexGroup);
-
-                    consumedPartitions.add(partition.getPartitionId());
-                }
-
-                ConsumedPartitionGroup consumedPartitionGroup =
-                        createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                consumerVertexGroup.size(), 
consumedPartitions, intermediateResult);
-                
executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
-            }
-        } else {
-            for (int partitionNum = 0; partitionNum < sourceCount; 
partitionNum++) {
-                int start = (partitionNum * targetCount + sourceCount - 1) / 
sourceCount;
-                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) 
/ sourceCount;
-
-                IntermediateResultPartition partition =
-                        intermediateResult.getPartitions()[partitionNum];
-                ConsumedPartitionGroup consumedPartitionGroup =
-                        createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                end - start, partition.getPartitionId(), 
intermediateResult);
-
-                List<ExecutionVertexID> consumers = new ArrayList<>(end - 
start);
-
-                for (int i = start; i < end; i++) {
-                    ExecutionVertex executionVertex = taskVertices[i];
-                    
executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
-
-                    consumers.add(executionVertex.getID());
-                }
-
-                ConsumerVertexGroup consumerVertexGroup =
-                        ConsumerVertexGroup.fromMultipleVertices(
-                                consumers, intermediateResult.getResultType());
-                partition.addConsumers(consumerVertexGroup);
-            }
-        }
-    }
-
     private static ConsumedPartitionGroup 
createAndRegisterConsumedPartitionGroupToEdgeManager(
             int numConsumers,
-            IntermediateResultPartitionID consumedPartitionId,
-            IntermediateResult intermediateResult) {
-        ConsumedPartitionGroup consumedPartitionGroup =
-                ConsumedPartitionGroup.fromSinglePartition(
-                        numConsumers, consumedPartitionId, 
intermediateResult.getResultType());
-        finishAllDataProducedPartitions(
-                intermediateResult,
-                Collections.singletonList(consumedPartitionId),
-                consumedPartitionGroup);
-        registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, 
intermediateResult);
-        return consumedPartitionGroup;
-    }
-
-    private static ConsumedPartitionGroup 
createAndRegisterConsumedPartitionGroupToEdgeManager(
-            int numConsumers,
-            List<IntermediateResultPartitionID> consumedPartitions,
-            IntermediateResult intermediateResult) {
+            List<IntermediateResultPartition> partitions,
+            ResultPartitionType resultPartitionType,
+            EdgeManager edgeManager) {
+        List<IntermediateResultPartitionID> partitionIds =
+                partitions.stream()
+                        .map(IntermediateResultPartition::getPartitionId)
+                        .collect(Collectors.toList());
         ConsumedPartitionGroup consumedPartitionGroup =
                 ConsumedPartitionGroup.fromMultiplePartitions(
-                        numConsumers, consumedPartitions, 
intermediateResult.getResultType());
-        finishAllDataProducedPartitions(
-                intermediateResult, consumedPartitions, 
consumedPartitionGroup);
-        registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, 
intermediateResult);
+                        numConsumers, partitionIds, resultPartitionType);
+        finishAllDataProducedPartitions(partitions, consumedPartitionGroup);
+        edgeManager.registerConsumedPartitionGroup(consumedPartitionGroup);
         return consumedPartitionGroup;
     }
 
     private static void finishAllDataProducedPartitions(
-            IntermediateResult intermediateResult,
-            List<IntermediateResultPartitionID> consumedPartitionIds,
+            List<IntermediateResultPartition> partitions,
             ConsumedPartitionGroup consumedPartitionGroup) {
-        for (IntermediateResultPartitionID consumedPartitionId : 
consumedPartitionIds) {
+        for (IntermediateResultPartition partition : partitions) {
             // this is for dynamic graph as consumedPartitionGroup has not 
been created when the
             // partition becomes finished.
-            if 
(intermediateResult.getPartitionById(consumedPartitionId).hasDataAllProduced()) 
{
+            if (partition.hasDataAllProduced()) {
                 consumedPartitionGroup.partitionFinished();
             }
         }
     }
-
-    private static void registerConsumedPartitionGroupToEdgeManager(
-            ConsumedPartitionGroup consumedPartitionGroup, IntermediateResult 
intermediateResult) {
-        intermediateResult
-                .getProducer()
-                .getGraph()
-                .getEdgeManager()
-                .registerConsumedPartitionGroup(consumedPartitionGroup);
-    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index 77802620b1d..2f90291e75c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -203,6 +203,15 @@ public interface ExecutionGraph extends 
AccessExecutionGraph {
     @Nonnull
     ComponentMainThreadExecutor getJobMasterMainThreadExecutor();
 
+    default void initializeJobVertex(ExecutionJobVertex ejv, long 
createTimestamp)
+            throws JobException {
+        initializeJobVertex(
+                ejv,
+                createTimestamp,
+                VertexInputInfoComputationUtils.computeVertexInputInfos(
+                        ejv, getAllIntermediateResults()::get));
+    }
+
     /**
      * Initialize the given execution job vertex, mainly includes creating 
execution vertices
      * according to the parallelism, and connecting to the predecessors.
@@ -210,8 +219,13 @@ public interface ExecutionGraph extends 
AccessExecutionGraph {
      * @param ejv The execution job vertex that needs to be initialized.
      * @param createTimestamp The timestamp for creating execution vertices, 
used to initialize the
      *     first Execution with.
+     * @param jobVertexInputInfos The input infos of this job vertex.
      */
-    void initializeJobVertex(ExecutionJobVertex ejv, long createTimestamp) 
throws JobException;
+    void initializeJobVertex(
+            ExecutionJobVertex ejv,
+            long createTimestamp,
+            Map<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos)
+            throws JobException;
 
     /**
      * Notify that some job vertices have been newly initialized, execution 
graph will try to update
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
index 1a24eda0c67..a914fd1630a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
@@ -487,7 +487,7 @@ public class ExecutionJobVertex
 
             this.inputs.add(ires);
 
-            EdgeManagerBuildUtil.connectVertexToResult(this, ires, 
edge.getDistributionPattern());
+            EdgeManagerBuildUtil.connectVertexToResult(this, ires);
         }
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index 06bc7e6b9f4..bd3a9794db0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
 import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobmaster.LogicalSlot;
@@ -161,6 +162,13 @@ public class ExecutionVertex
                 timeout);
     }
 
+    public ExecutionVertexInputInfo 
getExecutionVertexInputInfo(IntermediateDataSetID resultId) {
+        return getExecutionGraphAccessor()
+                .getJobVertexInputInfo(getJobvertexId(), resultId)
+                .getExecutionVertexInputInfos()
+                .get(subTaskIndex);
+    }
+
     public Execution getPartitionProducer() {
         return currentExecution;
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java
similarity index 50%
copy from 
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java
index 0fd14e4439e..26829893b5a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java
@@ -16,16 +16,11 @@
  * limitations under the License.
  */
 
-package org.apache.flink.runtime.scheduler.adaptivebatch;
+package org.apache.flink.runtime.executiongraph;
 
-import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
-/**
- * The blocking result info, which will be used to calculate the vertex 
parallelism and input infos.
- */
-public interface BlockingResultInfo {
-
+public interface IntermediateResultInfo {
     /**
      * Get the intermediate result id.
      *
@@ -48,30 +43,17 @@ public interface BlockingResultInfo {
     boolean isPointwise();
 
     /**
-     * Return the num of bytes produced(numBytesProduced) by the producer.
-     *
-     * <p>The difference between numBytesProduced and numBytesOut : 
numBytesProduced represents the
-     * number of bytes actually produced, and numBytesOut represents the 
number of bytes sent to
-     * downstream tasks. In unicast scenarios, these two values should be 
equal. In broadcast
-     * scenarios, numBytesOut should be (N * numBytesProduced), where N refers 
to the number of
-     * subpartitions.
-     *
-     * @return the num of bytes produced by the producer
-     */
-    long getNumBytesProduced();
-
-    /**
-     * Record the information of the result partition.
+     * Get number of partitions for this result.
      *
-     * @param partitionIndex the intermediate result partition index
-     * @param partitionBytes the {@link ResultPartitionBytes} of the partition
+     * @return the number of partitions in this result
      */
-    void recordPartitionInfo(int partitionIndex, ResultPartitionBytes 
partitionBytes);
+    int getNumPartitions();
 
     /**
-     * Reset the information of the result partition.
+     * Get number of subpartitions for the given partition.
      *
-     * @param partitionIndex the intermediate result partition index
+     * @param partitionIndex the partition index
+     * @return the number of subpartitions of the partition
      */
-    void resetPartitionInfo(int partitionIndex);
+    int getNumSubpartitions(int partitionIndex);
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
index fd5d91bbd3f..0d69cfeed96 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
@@ -33,7 +33,7 @@ import static org.apache.flink.util.Preconditions.checkState;
 
 public class IntermediateResultPartition {
 
-    private static final int UNKNOWN = -1;
+    static final int NUM_SUBPARTITIONS_UNKNOWN = -1;
 
     private final IntermediateResult totalResult;
 
@@ -44,7 +44,7 @@ public class IntermediateResultPartition {
     private final EdgeManager edgeManager;
 
     /** Number of subpartitions. Initialized lazily and will not change once 
set. */
-    private int numberOfSubpartitions = UNKNOWN;
+    private int numberOfSubpartitions = NUM_SUBPARTITIONS_UNKNOWN;
 
     /** Whether this partition has produced all data. */
     private boolean dataAllProduced = false;
@@ -114,7 +114,7 @@ public class IntermediateResultPartition {
     }
 
     public int getNumberOfSubpartitions() {
-        if (numberOfSubpartitions == UNKNOWN) {
+        if (numberOfSubpartitions == NUM_SUBPARTITIONS_UNKNOWN) {
             numberOfSubpartitions = computeNumberOfSubpartitions();
             checkState(
                     numberOfSubpartitions > 0,
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/InternalExecutionGraphAccessor.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/InternalExecutionGraphAccessor.java
index 5a814fb4211..bd2c7d70e04 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/InternalExecutionGraphAccessor.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/InternalExecutionGraphAccessor.java
@@ -122,4 +122,14 @@ public interface InternalExecutionGraphAccessor {
             IntermediateDataSetID intermediateResultPartition);
 
     MarkPartitionFinishedStrategy getMarkPartitionFinishedStrategy();
+
+    /**
+     * Get the input info of a certain input of a certain job vertex.
+     *
+     * @param jobVertexId the job vertex id
+     * @param resultId the input(intermediate result) id
+     * @return the input info
+     */
+    JobVertexInputInfo getJobVertexInputInfo(
+            JobVertexID jobVertexId, IntermediateDataSetID resultId);
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java
new file mode 100644
index 00000000000..04d3175d10f
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.JobException;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.JobEdge;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Supplier;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Util to compute {@link JobVertexInputInfo}s for execution job vertex. */
+public class VertexInputInfoComputationUtils {
+
+    public static Map<IntermediateDataSetID, JobVertexInputInfo> 
computeVertexInputInfos(
+            ExecutionJobVertex ejv,
+            Function<IntermediateDataSetID, IntermediateResult> 
intermediateResultRetriever)
+            throws JobException {
+        checkState(ejv.isParallelismDecided());
+        final List<IntermediateResultInfo> intermediateResultInfos = new 
ArrayList<>();
+        for (JobEdge edge : ejv.getJobVertex().getInputs()) {
+            IntermediateResult ires = 
intermediateResultRetriever.apply(edge.getSourceId());
+            if (ires == null) {
+                throw new JobException(
+                        "Cannot connect this job graph to the previous graph. 
No previous intermediate result found for ID "
+                                + edge.getSourceId());
+            }
+            intermediateResultInfos.add(new IntermediateResultWrapper(ires));
+        }
+        return computeVertexInputInfos(
+                ejv.getParallelism(), intermediateResultInfos, 
ejv.getGraph().isDynamic());
+    }
+
+    public static Map<IntermediateDataSetID, JobVertexInputInfo> 
computeVertexInputInfos(
+            int parallelism, List<IntermediateResultInfo> inputs, boolean 
isDynamicGraph) {
+
+        checkArgument(parallelism > 0);
+        final Map<IntermediateDataSetID, JobVertexInputInfo> 
jobVertexInputInfos =
+                new LinkedHashMap<>();
+
+        for (IntermediateResultInfo input : inputs) {
+            int sourceParallelism = input.getNumPartitions();
+
+            if (input.isPointwise()) {
+                jobVertexInputInfos.putIfAbsent(
+                        input.getResultId(),
+                        computeVertexInputInfoForPointwise(
+                                sourceParallelism,
+                                parallelism,
+                                input::getNumSubpartitions,
+                                isDynamicGraph));
+            } else {
+                jobVertexInputInfos.putIfAbsent(
+                        input.getResultId(),
+                        computeVertexInputInfoForAllToAll(
+                                sourceParallelism,
+                                parallelism,
+                                input::getNumSubpartitions,
+                                isDynamicGraph,
+                                input.isBroadcast()));
+            }
+        }
+
+        return jobVertexInputInfos;
+    }
+
+    /**
+     * Compute the {@link JobVertexInputInfo} for a {@link 
DistributionPattern#POINTWISE} edge. This
+     * computation algorithm will evenly distribute subpartitions to 
downstream subtasks according
+     * to the number of subpartitions. Different downstream subtasks consume 
roughly the same number
+     * of subpartitions.
+     *
+     * @param sourceCount the parallelism of upstream
+     * @param targetCount the parallelism of downstream
+     * @param numOfSubpartitionsRetriever a retriever to get the number of 
subpartitions
+     * @param isDynamicGraph whether is dynamic graph
+     * @return the computed {@link JobVertexInputInfo}
+     */
+    @VisibleForTesting
+    static JobVertexInputInfo computeVertexInputInfoForPointwise(
+            int sourceCount,
+            int targetCount,
+            Function<Integer, Integer> numOfSubpartitionsRetriever,
+            boolean isDynamicGraph) {
+
+        final List<ExecutionVertexInputInfo> executionVertexInputInfos = new 
ArrayList<>();
+
+        if (sourceCount >= targetCount) {
+            for (int index = 0; index < targetCount; index++) {
+
+                int start = index * sourceCount / targetCount;
+                int end = (index + 1) * sourceCount / targetCount;
+
+                IndexRange partitionRange = new IndexRange(start, end - 1);
+                IndexRange subpartitionRange =
+                        computeConsumedSubpartitionRange(
+                                index,
+                                1,
+                                () -> numOfSubpartitionsRetriever.apply(start),
+                                isDynamicGraph,
+                                false);
+                executionVertexInputInfos.add(
+                        new ExecutionVertexInputInfo(index, partitionRange, 
subpartitionRange));
+            }
+        } else {
+            for (int partitionNum = 0; partitionNum < sourceCount; 
partitionNum++) {
+
+                int start = (partitionNum * targetCount + sourceCount - 1) / 
sourceCount;
+                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) 
/ sourceCount;
+                int numConsumers = end - start;
+
+                IndexRange partitionRange = new IndexRange(partitionNum, 
partitionNum);
+                // Variable used in lambda expression should be final or 
effectively final
+                final int finalPartitionNum = partitionNum;
+                for (int i = start; i < end; i++) {
+                    IndexRange subpartitionRange =
+                            computeConsumedSubpartitionRange(
+                                    i,
+                                    numConsumers,
+                                    () -> 
numOfSubpartitionsRetriever.apply(finalPartitionNum),
+                                    isDynamicGraph,
+                                    false);
+                    executionVertexInputInfos.add(
+                            new ExecutionVertexInputInfo(i, partitionRange, 
subpartitionRange));
+                }
+            }
+        }
+        return new JobVertexInputInfo(executionVertexInputInfos);
+    }
+
+    /**
+     * Compute the {@link JobVertexInputInfo} for a {@link 
DistributionPattern#ALL_TO_ALL} edge.
+     * This computation algorithm will evenly distribute subpartitions to 
downstream subtasks
+     * according to the number of subpartitions. Different downstream subtasks 
consume roughly the
+     * same number of subpartitions.
+     *
+     * @param sourceCount the parallelism of upstream
+     * @param targetCount the parallelism of downstream
+     * @param numOfSubpartitionsRetriever a retriever to get the number of 
subpartitions
+     * @param isDynamicGraph whether is dynamic graph
+     * @param isBroadcast whether the edge is broadcast
+     * @return the computed {@link JobVertexInputInfo}
+     */
+    @VisibleForTesting
+    static JobVertexInputInfo computeVertexInputInfoForAllToAll(
+            int sourceCount,
+            int targetCount,
+            Function<Integer, Integer> numOfSubpartitionsRetriever,
+            boolean isDynamicGraph,
+            boolean isBroadcast) {
+        final List<ExecutionVertexInputInfo> executionVertexInputInfos = new 
ArrayList<>();
+        IndexRange partitionRange = new IndexRange(0, sourceCount - 1);
+        for (int i = 0; i < targetCount; ++i) {
+            IndexRange subpartitionRange =
+                    computeConsumedSubpartitionRange(
+                            i,
+                            targetCount,
+                            () -> numOfSubpartitionsRetriever.apply(0),
+                            isDynamicGraph,
+                            isBroadcast);
+            executionVertexInputInfos.add(
+                    new ExecutionVertexInputInfo(i, partitionRange, 
subpartitionRange));
+        }
+        return new JobVertexInputInfo(executionVertexInputInfos);
+    }
+
+    /**
+     * Compute the consumed subpartition range for a subtask. This computation 
algorithm will evenly
+     * distribute subpartitions to downstream subtasks according to the number 
of subpartitions.
+     * Different downstream subtasks consume roughly the same number of 
subpartitions.
+     *
+     * @param consumerSubtaskIndex the subtask index
+     * @param numConsumers the total number of consumers
+     * @param numOfSubpartitionsSupplier a supplier to get the number of 
subpartitions
+     * @param isDynamicGraph whether is dynamic graph
+     * @param isBroadcast whether the edge is broadcast
+     * @return the computed subpartition range
+     */
+    @VisibleForTesting
+    static IndexRange computeConsumedSubpartitionRange(
+            int consumerSubtaskIndex,
+            int numConsumers,
+            Supplier<Integer> numOfSubpartitionsSupplier,
+            boolean isDynamicGraph,
+            boolean isBroadcast) {
+        int consumerIndex = consumerSubtaskIndex % numConsumers;
+        if (!isDynamicGraph) {
+            return new IndexRange(consumerIndex, consumerIndex);
+        } else {
+            int numSubpartitions = numOfSubpartitionsSupplier.get();
+            if (isBroadcast) {
+                // broadcast results have only one subpartition, and be 
consumed multiple times.
+                checkArgument(numSubpartitions == 1);
+                return new IndexRange(0, 0);
+            } else {
+                checkArgument(consumerIndex < numConsumers);
+                checkArgument(numConsumers <= numSubpartitions);
+
+                int start = consumerIndex * numSubpartitions / numConsumers;
+                int nextStart = (consumerIndex + 1) * numSubpartitions / 
numConsumers;
+
+                return new IndexRange(start, nextStart - 1);
+            }
+        }
+    }
+
+    private static class IntermediateResultWrapper implements 
IntermediateResultInfo {
+        private final IntermediateResult intermediateResult;
+
+        IntermediateResultWrapper(IntermediateResult intermediateResult) {
+            this.intermediateResult = checkNotNull(intermediateResult);
+        }
+
+        @Override
+        public IntermediateDataSetID getResultId() {
+            return intermediateResult.getId();
+        }
+
+        @Override
+        public boolean isBroadcast() {
+            return intermediateResult.isBroadcast();
+        }
+
+        @Override
+        public boolean isPointwise() {
+            return intermediateResult.getConsumingDistributionPattern()
+                    == DistributionPattern.POINTWISE;
+        }
+
+        @Override
+        public int getNumPartitions() {
+            return intermediateResult.getNumberOfAssignedPartitions();
+        }
+
+        @Override
+        public int getNumSubpartitions(int partitionIndex) {
+            // Note that this method should only be called for dynamic 
graph.This method is used to
+            // compute which sub-partitions a consumer vertex should consume, 
however, for
+            // non-dynamic graph it is not needed, and the number of 
sub-partitions is not decided
+            // at this stage, due to the execution edge are not created.
+            checkState(
+                    intermediateResult.getProducer().getGraph().isDynamic(),
+                    "This method should only be called for dynamic graph.");
+            return 
intermediateResult.getPartitions()[partitionIndex].getNumberOfSubpartitions();
+        }
+    }
+
+    /** Private default constructor to avoid being instantiated. */
+    private VertexInputInfoComputationUtils() {}
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoStore.java
new file mode 100644
index 00000000000..21638f2166d
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoStore.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * A store contains all the {@link JobVertexInputInfo}s. Note that if a vertex 
has multiple job
+ * edges connecting to the same intermediate result, their {@link 
DistributionPattern} must be the
+ * same and therefore the {@link JobVertexInputInfo} will be the same.
+ */
+public class VertexInputInfoStore {
+
+    private final Map<JobVertexID, Map<IntermediateDataSetID, 
JobVertexInputInfo>>
+            jobVertexInputInfos = new HashMap<>();
+
+    /**
+     * Put a {@link JobVertexInputInfo}.
+     *
+     * @param jobVertexId the job vertex id
+     * @param resultId the intermediate result id
+     * @param info the {@link JobVertexInputInfo} to put
+     */
+    public void put(
+            JobVertexID jobVertexId, IntermediateDataSetID resultId, 
JobVertexInputInfo info) {
+        checkNotNull(jobVertexId);
+        checkNotNull(resultId);
+        checkNotNull(info);
+
+        jobVertexInputInfos.compute(
+                jobVertexId,
+                (ignored, inputInfos) -> {
+                    if (inputInfos == null) {
+                        inputInfos = new HashMap<>();
+                    }
+
+                    inputInfos.putIfAbsent(resultId, info);
+                    return inputInfos;
+                });
+    }
+
+    /**
+     * Get a {@link JobVertexInputInfo}.
+     *
+     * @param jobVertexId the job vertex id
+     * @param resultId the intermediate result id
+     * @return the {@link JobVertexInputInfo} identified by the job vertex id 
and intermediate
+     *     result id
+     */
+    public JobVertexInputInfo get(JobVertexID jobVertexId, 
IntermediateDataSetID resultId) {
+        checkNotNull(jobVertexId);
+        checkNotNull(resultId);
+        return 
checkNotNull(jobVertexInputInfos.get(jobVertexId).get(resultId));
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index ecd723a4246..bca093c7219 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -21,7 +21,6 @@ package org.apache.flink.runtime.scheduler;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.configuration.MemorySize;
 import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
-import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory;
 import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
@@ -184,10 +183,9 @@ public class SsgNetworkMemoryCalculationUtils {
                 IntermediateResultPartition resultPartition =
                         
ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst()));
                 IndexRange subpartitionIndexRange =
-                        
TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange(
-                                partitionGroup.getNumConsumers(),
-                                resultPartition,
-                                vertex.getParallelSubtaskIndex());
+                        vertex.getExecutionVertexInputInfo(
+                                        
resultPartition.getIntermediateResult().getId())
+                                .getSubpartitionIndexRange();
 
                 tmp.merge(
                         partitionGroup.getIntermediateDataSetID(),
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
index a4a06efc57a..bae270a608a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
@@ -61,6 +61,16 @@ public class AllToAllBlockingResultInfo extends 
AbstractBlockingResultInfo {
         return false;
     }
 
+    @Override
+    public int getNumPartitions() {
+        return numOfPartitions;
+    }
+
+    @Override
+    public int getNumSubpartitions(int partitionIndex) {
+        return numOfSubpartitions;
+    }
+
     @Override
     public long getNumBytesProduced() {
         checkState(aggregatedSubpartitionBytes != null, "Not all partition 
infos are ready");
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
index 0fd14e4439e..2eeb2a90a20 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
@@ -18,34 +18,13 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.runtime.executiongraph.IntermediateResultInfo;
 import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
-import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
 /**
  * The blocking result info, which will be used to calculate the vertex 
parallelism and input infos.
  */
-public interface BlockingResultInfo {
-
-    /**
-     * Get the intermediate result id.
-     *
-     * @return the intermediate result id
-     */
-    IntermediateDataSetID getResultId();
-
-    /**
-     * Whether it is a broadcast result.
-     *
-     * @return whether it is a broadcast result
-     */
-    boolean isBroadcast();
-
-    /**
-     * Whether it is a pointwise result.
-     *
-     * @return whether it is a pointwise result
-     */
-    boolean isPointwise();
+public interface BlockingResultInfo extends IntermediateResultInfo {
 
     /**
      * Return the num of bytes produced(numBytesProduced) by the producer.
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
index 287b180df00..225b653064b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
@@ -41,6 +41,16 @@ public class PointwiseBlockingResultInfo extends 
AbstractBlockingResultInfo {
         return true;
     }
 
+    @Override
+    public int getNumPartitions() {
+        return numOfPartitions;
+    }
+
+    @Override
+    public int getNumSubpartitions(int partitionIndex) {
+        return numOfSubpartitions;
+    }
+
     @Override
     public long getNumBytesProduced() {
         checkState(
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java
index 30938b355da..7ea1f77be79 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactoryTest.java
@@ -30,7 +30,6 @@ import 
org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded;
 import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
-import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
 import 
org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
@@ -57,8 +56,6 @@ import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 
 /** Tests for {@link TaskDeploymentDescriptorFactory}. */
@@ -188,87 +185,4 @@ public class TaskDeploymentDescriptorFactoryTest extends 
TestLogger {
             return 
compressedSerializedValue.deserializeValue(ClassLoader.getSystemClassLoader());
         }
     }
-
-    @Test
-    public void testComputeConsumedSubpartitionRange3to1() {
-        final IndexRange range = computeConsumedSubpartitionRange(0, 1, 3);
-        assertThat(range.getStartIndex(), is(0));
-        assertThat(range.getEndIndex(), is(2));
-    }
-
-    @Test
-    public void testComputeConsumedSubpartitionRange3to2() {
-        final IndexRange range1 = computeConsumedSubpartitionRange(0, 2, 3);
-        assertThat(range1.getStartIndex(), is(0));
-        assertThat(range1.getEndIndex(), is(0));
-
-        final IndexRange range2 = computeConsumedSubpartitionRange(1, 2, 3);
-        assertThat(range2.getStartIndex(), is(1));
-        assertThat(range2.getEndIndex(), is(2));
-    }
-
-    @Test
-    public void testComputeConsumedSubpartitionRange6to4() {
-        final IndexRange range1 = computeConsumedSubpartitionRange(0, 4, 6);
-        assertThat(range1.getStartIndex(), is(0));
-        assertThat(range1.getEndIndex(), is(0));
-
-        final IndexRange range2 = computeConsumedSubpartitionRange(1, 4, 6);
-        assertThat(range2.getStartIndex(), is(1));
-        assertThat(range2.getEndIndex(), is(2));
-
-        final IndexRange range3 = computeConsumedSubpartitionRange(2, 4, 6);
-        assertThat(range3.getStartIndex(), is(3));
-        assertThat(range3.getEndIndex(), is(3));
-
-        final IndexRange range4 = computeConsumedSubpartitionRange(3, 4, 6);
-        assertThat(range4.getStartIndex(), is(4));
-        assertThat(range4.getEndIndex(), is(5));
-    }
-
-    @Test
-    public void testComputeBroadcastConsumedSubpartitionRange() {
-        final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 1, 
true, true);
-        assertThat(range1.getStartIndex(), is(0));
-        assertThat(range1.getEndIndex(), is(0));
-
-        final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 1, 
true, true);
-        assertThat(range2.getStartIndex(), is(0));
-        assertThat(range2.getEndIndex(), is(0));
-
-        final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 1, 
true, true);
-        assertThat(range3.getStartIndex(), is(0));
-        assertThat(range3.getEndIndex(), is(0));
-    }
-
-    @Test
-    public void testComputeConsumedSubpartitionRangeForNonDynamicGraph() {
-        final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 3, 
false, false);
-        assertThat(range1.getStartIndex(), is(0));
-        assertThat(range1.getEndIndex(), is(0));
-
-        final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 3, 
false, false);
-        assertThat(range2.getStartIndex(), is(1));
-        assertThat(range2.getEndIndex(), is(1));
-
-        final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 3, 
false, false);
-        assertThat(range3.getStartIndex(), is(2));
-        assertThat(range3.getEndIndex(), is(2));
-    }
-
-    private static IndexRange computeConsumedSubpartitionRange(
-            int consumerIndex, int numConsumers, int numSubpartitions) {
-        return computeConsumedSubpartitionRange(
-                consumerIndex, numConsumers, numSubpartitions, true, false);
-    }
-
-    private static IndexRange computeConsumedSubpartitionRange(
-            int consumerIndex,
-            int numConsumers,
-            int numSubpartitions,
-            boolean isDynamicGraph,
-            boolean isBroadcast) {
-        return 
TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange(
-                consumerIndex, numConsumers, numSubpartitions, isDynamicGraph, 
isBroadcast);
-    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
index 47fce0af13b..2443a4f2b42 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
@@ -55,8 +55,9 @@ import static 
org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * This class contains test concerning the correct conversion from {@link 
JobGraph} to {@link
- * ExecutionGraph} objects. It also tests that {@link 
EdgeManagerBuildUtil#connectVertexToResult}
- * builds {@link DistributionPattern#ALL_TO_ALL} connections correctly.
+ * ExecutionGraph} objects. It also tests that {@link
+ * VertexInputInfoComputationUtils#computeVertexInputInfoForAllToAll} builds 
{@link
+ * DistributionPattern#ALL_TO_ALL} connections correctly.
  */
 class DefaultExecutionGraphConstructionTest {
     @RegisterExtension
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
index 28b44595a26..41422bb48b1 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
@@ -22,7 +22,6 @@ import 
org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
-import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
@@ -33,9 +32,13 @@ import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.Iterator;
 import java.util.List;
+import java.util.Objects;
 import java.util.concurrent.ScheduledExecutorService;
 
+import static 
org.apache.flink.runtime.executiongraph.IntermediateResultPartitionTest.computeVertexParallelismStoreConsideringDynamicGraph;
 import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL;
 import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE;
 import static org.assertj.core.api.Assertions.assertThat;
@@ -68,6 +71,142 @@ class EdgeManagerBuildUtilTest {
         testGetMaxNumEdgesToTarget(23, 17, ALL_TO_ALL);
     }
 
+    @Test
+    void testConnectAllToAll() throws Exception {
+        int upstream = 3;
+        int downstream = 2;
+
+        // use dynamic graph to specify the vertex input info
+        ExecutionGraph eg = setupExecutionGraph(upstream, downstream, 
POINTWISE, true);
+
+        List<ExecutionVertexInputInfo> executionVertexInputInfos = new 
ArrayList<>();
+        for (int i = 0; i < downstream; i++) {
+            executionVertexInputInfos.add(
+                    new ExecutionVertexInputInfo(
+                            i,
+                            new IndexRange(0, upstream - 1),
+                            // the subpartition range will not be used in edge 
manager, so set (0,
+                            // 0)
+                            new IndexRange(0, 0)));
+        }
+        final JobVertexInputInfo jobVertexInputInfo =
+                new JobVertexInputInfo(executionVertexInputInfos);
+
+        final Iterator<ExecutionJobVertex> vertexIterator =
+                eg.getVerticesTopologically().iterator();
+        final ExecutionJobVertex producer = vertexIterator.next();
+        final ExecutionJobVertex consumer = vertexIterator.next();
+
+        // initialize producer and consumer
+        eg.initializeJobVertex(producer, 1L, Collections.emptyMap());
+        eg.initializeJobVertex(
+                consumer,
+                1L,
+                Collections.singletonMap(
+                        producer.getProducedDataSets()[0].getId(), 
jobVertexInputInfo));
+
+        IntermediateResult result =
+                
Objects.requireNonNull(eg.getJobVertex(producer.getJobVertexId()))
+                        .getProducedDataSets()[0];
+        IntermediateResultPartition partition1 = result.getPartitions()[0];
+        IntermediateResultPartition partition2 = result.getPartitions()[1];
+        IntermediateResultPartition partition3 = result.getPartitions()[2];
+
+        ExecutionVertex vertex1 = consumer.getTaskVertices()[0];
+        ExecutionVertex vertex2 = consumer.getTaskVertices()[1];
+
+        // check consumers of the partitions
+        assertThat(partition1.getConsumerVertexGroups().get(0))
+                .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID());
+        assertThat(partition1.getConsumerVertexGroups().get(0))
+                .isEqualTo(partition1.getConsumerVertexGroups().get(0));
+        assertThat(partition3.getConsumerVertexGroups().get(0))
+                .isEqualTo(partition1.getConsumerVertexGroups().get(0));
+
+        // check inputs of the execution vertices
+        assertThat(vertex1.getConsumedPartitionGroup(0))
+                .containsExactlyInAnyOrder(
+                        partition1.getPartitionId(),
+                        partition2.getPartitionId(),
+                        partition3.getPartitionId());
+        assertThat(vertex2.getConsumedPartitionGroup(0))
+                .isEqualTo(vertex1.getConsumedPartitionGroup(0));
+    }
+
+    @Test
+    void testConnectPointwise() throws Exception {
+        int upstream = 4;
+        int downstream = 4;
+
+        // use dynamic graph to specify the vertex input info
+        ExecutionGraph eg = setupExecutionGraph(upstream, downstream, 
POINTWISE, true);
+
+        // set partition ranges
+        List<IndexRange> partitionRanges =
+                Arrays.asList(
+                        new IndexRange(0, 0),
+                        new IndexRange(0, 0),
+                        new IndexRange(1, 2),
+                        new IndexRange(3, 3));
+        List<ExecutionVertexInputInfo> executionVertexInputInfos = new 
ArrayList<>();
+        for (int i = 0; i < downstream; i++) {
+            executionVertexInputInfos.add(
+                    new ExecutionVertexInputInfo(
+                            // the subpartition range will not be used in edge 
manager, so set (0,
+                            // 0)
+                            i, partitionRanges.get(i), new IndexRange(0, 0)));
+        }
+        final JobVertexInputInfo jobVertexInputInfo =
+                new JobVertexInputInfo(executionVertexInputInfos);
+
+        final Iterator<ExecutionJobVertex> vertexIterator =
+                eg.getVerticesTopologically().iterator();
+        final ExecutionJobVertex producer = vertexIterator.next();
+        final ExecutionJobVertex consumer = vertexIterator.next();
+
+        // initialize producer and consumer
+        eg.initializeJobVertex(producer, 1L, Collections.emptyMap());
+        eg.initializeJobVertex(
+                consumer,
+                1L,
+                Collections.singletonMap(
+                        producer.getProducedDataSets()[0].getId(), 
jobVertexInputInfo));
+
+        IntermediateResult result =
+                
Objects.requireNonNull(eg.getJobVertex(producer.getJobVertexId()))
+                        .getProducedDataSets()[0];
+        IntermediateResultPartition partition1 = result.getPartitions()[0];
+        IntermediateResultPartition partition2 = result.getPartitions()[1];
+        IntermediateResultPartition partition3 = result.getPartitions()[2];
+        IntermediateResultPartition partition4 = result.getPartitions()[3];
+
+        ExecutionVertex vertex1 = consumer.getTaskVertices()[0];
+        ExecutionVertex vertex2 = consumer.getTaskVertices()[1];
+        ExecutionVertex vertex3 = consumer.getTaskVertices()[2];
+        ExecutionVertex vertex4 = consumer.getTaskVertices()[3];
+
+        // check consumers of the partitions
+        assertThat(partition1.getConsumerVertexGroups().get(0))
+                .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID());
+        assertThat(partition2.getConsumerVertexGroups().get(0))
+                .containsExactlyInAnyOrder(vertex3.getID());
+        assertThat(partition3.getConsumerVertexGroups().get(0))
+                .isEqualTo(partition2.getConsumerVertexGroups().get(0));
+        assertThat(partition4.getConsumerVertexGroups().get(0))
+                .containsExactlyInAnyOrder(vertex4.getID());
+
+        // check inputs of the execution vertices
+        assertThat(vertex1.getConsumedPartitionGroup(0))
+                .containsExactlyInAnyOrder(partition1.getPartitionId());
+        assertThat(vertex2.getConsumedPartitionGroup(0))
+                .isEqualTo(vertex1.getConsumedPartitionGroup(0));
+        assertThat(vertex3.getConsumedPartitionGroup(0))
+                .containsExactlyInAnyOrder(
+                        partition2.getPartitionId(), 
partition3.getPartitionId());
+        assertThat(vertex4.getConsumedPartitionGroup(0))
+                .containsExactlyInAnyOrder(partition4.getPartitionId());
+    }
+
     private void testGetMaxNumEdgesToTarget(
             int upstream, int downstream, DistributionPattern pattern) throws 
Exception {
 
@@ -110,6 +249,16 @@ class EdgeManagerBuildUtilTest {
 
     private Pair<ExecutionJobVertex, ExecutionJobVertex> setupExecutionGraph(
             int upstream, int downstream, DistributionPattern pattern) throws 
Exception {
+        Iterator<ExecutionJobVertex> jobVertices =
+                setupExecutionGraph(upstream, downstream, pattern, false)
+                        .getVerticesTopologically()
+                        .iterator();
+        return Pair.of(jobVertices.next(), jobVertices.next());
+    }
+
+    private ExecutionGraph setupExecutionGraph(
+            int upstream, int downstream, DistributionPattern pattern, boolean 
isDynamicGraph)
+            throws Exception {
         JobVertex v1 = new JobVertex("vertex1");
         JobVertex v2 = new JobVertex("vertex2");
 
@@ -123,12 +272,19 @@ class EdgeManagerBuildUtilTest {
 
         List<JobVertex> ordered = new ArrayList<>(Arrays.asList(v1, v2));
 
-        ExecutionGraph eg =
+        TestingDefaultExecutionGraphBuilder builder =
                 TestingDefaultExecutionGraphBuilder.newBuilder()
                         .setVertexParallelismStore(
-                                
SchedulerBase.computeVertexParallelismStore(ordered))
-                        .build(EXECUTOR_RESOURCE.getExecutor());
+                                
computeVertexParallelismStoreConsideringDynamicGraph(
+                                        ordered, isDynamicGraph, 128));
+        ExecutionGraph eg;
+        if (isDynamicGraph) {
+            eg = builder.buildDynamicGraph(EXECUTOR_RESOURCE.getExecutor());
+        } else {
+            eg = builder.build(EXECUTOR_RESOURCE.getExecutor());
+        }
+
         eg.attachJobGraph(ordered);
-        return Pair.of(eg.getAllVertices().get(v1.getID()), 
eg.getAllVertices().get(v2.getID()));
+        return eg;
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java
index d471cac83f5..c1ad5384c0a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java
@@ -43,7 +43,7 @@ import static org.junit.Assert.fail;
 
 /**
  * Tests for building {@link DistributionPattern#POINTWISE} connections in 
{@link
- * EdgeManagerBuildUtil#connectVertexToResult}.
+ * VertexInputInfoComputationUtils#computeVertexInputInfoForPointwise}.
  */
 public class PointwisePatternTest {
     @ClassRule
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java
new file mode 100644
index 00000000000..e0f4d6e2fad
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.junit.jupiter.api.Test;
+
+import static 
org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForAllToAll;
+import static 
org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link VertexInputInfoComputationUtils}. */
+class VertexInputInfoComputationUtilsTest {
+
+    @Test
+    void testComputeConsumedSubpartitionRange3to1() {
+        final IndexRange range = computeConsumedSubpartitionRange(0, 1, 3);
+        assertThat(range).isEqualTo(new IndexRange(0, 2));
+    }
+
+    @Test
+    void testComputeConsumedSubpartitionRange3to2() {
+        final IndexRange range1 = computeConsumedSubpartitionRange(0, 2, 3);
+        assertThat(range1).isEqualTo(new IndexRange(0, 0));
+
+        final IndexRange range2 = computeConsumedSubpartitionRange(1, 2, 3);
+        assertThat(range2).isEqualTo(new IndexRange(1, 2));
+    }
+
+    @Test
+    void testComputeConsumedSubpartitionRange6to4() {
+        final IndexRange range1 = computeConsumedSubpartitionRange(0, 4, 6);
+        assertThat(range1).isEqualTo(new IndexRange(0, 0));
+
+        final IndexRange range2 = computeConsumedSubpartitionRange(1, 4, 6);
+        assertThat(range2).isEqualTo(new IndexRange(1, 2));
+
+        final IndexRange range3 = computeConsumedSubpartitionRange(2, 4, 6);
+        assertThat(range3).isEqualTo(new IndexRange(3, 3));
+
+        final IndexRange range4 = computeConsumedSubpartitionRange(3, 4, 6);
+        assertThat(range4).isEqualTo(new IndexRange(4, 5));
+    }
+
+    @Test
+    void testComputeBroadcastConsumedSubpartitionRange() {
+        final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 1, 
true, true);
+        assertThat(range1).isEqualTo(new IndexRange(0, 0));
+
+        final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 1, 
true, true);
+        assertThat(range2).isEqualTo(new IndexRange(0, 0));
+
+        final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 1, 
true, true);
+        assertThat(range3).isEqualTo(new IndexRange(0, 0));
+    }
+
+    @Test
+    void testComputeConsumedSubpartitionRangeForNonDynamicGraph() {
+        final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, -1, 
false, false);
+        assertThat(range1).isEqualTo(new IndexRange(0, 0));
+
+        final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, -1, 
false, false);
+        assertThat(range2).isEqualTo(new IndexRange(1, 1));
+
+        final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, -1, 
false, false);
+        assertThat(range3).isEqualTo(new IndexRange(2, 2));
+    }
+
+    @Test
+    void testComputeVertexInputInfoForAllToAllWithNonDynamicGraph() {
+        final JobVertexInputInfo nonBroadcast =
+                computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, 
false);
+        assertThat(nonBroadcast.getExecutionVertexInputInfos())
+                .containsExactlyInAnyOrder(
+                        new ExecutionVertexInputInfo(0, new IndexRange(0, 1), 
new IndexRange(0, 0)),
+                        new ExecutionVertexInputInfo(1, new IndexRange(0, 1), 
new IndexRange(1, 1)),
+                        new ExecutionVertexInputInfo(
+                                2, new IndexRange(0, 1), new IndexRange(2, 
2)));
+
+        final JobVertexInputInfo broadcast =
+                computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, 
true);
+        assertThat(broadcast.getExecutionVertexInputInfos())
+                .containsExactlyInAnyOrder(
+                        new ExecutionVertexInputInfo(0, new IndexRange(0, 1), 
new IndexRange(0, 0)),
+                        new ExecutionVertexInputInfo(1, new IndexRange(0, 1), 
new IndexRange(1, 1)),
+                        new ExecutionVertexInputInfo(
+                                2, new IndexRange(0, 1), new IndexRange(2, 
2)));
+    }
+
+    @Test
+    void testComputeVertexInputInfoForAllToAllWithDynamicGraph() {
+        final JobVertexInputInfo nonBroadcast =
+                computeVertexInputInfoForAllToAll(2, 3, ignored -> 10, true, 
false);
+        assertThat(nonBroadcast.getExecutionVertexInputInfos())
+                .containsExactlyInAnyOrder(
+                        new ExecutionVertexInputInfo(0, new IndexRange(0, 1), 
new IndexRange(0, 2)),
+                        new ExecutionVertexInputInfo(1, new IndexRange(0, 1), 
new IndexRange(3, 5)),
+                        new ExecutionVertexInputInfo(
+                                2, new IndexRange(0, 1), new IndexRange(6, 
9)));
+
+        final JobVertexInputInfo broadcast =
+                computeVertexInputInfoForAllToAll(2, 3, ignored -> 1, true, 
true);
+        assertThat(broadcast.getExecutionVertexInputInfos())
+                .containsExactlyInAnyOrder(
+                        new ExecutionVertexInputInfo(0, new IndexRange(0, 1), 
new IndexRange(0, 0)),
+                        new ExecutionVertexInputInfo(1, new IndexRange(0, 1), 
new IndexRange(0, 0)),
+                        new ExecutionVertexInputInfo(
+                                2, new IndexRange(0, 1), new IndexRange(0, 
0)));
+    }
+
+    @Test
+    void testComputeVertexInputInfoForPointwiseWithNonDynamicGraph() {
+        final JobVertexInputInfo jobVertexInputInfo =
+                computeVertexInputInfoForPointwise(2, 3, ignored -> 3, false);
+        assertThat(jobVertexInputInfo.getExecutionVertexInputInfos())
+                .containsExactlyInAnyOrder(
+                        new ExecutionVertexInputInfo(0, new IndexRange(0, 0), 
new IndexRange(0, 0)),
+                        new ExecutionVertexInputInfo(1, new IndexRange(0, 0), 
new IndexRange(1, 1)),
+                        new ExecutionVertexInputInfo(
+                                2, new IndexRange(1, 1), new IndexRange(0, 
0)));
+    }
+
+    @Test
+    void testComputeVertexInputInfoForPointwiseWithDynamicGraph() {
+        final JobVertexInputInfo jobVertexInputInfo =
+                computeVertexInputInfoForPointwise(2, 3, ignored -> 4, true);
+        assertThat(jobVertexInputInfo.getExecutionVertexInputInfos())
+                .containsExactlyInAnyOrder(
+                        new ExecutionVertexInputInfo(0, new IndexRange(0, 0), 
new IndexRange(0, 1)),
+                        new ExecutionVertexInputInfo(1, new IndexRange(0, 0), 
new IndexRange(2, 3)),
+                        new ExecutionVertexInputInfo(
+                                2, new IndexRange(1, 1), new IndexRange(0, 
3)));
+    }
+
+    private static IndexRange computeConsumedSubpartitionRange(
+            int consumerIndex, int numConsumers, int numSubpartitions) {
+        return computeConsumedSubpartitionRange(
+                consumerIndex, numConsumers, numSubpartitions, true, false);
+    }
+
+    private static IndexRange computeConsumedSubpartitionRange(
+            int consumerIndex,
+            int numConsumers,
+            int numSubpartitions,
+            boolean isDynamicGraph,
+            boolean isBroadcast) {
+        return 
VertexInputInfoComputationUtils.computeConsumedSubpartitionRange(
+                consumerIndex, numConsumers, () -> numSubpartitions, 
isDynamicGraph, isBroadcast);
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
index bc366f8d478..9281d910485 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
@@ -48,6 +48,7 @@ import 
org.apache.flink.runtime.executiongraph.IntermediateResult;
 import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
 import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor;
 import org.apache.flink.runtime.executiongraph.JobInformation;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
 import org.apache.flink.runtime.executiongraph.MarkPartitionFinishedStrategy;
 import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition;
 import 
org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
@@ -1037,5 +1038,12 @@ public class ExecutingTest extends TestLogger {
             throw new UnsupportedOperationException(
                     "This method is not supported by the 
MockInternalExecutionGraphAccessor.");
         }
+
+        @Override
+        public JobVertexInputInfo getJobVertexInputInfo(
+                JobVertexID jobVertexId, IntermediateDataSetID resultId) {
+            throw new UnsupportedOperationException(
+                    "This method is not supported by the 
MockInternalExecutionGraphAccessor.");
+        }
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java
index 54a65ccddcd..639c846cbec 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java
@@ -44,6 +44,7 @@ import 
org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
 import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition;
 import 
org.apache.flink.runtime.executiongraph.failover.flip1.ResultPartitionAvailabilityChecker;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
@@ -370,7 +371,10 @@ class StateTrackingMockExecutionGraph implements 
ExecutionGraph {
     }
 
     @Override
-    public void initializeJobVertex(ExecutionJobVertex ejv, long 
createTimestamp)
+    public void initializeJobVertex(
+            ExecutionJobVertex ejv,
+            long createTimestamp,
+            Map<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos)
             throws JobException {
         throw new UnsupportedOperationException();
     }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java
index 0f588bcc5e6..1cb8f3d05b3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java
@@ -170,6 +170,16 @@ class DefaultVertexParallelismDeciderTest {
             return false;
         }
 
+        @Override
+        public int getNumPartitions() {
+            return 0;
+        }
+
+        @Override
+        public int getNumSubpartitions(int partitionIndex) {
+            return 0;
+        }
+
         @Override
         public long getNumBytesProduced() {
             return producedBytes;


Reply via email to