This is an automated email from the ASF dual-hosted git repository.

junrui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 95f9a1605321a55fc22bb81fa503d2a9eac9d433
Author: JunRuiLee <[email protected]>
AuthorDate: Tue Dec 24 20:00:27 2024 +0800

    [FLINK-36067][runtime] Support optimize stream graph based on input info.
---
 .../runtime/executiongraph/IntermediateResult.java |  25 ++++
 .../executiongraph/IntermediateResultInfo.java     |  16 ++-
 .../IntermediateResultPartition.java               |   2 +-
 .../VertexInputInfoComputationUtils.java           |  32 ++++--
 .../runtime/jobgraph/IntermediateDataSet.java      |  12 ++
 .../adaptivebatch/AbstractBlockingResultInfo.java  |  13 ++-
 .../adaptivebatch/AdaptiveBatchScheduler.java      |  68 ++++++++++-
 .../AdaptiveExecutionHandlerFactory.java           |   4 +-
 .../adaptivebatch/AllToAllBlockingResultInfo.java  |  41 ++++++-
 .../adaptivebatch/BlockingResultInfo.java          |  10 ++
 .../DefaultAdaptiveExecutionHandler.java           |  48 +++++++-
 ...faultVertexParallelismAndInputInfosDecider.java |  15 ++-
 .../adaptivebatch/PointwiseBlockingResultInfo.java |  20 +++-
 .../flink/runtime/shuffle/PartitionDescriptor.java |   2 +-
 .../streaming/api/graph/AdaptiveGraphManager.java  |   7 +-
 .../api/graph/DefaultStreamGraphContext.java       |  33 +++++-
 .../VertexInputInfoComputationUtilsTest.java       |  94 +++++++++++----
 ...AdaptiveExecutionPlanSchedulingContextTest.java |  16 ++-
 .../AllToAllBlockingResultInfoTest.java            |  24 ++--
 .../DefaultAdaptiveExecutionHandlerTest.java       | 128 ++++++++++++++++++++-
 ...tVertexParallelismAndInputInfosDeciderTest.java |  97 ++++++++++++----
 .../api/graph/DefaultStreamGraphContextTest.java   |   6 +-
 .../scheduling/AdaptiveBatchSchedulerITCase.java   | 112 ++++++++++++++++++
 23 files changed, 724 insertions(+), 101 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
