This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit a1555b0c53493fa80d0cf4b933396941c7de7073
Author: noorall <[email protected]>
AuthorDate: Fri Oct 25 17:50:55 2024 +0800

    [FLINK-36066][runtime] Introduce StreamGraphContext to provide the ability 
to modify the StreamGraph
---
 .../jobgraph/forwardgroup/ForwardGroup.java        | 135 ++++-------
 .../forwardgroup/ForwardGroupComputeUtil.java      | 114 +++++++--
 ...orwardGroup.java => JobVertexForwardGroup.java} |  34 ++-
 .../forwardgroup/StreamNodeForwardGroup.java       | 174 +++++++++++++
 .../adaptivebatch/AdaptiveBatchScheduler.java      |  11 +-
 .../AdaptiveBatchSchedulerFactory.java             |   4 +-
 .../api/graph/DefaultStreamGraphContext.java       | 270 +++++++++++++++++++++
 .../streaming/api/graph/StreamGraphContext.java    |  65 +++++
 .../flink/streaming/api/graph/StreamNode.java      |   4 +-
 .../api/graph/StreamingJobGraphGenerator.java      |   3 +-
 .../graph/util/StreamEdgeUpdateRequestInfo.java    |  59 +++++
 .../forwardgroup/ForwardGroupComputeUtilTest.java  | 163 ++++++++++++-
 .../forwardgroup/StreamNodeForwardGroupTest.java   | 134 ++++++++++
 .../api/graph/DefaultStreamGraphContextTest.java   | 217 +++++++++++++++++
 14 files changed, 1254 insertions(+), 133 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
index 922acf231e2..0f81075f95d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
@@ -19,90 +19,61 @@
 
 package org.apache.flink.runtime.jobgraph.forwardgroup;
 
-import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.runtime.jobgraph.JobVertex;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
-
-import java.util.Collections;
-import java.util.HashSet;
 import java.util.Set;
-import java.util.stream.Collectors;
-
-import static org.apache.flink.util.Preconditions.checkNotNull;
-import static org.apache.flink.util.Preconditions.checkState;
 
 /**
- * A forward group is a set of job vertices connected via forward edges. 
Parallelisms of all job
- * vertices in the same {@link ForwardGroup} must be the same.
+ * A forward group is a set of job vertices or stream nodes connected via 
forward edges.
+ * Parallelisms of all job vertices or stream nodes in the same {@link 
ForwardGroup} must be the
+ * same.
  */