index f00539b5307..c010057ef4f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
@@ -63,6 +63,7 @@ public class IntermediateResult {
     private final int numParallelProducers;
 
     private final ExecutionPlanSchedulingContext 
executionPlanSchedulingContext;
+    private final boolean singleSubpartitionContainsAllData;
 
     private int partitionsAssigned;
 
@@ -102,6 +103,8 @@ public class IntermediateResult {
         this.shuffleDescriptorCache = new HashMap<>();
 
         this.executionPlanSchedulingContext = 
checkNotNull(executionPlanSchedulingContext);
+
+        this.singleSubpartitionContainsAllData = 
intermediateDataSet.isBroadcast();
     }
 
     public boolean areAllConsumerVerticesCreated() {
@@ -199,6 +202,16 @@ public class IntermediateResult {
         return intermediateDataSet.getDistributionPattern();
     }
 
+    /**
+     * Determines whether the associated intermediate data set uses a 
broadcast distribution
+     * pattern.
+     *
+     * <p>A broadcast distribution pattern indicates that all data produced by 
this intermediate
+     * data set should be broadcast to every downstream consumer.
+     *
+     * @return true if the intermediate data set is using a broadcast 
distribution pattern; false
+     *     otherwise.
+     */
     public boolean isBroadcast() {
         return intermediateDataSet.isBroadcast();
     }
@@ -207,6 +220,18 @@ public class IntermediateResult {
         return intermediateDataSet.isForward();
     }
 
+    /**
+     * Checks if a single subpartition contains all the produced data. This 
condition indicate that
+     * the data was intended to be broadcast to all consumers. If the decision 
to broadcast was made
+     * before the data production, this flag would likely be set accordingly. 
Conversely, if the
+     * broadcasting decision was made post-production, this flag will be false.
+     *
+     * @return true if a single subpartition contains all the data; false 
otherwise.
+     */
+    public boolean isSingleSubpartitionContainsAllData() {
+        return singleSubpartitionContainsAllData;
+    }
+
     public int getConnectionIndex() {
         return connectionIndex;
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java
index 26829893b5a..2c52d340afc 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java
@@ -29,9 +29,21 @@ public interface IntermediateResultInfo {
     IntermediateDataSetID getResultId();
 
     /**
-     * Whether it is a broadcast result.
+     * Checks whether there is a single subpartition that contains all the 
produced data.
      *
-     * @return whether it is a broadcast result
+     * @return true if one subpartition that contains all the data; false 
otherwise.
+     */
+    boolean isSingleSubpartitionContainsAllData();
+
+    /**
+     * Determines whether the associated intermediate data set uses a 
broadcast distribution
+     * pattern.
+     *
+     * <p>A broadcast distribution pattern indicates that all data produced by 
this intermediate
+     * data set should be broadcast to every downstream consumer.
+     *
+     * @return true if the intermediate data set is using a broadcast 
distribution pattern; false
+     *     otherwise.
      */
     boolean isBroadcast();
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
index 19ce7753e3f..6962e6641a7 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
@@ -151,7 +151,7 @@ public class IntermediateResultPartition {
     }
 
     private int computeNumberOfSubpartitionsForDynamicGraph() {
-        if (totalResult.isBroadcast() || totalResult.isForward()) {
+        if (totalResult.isSingleSubpartitionContainsAllData() || 
totalResult.isForward()) {
             // for dynamic graph and broadcast result, and forward result, we 
only produced one
             // subpartition, and all the downstream vertices should consume 
this subpartition.
             return 1;
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java
index 680a0bb1634..3c8dfc50e9b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java
@@ -84,7 +84,8 @@ public class VertexInputInfoComputationUtils {
                                 parallelism,
                                 input::getNumSubpartitions,
                                 isDynamicGraph,
-                                input.isBroadcast()));
+                                input.isBroadcast(),
+                                input.isSingleSubpartitionContainsAllData()));
             }
         }
 
@@ -124,6 +125,7 @@ public class VertexInputInfoComputationUtils {
                                 1,
                                 () -> numOfSubpartitionsRetriever.apply(start),
                                 isDynamicGraph,
+                                false,
                                 false);
                 executionVertexInputInfos.add(
                         new ExecutionVertexInputInfo(index, partitionRange, 
subpartitionRange));
@@ -145,6 +147,7 @@ public class VertexInputInfoComputationUtils {
                                     numConsumers,
                                     () -> 
numOfSubpartitionsRetriever.apply(finalPartitionNum),
                                     isDynamicGraph,
+                                    false,
                                     false);
                     executionVertexInputInfos.add(
                             new ExecutionVertexInputInfo(i, partitionRange, 
subpartitionRange));
@@ -165,6 +168,7 @@ public class VertexInputInfoComputationUtils {
      * @param numOfSubpartitionsRetriever a retriever to get the number of 
subpartitions
      * @param isDynamicGraph whether is dynamic graph
      * @param isBroadcast whether the edge is broadcast
+     * @param isSingleSubpartitionContainsAllData whether single subpartition 
contains all data
      * @return the computed {@link JobVertexInputInfo}
      */
     static JobVertexInputInfo computeVertexInputInfoForAllToAll(
@@ -172,7 +176,8 @@ public class VertexInputInfoComputationUtils {
             int targetCount,
             Function<Integer, Integer> numOfSubpartitionsRetriever,
             boolean isDynamicGraph,
-            boolean isBroadcast) {
+            boolean isBroadcast,
+            boolean isSingleSubpartitionContainsAllData) {
         final List<ExecutionVertexInputInfo> executionVertexInputInfos = new 
ArrayList<>();
         IndexRange partitionRange = new IndexRange(0, sourceCount - 1);
         for (int i = 0; i < targetCount; ++i) {
@@ -182,7 +187,8 @@ public class VertexInputInfoComputationUtils {
                             targetCount,
                             () -> numOfSubpartitionsRetriever.apply(0),
                             isDynamicGraph,
-                            isBroadcast);
+                            isBroadcast,
+                            isSingleSubpartitionContainsAllData);
             executionVertexInputInfos.add(
                     new ExecutionVertexInputInfo(i, partitionRange, 
subpartitionRange));
         }
@@ -199,6 +205,7 @@ public class VertexInputInfoComputationUtils {
      * @param numOfSubpartitionsSupplier a supplier to get the number of 
subpartitions
      * @param isDynamicGraph whether is dynamic graph
      * @param isBroadcast whether the edge is broadcast
+     * @param isSingleSubpartitionContainsAllData whether single subpartition 
contains all data
      * @return the computed subpartition range
      */
     @VisibleForTesting
@@ -207,16 +214,22 @@ public class VertexInputInfoComputationUtils {
             int numConsumers,
             Supplier<Integer> numOfSubpartitionsSupplier,
             boolean isDynamicGraph,
-            boolean isBroadcast) {
+            boolean isBroadcast,
+            boolean isSingleSubpartitionContainsAllData) {
         int consumerIndex = consumerSubtaskIndex % numConsumers;
         if (!isDynamicGraph) {
             return new IndexRange(consumerIndex, consumerIndex);
         } else {
             int numSubpartitions = numOfSubpartitionsSupplier.get();
             if (isBroadcast) {
-                // broadcast results have only one subpartition, and be 
consumed multiple times.
-                checkArgument(numSubpartitions == 1);
-                return new IndexRange(0, 0);
+                if (isSingleSubpartitionContainsAllData) {
+                    // early decided broadcast results have only one 
subpartition, and be consumed
+                    // multiple times.
+                    checkArgument(numSubpartitions == 1);
+                    return new IndexRange(0, 0);
+                } else {
+                    return new IndexRange(0, numSubpartitions - 1);
+                }
             } else {
                 checkArgument(consumerIndex < numConsumers);
                 checkArgument(numConsumers <= numSubpartitions);
@@ -246,6 +259,11 @@ public class VertexInputInfoComputationUtils {
             return intermediateResult.isBroadcast();
         }
 
+        @Override
+        public boolean isSingleSubpartitionContainsAllData() {
+            return intermediateResult.isSingleSubpartitionContainsAllData();
+        }
+
         @Override
         public boolean isPointwise() {
             return intermediateResult.getConsumingDistributionPattern()
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
index c5d1187d230..ec73f25e283 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
@@ -134,6 +134,18 @@ public class IntermediateDataSet implements 
java.io.Serializable {
         }
     }
 
+    public void updateOutputPattern(
+            DistributionPattern distributionPattern, boolean isBroadcast, 
boolean isForward) {
+        checkState(consumers.isEmpty(), "The output job edges have already 
been added.");
+        checkState(
+                numJobEdgesToCreate == 1,
+                "Modification is not allowed when the subscribing output is 
reused.");
+
+        this.distributionPattern = distributionPattern;
+        this.isBroadcast = isBroadcast;
+        this.isForward = isForward;
+    }
+
     public void increaseNumJobEdgesToCreate() {
         this.numJobEdgesToCreate++;
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java
index 33147bcdc16..3844c50d84a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -44,11 +45,14 @@ abstract class AbstractBlockingResultInfo implements 
BlockingResultInfo {
     protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex;
 
     AbstractBlockingResultInfo(
-            IntermediateDataSetID resultId, int numOfPartitions, int 
numOfSubpartitions) {
+            IntermediateDataSetID resultId,
+            int numOfPartitions,
+            int numOfSubpartitions,
+            Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
         this.resultId = checkNotNull(resultId);
         this.numOfPartitions = numOfPartitions;
         this.numOfSubpartitions = numOfSubpartitions;
-        this.subpartitionBytesByPartitionIndex = new HashMap<>();
+        this.subpartitionBytesByPartitionIndex = new 
HashMap<>(subpartitionBytesByPartitionIndex);
     }
 
     @Override
@@ -72,4 +76,9 @@ abstract class AbstractBlockingResultInfo implements 
BlockingResultInfo {
     int getNumOfRecordedPartitions() {
         return subpartitionBytesByPartitionIndex.size();
     }
+
+    @Override
+    public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() {
+        return Collections.unmodifiableMap(subpartitionBytesByPartitionIndex);
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
index a46a210446f..966c6eea998 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
@@ -274,9 +274,14 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
         // 4. update json plan
         
getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph()));
 
-        // 5. try aggregate subpartition bytes
+        // 5. In broadcast join optimization, results might be written first 
with a hash
+        // method and then read with a broadcast method. Therefore, we need to 
update the
+        // result info:
+        // 1. Update the DistributionPattern to reflect the optimized data 
distribution.
+        // 2. Aggregate subpartition bytes when possible for efficiency.
         for (JobVertex newVertex : newVertices) {
             for (JobEdge input : newVertex.getInputs()) {
+                tryUpdateResultInfo(input.getSourceId(), 
input.getDistributionPattern());
                 
Optional.ofNullable(blockingResultInfos.get(input.getSourceId()))
                         .ifPresent(this::maybeAggregateSubpartitionBytes);
             }
@@ -490,7 +495,8 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
                             result.getId(),
                             (ignored, resultInfo) -> {
                                 if (resultInfo == null) {
-                                    resultInfo = 
createFromIntermediateResult(result);
+                                    resultInfo =
+                                            
createFromIntermediateResult(result, new HashMap<>());
                                 }
                                 resultInfo.recordPartitionInfo(
                                         partitionId.getPartitionNumber(), 
partitionBytes);
@@ -500,6 +506,16 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
                 });
     }
 
+    /**
+     * Aggregates subpartition bytes if all conditions are met. This method 
checks whether the
+     * result info instance is of type {@link AllToAllBlockingResultInfo}, 
whether all consumer
+     * vertices are created, and whether all consumer vertices are 
initialized. If these conditions
+     * are satisfied, the fine-grained statistic info will not be required by 
consumer vertices, and
+     * then we could aggregate the subpartition bytes.
+     *
+     * @param resultInfo the BlockingResultInfo instance to potentially 
aggregate subpartition bytes
+     *     for.
+     */
     private void maybeAggregateSubpartitionBytes(BlockingResultInfo 
resultInfo) {
         IntermediateResult intermediateResult =
                 
getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId());
@@ -937,7 +953,8 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
         }
     }
 
-    private static BlockingResultInfo 
createFromIntermediateResult(IntermediateResult result) {
+    private static BlockingResultInfo createFromIntermediateResult(
+            IntermediateResult result, Map<Integer, long[]> 
subpartitionBytesByPartitionIndex) {
         checkArgument(result != null);
         // Note that for dynamic graph, different partitions in the same 
result have the same number
         // of subpartitions.
@@ -945,13 +962,15 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
             return new PointwiseBlockingResultInfo(
                     result.getId(),
                     result.getNumberOfAssignedPartitions(),
-                    result.getPartitions()[0].getNumberOfSubpartitions());
+                    result.getPartitions()[0].getNumberOfSubpartitions(),
+                    subpartitionBytesByPartitionIndex);
         } else {
             return new AllToAllBlockingResultInfo(
                     result.getId(),
                     result.getNumberOfAssignedPartitions(),
                     result.getPartitions()[0].getNumberOfSubpartitions(),
-                    result.isBroadcast());
+                    result.isSingleSubpartitionContainsAllData(),
+                    subpartitionBytesByPartitionIndex);
         }
     }
 
@@ -965,6 +984,45 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
         return speculativeExecutionHandler;
     }
 
+    /**
+     * Tries to update the result information for a given 
IntermediateDataSetID according to the
+     * specified DistributionPattern. This ensures consistency between the 
distribution pattern and
+     * the stored result information.
+     *
+     * <p>The result information is updated under the following conditions:
+     *
+     * <ul>
+     *   <li>If the target pattern is ALL_TO_ALL and the current result info 
is POINTWISE, a new
+     *       BlockingResultInfo is created and stored.
+     *   <li>If the target pattern is POINTWISE and the current result info is 
ALL_TO_ALL, a
+     *       conversion is similarly triggered.
+     *   <li>Additionally, for ALL_TO_ALL patterns, the status of broadcast of 
the result info
+     *       should be updated.
+     * </ul>
+     *
+     * @param id The ID of the intermediate dataset to update.
+     * @param targetPattern The target distribution pattern to apply.
+     */
+    private void tryUpdateResultInfo(IntermediateDataSetID id, 
DistributionPattern targetPattern) {
+        if (blockingResultInfos.containsKey(id)) {
+            BlockingResultInfo resultInfo = blockingResultInfos.get(id);
+            IntermediateResult result = 
getExecutionGraph().getAllIntermediateResults().get(id);
+
+            if ((targetPattern == DistributionPattern.ALL_TO_ALL && 
resultInfo.isPointwise())
+                    || (targetPattern == DistributionPattern.POINTWISE
+                            && !resultInfo.isPointwise())) {
+
+                BlockingResultInfo newInfo =
+                        createFromIntermediateResult(
+                                result, 
resultInfo.getSubpartitionBytesByPartitionIndex());
+
+                blockingResultInfos.put(id, newInfo);
+            } else if (resultInfo instanceof AllToAllBlockingResultInfo) {
+                ((AllToAllBlockingResultInfo) 
resultInfo).setBroadcast(result.isBroadcast());
+            }
+        }
+    }
+
     private class DefaultBatchJobRecoveryContext implements 
BatchJobRecoveryContext {
 
         private final FailoverStrategy restartStrategyOnResultConsumable =
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java
index b6113012f00..2d7be76c729 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.scheduler.adaptivebatch;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.streaming.api.graph.ExecutionPlan;
 import org.apache.flink.streaming.api.graph.StreamGraph;
+import org.apache.flink.util.DynamicCodeLoadingException;
 
 import java.util.concurrent.Executor;
 
@@ -46,7 +47,8 @@ public class AdaptiveExecutionHandlerFactory {
     public static AdaptiveExecutionHandler create(
             ExecutionPlan executionPlan,
             ClassLoader userClassLoader,
-            Executor serializationExecutor) {
+            Executor serializationExecutor)
+            throws DynamicCodeLoadingException {
         if (executionPlan instanceof JobGraph) {
             return new NonAdaptiveExecutionHandler((JobGraph) executionPlan);
         } else {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
index ed1e945912f..b7b4b8ef1cd 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
@@ -26,7 +27,9 @@ import javax.annotation.Nullable;
 
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
 import java.util.stream.Collectors;
 
@@ -35,29 +38,57 @@ import static 
org.apache.flink.util.Preconditions.checkState;
 /** Information of All-To-All result. */
 public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo {
 
-    private final boolean isBroadcast;
+    private final boolean singleSubpartitionContainsAllData;
+
+    private boolean isBroadcast;
 
     /**
      * Aggregated subpartition bytes, which aggregates the subpartition bytes 
with the same
      * subpartition index in different partitions. Note that We can aggregate 
them because they will
      * be consumed by the same downstream task.
      */
-    @Nullable private List<Long> aggregatedSubpartitionBytes;
+    @Nullable protected List<Long> aggregatedSubpartitionBytes;
 
+    @VisibleForTesting
     AllToAllBlockingResultInfo(
             IntermediateDataSetID resultId,
             int numOfPartitions,
             int numOfSubpartitions,
-            boolean isBroadcast) {
-        super(resultId, numOfPartitions, numOfSubpartitions);
+            boolean isBroadcast,
+            boolean singleSubpartitionContainsAllData) {
+        this(
+                resultId,
+                numOfPartitions,
+                numOfSubpartitions,
+                singleSubpartitionContainsAllData,
+                new HashMap<>());
         this.isBroadcast = isBroadcast;
     }
 
+    AllToAllBlockingResultInfo(
+            IntermediateDataSetID resultId,
+            int numOfPartitions,
+            int numOfSubpartitions,
+            boolean singleSubpartitionContainsAllData,
+            Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
+        super(resultId, numOfPartitions, numOfSubpartitions, 
subpartitionBytesByPartitionIndex);
+        this.singleSubpartitionContainsAllData = 
singleSubpartitionContainsAllData;
+    }
+
     @Override
     public boolean isBroadcast() {
         return isBroadcast;
     }
 
+    @Override
+    public boolean isSingleSubpartitionContainsAllData() {
+        return singleSubpartitionContainsAllData;
+    }
+
+    void setBroadcast(boolean isBroadcast) {
+        this.isBroadcast = isBroadcast;
+    }
+
     @Override
     public boolean isPointwise() {
         return false;
@@ -83,7 +114,7 @@ public class AllToAllBlockingResultInfo extends 
AbstractBlockingResultInfo {
         List<Long> bytes =
                 Optional.ofNullable(aggregatedSubpartitionBytes)
                         .orElse(getAggregatedSubpartitionBytesInternal());
-        if (isBroadcast) {
+        if (singleSubpartitionContainsAllData) {
             return bytes.get(0);
         } else {
             return bytes.stream().reduce(0L, Long::sum);
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
index e836d993869..0669417cc95 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
@@ -22,6 +22,8 @@ import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.IntermediateResultInfo;
 import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 
+import java.util.Map;
+
 /**
  * The blocking result info, which will be used to calculate the vertex 
parallelism and input infos.
  */
@@ -64,4 +66,12 @@ public interface BlockingResultInfo extends 
IntermediateResultInfo {
      * @param partitionIndex the intermediate result partition index
      */
     void resetPartitionInfo(int partitionIndex);
+
+    /**
+     * Gets subpartition bytes by partition index.
+     *
+     * @return a map with integer keys representing partition indices and long 
array values
+     *     representing subpartition bytes.
+     */
+    Map<Integer, long[]> getSubpartitionBytesByPartitionIndex();
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java
index b365db8d0e0..c0942f65372 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
@@ -27,13 +28,17 @@ import 
org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexFinishedEvent;
 import org.apache.flink.runtime.jobmaster.event.JobEvent;
 import org.apache.flink.streaming.api.graph.AdaptiveGraphManager;
 import org.apache.flink.streaming.api.graph.StreamGraph;
+import org.apache.flink.util.DynamicCodeLoadingException;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Executor;
+import java.util.stream.Collectors;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 
@@ -52,10 +57,16 @@ public class DefaultAdaptiveExecutionHandler implements 
AdaptiveExecutionHandler
 
     private final AdaptiveGraphManager adaptiveGraphManager;
 
+    private final StreamGraphOptimizer streamGraphOptimizer;
+
     public DefaultAdaptiveExecutionHandler(
-            ClassLoader userClassloader, StreamGraph streamGraph, Executor 
serializationExecutor) {
+            ClassLoader userClassloader, StreamGraph streamGraph, Executor 
serializationExecutor)
+            throws DynamicCodeLoadingException {
         this.adaptiveGraphManager =
                 new AdaptiveGraphManager(userClassloader, streamGraph, 
serializationExecutor);
+
+        this.streamGraphOptimizer =
+                new StreamGraphOptimizer(streamGraph.getJobConfiguration(), 
userClassloader);
     }
 
     @Override
@@ -66,6 +77,7 @@ public class DefaultAdaptiveExecutionHandler implements 
AdaptiveExecutionHandler
     @Override
     public void handleJobEvent(JobEvent jobEvent) {
         try {
+            tryOptimizeStreamGraph(jobEvent);
             tryUpdateJobGraph(jobEvent);
         } catch (Exception e) {
             log.error("Failed to handle job event {}.", jobEvent, e);
@@ -73,6 +85,40 @@ public class DefaultAdaptiveExecutionHandler implements 
AdaptiveExecutionHandler
         }
     }
 
+    private void tryOptimizeStreamGraph(JobEvent jobEvent) throws Exception {
+        if (jobEvent instanceof ExecutionJobVertexFinishedEvent) {
+            ExecutionJobVertexFinishedEvent event = 
(ExecutionJobVertexFinishedEvent) jobEvent;
+
+            JobVertexID vertexId = event.getVertexId();
+            Map<IntermediateDataSetID, BlockingResultInfo> resultInfo = 
event.getResultInfo();
+            Map<Integer, List<BlockingResultInfo>> resultInfoMap =
+                    resultInfo.entrySet().stream()
+                            .collect(
+                                    Collectors.toMap(
+                                            entry ->
+                                                    
adaptiveGraphManager.getProducerStreamNodeId(
+                                                            entry.getKey()),
+                                            entry ->
+                                                    new ArrayList<>(
+                                                            
Collections.singletonList(
+                                                                    
entry.getValue())),
+                                            (existing, replacement) -> {
+                                                existing.addAll(replacement);
+                                                return existing;
+                                            }));
+
+            OperatorsFinished operatorsFinished =
+                    new OperatorsFinished(
+                            
adaptiveGraphManager.getStreamNodeIdsByJobVertexId(vertexId),
+                            resultInfoMap);
+
+            streamGraphOptimizer.onOperatorsFinished(
+                    operatorsFinished, 
adaptiveGraphManager.getStreamGraphContext());
+        } else {
+            throw new IllegalArgumentException("Unsupported job event " + 
jobEvent);
+        }
+    }
+
     private void tryUpdateJobGraph(JobEvent jobEvent) throws Exception {
         if (jobEvent instanceof ExecutionJobVertexFinishedEvent) {
             ExecutionJobVertexFinishedEvent event = 
(ExecutionJobVertexFinishedEvent) jobEvent;
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java
index eb78b9cd7a3..bcdaf0b5176 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java
@@ -200,7 +200,8 @@ public class DefaultVertexParallelismAndInputInfosDecider
     }
 
     private static boolean areAllInputsBroadcast(List<BlockingResultInfo> 
consumedResults) {
-        return 
consumedResults.stream().allMatch(BlockingResultInfo::isBroadcast);
+        return consumedResults.stream()
+                
.allMatch(BlockingResultInfo::isSingleSubpartitionContainsAllData);
     }
 
     /**
@@ -468,7 +469,15 @@ public class DefaultVertexParallelismAndInputInfosDecider
                     for (int i = 0; i < subpartitionRanges.size(); ++i) {
                         IndexRange subpartitionRange;
                         if (resultInfo.isBroadcast()) {
-                            subpartitionRange = new IndexRange(0, 0);
+                            if 
(resultInfo.isSingleSubpartitionContainsAllData()) {
+                                subpartitionRange = new IndexRange(0, 0);
+                            } else {
+                                // The partitions of the all-to-all result 
have the same number of
+                                // subpartitions. So we can use the first 
partition's subpartition
+                                // number.
+                                subpartitionRange =
+                                        new IndexRange(0, 
resultInfo.getNumSubpartitions(0) - 1);
+                            }
                         } else {
                             subpartitionRange = subpartitionRanges.get(i);
                         }
@@ -546,7 +555,7 @@ public class DefaultVertexParallelismAndInputInfosDecider
     private static List<BlockingResultInfo> getNonBroadcastResultInfos(
             List<BlockingResultInfo> consumedResults) {
         return consumedResults.stream()
-                .filter(resultInfo -> !resultInfo.isBroadcast())
+                .filter(resultInfo -> 
!resultInfo.isSingleSubpartitionContainsAllData())
                 .collect(Collectors.toList());
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
index ed993af9d81..9a3f3b418aa 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
@@ -18,18 +18,31 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
 import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkState;
 
 /** Information of Pointwise result. */
 public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo {
+
+    @VisibleForTesting
     PointwiseBlockingResultInfo(
             IntermediateDataSetID resultId, int numOfPartitions, int 
numOfSubpartitions) {
-        super(resultId, numOfPartitions, numOfSubpartitions);
+        this(resultId, numOfPartitions, numOfSubpartitions, new HashMap<>());
+    }
+
+    PointwiseBlockingResultInfo(
+            IntermediateDataSetID resultId,
+            int numOfPartitions,
+            int numOfSubpartitions,
+            Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
+        super(resultId, numOfPartitions, numOfSubpartitions, 
subpartitionBytesByPartitionIndex);
     }
 
     @Override
@@ -37,6 +50,11 @@ public class PointwiseBlockingResultInfo extends 
AbstractBlockingResultInfo {
         return false;
     }
 
+    @Override
+    public boolean isSingleSubpartitionContainsAllData() {
+        return false;
+    }
+
     @Override
     public boolean isPointwise() {
         return true;
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java
index c3277932ae2..40296d27cb6 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java
@@ -151,7 +151,7 @@ public class PartitionDescriptor implements Serializable {
                 result.getResultType(),
                 partition.getNumberOfSubpartitions(),
                 result.getConnectionIndex(),
-                result.isBroadcast(),
+                result.isSingleSubpartitionContainsAllData(),
                 result.getConsumingDistributionPattern() == 
DistributionPattern.ALL_TO_ALL,
                 partition.isNumberOfPartitionConsumersUndefined());
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java
index bec248898b1..f683f586d4e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java
@@ -107,6 +107,9 @@ public class AdaptiveGraphManager implements 
AdaptiveGraphGenerator {
     private final Map<IntermediateDataSetID, List<StreamEdge>>
             intermediateDataSetIdToOutputEdgesMap;
 
+    private final Map<String, IntermediateDataSet> 
consumerEdgeIdToIntermediateDataSetMap =
+            new HashMap<>();
+
     // Records the ids of stream nodes in the StreamNodeForwardGroup.
     // When stream edge's partitioner is modified to forward, we need get 
forward groups by source
     // and target node id.
@@ -167,7 +170,8 @@ public class AdaptiveGraphManager implements 
AdaptiveGraphGenerator {
                         streamGraph,
                         steamNodeIdToForwardGroupMap,
                         frozenNodeToStartNodeMap,
-                        intermediateOutputsCaches);
+                        intermediateOutputsCaches,
+                        consumerEdgeIdToIntermediateDataSetMap);
 
         this.jobGraph = createAndInitializeJobGraph(streamGraph, 
streamGraph.getJobID());
 
@@ -382,6 +386,7 @@ public class AdaptiveGraphManager implements 
AdaptiveGraphGenerator {
                 intermediateDataSetIdToOutputEdgesMap
                         .computeIfAbsent(dataSet.getId(), ignored -> new 
ArrayList<>())
                         .add(edge);
+                consumerEdgeIdToIntermediateDataSetMap.put(edge.getEdgeId(), 
dataSet);
                 // we cache the output here for downstream vertex to create 
jobEdge.
                 intermediateOutputsCaches
                         .computeIfAbsent(edge.getSourceId(), k -> new 
HashMap<>())
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
index 07d8631bf92..324f3466e7c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
@@ -19,6 +19,8 @@
 package org.apache.flink.streaming.api.graph;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
 import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup;
 import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph;
 import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
@@ -36,6 +38,7 @@ import javax.annotation.Nullable;
 
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -69,16 +72,21 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
     // as they reuse some attributes.
     private final Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches;
 
+    private final Map<String, IntermediateDataSet> 
consumerEdgeIdToIntermediateDataSetMap;
+
     public DefaultStreamGraphContext(
             StreamGraph streamGraph,
             Map<Integer, StreamNodeForwardGroup> steamNodeIdToForwardGroupMap,
             Map<Integer, Integer> frozenNodeToStartNodeMap,
-            Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches) {
+            Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches,
+            Map<String, IntermediateDataSet> 
consumerEdgeIdToIntermediateDataSetMap) {
         this.streamGraph = checkNotNull(streamGraph);
         this.steamNodeIdToForwardGroupMap = 
checkNotNull(steamNodeIdToForwardGroupMap);
         this.frozenNodeToStartNodeMap = checkNotNull(frozenNodeToStartNodeMap);
         this.opIntermediateOutputsCaches = 
checkNotNull(opIntermediateOutputsCaches);
         this.immutableStreamGraph = new ImmutableStreamGraph(this.streamGraph);
+        this.consumerEdgeIdToIntermediateDataSetMap =
+                checkNotNull(consumerEdgeIdToIntermediateDataSetMap);
     }
 
     @Override
@@ -188,9 +196,9 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
             tryConvertForwardPartitionerAndMergeForwardGroup(targetEdge);
         }
 
-        // The partitioner in NonChainedOutput derived from the consumer edge, 
so we need to ensure
-        // that any modifications to the partitioner of consumer edge are 
synchronized with
-        // NonChainedOutput.
+        // The partitioner in NonChainedOutput and IntermediateDataSet derived 
from the consumer
+        // edge, so we need to ensure that any modifications to the 
partitioner of consumer edge are
+        // synchronized with NonChainedOutput and IntermediateDataSet.
         Map<StreamEdge, NonChainedOutput> opIntermediateOutputs =
                 opIntermediateOutputsCaches.get(targetEdge.getSourceId());
         NonChainedOutput output =
@@ -198,6 +206,23 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
         if (output != null) {
             output.setPartitioner(targetEdge.getPartitioner());
         }
+
+        
Optional.ofNullable(consumerEdgeIdToIntermediateDataSetMap.get(targetEdge.getEdgeId()))
+                .ifPresent(
+                        dataSet -> {
+                            DistributionPattern distributionPattern =
+                                    targetEdge.getPartitioner().isPointwise()
+                                            ? DistributionPattern.POINTWISE
+                                            : DistributionPattern.ALL_TO_ALL;
+                            dataSet.updateOutputPattern(
+                                    distributionPattern,
+                                    targetEdge.getPartitioner().isBroadcast(),
+                                    targetEdge
+                                            .getPartitioner()
+                                            .getClass()
+                                            .equals(ForwardPartitioner.class));
+                        });
+
         LOG.info(
                 "The original partitioner of the edge {} is: {} , requested 
change to: {} , and finally modified to: {}.",
                 targetEdge,
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java
index e0f4d6e2fad..72de899aaf5 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java
@@ -19,10 +19,13 @@
 package org.apache.flink.runtime.executiongraph;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 
 import static 
org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForAllToAll;
 import static 
org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Test for {@link VertexInputInfoComputationUtils}. */
 class VertexInputInfoComputationUtilsTest {
@@ -57,34 +60,55 @@ class VertexInputInfoComputationUtilsTest {
         assertThat(range4).isEqualTo(new IndexRange(4, 5));
     }
 
-    @Test
-    void testComputeBroadcastConsumedSubpartitionRange() {
-        final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 1, 
true, true);
-        assertThat(range1).isEqualTo(new IndexRange(0, 0));
-
-        final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 1, 
true, true);
-        assertThat(range2).isEqualTo(new IndexRange(0, 0));
-
-        final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 1, 
true, true);
-        assertThat(range3).isEqualTo(new IndexRange(0, 0));
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testComputeBroadcastConsumedSubpartitionRange(boolean 
singleSubpartitionContainsAllData) {
+        int numSubpartitions = singleSubpartitionContainsAllData ? 1 : 3;
+        final IndexRange range1 =
+                computeConsumedSubpartitionRange(
+                        0, 3, numSubpartitions, true, true, 
singleSubpartitionContainsAllData);
+        assertThat(range1).isEqualTo(new IndexRange(0, numSubpartitions - 1));
+
+        final IndexRange range2 =
+                computeConsumedSubpartitionRange(
+                        1, 3, numSubpartitions, true, true, 
singleSubpartitionContainsAllData);
+        assertThat(range2).isEqualTo(new IndexRange(0, numSubpartitions - 1));
+
+        final IndexRange range3 =
+                computeConsumedSubpartitionRange(
+                        2, 3, numSubpartitions, true, true, 
singleSubpartitionContainsAllData);
+        assertThat(range3).isEqualTo(new IndexRange(0, numSubpartitions - 1));
+
+        if (singleSubpartitionContainsAllData) {
+            assertThatThrownBy(
+                            () ->
+                                    computeConsumedSubpartitionRange(
+                                            2,
+                                            3,
+                                            numSubpartitions + 1,
+                                            true,
+                                            true,
+                                            singleSubpartitionContainsAllData))
+                    .isInstanceOf(IllegalArgumentException.class);
+        }
     }
 
     @Test
     void testComputeConsumedSubpartitionRangeForNonDynamicGraph() {
-        final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, -1, 
false, false);
+        final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, -1, 
false, false, false);
         assertThat(range1).isEqualTo(new IndexRange(0, 0));
 
-        final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, -1, 
false, false);
+        final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, -1, 
false, false, false);
         assertThat(range2).isEqualTo(new IndexRange(1, 1));
 
-        final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, -1, 
false, false);
+        final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, -1, 
false, false, false);
         assertThat(range3).isEqualTo(new IndexRange(2, 2));
     }
 
     @Test
     void testComputeVertexInputInfoForAllToAllWithNonDynamicGraph() {
         final JobVertexInputInfo nonBroadcast =
-                computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, 
false);
+                computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, 
false, false);
         assertThat(nonBroadcast.getExecutionVertexInputInfos())
                 .containsExactlyInAnyOrder(
                         new ExecutionVertexInputInfo(0, new IndexRange(0, 1), 
new IndexRange(0, 0)),
@@ -93,7 +117,7 @@ class VertexInputInfoComputationUtilsTest {
                                 2, new IndexRange(0, 1), new IndexRange(2, 
2)));
 
         final JobVertexInputInfo broadcast =
-                computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, 
true);
+                computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, 
true, false);
         assertThat(broadcast.getExecutionVertexInputInfos())
                 .containsExactlyInAnyOrder(
                         new ExecutionVertexInputInfo(0, new IndexRange(0, 1), 
new IndexRange(0, 0)),
@@ -102,10 +126,13 @@ class VertexInputInfoComputationUtilsTest {
                                 2, new IndexRange(0, 1), new IndexRange(2, 
2)));
     }
 
-    @Test
-    void testComputeVertexInputInfoForAllToAllWithDynamicGraph() {
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testComputeVertexInputInfoForAllToAllWithDynamicGraph(
+            boolean singleSubpartitionContainsAllData) {
         final JobVertexInputInfo nonBroadcast =
-                computeVertexInputInfoForAllToAll(2, 3, ignored -> 10, true, 
false);
+                computeVertexInputInfoForAllToAll(
+                        2, 3, ignored -> 10, true, false, 
singleSubpartitionContainsAllData);
         assertThat(nonBroadcast.getExecutionVertexInputInfos())
                 .containsExactlyInAnyOrder(
                         new ExecutionVertexInputInfo(0, new IndexRange(0, 1), 
new IndexRange(0, 2)),
@@ -114,13 +141,30 @@ class VertexInputInfoComputationUtilsTest {
                                 2, new IndexRange(0, 1), new IndexRange(6, 
9)));
 
         final JobVertexInputInfo broadcast =
-                computeVertexInputInfoForAllToAll(2, 3, ignored -> 1, true, 
true);
+                computeVertexInputInfoForAllToAll(
+                        2, 3, ignored -> 1, true, true, 
singleSubpartitionContainsAllData);
         assertThat(broadcast.getExecutionVertexInputInfos())
                 .containsExactlyInAnyOrder(
                         new ExecutionVertexInputInfo(0, new IndexRange(0, 1), 
new IndexRange(0, 0)),
                         new ExecutionVertexInputInfo(1, new IndexRange(0, 1), 
new IndexRange(0, 0)),
                         new ExecutionVertexInputInfo(
                                 2, new IndexRange(0, 1), new IndexRange(0, 
0)));
+
+        if (!singleSubpartitionContainsAllData) {
+            final JobVertexInputInfo 
broadcastAndNotSingleSubpartitionContainsAllData =
+                    computeVertexInputInfoForAllToAll(
+                            2, 3, ignored -> 4, true, true, 
singleSubpartitionContainsAllData);
+            assertThat(
+                            broadcastAndNotSingleSubpartitionContainsAllData
+                                    .getExecutionVertexInputInfos())
+                    .containsExactlyInAnyOrder(
+                            new ExecutionVertexInputInfo(
+                                    0, new IndexRange(0, 1), new IndexRange(0, 
3)),
+                            new ExecutionVertexInputInfo(
+                                    1, new IndexRange(0, 1), new IndexRange(0, 
3)),
+                            new ExecutionVertexInputInfo(
+                                    2, new IndexRange(0, 1), new IndexRange(0, 
3)));
+        }
     }
 
     @Test
@@ -150,7 +194,7 @@ class VertexInputInfoComputationUtilsTest {
     private static IndexRange computeConsumedSubpartitionRange(
             int consumerIndex, int numConsumers, int numSubpartitions) {
         return computeConsumedSubpartitionRange(
-                consumerIndex, numConsumers, numSubpartitions, true, false);
+                consumerIndex, numConsumers, numSubpartitions, true, false, 
false);
     }
 
     private static IndexRange computeConsumedSubpartitionRange(
@@ -158,8 +202,14 @@ class VertexInputInfoComputationUtilsTest {
             int numConsumers,
             int numSubpartitions,
             boolean isDynamicGraph,
-            boolean isBroadcast) {
+            boolean isBroadcast,
+            boolean broadcast) {
         return 
VertexInputInfoComputationUtils.computeConsumedSubpartitionRange(
-                consumerIndex, numConsumers, () -> numSubpartitions, 
isDynamicGraph, isBroadcast);
+                consumerIndex,
+                numConsumers,
+                () -> numSubpartitions,
+                isDynamicGraph,
+                isBroadcast,
+                broadcast);
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.java
index 500d65b67ff..b1dc2260ae3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionPlanSchedulingContextTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.streaming.api.graph.StreamGraph;
 import org.apache.flink.streaming.api.graph.StreamNode;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
+import org.apache.flink.util.DynamicCodeLoadingException;
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
@@ -45,7 +46,7 @@ class AdaptiveExecutionPlanSchedulingContextTest {
             TestingUtils.defaultExecutorExtension();
 
     @Test
-    void testGetParallelismAndMaxParallelism() {
+    void testGetParallelismAndMaxParallelism() throws 
DynamicCodeLoadingException {
         int sinkParallelism = 4;
         int sinkMaxParallelism = 5;
 
@@ -73,7 +74,8 @@ class AdaptiveExecutionPlanSchedulingContextTest {
     }
 
     @Test
-    void testGetDefaultMaxParallelismWhenParallelismGreaterThanZero() {
+    void testGetDefaultMaxParallelismWhenParallelismGreaterThanZero()
+            throws DynamicCodeLoadingException {
         int sinkParallelism = 4;
         int sinkMaxParallelism = -1;
         int defaultMaxParallelism = 100;
@@ -94,7 +96,8 @@ class AdaptiveExecutionPlanSchedulingContextTest {
     }
 
     @Test
-    void testGetDefaultMaxParallelismWhenParallelismLessThanZero() {
+    void testGetDefaultMaxParallelismWhenParallelismLessThanZero()
+            throws DynamicCodeLoadingException {
         int sinkParallelism = -1;
         int sinkMaxParallelism = -1;
         int defaultMaxParallelism = 100;
@@ -115,7 +118,7 @@ class AdaptiveExecutionPlanSchedulingContextTest {
     }
 
     @Test
-    public void testGetPendingOperatorCount() {
+    public void testGetPendingOperatorCount() throws 
DynamicCodeLoadingException {
         DefaultAdaptiveExecutionHandler adaptiveExecutionHandler =
                 getDefaultAdaptiveExecutionHandler();
         ExecutionPlanSchedulingContext schedulingContext =
@@ -131,12 +134,13 @@ class AdaptiveExecutionPlanSchedulingContextTest {
         assertThat(schedulingContext.getPendingOperatorCount()).isEqualTo(0);
     }
 
-    private static DefaultAdaptiveExecutionHandler 
getDefaultAdaptiveExecutionHandler() {
+    private static DefaultAdaptiveExecutionHandler 
getDefaultAdaptiveExecutionHandler()
+            throws DynamicCodeLoadingException {
         return getDefaultAdaptiveExecutionHandler(2, 2);
     }
 
     private static DefaultAdaptiveExecutionHandler 
getDefaultAdaptiveExecutionHandler(
-            int sinkParallelism, int sinkMaxParallelism) {
+            int sinkParallelism, int sinkMaxParallelism) throws 
DynamicCodeLoadingException {
         StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
         env.fromSequence(0L, 1L).disableChaining().print();
         StreamGraph streamGraph = env.getStreamGraph();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
index e298b4a065a..6930940fd7c 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
@@ -32,18 +32,19 @@ class AllToAllBlockingResultInfoTest {
 
     @Test
     void testGetNumBytesProducedForNonBroadcast() {
-        testGetNumBytesProduced(false, 192L);
+        testGetNumBytesProduced(false, false, 192L);
     }
 
     @Test
     void testGetNumBytesProducedForBroadcast() {
-        testGetNumBytesProduced(true, 96L);
+        testGetNumBytesProduced(true, true, 96L);
+        testGetNumBytesProduced(true, false, 192L);
     }
 
     @Test
     void testGetNumBytesProducedWithIndexRange() {
         AllToAllBlockingResultInfo resultInfo =
-                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false);
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false, false);
         resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 64L}));
         resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] 
{128L, 256L}));
 
@@ -57,7 +58,7 @@ class AllToAllBlockingResultInfoTest {
     @Test
     void testGetAggregatedSubpartitionBytes() {
         AllToAllBlockingResultInfo resultInfo =
-                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false);
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false, false);
         resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 64L}));
         resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] 
{128L, 256L}));
 
@@ -67,8 +68,9 @@ class AllToAllBlockingResultInfoTest {
     @Test
     void testGetBytesWithPartialPartitionInfos() {
         AllToAllBlockingResultInfo resultInfo =
-                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false);
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false, false);
         resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 64L}));
+        resultInfo.aggregateSubpartitionBytes();
 
         assertThatThrownBy(resultInfo::getNumBytesProduced)
                 .isInstanceOf(IllegalStateException.class);
@@ -79,7 +81,7 @@ class AllToAllBlockingResultInfoTest {
     @Test
     void testRecordPartitionInfoMultiTimes() {
         AllToAllBlockingResultInfo resultInfo =
-                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false);
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false, false);
 
         ResultPartitionBytes partitionBytes1 = new ResultPartitionBytes(new 
long[] {32L, 64L});
         ResultPartitionBytes partitionBytes2 = new ResultPartitionBytes(new 
long[] {64L, 128L});
@@ -115,9 +117,15 @@ class AllToAllBlockingResultInfoTest {
         assertThat(resultInfo.getNumOfRecordedPartitions()).isZero();
     }
 
-    private void testGetNumBytesProduced(boolean isBroadcast, long 
expectedBytes) {
+    private void testGetNumBytesProduced(
+            boolean isBroadcast, boolean singleSubpartitionContainsAllData, 
long expectedBytes) {
         AllToAllBlockingResultInfo resultInfo =
-                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, isBroadcast);
+                new AllToAllBlockingResultInfo(
+                        new IntermediateDataSetID(),
+                        2,
+                        2,
+                        isBroadcast,
+                        singleSubpartitionContainsAllData);
         resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 32L}));
         resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] 
{64L, 64L}));
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java
index 6049006d245..3ad14047b0e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java
@@ -24,16 +24,27 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import 
org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexFinishedEvent;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.graph.StreamGraph;
+import org.apache.flink.streaming.api.graph.StreamGraphContext;
+import org.apache.flink.streaming.api.graph.StreamNode;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
+import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
+import org.apache.flink.util.DynamicCodeLoadingException;
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -47,7 +58,7 @@ class DefaultAdaptiveExecutionHandlerTest {
             TestingUtils.defaultExecutorExtension();
 
     @Test
-    void testGetJobGraph() {
+    void testGetJobGraph() throws DynamicCodeLoadingException {
         JobGraph jobGraph = createAdaptiveExecutionHandler().getJobGraph();
 
         assertThat(jobGraph).isNotNull();
@@ -56,7 +67,7 @@ class DefaultAdaptiveExecutionHandlerTest {
     }
 
     @Test
-    void testHandleJobEvent() {
+    void testHandleJobEvent() throws DynamicCodeLoadingException {
         List<JobVertex> newAddedJobVertices = new ArrayList<>();
         AtomicInteger pendingOperators = new AtomicInteger();
 
@@ -95,7 +106,82 @@ class DefaultAdaptiveExecutionHandlerTest {
     }
 
     @Test
-    void testGetInitialParallelismAndNotifyJobVertexParallelismDecided() {
+    void testOptimizeStreamGraph() throws DynamicCodeLoadingException {
+        StreamGraph streamGraph = createStreamGraph();
+        StreamNode source =
+                streamGraph.getStreamNodes().stream()
+                        .filter(node -> 
node.getOperatorName().contains("Source"))
+                        .findFirst()
+                        .get();
+        StreamNode map =
+                streamGraph.getStreamNodes().stream()
+                        .filter(node -> node.getOperatorName().contains("Map"))
+                        .findFirst()
+                        .get();
+
+        assertThat(source.getOutEdges().get(0).getPartitioner())
+                .isInstanceOf(ForwardPartitioner.class);
+        assertThat(map.getOutEdges().get(0).getPartitioner())
+                .isInstanceOf(RescalePartitioner.class);
+
+        streamGraph
+                .getJobConfiguration()
+                .set(
+                        
StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY,
+                        Collections.singletonList(
+                                
TestingStreamGraphOptimizerStrategy.class.getName()));
+        TestingStreamGraphOptimizerStrategy.convertToReBalanceEdgeIds.add(
+                source.getOutEdges().get(0).getEdgeId());
+        TestingStreamGraphOptimizerStrategy.convertToReBalanceEdgeIds.add(
+                map.getOutEdges().get(0).getEdgeId());
+
+        DefaultAdaptiveExecutionHandler handler =
+                createAdaptiveExecutionHandler(
+                        (newVertices, pendingOperatorsCount) -> {}, 
streamGraph);
+
+        JobGraph jobGraph = handler.getJobGraph();
+        JobVertex sourceVertex = jobGraph.getVertices().iterator().next();
+
+        // notify Source node is finished
+        ExecutionJobVertexFinishedEvent event1 =
+                new ExecutionJobVertexFinishedEvent(sourceVertex.getID(), 
Collections.emptyMap());
+        handler.handleJobEvent(event1);
+
+        // verify that the source output edge is not updated because the 
original edge is forward.
+        
assertThat(sourceVertex.getProducedDataSets().get(0).getConsumers()).hasSize(1);
+        assertThat(
+                        sourceVertex
+                                .getProducedDataSets()
+                                .get(0)
+                                .getConsumers()
+                                .get(0)
+                                .getShipStrategyName())
+                .isEqualToIgnoringCase("forward");
+
+        // notify Map node is finished
+        Iterator<JobVertex> jobVertexIterator = 
jobGraph.getVertices().iterator();
+        jobVertexIterator.next();
+        JobVertex mapVertex = jobVertexIterator.next();
+
+        ExecutionJobVertexFinishedEvent event2 =
+                new ExecutionJobVertexFinishedEvent(mapVertex.getID(), 
Collections.emptyMap());
+        handler.handleJobEvent(event2);
+
+        // verify that the map output edge is updated to reBalance.
+        
assertThat(mapVertex.getProducedDataSets().get(0).getConsumers()).hasSize(1);
+        assertThat(
+                        mapVertex
+                                .getProducedDataSets()
+                                .get(0)
+                                .getConsumers()
+                                .get(0)
+                                .getShipStrategyName())
+                .isEqualToIgnoringCase("rebalance");
+    }
+
+    @Test
+    void testGetInitialParallelismAndNotifyJobVertexParallelismDecided()
+            throws DynamicCodeLoadingException {
         StreamGraph streamGraph = createStreamGraph();
         DefaultAdaptiveExecutionHandler handler =
                 createAdaptiveExecutionHandler(
@@ -123,7 +209,8 @@ class DefaultAdaptiveExecutionHandlerTest {
         
assertThat(handler.getInitialParallelism(map.getID())).isEqualTo(parallelism);
     }
 
-    private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler() {
+    private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler()
+            throws DynamicCodeLoadingException {
         return createAdaptiveExecutionHandler(
                 (newVertices, pendingOperatorsCount) -> {}, 
createStreamGraph());
     }
@@ -159,7 +246,8 @@ class DefaultAdaptiveExecutionHandlerTest {
      * and a given {@link StreamGraph}.
      */
     private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler(
-            JobGraphUpdateListener listener, StreamGraph streamGraph) {
+            JobGraphUpdateListener listener, StreamGraph streamGraph)
+            throws DynamicCodeLoadingException {
         DefaultAdaptiveExecutionHandler handler =
                 new DefaultAdaptiveExecutionHandler(
                         getClass().getClassLoader(), streamGraph, 
EXECUTOR_RESOURCE.getExecutor());
@@ -167,4 +255,34 @@ class DefaultAdaptiveExecutionHandlerTest {
 
         return handler;
     }
+
+    public static final class TestingStreamGraphOptimizerStrategy
+            implements StreamGraphOptimizationStrategy {
+
+        private static final Set<String> convertToReBalanceEdgeIds = new 
HashSet<>();
+
+        @Override
+        public boolean onOperatorsFinished(
+                OperatorsFinished operatorsFinished, StreamGraphContext 
context) {
+            List<Integer> finishedStreamNodeIds = 
operatorsFinished.getFinishedStreamNodeIds();
+            List<StreamEdgeUpdateRequestInfo> requestInfos = new ArrayList<>();
+            for (Integer finishedStreamNodeId : finishedStreamNodeIds) {
+                for (ImmutableStreamEdge outEdge :
+                        context.getStreamGraph()
+                                .getStreamNode(finishedStreamNodeId)
+                                .getOutEdges()) {
+                    if 
(convertToReBalanceEdgeIds.contains(outEdge.getEdgeId())) {
+                        StreamEdgeUpdateRequestInfo requestInfo =
+                                new StreamEdgeUpdateRequestInfo(
+                                        outEdge.getEdgeId(),
+                                        outEdge.getSourceId(),
+                                        outEdge.getTargetId());
+                        requestInfo.outputPartitioner(new 
RebalancePartitioner<>());
+                        requestInfos.add(requestInfo);
+                    }
+                }
+            }
+            return context.modifyStreamEdge(requestInfos);
+        }
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
index 23c70f317bd..c1ea24e43aa 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
@@ -32,11 +32,14 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.shaded.guava32.com.google.common.collect.Iterables;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -104,8 +107,9 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
 
     @Test
     void testDecideParallelismWithMaxSubpartitionLimitation() {
-        BlockingResultInfo resultInfo1 = new TestingBlockingResultInfo(false, 
1L, 1024, 1024);
-        BlockingResultInfo resultInfo2 = new TestingBlockingResultInfo(false, 
1L, 512, 512);
+        BlockingResultInfo resultInfo1 =
+                new TestingBlockingResultInfo(false, false, 1L, 1024, 1024);
+        BlockingResultInfo resultInfo2 = new TestingBlockingResultInfo(false, 
false, 1L, 512, 512);
 
         int parallelism =
                 createDeciderAndDecideParallelism(
@@ -206,13 +210,19 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
                         new IndexRange(8, 9)));
     }
 
-    @Test
-    void testAllEdgesAllToAllAndOneIsBroadcast() {
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testAllEdgesAllToAllAndOneIsBroadcast(boolean 
singleSubpartitionContainsAllData) {
         AllToAllBlockingResultInfo resultInfo1 =
                 createAllToAllBlockingResultInfo(
-                        new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 
17L});
+                        new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 
17L}, false, false);
         AllToAllBlockingResultInfo resultInfo2 =
-                createAllToAllBlockingResultInfo(new long[] {10L}, true);
+                createAllToAllBlockingResultInfo(
+                        singleSubpartitionContainsAllData
+                                ? new long[] {10L}
+                                : new long[] {1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L},
+                        true,
+                        singleSubpartitionContainsAllData);
 
         ParallelismAndInputInfos parallelismAndInputInfos =
                 createDeciderAndDecideParallelismAndInputInfos(
@@ -224,17 +234,30 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
         checkAllToAllJobVertexInputInfo(
                 
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()),
                 Arrays.asList(new IndexRange(0, 4), new IndexRange(5, 8), new 
IndexRange(9, 9)));
-        checkAllToAllJobVertexInputInfo(
-                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()),
-                Arrays.asList(new IndexRange(0, 0), new IndexRange(0, 0), new 
IndexRange(0, 0)));
+        if (singleSubpartitionContainsAllData) {
+            checkAllToAllJobVertexInputInfo(
+                    parallelismAndInputInfos
+                            .getJobVertexInputInfos()
+                            .get(resultInfo2.getResultId()),
+                    Arrays.asList(
+                            new IndexRange(0, 0), new IndexRange(0, 0), new 
IndexRange(0, 0)));
+        } else {
+            checkAllToAllJobVertexInputInfo(
+                    parallelismAndInputInfos
+                            .getJobVertexInputInfos()
+                            .get(resultInfo2.getResultId()),
+                    Arrays.asList(
+                            new IndexRange(0, 9), new IndexRange(0, 9), new 
IndexRange(0, 9)));
+        }
     }
 
     @Test
     void testAllEdgesBroadcast() {
-        AllToAllBlockingResultInfo resultInfo1 =
-                createAllToAllBlockingResultInfo(new long[] {10L}, true);
-        AllToAllBlockingResultInfo resultInfo2 =
-                createAllToAllBlockingResultInfo(new long[] {10L}, true);
+        AllToAllBlockingResultInfo resultInfo1;
+        AllToAllBlockingResultInfo resultInfo2;
+        resultInfo1 = createAllToAllBlockingResultInfo(new long[] {10L}, true, 
false);
+        resultInfo2 = createAllToAllBlockingResultInfo(new long[] {10L}, true, 
false);
+
         ParallelismAndInputInfos parallelismAndInputInfos =
                 createDeciderAndDecideParallelismAndInputInfos(
                         1, 10, 60L, Arrays.asList(resultInfo1, resultInfo2));
@@ -242,12 +265,15 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
         assertThat(parallelismAndInputInfos.getParallelism()).isOne();
         
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
 
+        List<IndexRange> expectedSubpartitionRanges =
+                Collections.singletonList(new IndexRange(0, 0));
+
         checkAllToAllJobVertexInputInfo(
                 
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()),
-                Collections.singletonList(new IndexRange(0, 0)));
+                expectedSubpartitionRanges);
         checkAllToAllJobVertexInputInfo(
                 
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()),
-                Collections.singletonList(new IndexRange(0, 0)));
+                expectedSubpartitionRanges);
     }
 
     @Test
@@ -359,7 +385,8 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
         long[] subpartitionBytes = new long[1024];
         Arrays.fill(subpartitionBytes, 1L);
         AllToAllBlockingResultInfo resultInfo =
-                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 
1024, 1024, false);
+                new AllToAllBlockingResultInfo(
+                        new IntermediateDataSetID(), 1024, 1024, false, false);
         for (int i = 0; i < 1024; ++i) {
             resultInfo.recordPartitionInfo(i, new 
ResultPartitionBytes(subpartitionBytes));
         }
@@ -507,11 +534,13 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
 
     private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo(
             long[] aggregatedSubpartitionBytes) {
-        return createAllToAllBlockingResultInfo(aggregatedSubpartitionBytes, 
false);
+        return createAllToAllBlockingResultInfo(aggregatedSubpartitionBytes, 
false, false);
     }
 
     private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo(
-            long[] aggregatedSubpartitionBytes, boolean isBroadcast) {
+            long[] aggregatedSubpartitionBytes,
+            boolean isBroadcast,
+            boolean isSingleSubpartitionContainsAllData) {
         // For simplicity, we configure only one partition here, so the 
aggregatedSubpartitionBytes
         // is equivalent to the subpartition bytes of partition0
         AllToAllBlockingResultInfo resultInfo =
@@ -519,7 +548,8 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
                         new IntermediateDataSetID(),
                         1,
                         aggregatedSubpartitionBytes.length,
-                        isBroadcast);
+                        isBroadcast,
+                        isSingleSubpartitionContainsAllData);
         resultInfo.recordPartitionInfo(0, new 
ResultPartitionBytes(aggregatedSubpartitionBytes));
         return resultInfo;
     }
@@ -552,17 +582,31 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
     private static class TestingBlockingResultInfo implements 
BlockingResultInfo {
 
         private final boolean isBroadcast;
+        private final boolean singleSubpartitionContainsAllData;
         private final long producedBytes;
         private final int numPartitions;
         private final int numSubpartitions;
 
-        private TestingBlockingResultInfo(boolean isBroadcast, long 
producedBytes) {
-            this(isBroadcast, producedBytes, MAX_PARALLELISM, MAX_PARALLELISM);
+        private TestingBlockingResultInfo(
+                boolean isBroadcast,
+                boolean singleSubpartitionContainsAllData,
+                long producedBytes) {
+            this(
+                    isBroadcast,
+                    singleSubpartitionContainsAllData,
+                    producedBytes,
+                    MAX_PARALLELISM,
+                    MAX_PARALLELISM);
         }
 
         private TestingBlockingResultInfo(
-                boolean isBroadcast, long producedBytes, int numPartitions, 
int numSubpartitions) {
+                boolean isBroadcast,
+                boolean singleSubpartitionContainsAllData,
+                long producedBytes,
+                int numPartitions,
+                int numSubpartitions) {
             this.isBroadcast = isBroadcast;
+            this.singleSubpartitionContainsAllData = 
singleSubpartitionContainsAllData;
             this.producedBytes = producedBytes;
             this.numPartitions = numPartitions;
             this.numSubpartitions = numSubpartitions;
@@ -578,6 +622,11 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
             return isBroadcast;
         }
 
+        @Override
+        public boolean isSingleSubpartitionContainsAllData() {
+            return singleSubpartitionContainsAllData;
+        }
+
         @Override
         public boolean isPointwise() {
             return false;
@@ -617,10 +666,10 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
     }
 
     private static BlockingResultInfo createFromBroadcastResult(long 
producedBytes) {
-        return new TestingBlockingResultInfo(true, producedBytes);
+        return new TestingBlockingResultInfo(true, true, producedBytes);
     }
 
     private static BlockingResultInfo createFromNonBroadcastResult(long 
producedBytes) {
-        return new TestingBlockingResultInfo(false, producedBytes);
+        return new TestingBlockingResultInfo(false, false, producedBytes);
     }
 }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
index 468f6883b73..0be8c246893 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
@@ -52,7 +52,8 @@ class DefaultStreamGraphContextTest {
                         streamGraph,
                         forwardGroupsByEndpointNodeIdCache,
                         frozenNodeToStartNodeMap,
-                        opIntermediateOutputsCaches);
+                        opIntermediateOutputsCaches,
+                        new HashMap<>());
 
         StreamNode sourceNode =
                 
streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next());
@@ -136,7 +137,8 @@ class DefaultStreamGraphContextTest {
                         streamGraph,
                         forwardGroupsByEndpointNodeIdCache,
                         frozenNodeToStartNodeMap,
-                        opIntermediateOutputsCaches);
+                        opIntermediateOutputsCaches,
+                        new HashMap<>());
 
         StreamNode sourceNode =
                 
streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next());
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
index 6654c1eb17e..21b9a7645a9 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
@@ -31,8 +31,18 @@ import org.apache.flink.configuration.MemorySize;
 import org.apache.flink.configuration.RestOptions;
 import org.apache.flink.configuration.TaskManagerOptions;
 import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler;
+import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
+import 
org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.graph.StreamGraph;
+import org.apache.flink.streaming.api.graph.StreamGraphContext;
+import org.apache.flink.streaming.api.graph.StreamNode;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
+import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -40,8 +50,10 @@ import org.junit.jupiter.api.Test;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -130,6 +142,72 @@ class AdaptiveBatchSchedulerITCase {
         env.execute();
     }
 
+    @Test
+    void testAdaptiveOptimizeStreamGraph() throws Exception {
+        final Configuration configuration = createConfiguration();
+        configuration.set(
+                
StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY,
+                List.of(TestingStreamGraphOptimizerStrategy.class.getName()));
+        final StreamExecutionEnvironment env =
+                
StreamExecutionEnvironment.getExecutionEnvironment(configuration);
+        env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+        env.disableOperatorChaining();
+        env.setParallelism(8);
+
+        SingleOutputStreamOperator<Long> source1 =
+                env.fromSequence(0, NUMBERS_TO_PRODUCE - 1)
+                        .setParallelism(SOURCE_PARALLELISM_1)
+                        .name("source1");
+        SingleOutputStreamOperator<Long> source2 =
+                env.fromSequence(0, NUMBERS_TO_PRODUCE - 1)
+                        .setParallelism(SOURCE_PARALLELISM_2)
+                        .name("source2");
+
+        source1.keyBy(i -> i % SOURCE_PARALLELISM_1)
+                .map(i -> i)
+                .name("map1")
+                .rebalance()
+                .union(source2)
+                .rebalance()
+                .map(new NumberCounter())
+                .name("map2")
+                .setParallelism(1);
+
+        StreamGraph streamGraph = env.getStreamGraph();
+        StreamNode sourceNode1 =
+                streamGraph.getStreamNodes().stream()
+                        .filter(node -> 
node.getOperatorName().contains("source1"))
+                        .findFirst()
+                        .get();
+        StreamNode mapNode1 =
+                streamGraph.getStreamNodes().stream()
+                        .filter(node -> 
node.getOperatorName().contains("map1"))
+                        .findFirst()
+                        .get();
+
+        TestingStreamGraphOptimizerStrategy.convertToRescaleEdgeIds.add(
+                sourceNode1.getOutEdges().get(0).getEdgeId());
+        TestingStreamGraphOptimizerStrategy.convertToBroadcastEdgeIds.add(
+                mapNode1.getOutEdges().get(0).getEdgeId());
+
+        env.execute(streamGraph);
+
+        Map<Long, Long> numberCountResultMap =
+                numberCountResults.stream()
+                        .flatMap(map -> map.entrySet().stream())
+                        .collect(
+                                Collectors.toMap(
+                                        Map.Entry::getKey, 
Map.Entry::getValue, Long::sum));
+
+        // One part comes from source1, while the other parts come from the 
broadcast results of
+        // source2.
+        Map<Long, Long> expectedResult =
+                LongStream.range(0, NUMBERS_TO_PRODUCE)
+                        .boxed()
+                        .collect(Collectors.toMap(Function.identity(), i -> 
2L));
+        assertThat(numberCountResultMap).isEqualTo(expectedResult);
+    }
+
     private void testSchedulingBase(Boolean useSourceParallelismInference) 