-public class ForwardGroup {
-
-    private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT;
-
-    private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT;
-    private final Set<JobVertexID> jobVertexIds = new HashSet<>();
-
-    public ForwardGroup(final Set<JobVertex> jobVertices) {
-        checkNotNull(jobVertices);
-
-        Set<Integer> configuredParallelisms =
-                jobVertices.stream()
-                        .filter(
-                                jobVertex -> {
-                                    jobVertexIds.add(jobVertex.getID());
-                                    return jobVertex.getParallelism() > 0;
-                                })
-                        .map(JobVertex::getParallelism)
-                        .collect(Collectors.toSet());
-
-        checkState(configuredParallelisms.size() <= 1);
-        if (configuredParallelisms.size() == 1) {
-            this.parallelism = configuredParallelisms.iterator().next();
-        }
-
-        Set<Integer> configuredMaxParallelisms =
-                jobVertices.stream()
-                        .map(JobVertex::getMaxParallelism)
-                        .filter(val -> val > 0)
-                        .collect(Collectors.toSet());
-
-        if (!configuredMaxParallelisms.isEmpty()) {
-            this.maxParallelism = Collections.min(configuredMaxParallelisms);
-            checkState(
-                    parallelism == ExecutionConfig.PARALLELISM_DEFAULT
-                            || maxParallelism >= parallelism,
-                    "There is a job vertex in the forward group whose maximum 
parallelism is smaller than the group's parallelism");
-        }
-    }
-
-    public void setParallelism(int parallelism) {
-        checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT);
-        this.parallelism = parallelism;
-    }
-
-    public boolean isParallelismDecided() {
-        return parallelism > 0;
-    }
-
-    public int getParallelism() {
-        checkState(isParallelismDecided());
-        return parallelism;
-    }
-
-    public boolean isMaxParallelismDecided() {
-        return maxParallelism > 0;
-    }
-
-    public int getMaxParallelism() {
-        checkState(isMaxParallelismDecided());
-        return maxParallelism;
-    }
-
-    public int size() {
-        return jobVertexIds.size();
-    }
-
-    public Set<JobVertexID> getJobVertexIds() {
-        return jobVertexIds;
-    }
+public interface ForwardGroup<T> {
+
+    /**
+     * Sets the parallelism for this forward group.
+     *
+     * @param parallelism the parallelism to set.
+     */
+    void setParallelism(int parallelism);
+
+    /**
+     * Returns if parallelism has been decided for this forward group.
+     *
+     * @return is parallelism decided for this forward group.
+     */
+    boolean isParallelismDecided();
+
+    /**
+     * Returns the parallelism for this forward group.
+     *
+     * @return parallelism for this forward group.
+     */
+    int getParallelism();
+
+    /**
+     * Sets the max parallelism for this forward group.
+     *
+     * @param maxParallelism the max parallelism to set.
+     */
+    void setMaxParallelism(int maxParallelism);
+
+    /**
+     * Returns if max parallelism has been decided for this forward group.
+     *
+     * @return is max parallelism decided for this forward group.
+     */
+    boolean isMaxParallelismDecided();
+
+    /**
+     * Returns the max parallelism for this forward group.
+     *
+     * @return max parallelism for this forward group.
+     */
+    int getMaxParallelism();
+
+    /**
+     * Returns the vertex ids in this forward group.
+     *
+     * @return vertex ids in this forward group.
+     */
+    Set<T> getVertexIds();
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java
index dc2dc702e34..0db296b62ac 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
 import org.apache.flink.runtime.jobgraph.JobEdge;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.streaming.api.graph.StreamNode;
 
 import java.util.HashMap;
 import java.util.HashSet;
@@ -38,16 +39,17 @@ import static 
org.apache.flink.util.Preconditions.checkState;
 /** Common utils for computing forward groups. */
 public class ForwardGroupComputeUtil {
 
-    public static Map<JobVertexID, ForwardGroup> 
computeForwardGroupsAndCheckParallelism(
+    public static Map<JobVertexID, JobVertexForwardGroup> 
computeForwardGroupsAndCheckParallelism(
             final Iterable<JobVertex> topologicallySortedVertices) {
-        final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId =
+        final Map<JobVertexID, JobVertexForwardGroup> 
forwardGroupsByJobVertexId =
                 computeForwardGroups(
                         topologicallySortedVertices, 
ForwardGroupComputeUtil::getForwardProducers);
         // the vertex's parallelism in parallelism-decided forward group 
should have been set at
         // compilation phase
         topologicallySortedVertices.forEach(
                 jobVertex -> {
-                    ForwardGroup forwardGroup = 
forwardGroupsByJobVertexId.get(jobVertex.getID());
+                    JobVertexForwardGroup forwardGroup =
+                            forwardGroupsByJobVertexId.get(jobVertex.getID());
                     if (forwardGroup != null && 
forwardGroup.isParallelismDecided()) {
                         checkState(jobVertex.getParallelism() == 
forwardGroup.getParallelism());
                     }
@@ -55,28 +57,73 @@ public class ForwardGroupComputeUtil {
         return forwardGroupsByJobVertexId;
     }
 
-    public static Map<JobVertexID, ForwardGroup> computeForwardGroups(
+    public static Map<JobVertexID, JobVertexForwardGroup> computeForwardGroups(
             final Iterable<JobVertex> topologicallySortedVertices,
             final Function<JobVertex, Set<JobVertex>> 
forwardProducersRetriever) {
 
-        final Map<JobVertex, Set<JobVertex>> vertexToGroup = new 
IdentityHashMap<>();
+        final Map<JobVertex, Set<JobVertex>> vertexToGroup =
+                computeVertexToGroup(topologicallySortedVertices, 
forwardProducersRetriever);
+
+        final Map<JobVertexID, JobVertexForwardGroup> ret = new HashMap<>();
+        for (Set<JobVertex> vertexGroup :
+                VertexGroupComputeUtil.uniqueVertexGroups(vertexToGroup)) {
+            if (vertexGroup.size() > 1) {
+                JobVertexForwardGroup forwardGroup = new 
JobVertexForwardGroup(vertexGroup);
+                for (JobVertexID jobVertexId : forwardGroup.getVertexIds()) {
+                    ret.put(jobVertexId, forwardGroup);
+                }
+            }
+        }
+        return ret;
+    }
+
+    /**
+     * We calculate forward group by a set of stream nodes.
+     *
+     * @param topologicallySortedStreamNodes topologically sorted chained 
stream nodes
+     * @param forwardProducersRetriever records all upstream stream nodes 
which connected to the
+     *     given stream node with forward edge
+     * @return a map of forward groups, with the stream node id as the key
+     */
+    public static Map<Integer, StreamNodeForwardGroup> 
computeStreamNodeForwardGroup(
+            final Iterable<StreamNode> topologicallySortedStreamNodes,
+            final Function<StreamNode, Set<StreamNode>> 
forwardProducersRetriever) {
+        // In the forwardProducersRetriever, only the upstream nodes connected 
to the given start
+        // node by the forward edge are saved. We need to calculate the chain 
groups that can be
+        // accessed with consecutive forward edges and put them in the same 
forward group.
+        final Map<StreamNode, Set<StreamNode>> nodeToGroup =
+                computeVertexToGroup(topologicallySortedStreamNodes, 
forwardProducersRetriever);
+        final Map<Integer, StreamNodeForwardGroup> ret = new HashMap<>();
+        for (Set<StreamNode> nodeGroup : 
VertexGroupComputeUtil.uniqueVertexGroups(nodeToGroup)) {
+            StreamNodeForwardGroup forwardGroup = new 
StreamNodeForwardGroup(nodeGroup);
+            for (Integer vertexId : forwardGroup.getVertexIds()) {
+                ret.put(vertexId, forwardGroup);
+            }
+        }
+        return ret;
+    }
+
+    private static <T> Map<T, Set<T>> computeVertexToGroup(
+            final Iterable<T> topologicallySortedVertices,
+            final Function<T, Set<T>> forwardProducersRetriever) {
+        final Map<T, Set<T>> vertexToGroup = new IdentityHashMap<>();
 
         // iterate all the vertices which are topologically sorted
-        for (JobVertex vertex : topologicallySortedVertices) {
-            Set<JobVertex> currentGroup = new HashSet<>();
+        for (T vertex : topologicallySortedVertices) {
+            Set<T> currentGroup = new HashSet<>();
             currentGroup.add(vertex);
             vertexToGroup.put(vertex, currentGroup);
 
-            for (JobVertex producerVertex : 
forwardProducersRetriever.apply(vertex)) {
-                final Set<JobVertex> producerGroup = 
vertexToGroup.get(producerVertex);
+            for (T producerVertex : forwardProducersRetriever.apply(vertex)) {
+                final Set<T> producerGroup = vertexToGroup.get(producerVertex);
 
                 if (producerGroup == null) {
                     throw new IllegalStateException(
                             "Producer task "
-                                    + producerVertex.getID()
+                                    + producerVertex
                                     + " forward group is null"
                                     + " while calculating forward group for 
the consumer task "
-                                    + vertex.getID()
+                                    + vertex
                                     + ". This should be a forward group 
building bug.");
                 }
 
@@ -87,18 +134,43 @@ public class ForwardGroupComputeUtil {
                 }
             }
         }
+        return vertexToGroup;
+    }
 
-        final Map<JobVertexID, ForwardGroup> ret = new HashMap<>();
-        for (Set<JobVertex> vertexGroup :
-                VertexGroupComputeUtil.uniqueVertexGroups(vertexToGroup)) {
-            if (vertexGroup.size() > 1) {
-                ForwardGroup forwardGroup = new ForwardGroup(vertexGroup);
-                for (JobVertexID jobVertexId : forwardGroup.getJobVertexIds()) 
{
-                    ret.put(jobVertexId, forwardGroup);
-                }
-            }
+    /**
+     * Determines whether the target forward group can be merged into the 
source forward group.
+     *
+     * @param sourceForwardGroup The source forward group.
+     * @param forwardGroupToMerge The forward group needs to be merged.
+     * @return whether the merge is valid.
+     */
+    public static boolean canTargetMergeIntoSourceForwardGroup(
+            ForwardGroup<?> sourceForwardGroup, ForwardGroup<?> 
forwardGroupToMerge) {
+        if (sourceForwardGroup == null || forwardGroupToMerge == null) {
+            return false;
         }
-        return ret;
+
+        if (sourceForwardGroup == forwardGroupToMerge) {
+            return true;
+        }
+
+        if (sourceForwardGroup.isParallelismDecided()
+                && forwardGroupToMerge.isParallelismDecided()
+                && sourceForwardGroup.getParallelism() != 
forwardGroupToMerge.getParallelism()) {
+            return false;
+        }
+
+        // When the parallelism of source forward groups is determined, the 
maximum
+        // parallelism of the forwardGroupToMerge should not be less than the 
parallelism of the
+        // sourceForwardGroup to ensure the forwardGroupToMerge can also 
achieve the same
+        // parallelism.
+        if (sourceForwardGroup.isParallelismDecided()
+                && forwardGroupToMerge.isMaxParallelismDecided()
+                && sourceForwardGroup.getParallelism() > 
forwardGroupToMerge.getMaxParallelism()) {
+            return false;
+        }
+
+        return true;
     }
 
     static Set<JobVertex> getForwardProducers(final JobVertex jobVertex) {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/JobVertexForwardGroup.java
similarity index 81%
copy from 
flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/JobVertexForwardGroup.java
index 922acf231e2..d9ad77a61de 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/JobVertexForwardGroup.java
@@ -19,6 +19,7 @@
 
 package org.apache.flink.runtime.jobgraph.forwardgroup;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
@@ -31,18 +32,15 @@ import java.util.stream.Collectors;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
-/**
- * A forward group is a set of job vertices connected via forward edges. 
Parallelisms of all job
- * vertices in the same {@link ForwardGroup} must be the same.
- */
-public class ForwardGroup {
+/** Job vertex level implement for {@link ForwardGroup}. */
+public class JobVertexForwardGroup implements ForwardGroup<JobVertexID> {
 
     private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT;
 
     private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT;
     private final Set<JobVertexID> jobVertexIds = new HashSet<>();
 
-    public ForwardGroup(final Set<JobVertex> jobVertices) {
+    public JobVertexForwardGroup(final Set<JobVertex> jobVertices) {
         checkNotNull(jobVertices);
 
         Set<Integer> configuredParallelisms =
@@ -75,34 +73,50 @@ public class ForwardGroup {
         }
     }
 
+    @Override
     public void setParallelism(int parallelism) {
         checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT);
         this.parallelism = parallelism;
     }
 
+    @Override
     public boolean isParallelismDecided() {
         return parallelism > 0;
     }
 
+    @Override
     public int getParallelism() {
         checkState(isParallelismDecided());
         return parallelism;
     }
 
+    @Override
+    public void setMaxParallelism(int maxParallelism) {
+        checkState(
+                maxParallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                        || maxParallelism >= parallelism,
+                "There is a job vertex in the forward group whose maximum 
parallelism is smaller than the group's parallelism");
+        this.maxParallelism = maxParallelism;
+    }
+
+    @Override
     public boolean isMaxParallelismDecided() {
         return maxParallelism > 0;
     }
 
+    @Override
     public int getMaxParallelism() {
         checkState(isMaxParallelismDecided());
         return maxParallelism;
     }
 
-    public int size() {
-        return jobVertexIds.size();
+    @Override
+    public Set<JobVertexID> getVertexIds() {
+        return jobVertexIds;
     }
 
-    public Set<JobVertexID> getJobVertexIds() {
-        return jobVertexIds;
+    @VisibleForTesting
+    public int size() {
+        return jobVertexIds.size();
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java
new file mode 100644
index 00000000000..af37d4262f7
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.jobgraph.forwardgroup;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.streaming.api.graph.StreamNode;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Stream node level implement for {@link ForwardGroup}. */
+public class StreamNodeForwardGroup implements ForwardGroup<Integer> {
+
+    private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT;
+
+    private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT;
+
+    private final Set<StreamNode> streamNodes = new HashSet<>();
+
+    // For a group of chained stream nodes, their parallelism is consistent. 
In order to make
+    // calculation and usage easier, we only use the start node to calculate 
forward group.
+    public StreamNodeForwardGroup(final Set<StreamNode> streamNodes) {
+        checkNotNull(streamNodes);
+
+        Set<Integer> configuredParallelisms =
+                streamNodes.stream()
+                        .map(StreamNode::getParallelism)
+                        .filter(v -> v > 0)
+                        .collect(Collectors.toSet());
+
+        checkState(configuredParallelisms.size() <= 1);
+        if (configuredParallelisms.size() == 1) {
+            this.parallelism = configuredParallelisms.iterator().next();
+        }
+
+        Set<Integer> configuredMaxParallelisms =
+                streamNodes.stream()
+                        .map(StreamNode::getMaxParallelism)
+                        .filter(val -> val > 0)
+                        .collect(Collectors.toSet());
+
+        if (!configuredMaxParallelisms.isEmpty()) {
+            this.maxParallelism = Collections.min(configuredMaxParallelisms);
+            checkState(
+                    parallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                            || maxParallelism >= parallelism,
+                    "There is a start node in the forward group whose maximum 
parallelism is smaller than the group's parallelism");
+        }
+
+        this.streamNodes.addAll(streamNodes);
+    }
+
+    @Override
+    public void setParallelism(int parallelism) {
+        checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT);
+        this.parallelism = parallelism;
+        this.streamNodes.forEach(
+                streamNode -> {
+                    streamNode.setParallelism(parallelism);
+                });
+    }
+
+    @Override
+    public boolean isParallelismDecided() {
+        return parallelism > 0;
+    }
+
+    @Override
+    public int getParallelism() {
+        checkState(isParallelismDecided());
+        return parallelism;
+    }
+
+    @Override
+    public void setMaxParallelism(int maxParallelism) {
+        checkState(
+                maxParallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                        || maxParallelism >= parallelism,
+                "There is a job vertex in the forward group whose maximum 
parallelism is smaller than the group's parallelism");
+        this.maxParallelism = maxParallelism;
+        this.streamNodes.forEach(
+                streamNode -> {
+                    streamNode.setMaxParallelism(maxParallelism);
+                });
+    }
+
+    @Override
+    public boolean isMaxParallelismDecided() {
+        return maxParallelism > 0;
+    }
+
+    @Override
+    public int getMaxParallelism() {
+        checkState(isMaxParallelismDecided());
+        return maxParallelism;
+    }
+
+    @Override
+    public Set<Integer> getVertexIds() {
+        return 
streamNodes.stream().map(StreamNode::getId).collect(Collectors.toSet());
+    }
+
+    /**
+     * Merges forwardGroupToMerge into this and update the parallelism 
information for stream nodes
+     * in merged forward group.
+     *
+     * @param forwardGroupToMerge The forward group to be merged.
+     * @return whether the merge was successful.
+     */
+    public boolean mergeForwardGroup(final StreamNodeForwardGroup 
forwardGroupToMerge) {
+        checkNotNull(forwardGroupToMerge);
+
+        if (forwardGroupToMerge == this) {
+            return true;
+        }
+
+        if (!ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup(
+                this, forwardGroupToMerge)) {
+            return false;
+        }
+
+        if (this.isParallelismDecided() && 
!forwardGroupToMerge.isParallelismDecided()) {
+            forwardGroupToMerge.setParallelism(this.parallelism);
+        } else if (!this.isParallelismDecided() && 
forwardGroupToMerge.isParallelismDecided()) {
+            this.setParallelism(forwardGroupToMerge.parallelism);
+        } else {
+            checkState(this.parallelism == forwardGroupToMerge.parallelism);
+        }
+
+        if (forwardGroupToMerge.isMaxParallelismDecided()
+                && (!this.isMaxParallelismDecided()
+                        || this.maxParallelism > 
forwardGroupToMerge.maxParallelism)) {
+            this.setMaxParallelism(forwardGroupToMerge.maxParallelism);
+        } else if (this.isMaxParallelismDecided()
+                && (!forwardGroupToMerge.isMaxParallelismDecided()
+                        || forwardGroupToMerge.maxParallelism > 
this.maxParallelism)) {
+            forwardGroupToMerge.setMaxParallelism(this.maxParallelism);
+        } else {
+            checkState(this.maxParallelism == 
forwardGroupToMerge.maxParallelism);
+        }
+
+        this.streamNodes.addAll(forwardGroupToMerge.streamNodes);
+
+        return true;
+    }
+
+    @VisibleForTesting
+    public int size() {
+        return streamNodes.size();
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
index a3d3e582c65..e012712739a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
@@ -58,7 +58,7 @@ import org.apache.flink.runtime.jobgraph.JobEdge;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup;
+import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup;
 import org.apache.flink.runtime.jobgraph.jsonplan.JsonPlanGenerator;
 import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalResult;
 import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalTopology;
@@ -117,7 +117,7 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler {
 
     private final VertexParallelismAndInputInfosDecider 
vertexParallelismAndInputInfosDecider;
 
-    private final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId;
+    private final Map<JobVertexID, JobVertexForwardGroup> 
forwardGroupsByJobVertexId;
 
     private final Map<IntermediateDataSetID, BlockingResultInfo> 
blockingResultInfos;
 
@@ -166,7 +166,7 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler {
             final int defaultMaxParallelism,
             final BlocklistOperations blocklistOperations,
             final HybridPartitionDataConsumeConstraint 
hybridPartitionDataConsumeConstraint,
-            final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId,
+            final Map<JobVertexID, JobVertexForwardGroup> 
forwardGroupsByJobVertexId,
             final BatchJobRecoveryHandler jobRecoveryHandler)
             throws Exception {
 
@@ -626,7 +626,8 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler {
     private ParallelismAndInputInfos tryDecideParallelismAndInputInfos(
             final ExecutionJobVertex jobVertex, List<BlockingResultInfo> 
inputs) {
         int vertexInitialParallelism = jobVertex.getParallelism();
-        ForwardGroup forwardGroup = 
forwardGroupsByJobVertexId.get(jobVertex.getJobVertexId());
+        JobVertexForwardGroup forwardGroup =
+                forwardGroupsByJobVertexId.get(jobVertex.getJobVertexId());
         if (!jobVertex.isParallelismDecided() && forwardGroup != null) {
             checkState(!forwardGroup.isParallelismDecided());
         }
@@ -681,7 +682,7 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler {
             // the ordering of these elements received by the committer cannot 
be assured, which
             // would break the assumption that CommittableSummary is received 
before
             // CommittableWithLineage.
-            for (JobVertexID jobVertexId : forwardGroup.getJobVertexIds()) {
+            for (JobVertexID jobVertexId : forwardGroup.getVertexIds()) {
                 ExecutionJobVertex executionJobVertex = 
getExecutionJobVertex(jobVertexId);
                 if (!executionJobVertex.isParallelismDecided()) {
                     log.info(
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java
index 24d9f64b2b8..91a578ea995 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java
@@ -45,8 +45,8 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobType;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup;
 import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil;
+import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup;
 import org.apache.flink.runtime.jobmaster.ExecutionDeploymentTracker;
 import org.apache.flink.runtime.jobmaster.event.FileSystemJobEventStore;
 import org.apache.flink.runtime.jobmaster.event.JobEventManager;
@@ -269,7 +269,7 @@ public class AdaptiveBatchSchedulerFactory implements 
SchedulerNGFactory {
         int defaultMaxParallelism =
                 getDefaultMaxParallelism(jobMasterConfiguration, 
executionConfig);
 
-        final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId =
+        final Map<JobVertexID, JobVertexForwardGroup> 
forwardGroupsByJobVertexId =
                 
ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism(
                         jobGraph.getVerticesSortedTopologicallyFromSources());
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
new file mode 100644
index 00000000000..07d8631bf92
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.graph;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph;
+import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import 
org.apache.flink.streaming.runtime.partitioner.ForwardForConsecutiveHashPartitioner;
+import 
org.apache.flink.streaming.runtime.partitioner.ForwardForUnspecifiedPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static 
org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Default implementation for {@link StreamGraphContext}. */
+@Internal
+public class DefaultStreamGraphContext implements StreamGraphContext {
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(DefaultStreamGraphContext.class);
+
+    private final StreamGraph streamGraph;
+    private final ImmutableStreamGraph immutableStreamGraph;
+
+    // The attributes below are reused from AdaptiveGraphManager as 
AdaptiveGraphManager also needs
+    // to use the modified information to create the job vertex.
+
+    // A modifiable map which records the ids of stream nodes to their forward 
groups.
+    // When stream edge's partitioner is modified to forward, we need get 
forward groups by source
+    // and target node id and merge them.
+    private final Map<Integer, StreamNodeForwardGroup> 
steamNodeIdToForwardGroupMap;
+    // A read only map which records the id of stream node which job vertex is 
created, used to
+    // ensure that the stream nodes involved in the modification have not yet 
created job vertices.
+    private final Map<Integer, Integer> frozenNodeToStartNodeMap;
+    // A modifiable map which key is the id of stream node which creates the 
non-chained output, and
+    // value is the stream edge connected to the stream node and the 
non-chained output subscribed
+    // by the edge. It is used to verify whether the edge being modified is 
subscribed to a reused
+    // output and ensures that modifications to StreamEdge can be synchronized 
to NonChainedOutput
+    // as they reuse some attributes.
+    private final Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches;
+
+    public DefaultStreamGraphContext(
+            StreamGraph streamGraph,
+            Map<Integer, StreamNodeForwardGroup> steamNodeIdToForwardGroupMap,
+            Map<Integer, Integer> frozenNodeToStartNodeMap,
+            Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches) {
+        this.streamGraph = checkNotNull(streamGraph);
+        this.steamNodeIdToForwardGroupMap = 
checkNotNull(steamNodeIdToForwardGroupMap);
+        this.frozenNodeToStartNodeMap = checkNotNull(frozenNodeToStartNodeMap);
+        this.opIntermediateOutputsCaches = 
checkNotNull(opIntermediateOutputsCaches);
+        this.immutableStreamGraph = new ImmutableStreamGraph(this.streamGraph);
+    }
+
+    @Override
+    public ImmutableStreamGraph getStreamGraph() {
+        return immutableStreamGraph;
+    }
+
+    @Override
+    public @Nullable StreamOperatorFactory<?> getOperatorFactory(Integer 
streamNodeId) {
+        return streamGraph.getStreamNode(streamNodeId).getOperatorFactory();
+    }
+
+    @Override
+    public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> 
requestInfos) {
+        // We first verify the legality of all requestInfos to ensure that all 
requests can be
+        // modified atomically.
+        for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) {
+            if (!validateStreamEdgeUpdateRequest(requestInfo)) {
+                return false;
+            }
+        }
+
+        for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) {
+            StreamEdge targetEdge =
+                    getStreamEdge(
+                            requestInfo.getSourceId(),
+                            requestInfo.getTargetId(),
+                            requestInfo.getEdgeId());
+            StreamPartitioner<?> newPartitioner = 
requestInfo.getOutputPartitioner();
+            if (newPartitioner != null) {
+                modifyOutputPartitioner(targetEdge, newPartitioner);
+            }
+        }
+
+        return true;
+    }
+
+    private boolean 
validateStreamEdgeUpdateRequest(StreamEdgeUpdateRequestInfo requestInfo) {
+        Integer sourceNodeId = requestInfo.getSourceId();
+        Integer targetNodeId = requestInfo.getTargetId();
+
+        StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, 
requestInfo.getEdgeId());
+
+        if (targetEdge == null) {
+            return false;
+        }
+
+        // Modification is not allowed when the subscribing output is reused.
+        Map<StreamEdge, NonChainedOutput> opIntermediateOutputs =
+                opIntermediateOutputsCaches.get(sourceNodeId);
+        NonChainedOutput output =
+                opIntermediateOutputs != null ? 
opIntermediateOutputs.get(targetEdge) : null;
+        if (output != null) {
+            Set<StreamEdge> consumerStreamEdges =
+                    opIntermediateOutputs.entrySet().stream()
+                            .filter(entry -> entry.getValue().equals(output))
+                            .map(Map.Entry::getKey)
+                            .collect(Collectors.toSet());
+            if (consumerStreamEdges.size() != 1) {
+                LOG.info(
+                        "Skip modifying edge {} because the subscribing output 
is reused.",
+                        targetEdge);
+                return false;
+            }
+        }
+
+        if (frozenNodeToStartNodeMap.containsKey(targetNodeId)) {
+            LOG.info(
+                    "Skip modifying edge {} because the target node with id {} 
is in frozen list.",
+                    targetEdge,
+                    targetNodeId);
+            return false;
+        }
+
+        StreamPartitioner<?> newPartitioner = 
requestInfo.getOutputPartitioner();
+
+        if (newPartitioner != null) {
+            if 
(targetEdge.getPartitioner().getClass().equals(ForwardPartitioner.class)) {
+                LOG.info(
+                        "Modification for edge {} is not allowed as the origin 
partitioner is ForwardPartitioner.",
+                        targetEdge);
+                return false;
+            }
+            if (newPartitioner.getClass().equals(ForwardPartitioner.class)
+                    && !canTargetMergeIntoSourceForwardGroup(
+                            
steamNodeIdToForwardGroupMap.get(targetEdge.getSourceId()),
+                            
steamNodeIdToForwardGroupMap.get(targetEdge.getTargetId()))) {
+                LOG.info(
+                        "Skip modifying edge {} because forward groups can not 
be merged.",
+                        targetEdge);
+                return false;
+            }
+        }
+
+        return true;
+    }
+
+    private void modifyOutputPartitioner(
+            StreamEdge targetEdge, StreamPartitioner<?> newPartitioner) {
+        if (newPartitioner == null || targetEdge == null) {
+            return;
+        }
+        StreamPartitioner<?> oldPartitioner = targetEdge.getPartitioner();
+        targetEdge.setPartitioner(newPartitioner);
+
+        if (targetEdge.getPartitioner() instanceof ForwardPartitioner) {
+            tryConvertForwardPartitionerAndMergeForwardGroup(targetEdge);
+        }
+
+        // The partitioner in NonChainedOutput derived from the consumer edge, 
so we need to ensure
+        // that any modifications to the partitioner of consumer edge are 
synchronized with
+        // NonChainedOutput.
+        Map<StreamEdge, NonChainedOutput> opIntermediateOutputs =
+                opIntermediateOutputsCaches.get(targetEdge.getSourceId());
+        NonChainedOutput output =
+                opIntermediateOutputs != null ? 
opIntermediateOutputs.get(targetEdge) : null;
+        if (output != null) {
+            output.setPartitioner(targetEdge.getPartitioner());
+        }
+        LOG.info(
+                "The original partitioner of the edge {} is: {} , requested 
change to: {} , and finally modified to: {}.",
+                targetEdge,
+                oldPartitioner,
+                newPartitioner,
+                targetEdge.getPartitioner());
+    }
+
+    private void tryConvertForwardPartitionerAndMergeForwardGroup(StreamEdge 
targetEdge) {
+        checkState(targetEdge.getPartitioner() instanceof ForwardPartitioner);
+        Integer sourceNodeId = targetEdge.getSourceId();
+        Integer targetNodeId = targetEdge.getTargetId();
+        if (canConvertToForwardPartitioner(targetEdge)) {
+            targetEdge.setPartitioner(new ForwardPartitioner<>());
+            checkState(mergeForwardGroups(sourceNodeId, targetNodeId));
+        } else if (targetEdge.getPartitioner() instanceof 
ForwardForUnspecifiedPartitioner) {
+            targetEdge.setPartitioner(new RescalePartitioner<>());
+        } else if (targetEdge.getPartitioner() instanceof 
ForwardForConsecutiveHashPartitioner) {
+            targetEdge.setPartitioner(
+                    ((ForwardForConsecutiveHashPartitioner<?>) 
targetEdge.getPartitioner())
+                            .getHashPartitioner());
+        } else {
+            // For ForwardPartitioner, StreamGraphContext can ensure the 
success of the merge.
+            checkState(mergeForwardGroups(sourceNodeId, targetNodeId));
+        }
+    }
+
+    private boolean canConvertToForwardPartitioner(StreamEdge targetEdge) {
+        Integer sourceNodeId = targetEdge.getSourceId();
+        Integer targetNodeId = targetEdge.getTargetId();
+        if (targetEdge.getPartitioner() instanceof 
ForwardForUnspecifiedPartitioner) {
+            return !frozenNodeToStartNodeMap.containsKey(sourceNodeId)
+                    && StreamingJobGraphGenerator.isChainable(targetEdge, 
streamGraph, true)
+                    && canTargetMergeIntoSourceForwardGroup(
+                            steamNodeIdToForwardGroupMap.get(sourceNodeId),
+                            steamNodeIdToForwardGroupMap.get(targetNodeId));
+        } else if (targetEdge.getPartitioner() instanceof 
ForwardForConsecutiveHashPartitioner) {
+            return canTargetMergeIntoSourceForwardGroup(
+                    steamNodeIdToForwardGroupMap.get(sourceNodeId),
+                    steamNodeIdToForwardGroupMap.get(targetNodeId));
+        } else {
+            return false;
+        }
+    }
+
+    private boolean mergeForwardGroups(Integer sourceNodeId, Integer 
targetNodeId) {
+        StreamNodeForwardGroup sourceForwardGroup = 
steamNodeIdToForwardGroupMap.get(sourceNodeId);
+        StreamNodeForwardGroup forwardGroupToMerge = 
steamNodeIdToForwardGroupMap.get(targetNodeId);
+        if (sourceForwardGroup == null || forwardGroupToMerge == null) {
+            return false;
+        }
+        if (!sourceForwardGroup.mergeForwardGroup(forwardGroupToMerge)) {
+            return false;
+        }
+        // Update steamNodeIdToForwardGroupMap.
+        forwardGroupToMerge
+                .getVertexIds()
+                .forEach(nodeId -> steamNodeIdToForwardGroupMap.put(nodeId, 
sourceForwardGroup));
+        return true;
+    }
+
+    private StreamEdge getStreamEdge(Integer sourceId, Integer targetId, 
String edgeId) {
+        for (StreamEdge edge : streamGraph.getStreamEdges(sourceId, targetId)) 
{
+            if (edge.getEdgeId().equals(edgeId)) {
+                return edge;
+            }
+        }
+        return null;
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
new file mode 100644
index 00000000000..82cbac56135
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.graph;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph;
+import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+
+import javax.annotation.Nullable;
+
+import java.util.List;
+
+/**
+ * Defines a context for optimizing and working with a read-only view of a 
StreamGraph. It provides
+ * methods to modify StreamEdges and StreamNodes within the StreamGraph.
+ */
+@Internal
+public interface StreamGraphContext {
+
+    /**
+     * Returns a read-only view of the StreamGraph.
+     *
+     * @return a read-only view of the StreamGraph.
+     */
+    ImmutableStreamGraph getStreamGraph();
+
+    /**
+     * Retrieves the {@link StreamOperatorFactory} for the specified stream 
node id.
+     *
+     * @param streamNodeId the id of the stream node
+     * @return the {@link StreamOperatorFactory} associated with the given 
{@code streamNodeId}, or
+     *     {@code null} if no operator factory is available.
+     */
+    @Nullable
+    StreamOperatorFactory<?> getOperatorFactory(Integer streamNodeId);
+
+    /**
+     * Atomically modifies stream edges within the StreamGraph.
+     *
+     * <p>This method ensures that all the requested modifications to stream 
edges are applied
+     * atomically. This means that if any modification fails, none of the 
changes will be applied,
+     * maintaining the consistency of the StreamGraph.
+     *
+     * @param requestInfos the stream edges to be modified.
+     * @return true if all modifications were successful and applied 
atomically, false otherwise.
+     */
+    boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos);
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
index dbc5c90db5f..90431aa6b83 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
@@ -224,7 +224,7 @@ public class StreamNode implements Serializable {
      *
      * @return Maximum parallelism
      */
-    int getMaxParallelism() {
+    public int getMaxParallelism() {
         return maxParallelism;
     }
 
@@ -233,7 +233,7 @@ public class StreamNode implements Serializable {
      *
      * @param maxParallelism Maximum parallelism to be set
      */
-    void setMaxParallelism(int maxParallelism) {
+    public void setMaxParallelism(int maxParallelism) {
         this.maxParallelism = maxParallelism;
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index 6476cfdfb52..d8d89a3a68b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -46,6 +46,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup;
 import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil;
+import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup;
 import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable;
 import 
org.apache.flink.runtime.jobgraph.topology.DefaultLogicalPipelinedRegion;
 import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalTopology;
@@ -850,7 +851,7 @@ public class StreamingJobGraphGenerator {
                 });
 
         // compute forward groups
-        final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId =
+        final Map<JobVertexID, JobVertexForwardGroup> 
forwardGroupsByJobVertexId =
                 ForwardGroupComputeUtil.computeForwardGroups(
                         topologicalOrderVertices,
                         jobVertex ->
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
new file mode 100644
index 00000000000..7a33e6d9265
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.graph.util;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+
+/** Helper class carries the data required to updates a stream edge. */
+@Internal
+public class StreamEdgeUpdateRequestInfo {
+    private final String edgeId;
+    private final Integer sourceId;
+    private final Integer targetId;
+
+    private StreamPartitioner<?> outputPartitioner;
+
+    public StreamEdgeUpdateRequestInfo(String edgeId, Integer sourceId, 
Integer targetId) {
+        this.edgeId = edgeId;
+        this.sourceId = sourceId;
+        this.targetId = targetId;
+    }
+
+    public StreamEdgeUpdateRequestInfo outputPartitioner(StreamPartitioner<?> 
outputPartitioner) {
+        this.outputPartitioner = outputPartitioner;
+        return this;
+    }
+
+    public String getEdgeId() {
+        return edgeId;
+    }
+
+    public Integer getSourceId() {
+        return sourceId;
+    }
+
+    public Integer getTargetId() {
+        return targetId;
+    }
+
+    public StreamPartitioner<?> getOutputPartitioner() {
+        return outputPartitioner;
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java
index e9a90e98918..3e8e31da8b6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java
@@ -22,19 +22,27 @@ import 
org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
+import org.apache.flink.streaming.api.graph.StreamNode;
+import org.apache.flink.streaming.api.operators.StreamOperator;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
 import java.util.Set;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
+import static 
org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil.computeStreamNodeForwardGroup;
+import static org.apache.flink.util.Preconditions.checkState;
 import static org.assertj.core.api.Assertions.assertThat;
 
-/**
- * Unit tests for {@link 
org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil}.
- */
+/** Unit tests for {@link ForwardGroupComputeUtil}. */
 class ForwardGroupComputeUtilTest {
 
     /**
@@ -54,11 +62,26 @@ class ForwardGroupComputeUtilTest {
         JobVertex v2 = new JobVertex("v2");
         JobVertex v3 = new JobVertex("v3");
 
-        Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3);
+        Set<ForwardGroup<?>> groups = computeForwardGroups(v1, v2, v3);
 
         checkGroupSize(groups, 0);
     }
 
+    @Test
+    void testIsolatedChainedStreamNodeGroups() throws Exception {
+        List<StreamNode> topologicallySortedStreamNodes = createStreamNodes(3);
+        Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = 
Collections.emptyMap();
+
+        Set<ForwardGroup<?>> groups =
+                computeForwardGroups(
+                        topologicallySortedStreamNodes, 
forwardProducersByConsumerNodeId);
+
+        // Different from the job vertex forward group, the stream node 
forward group is allowed to
+        // contain only one single stream node, as these groups may merge with 
other groups in the
+        // future.
+        checkGroupSize(groups, 3, 1, 1, 1);
+    }
+
     /**
      * Tests that the computation of the vertices connected with edges which 
have various result
      * partition types works correctly.
@@ -96,7 +119,39 @@ class ForwardGroupComputeUtilTest {
             
v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         }
 
-        Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3);
+        Set<ForwardGroup<?>> groups = computeForwardGroups(v1, v2, v3);
+
+        checkGroupSize(groups, numOfGroups, groupSizes);
+    }
+
+    @Test
+    void testVariousConnectTypesBetweenChainedStreamNodeGroup() throws 
Exception {
+        testThreeChainedStreamNodeGroupsConnectSequentially(false, true, 2, 1, 
2);
+        testThreeChainedStreamNodeGroupsConnectSequentially(false, false, 3, 
1, 1, 1);
+        testThreeChainedStreamNodeGroupsConnectSequentially(true, true, 1, 3);
+    }
+
+    private void testThreeChainedStreamNodeGroupsConnectSequentially(
+            boolean isForward1, boolean isForward2, int numOfGroups, 
Integer... groupSizes)
+            throws Exception {
+        List<StreamNode> topologicallySortedStreamNodes = createStreamNodes(3);
+        Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = 
new HashMap<>();
+
+        if (isForward1) {
+            forwardProducersByConsumerNodeId
+                    .computeIfAbsent(topologicallySortedStreamNodes.get(1), k 
-> new HashSet<>())
+                    .add(topologicallySortedStreamNodes.get(0));
+        }
+
+        if (isForward2) {
+            forwardProducersByConsumerNodeId
+                    .computeIfAbsent(topologicallySortedStreamNodes.get(2), k 
-> new HashSet<>())
+                    .add(topologicallySortedStreamNodes.get(1));
+        }
+
+        Set<ForwardGroup<?>> groups =
+                computeForwardGroups(
+                        topologicallySortedStreamNodes, 
forwardProducersByConsumerNodeId);
 
         checkGroupSize(groups, numOfGroups, groupSizes);
     }
@@ -131,11 +186,31 @@ class ForwardGroupComputeUtilTest {
         v4.connectNewDataSetAsInput(
                 v3, DistributionPattern.ALL_TO_ALL, 
ResultPartitionType.BLOCKING);
 
-        Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3, v4);
+        Set<ForwardGroup<?>> groups = computeForwardGroups(v1, v2, v3, v4);
 
         checkGroupSize(groups, 1, 3);
     }
 
+    @Test
+    void testTwoInputsMergesIntoOneForStreamNodeForwardGroup() throws 
Exception {
+        List<StreamNode> topologicallySortedStreamNodes = createStreamNodes(4);
+        Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = 
new HashMap<>();
+
+        forwardProducersByConsumerNodeId
+                .computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> 
new HashSet<>())
+                .add(topologicallySortedStreamNodes.get(0));
+
+        forwardProducersByConsumerNodeId
+                .computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> 
new HashSet<>())
+                .add(topologicallySortedStreamNodes.get(1));
+
+        Set<ForwardGroup<?>> groups =
+                computeForwardGroups(
+                        topologicallySortedStreamNodes, 
forwardProducersByConsumerNodeId);
+
+        checkGroupSize(groups, 2, 3, 1);
+    }
+
     /**
      * Tests that the computation of the job graph where one upstream vertex 
connect with two
      * downstream vertices works correctly.
@@ -166,12 +241,28 @@ class ForwardGroupComputeUtilTest {
         v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         v2.getProducedDataSets().get(1).getConsumers().get(0).setForward(true);
 
-        Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3, v4);
+        Set<ForwardGroup<?>> groups = computeForwardGroups(v1, v2, v3, v4);
 
         checkGroupSize(groups, 1, 3);
     }
 
-    private static Set<ForwardGroup> computeForwardGroups(JobVertex... 
vertices) {
+    @Test
+    void testOneInputSplitsIntoTwoForStreamNodeForwardGroup() throws Exception 
{
+        List<StreamNode> topologicallySortedStreamNodes = createStreamNodes(4);
+        Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = 
new HashMap<>();
+        forwardProducersByConsumerNodeId
+                .computeIfAbsent(topologicallySortedStreamNodes.get(3), k -> 
new HashSet<>())
+                .add(topologicallySortedStreamNodes.get(1));
+        forwardProducersByConsumerNodeId
+                .computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> 
new HashSet<>())
+                .add(topologicallySortedStreamNodes.get(1));
+        Set<ForwardGroup<?>> groups =
+                computeForwardGroups(
+                        topologicallySortedStreamNodes, 
forwardProducersByConsumerNodeId);
+        checkGroupSize(groups, 2, 3, 1);
+    }
+
+    private static Set<ForwardGroup<?>> computeForwardGroups(JobVertex... 
vertices) {
         Arrays.asList(vertices).forEach(vertex -> 
vertex.setInvokableClass(NoOpInvokable.class));
         return new HashSet<>(
                 
ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism(
@@ -180,9 +271,61 @@ class ForwardGroupComputeUtilTest {
     }
 
     private static void checkGroupSize(
-            Set<ForwardGroup> groups, int numOfGroups, Integer... sizes) {
+            Set<ForwardGroup<?>> groups, int numOfGroups, Integer... sizes) {
         assertThat(groups.size()).isEqualTo(numOfGroups);
-        
assertThat(groups.stream().map(ForwardGroup::size).collect(Collectors.toList()))
+        assertThat(
+                        groups.stream()
+                                .map(
+                                        group -> {
+                                            if (group instanceof 
JobVertexForwardGroup) {
+                                                return 
((JobVertexForwardGroup) group).size();
+                                            } else {
+                                                return 
((StreamNodeForwardGroup) group).size();
+                                            }
+                                        })
+                                .collect(Collectors.toList()))
                 .contains(sizes);
     }
+
+    private static StreamNode createStreamNode(int id) {
+        return new StreamNode(id, null, null, (StreamOperator<?>) null, null, 
null);
+    }
+
+    private static List<StreamNode> createStreamNodes(int count) {
+        List<StreamNode> streamNodes = new ArrayList<>();
+        for (int i = 1; i <= count; i++) {
+            streamNodes.add(new StreamNode(i, null, null, (StreamOperator<?>) 
null, null, null));
+        }
+        return streamNodes;
+    }
+
+    private static Set<ForwardGroup<?>> computeForwardGroups(
+            List<StreamNode> topologicallySortedStreamNodes,
+            Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId) 
{
+        return new HashSet<>(
+                computeStreamNodeForwardGroupAndCheckParallelism(
+                                topologicallySortedStreamNodes,
+                                id ->
+                                        
forwardProducersByConsumerNodeId.getOrDefault(
+                                                id, Collections.emptySet()))
+                        .values());
+    }
+
+    public static Map<Integer, StreamNodeForwardGroup>
+            computeStreamNodeForwardGroupAndCheckParallelism(
+                    final Iterable<StreamNode> topologicallySortedStreamNodes,
+                    final Function<StreamNode, Set<StreamNode>> 
forwardProducersRetriever) {
+        final Map<Integer, StreamNodeForwardGroup> forwardGroupsByStartNodeId =
+                computeStreamNodeForwardGroup(
+                        topologicallySortedStreamNodes, 
forwardProducersRetriever);
+        topologicallySortedStreamNodes.forEach(
+                startNode -> {
+                    StreamNodeForwardGroup forwardGroup =
+                            forwardGroupsByStartNodeId.get(startNode.getId());
+                    if (forwardGroup != null && 
forwardGroup.isParallelismDecided()) {
+                        checkState(startNode.getParallelism() == 
forwardGroup.getParallelism());
+                    }
+                });
+        return forwardGroupsByStartNodeId;
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroupTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroupTest.java
new file mode 100644
index 00000000000..cc68923cf76
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroupTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.jobgraph.forwardgroup;
+
+import org.apache.flink.streaming.api.graph.StreamNode;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Unit tests for {@link StreamNodeForwardGroup}. */
+class StreamNodeForwardGroupTest {
+    @Test
+    void testStreamNodeForwardGroup() {
+        Set<StreamNode> streamNodes = new HashSet<>();
+
+        streamNodes.add(createStreamNode(0, 1, 1));
+        streamNodes.add(createStreamNode(1, 1, 1));
+
+        StreamNodeForwardGroup forwardGroup = new 
StreamNodeForwardGroup(streamNodes);
+        assertThat(forwardGroup.getParallelism()).isEqualTo(1);
+        assertThat(forwardGroup.getMaxParallelism()).isEqualTo(1);
+        assertThat(forwardGroup.size()).isEqualTo(2);
+
+        streamNodes.add(createStreamNode(3, 1, 1));
+
+        StreamNodeForwardGroup forwardGroup2 = new 
StreamNodeForwardGroup(streamNodes);
+        assertThat(forwardGroup2.size()).isEqualTo(3);
+    }
+
+    @Test
+    void testMergeForwardGroup() {
+        Map<Integer, StreamNode> streamNodeRetriever = new HashMap<>();
+        StreamNodeForwardGroup forwardGroup =
+                createForwardGroupAndUpdateStreamNodeRetriever(
+                        createStreamNode(0, -1, -1), streamNodeRetriever);
+
+        StreamNodeForwardGroup forwardGroupWithUnDecidedParallelism =
+                createForwardGroupAndUpdateStreamNodeRetriever(
+                        createStreamNode(1, -1, -1), streamNodeRetriever);
+
+        forwardGroup.mergeForwardGroup(forwardGroupWithUnDecidedParallelism);
+        assertThat(forwardGroup.isParallelismDecided()).isFalse();
+        assertThat(forwardGroup.isMaxParallelismDecided()).isFalse();
+
+        StreamNodeForwardGroup forwardGroupWithDecidedParallelism =
+                createForwardGroupAndUpdateStreamNodeRetriever(
+                        createStreamNode(2, 2, 4), streamNodeRetriever);
+        forwardGroup.mergeForwardGroup(forwardGroupWithDecidedParallelism);
+        assertThat(forwardGroup.getParallelism()).isEqualTo(2);
+        assertThat(forwardGroup.getMaxParallelism()).isEqualTo(4);
+
+        StreamNodeForwardGroup forwardGroupWithLargerMaxParallelism =
+                createForwardGroupAndUpdateStreamNodeRetriever(
+                        createStreamNode(3, 2, 5), streamNodeRetriever);
+        // The target max parallelism is larger than source.
+        
assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithLargerMaxParallelism)).isTrue();
+        assertThat(forwardGroup.getMaxParallelism()).isEqualTo(4);
+
+        StreamNodeForwardGroup forwardGroupWithSmallerMaxParallelism =
+                createForwardGroupAndUpdateStreamNodeRetriever(
+                        createStreamNode(4, 2, 3), streamNodeRetriever);
+        
assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithSmallerMaxParallelism)).isTrue();
+        assertThat(forwardGroup.getMaxParallelism()).isEqualTo(3);
+
+        StreamNodeForwardGroup 
forwardGroupWithMaxParallelismSmallerThanSourceParallelism =
+                createForwardGroupAndUpdateStreamNodeRetriever(
+                        createStreamNode(5, -1, 1), streamNodeRetriever);
+        assertThat(
+                        forwardGroup.mergeForwardGroup(
+                                
forwardGroupWithMaxParallelismSmallerThanSourceParallelism))
+                .isFalse();
+
+        StreamNodeForwardGroup forwardGroupWithDifferentParallelism =
+                createForwardGroupAndUpdateStreamNodeRetriever(
+                        createStreamNode(6, 1, 3), streamNodeRetriever);
+        
assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithDifferentParallelism)).isFalse();
+
+        StreamNodeForwardGroup forwardGroupWithUndefinedParallelism =
+                createForwardGroupAndUpdateStreamNodeRetriever(
+                        createStreamNode(7, -1, 2), streamNodeRetriever);
+        
assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithUndefinedParallelism)).isTrue();
+        assertThat(forwardGroup.size()).isEqualTo(6);
+        assertThat(forwardGroup.getParallelism()).isEqualTo(2);
+        assertThat(forwardGroup.getMaxParallelism()).isEqualTo(2);
+
+        for (Integer nodeId : forwardGroup.getVertexIds()) {
+            StreamNode node = streamNodeRetriever.get(nodeId);
+            
assertThat(node.getParallelism()).isEqualTo(forwardGroup.getParallelism());
+            
assertThat(node.getMaxParallelism()).isEqualTo(forwardGroup.getMaxParallelism());
+        }
+    }
+
+    private static StreamNode createStreamNode(int id, int parallelism, int 
maxParallelism) {
+        StreamNode streamNode =
+                new StreamNode(id, null, null, (StreamOperator<?>) null, null, 
null);
+        if (parallelism > 0) {
+            streamNode.setParallelism(parallelism);
+        }
+        if (maxParallelism > 0) {
+            streamNode.setMaxParallelism(maxParallelism);
+        }
+        return streamNode;
+    }
+
+    private StreamNodeForwardGroup 
createForwardGroupAndUpdateStreamNodeRetriever(
+            StreamNode streamNode, Map<Integer, StreamNode> 
streamNodeRetriever) {
+        streamNodeRetriever.put(streamNode.getId(), streamNode);
+        return new StreamNodeForwardGroup(Collections.singleton(streamNode));
+    }
+}
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
new file mode 100644
index 00000000000..468f6883b73
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.graph;
+
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.api.transformations.PartitionTransformation;
+import org.apache.flink.streaming.api.transformations.StreamExchangeMode;
+import 
org.apache.flink.streaming.runtime.partitioner.ForwardForUnspecifiedPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
+
+/** Unit tests for {@link DefaultStreamGraphContext}. */
+class DefaultStreamGraphContextTest {
+    @Test
+    void testModifyStreamEdge() {
+        StreamGraph streamGraph = createStreamGraphForModifyStreamEdgeTest();
+        Map<Integer, StreamNodeForwardGroup> 
forwardGroupsByEndpointNodeIdCache = new HashMap<>();
+        Map<Integer, Integer> frozenNodeToStartNodeMap = new HashMap<>();
+        Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches =
+                new HashMap<>();
+        StreamGraphContext streamGraphContext =
+                new DefaultStreamGraphContext(
+                        streamGraph,
+                        forwardGroupsByEndpointNodeIdCache,
+                        frozenNodeToStartNodeMap,
+                        opIntermediateOutputsCaches);
+
+        StreamNode sourceNode =
+                
streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next());
+        StreamNode targetNode =
+                
streamGraph.getStreamNode(sourceNode.getOutEdges().get(0).getTargetId());
+        targetNode.setParallelism(1);
+        StreamEdge targetEdge = sourceNode.getOutEdges().get(0);
+
+        StreamNodeForwardGroup forwardGroup1 =
+                new StreamNodeForwardGroup(Collections.singleton(sourceNode));
+        StreamNodeForwardGroup forwardGroup2 =
+                new StreamNodeForwardGroup(Collections.singleton(targetNode));
+        forwardGroupsByEndpointNodeIdCache.put(sourceNode.getId(), 
forwardGroup1);
+        forwardGroupsByEndpointNodeIdCache.put(targetNode.getId(), 
forwardGroup2);
+
+        StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo =
+                new StreamEdgeUpdateRequestInfo(
+                                targetEdge.getEdgeId(),
+                                targetEdge.getSourceId(),
+                                targetEdge.getTargetId())
+                        .outputPartitioner(new 
ForwardForUnspecifiedPartitioner<>());
+
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                
Collections.singletonList(streamEdgeUpdateRequestInfo)))
+                .isTrue();
+        assertThat(targetEdge.getPartitioner() instanceof 
ForwardPartitioner).isTrue();
+
+        // We cannot modify when partitioner is forward partitioner.
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                
Collections.singletonList(streamEdgeUpdateRequestInfo)))
+                .isEqualTo(false);
+
+        // We cannot modify when target node job vertex is created.
+        frozenNodeToStartNodeMap.put(targetEdge.getTargetId(), 
targetEdge.getTargetId());
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                
Collections.singletonList(streamEdgeUpdateRequestInfo)))
+                .isEqualTo(false);
+
+        NonChainedOutput nonChainedOutput =
+                new NonChainedOutput(
+                        targetEdge.supportsUnalignedCheckpoints(),
+                        targetEdge.getSourceId(),
+                        targetNode.getParallelism(),
+                        targetNode.getMaxParallelism(),
+                        targetEdge.getBufferTimeout(),
+                        false,
+                        new IntermediateDataSetID(),
+                        targetEdge.getOutputTag(),
+                        targetEdge.getPartitioner(),
+                        ResultPartitionType.BLOCKING);
+        opIntermediateOutputsCaches.put(
+                targetEdge.getSourceId(),
+                Map.of(
+                        targetEdge,
+                        nonChainedOutput,
+                        targetNode.getOutEdges().get(0),
+                        nonChainedOutput));
+
+        // We cannot modify when target edge is consumed by multi edges.
+        frozenNodeToStartNodeMap.put(targetEdge.getTargetId(), 
targetEdge.getTargetId());
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                
Collections.singletonList(streamEdgeUpdateRequestInfo)))
+                .isEqualTo(false);
+    }
+
+    @Test
+    void testModifyToForwardPartitionerButResultIsRescale() {
+        StreamGraph streamGraph = createStreamGraphForModifyStreamEdgeTest();
+
+        Map<Integer, StreamNodeForwardGroup> 
forwardGroupsByEndpointNodeIdCache = new HashMap<>();
+        Map<Integer, Integer> frozenNodeToStartNodeMap = new HashMap<>();
+        Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches =
+                new HashMap<>();
+
+        StreamGraphContext streamGraphContext =
+                new DefaultStreamGraphContext(
+                        streamGraph,
+                        forwardGroupsByEndpointNodeIdCache,
+                        frozenNodeToStartNodeMap,
+                        opIntermediateOutputsCaches);
+
+        StreamNode sourceNode =
+                
streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next());
+        StreamNode targetNode =
+                
streamGraph.getStreamNode(sourceNode.getOutEdges().get(0).getTargetId());
+        StreamEdge targetEdge = sourceNode.getOutEdges().get(0);
+
+        StreamNodeForwardGroup forwardGroup1 =
+                new StreamNodeForwardGroup(Collections.singleton(sourceNode));
+        StreamNodeForwardGroup forwardGroup2 =
+                new StreamNodeForwardGroup(Collections.singleton(targetNode));
+        forwardGroupsByEndpointNodeIdCache.put(sourceNode.getId(), 
forwardGroup1);
+        forwardGroupsByEndpointNodeIdCache.put(targetNode.getId(), 
forwardGroup2);
+
+        StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo =
+                new StreamEdgeUpdateRequestInfo(
+                                targetEdge.getEdgeId(),
+                                targetEdge.getSourceId(),
+                                targetEdge.getTargetId())
+                        .outputPartitioner(new 
ForwardForUnspecifiedPartitioner<>());
+
+        // Modify rescale partitioner to forward partitioner.
+
+        // 1. If the source and target are non-chainable.
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                
Collections.singletonList(streamEdgeUpdateRequestInfo)))
+                .isTrue();
+        assertThat(targetEdge.getPartitioner() instanceof 
RescalePartitioner).isTrue();
+
+        // 2. If the forward group cannot be merged.
+        targetNode.setParallelism(1);
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                
Collections.singletonList(streamEdgeUpdateRequestInfo)))
+                .isTrue();
+        assertThat(targetEdge.getPartitioner() instanceof 
RescalePartitioner).isTrue();
+
+        // 3. If the upstream job vertex is created.
+        frozenNodeToStartNodeMap.put(
+                streamGraph.getSourceIDs().iterator().next(),
+                streamGraph.getSourceIDs().iterator().next());
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                
Collections.singletonList(streamEdgeUpdateRequestInfo)))
+                .isTrue();
+        assertThat(targetEdge.getPartitioner() instanceof 
RescalePartitioner).isTrue();
+    }
+
+    private StreamGraph createStreamGraphForModifyStreamEdgeTest() {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        // fromElements(1) -> Map(2) -> Print
+        DataStream<Integer> sourceDataStream = env.fromData(1, 2, 
3).setParallelism(1);
+
+        DataStream<Integer> partitionAfterSourceDataStream =
+                new DataStream<>(
+                        env,
+                        new PartitionTransformation<>(
+                                sourceDataStream.getTransformation(),
+                                new RescalePartitioner<>(),
+                                StreamExchangeMode.PIPELINED));
+
+        DataStream<Integer> mapDataStream =
+                partitionAfterSourceDataStream.map(value -> 
value).setParallelism(2);
+
+        DataStream<Integer> partitionAfterMapDataStream =
+                new DataStream<>(
+                        env,
+                        new PartitionTransformation<>(
+                                mapDataStream.getTransformation(),
+                                new RescalePartitioner<>(),
+                                StreamExchangeMode.PIPELINED));
+
+        partitionAfterMapDataStream.print();
+
+        return env.getStreamGraph();
+    }
+}

Reply via email to