throws Exception {
         executeJob(useSourceParallelismInference);
 
@@ -257,4 +335,38 @@ class AdaptiveBatchSchedulerITCase {
             return expectedParallelism;
         }
     }
+
+    public static final class TestingStreamGraphOptimizerStrategy
+            implements StreamGraphOptimizationStrategy {
+
+        private static final Set<String> convertToBroadcastEdgeIds = new 
HashSet<>();
+        private static final Set<String> convertToRescaleEdgeIds = new 
HashSet<>();
+
+        @Override
+        public boolean onOperatorsFinished(
+                OperatorsFinished operatorsFinished, StreamGraphContext 
context) throws Exception {
+            List<Integer> finishedStreamNodeIds = 
operatorsFinished.getFinishedStreamNodeIds();
+            List<StreamEdgeUpdateRequestInfo> requestInfos = new ArrayList<>();
+            for (Integer finishedStreamNodeId : finishedStreamNodeIds) {
+                for (ImmutableStreamEdge outEdge :
+                        context.getStreamGraph()
+                                .getStreamNode(finishedStreamNodeId)
+                                .getOutEdges()) {
+                    StreamEdgeUpdateRequestInfo requestInfo =
+                            new StreamEdgeUpdateRequestInfo(
+                                    outEdge.getEdgeId(),
+                                    outEdge.getSourceId(),
+                                    outEdge.getTargetId());
+                    if 
(convertToBroadcastEdgeIds.contains(outEdge.getEdgeId())) {
+                        requestInfo.outputPartitioner(new 
BroadcastPartitioner<>());
+                        requestInfos.add(requestInfo);
+                    } else if 
(convertToRescaleEdgeIds.contains(outEdge.getEdgeId())) {
+                        requestInfo.outputPartitioner(new 
RescalePartitioner<>());
+                        requestInfos.add(requestInfo);
+                    }
+                }
+            }
+            return context.modifyStreamEdge(requestInfos);
+        }
+    }
 }


Reply via email to