This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 5d0cb98fe7843796e6ca0598060839f2b48f0882
Author: sunxia <xingbe...@gmail.com>
AuthorDate: Fri Jan 3 15:01:29 2025 +0800

    [FLINK-36608][runtime] Expand the adaptive graph component to support 
adaptive broadcast join capability.
---
 .../DefaultAdaptiveExecutionHandler.java           |   9 +-
 .../StreamGraphOptimizationStrategy.java           |   8 ++
 .../adaptivebatch/StreamGraphOptimizer.java        |   4 +
 .../streaming/api/graph/AdaptiveGraphManager.java  |  10 ++
 .../api/graph/DefaultStreamGraphContext.java       | 105 ++++++++++++++++-----
 .../flink/streaming/api/graph/StreamEdge.java      |   6 +-
 .../streaming/api/graph/StreamGraphContext.java    |  27 ++++++
 .../api/graph/util/ImmutableStreamGraph.java       |  13 ++-
 .../api/graph/util/ImmutableStreamNode.java        |  16 ++++
 .../graph/util/StreamEdgeUpdateRequestInfo.java    |  17 +++-
 ...tInfo.java => StreamNodeUpdateRequestInfo.java} |  41 ++++----
 .../DefaultAdaptiveExecutionHandlerTest.java       |   2 +-
 .../adaptivebatch/StreamGraphOptimizerTest.java    |  19 ++++
 .../api/graph/DefaultStreamGraphContextTest.java   |  13 ++-
 .../api/graph/util/ImmutableStreamGraphTest.java   |   4 +-
 .../scheduling/AdaptiveBatchSchedulerITCase.java   |   4 +-
 16 files changed, 236 insertions(+), 62 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java
index c0942f65372..b55734220d3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java
@@ -67,6 +67,8 @@ public class DefaultAdaptiveExecutionHandler implements 
AdaptiveExecutionHandler
 
         this.streamGraphOptimizer =
                 new StreamGraphOptimizer(streamGraph.getJobConfiguration(), 
userClassloader);
+        this.streamGraphOptimizer.initializeStrategies(
+                adaptiveGraphManager.getStreamGraphContext());
     }
 
     @Override
@@ -107,10 +109,11 @@ public class DefaultAdaptiveExecutionHandler implements 
AdaptiveExecutionHandler
                                                 return existing;
                                             }));
 
+            List<Integer> finishedStreamNodeIds =
+                    
adaptiveGraphManager.getStreamNodeIdsByJobVertexId(vertexId);
             OperatorsFinished operatorsFinished =
-                    new OperatorsFinished(
-                            
adaptiveGraphManager.getStreamNodeIdsByJobVertexId(vertexId),
-                            resultInfoMap);
+                    new OperatorsFinished(finishedStreamNodeIds, 
resultInfoMap);
+            
adaptiveGraphManager.addFinishedStreamNodeIds(finishedStreamNodeIds);
 
             streamGraphOptimizer.onOperatorsFinished(
                     operatorsFinished, 
adaptiveGraphManager.getStreamGraphContext());
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizationStrategy.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizationStrategy.java
index 24d9f3252e0..58308614910 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizationStrategy.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizationStrategy.java
@@ -42,6 +42,14 @@ public interface StreamGraphOptimizationStrategy {
                             "Defines a comma-separated list of fully qualified 
class names "
                                     + "implementing the 
StreamGraphOptimizationStrategy interface.");
 
+    /**
+     * Initializes the StreamGraphOptimizationStrategy with the provided 
{@link StreamGraphContext}.
+     *
+     * @param context the StreamGraphContext with a read-only view of a 
StreamGraph, providing
+     *     methods to modify StreamEdges and StreamNodes within the 
StreamGraph.
+     */
+    default void initialize(StreamGraphContext context) {}
+
     /**
      * Tries to optimize the StreamGraph using the provided {@link 
OperatorsFinished} and {@link
      * StreamGraphContext}. The method returns a boolean indicating whether 
the StreamGraph was
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizer.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizer.java
index 1ac7583d6f2..cf459479e05 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizer.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizer.java
@@ -56,6 +56,10 @@ public class StreamGraphOptimizer {
         }
     }
 
+    public void initializeStrategies(StreamGraphContext context) {
+        checkNotNull(optimizationStrategies).forEach(strategy -> 
strategy.initialize(context));
+    }
+
     /**
      * Applies all loaded optimization strategies to the StreamGraph.
      *
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java
index 965187fae5b..cb327ac9600 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java
@@ -135,6 +135,9 @@ public class AdaptiveGraphManager
     // Records the ID of the job vertex that has completed execution.
     private final Set<JobVertexID> finishedJobVertices;
 
+    // Records the ID of the stream nodes that has completed execution.
+    private final Set<Integer> finishedStreamNodeIds;
+
     private final AtomicBoolean hasHybridResultPartition;
 
     private final SlotSharingGroup defaultSlotSharingGroup;
@@ -171,6 +174,7 @@ public class AdaptiveGraphManager
         this.streamNodeIdsToJobVertexMap = new HashMap<>();
 
         this.finishedJobVertices = new HashSet<>();
+        this.finishedStreamNodeIds = new HashSet<>();
 
         this.streamGraphContext =
                 new DefaultStreamGraphContext(
@@ -179,6 +183,8 @@ public class AdaptiveGraphManager
                         frozenNodeToStartNodeMap,
                         intermediateOutputsCaches,
                         consumerEdgeIdToIntermediateDataSetMap,
+                        finishedStreamNodeIds,
+                        userClassloader,
                         this);
 
         this.jobGraph = createAndInitializeJobGraph(streamGraph, 
streamGraph.getJobID());
@@ -208,6 +214,10 @@ public class AdaptiveGraphManager
         return createJobVerticesAndUpdateGraph(streamNodes);
     }
 
+    public void addFinishedStreamNodeIds(List<Integer> finishedStreamNodeIds) {
+        this.finishedStreamNodeIds.addAll(finishedStreamNodeIds);
+    }
+
     /**
      * Retrieves the StreamNodeForwardGroup which provides a stream node level 
ForwardGroup.
      *
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
index f0ce8f8cbd7..b7017b36e77 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
@@ -22,9 +22,13 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
 import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
 import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.api.graph.util.StreamNodeUpdateRequestInfo;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import 
org.apache.flink.streaming.runtime.partitioner.ForwardForConsecutiveHashPartitioner;
 import 
org.apache.flink.streaming.runtime.partitioner.ForwardForUnspecifiedPartitioner;
@@ -74,6 +78,7 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
     private final Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches;
 
     private final Map<String, IntermediateDataSet> 
consumerEdgeIdToIntermediateDataSetMap;
+    private final Set<Integer> finishedStreamNodeIds;
 
     @Nullable private final StreamGraphUpdateListener 
streamGraphUpdateListener;
 
@@ -83,13 +88,17 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
             Map<Integer, StreamNodeForwardGroup> steamNodeIdToForwardGroupMap,
             Map<Integer, Integer> frozenNodeToStartNodeMap,
             Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches,
-            Map<String, IntermediateDataSet> 
consumerEdgeIdToIntermediateDataSetMap) {
+            Map<String, IntermediateDataSet> 
consumerEdgeIdToIntermediateDataSetMap,
+            Set<Integer> finishedStreamNodeIds,
+            ClassLoader userClassloader) {
         this(
                 streamGraph,
                 steamNodeIdToForwardGroupMap,
                 frozenNodeToStartNodeMap,
                 opIntermediateOutputsCaches,
                 consumerEdgeIdToIntermediateDataSetMap,
+                finishedStreamNodeIds,
+                userClassloader,
                 null);
     }
 
@@ -99,14 +108,17 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
             Map<Integer, Integer> frozenNodeToStartNodeMap,
             Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputsCaches,
             Map<String, IntermediateDataSet> 
consumerEdgeIdToIntermediateDataSetMap,
+            Set<Integer> finishedStreamNodeIds,
+            ClassLoader userClassloader,
             @Nullable StreamGraphUpdateListener streamGraphUpdateListener) {
         this.streamGraph = checkNotNull(streamGraph);
         this.steamNodeIdToForwardGroupMap = 
checkNotNull(steamNodeIdToForwardGroupMap);
         this.frozenNodeToStartNodeMap = checkNotNull(frozenNodeToStartNodeMap);
         this.opIntermediateOutputsCaches = 
checkNotNull(opIntermediateOutputsCaches);
-        this.immutableStreamGraph = new ImmutableStreamGraph(this.streamGraph);
+        this.immutableStreamGraph = new ImmutableStreamGraph(this.streamGraph, 
userClassloader);
         this.consumerEdgeIdToIntermediateDataSetMap =
                 checkNotNull(consumerEdgeIdToIntermediateDataSetMap);
+        this.finishedStreamNodeIds = finishedStreamNodeIds;
         this.streamGraphUpdateListener = streamGraphUpdateListener;
     }
 
@@ -140,6 +152,33 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
             if (newPartitioner != null) {
                 modifyOutputPartitioner(targetEdge, newPartitioner);
             }
+            if (requestInfo.getTypeNumber() != 0) {
+                targetEdge.setTypeNumber(requestInfo.getTypeNumber());
+            }
+        }
+
+        // Notify the listener that the StreamGraph has been updated.
+        if (streamGraphUpdateListener != null) {
+            streamGraphUpdateListener.onStreamGraphUpdated();
+        }
+
+        return true;
+    }
+
+    @Override
+    public boolean modifyStreamNode(List<StreamNodeUpdateRequestInfo> 
requestInfos) {
+        for (StreamNodeUpdateRequestInfo requestInfo : requestInfos) {
+            StreamNode streamNode = 
streamGraph.getStreamNode(requestInfo.getNodeId());
+            if (requestInfo.getTypeSerializersIn() != null) {
+                if (requestInfo.getTypeSerializersIn().length
+                        != streamNode.getTypeSerializersIn().length) {
+                    LOG.info(
+                            "Modification for node {} is not allowed as the 
array size of typeSerializersIn is not matched.",
+                            requestInfo.getNodeId());
+                    return false;
+                }
+                
streamNode.setSerializersIn(requestInfo.getTypeSerializersIn());
+            }
         }
 
         // Notify the listener that the StreamGraph has been updated.
@@ -150,32 +189,46 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
         return true;
     }
 
+    @Override
+    public boolean areAllUpstreamNodesFinished(ImmutableStreamNode streamNode) 
{
+        for (ImmutableStreamEdge streamEdge : streamNode.getInEdges()) {
+            if (!finishedStreamNodeIds.contains(streamEdge.getSourceId())) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
+    @Override
+    public IntermediateDataSetID getConsumedIntermediateDataSetId(String 
edgeId) {
+        return consumerEdgeIdToIntermediateDataSetMap.get(edgeId).getId();
+    }
+
     private boolean 
validateStreamEdgeUpdateRequest(StreamEdgeUpdateRequestInfo requestInfo) {
         Integer sourceNodeId = requestInfo.getSourceId();
         Integer targetNodeId = requestInfo.getTargetId();
 
         StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, 
requestInfo.getEdgeId());
 
-        if (targetEdge == null) {
-            return false;
-        }
-
-        // Modification is not allowed when the subscribing output is reused.
-        Map<StreamEdge, NonChainedOutput> opIntermediateOutputs =
-                opIntermediateOutputsCaches.get(sourceNodeId);
-        NonChainedOutput output =
-                opIntermediateOutputs != null ? 
opIntermediateOutputs.get(targetEdge) : null;
-        if (output != null) {
-            Set<StreamEdge> consumerStreamEdges =
-                    opIntermediateOutputs.entrySet().stream()
-                            .filter(entry -> entry.getValue().equals(output))
-                            .map(Map.Entry::getKey)
-                            .collect(Collectors.toSet());
-            if (consumerStreamEdges.size() != 1) {
-                LOG.info(
-                        "Skip modifying edge {} because the subscribing output 
is reused.",
-                        targetEdge);
-                return false;
+        // Modification to output partitioner is not allowed when the 
subscribing output is reused.
+        if (requestInfo.getOutputPartitioner() != null) {
+            Map<StreamEdge, NonChainedOutput> opIntermediateOutputs =
+                    opIntermediateOutputsCaches.get(sourceNodeId);
+            NonChainedOutput output =
+                    opIntermediateOutputs != null ? 
opIntermediateOutputs.get(targetEdge) : null;
+            if (output != null) {
+                Set<StreamEdge> consumerStreamEdges =
+                        opIntermediateOutputs.entrySet().stream()
+                                .filter(entry -> 
entry.getValue().equals(output))
+                                .map(Map.Entry::getKey)
+                                .collect(Collectors.toSet());
+                if (consumerStreamEdges.size() != 1) {
+                    LOG.info(
+                            "Skip modifying edge {} because the subscribing 
output is reused.",
+                            targetEdge);
+                    return false;
+                }
             }
         }
 
@@ -212,7 +265,7 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
 
     private void modifyOutputPartitioner(
             StreamEdge targetEdge, StreamPartitioner<?> newPartitioner) {
-        if (newPartitioner == null || targetEdge == null) {
+        if (newPartitioner == null) {
             return;
         }
         StreamPartitioner<?> oldPartitioner = targetEdge.getPartitioner();
@@ -316,6 +369,10 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
                 return edge;
             }
         }
-        return null;
+
+        throw new RuntimeException(
+                String.format(
+                        "Stream edge with id '%s' is not found whose source id 
is %d, target id is %d.",
+                        edgeId, sourceId, targetId));
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java
index fc5e8c440c0..65472508cb4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java
@@ -55,7 +55,7 @@ public class StreamEdge implements Serializable {
     private final int uniqueId;
 
     /** The type number of the input for co-tasks. */
-    private final int typeNumber;
+    private int typeNumber;
 
     /** The side-output tag (if any) of this {@link StreamEdge}. */
     private final OutputTag outputTag;
@@ -193,6 +193,10 @@ public class StreamEdge implements Serializable {
         this.supportsUnalignedCheckpoints = supportsUnalignedCheckpoints;
     }
 
+    public void setTypeNumber(int typeNumber) {
+        this.typeNumber = typeNumber;
+    }
+
     public boolean supportsUnalignedCheckpoints() {
         return supportsUnalignedCheckpoints;
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
index 13d1c8aa513..2368c7fa329 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
@@ -19,8 +19,11 @@
 package org.apache.flink.streaming.api.graph;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
 import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.api.graph.util.StreamNodeUpdateRequestInfo;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 
 import javax.annotation.Nullable;
@@ -63,6 +66,30 @@ public interface StreamGraphContext {
      */
     boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos);
 
+    /**
+     * Modifies stream nodes within the StreamGraph.
+     *
+     * @param requestInfos the stream nodes to be modified.
+     * @return true if the modification was successful, false otherwise.
+     */
+    boolean modifyStreamNode(List<StreamNodeUpdateRequestInfo> requestInfos);
+
+    /**
+     * Check whether all upstream nodes of the stream node have finished 
executing.
+     *
+     * @param streamNode the stream node that needs to be determined.
+     * @return true if all upstream nodes are finished, false otherwise.
+     */
+    boolean areAllUpstreamNodesFinished(ImmutableStreamNode streamNode);
+
+    /**
+     * Retrieves the IntermediateDataSetID consumed by the specified edge.
+     *
+     * @param edgeId id of the edge
+     * @return the consumed IntermediateDataSetID
+     */
+    IntermediateDataSetID getConsumedIntermediateDataSetId(String edgeId);
+
     /** Interface for observers that monitor the status of a StreamGraph. */
     @Internal
     interface StreamGraphUpdateListener {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamGraph.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamGraph.java
index eef1f001c31..6820cab0ac6 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamGraph.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamGraph.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.api.graph.util;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.ReadableConfig;
 import org.apache.flink.streaming.api.graph.StreamGraph;
 
 import java.util.HashMap;
@@ -28,11 +29,13 @@ import java.util.Map;
 @Internal
 public class ImmutableStreamGraph {
     private final StreamGraph streamGraph;
+    private final ClassLoader userClassloader;
 
     private final Map<Integer, ImmutableStreamNode> immutableStreamNodes;
 
-    public ImmutableStreamGraph(StreamGraph streamGraph) {
+    public ImmutableStreamGraph(StreamGraph streamGraph, ClassLoader 
userClassloader) {
         this.streamGraph = streamGraph;
+        this.userClassloader = userClassloader;
         this.immutableStreamNodes = new HashMap<>();
     }
 
@@ -43,4 +46,12 @@ public class ImmutableStreamGraph {
         return immutableStreamNodes.computeIfAbsent(
                 vertexId, id -> new 
ImmutableStreamNode(streamGraph.getStreamNode(id)));
     }
+
+    public ReadableConfig getConfiguration() {
+        return streamGraph.getJobConfiguration();
+    }
+
+    public ClassLoader getUserClassLoader() {
+        return userClassloader;
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamNode.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamNode.java
index d4d3d76af19..e5dae3fbc1a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamNode.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamNode.java
@@ -19,12 +19,18 @@
 package org.apache.flink.streaming.api.graph.util;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamNode;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+
+import javax.annotation.Nullable;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 
 /** Helper class that provides read-only StreamNode. */
 @Internal
@@ -61,6 +67,10 @@ public class ImmutableStreamNode {
         return streamNode.getId();
     }
 
+    public @Nullable StreamOperatorFactory<?> getOperatorFactory() {
+        return streamNode.getOperatorFactory();
+    }
+
     public int getMaxParallelism() {
         return streamNode.getMaxParallelism();
     }
@@ -68,4 +78,10 @@ public class ImmutableStreamNode {
     public int getParallelism() {
         return streamNode.getParallelism();
     }
+
+    public TypeSerializer<?>[] getTypeSerializersIn() {
+        return Arrays.stream(streamNode.getTypeSerializersIn())
+                .filter(Objects::nonNull)
+                .toArray(TypeSerializer<?>[]::new);
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
index 7a33e6d9265..4d0c4e78565 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
@@ -30,17 +30,28 @@ public class StreamEdgeUpdateRequestInfo {
 
     private StreamPartitioner<?> outputPartitioner;
 
+    // The type number for the input of co-tasks.
+    // For two or more inputs, typeNumber must be >= 1, and 0 means the 
request will not change the
+    // typeNumber.
+    private int typeNumber;
+
     public StreamEdgeUpdateRequestInfo(String edgeId, Integer sourceId, 
Integer targetId) {
         this.edgeId = edgeId;
         this.sourceId = sourceId;
         this.targetId = targetId;
     }
 
-    public StreamEdgeUpdateRequestInfo outputPartitioner(StreamPartitioner<?> 
outputPartitioner) {
+    public StreamEdgeUpdateRequestInfo withOutputPartitioner(
+            StreamPartitioner<?> outputPartitioner) {
         this.outputPartitioner = outputPartitioner;
         return this;
     }
 
+    public StreamEdgeUpdateRequestInfo withTypeNumber(int typeNumber) {
+        this.typeNumber = typeNumber;
+        return this;
+    }
+
     public String getEdgeId() {
         return edgeId;
     }
@@ -56,4 +67,8 @@ public class StreamEdgeUpdateRequestInfo {
     public StreamPartitioner<?> getOutputPartitioner() {
         return outputPartitioner;
     }
+
+    public int getTypeNumber() {
+        return typeNumber;
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamNodeUpdateRequestInfo.java
similarity index 53%
copy from 
flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
copy to 
flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamNodeUpdateRequestInfo.java
index 7a33e6d9265..fd210665084 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamNodeUpdateRequestInfo.java
@@ -19,41 +19,34 @@
 package org.apache.flink.streaming.api.graph.util;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+
+import javax.annotation.Nullable;
 
 /** Helper class carries the data required to updates a stream edge. */
 @Internal
-public class StreamEdgeUpdateRequestInfo {
-    private final String edgeId;
-    private final Integer sourceId;
-    private final Integer targetId;
+public class StreamNodeUpdateRequestInfo {
+    private final Integer nodeId;
 
-    private StreamPartitioner<?> outputPartitioner;
+    // Null means it does not request to change the typeSerializersIn.
+    @Nullable private TypeSerializer<?>[] typeSerializersIn;
 
-    public StreamEdgeUpdateRequestInfo(String edgeId, Integer sourceId, 
Integer targetId) {
-        this.edgeId = edgeId;
-        this.sourceId = sourceId;
-        this.targetId = targetId;
+    public StreamNodeUpdateRequestInfo(Integer nodeId) {
+        this.nodeId = nodeId;
     }
 
-    public StreamEdgeUpdateRequestInfo outputPartitioner(StreamPartitioner<?> 
outputPartitioner) {
-        this.outputPartitioner = outputPartitioner;
+    public StreamNodeUpdateRequestInfo withTypeSerializersIn(
+            TypeSerializer<?>[] typeSerializersIn) {
+        this.typeSerializersIn = typeSerializersIn;
         return this;
     }
 
-    public String getEdgeId() {
-        return edgeId;
-    }
-
-    public Integer getSourceId() {
-        return sourceId;
-    }
-
-    public Integer getTargetId() {
-        return targetId;
+    public Integer getNodeId() {
+        return nodeId;
     }
 
-    public StreamPartitioner<?> getOutputPartitioner() {
-        return outputPartitioner;
+    @Nullable
+    public TypeSerializer<?>[] getTypeSerializersIn() {
+        return typeSerializersIn;
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java
index 3ad14047b0e..791ab68e0b0 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java
@@ -277,7 +277,7 @@ class DefaultAdaptiveExecutionHandlerTest {
                                         outEdge.getEdgeId(),
                                         outEdge.getSourceId(),
                                         outEdge.getTargetId());
-                        requestInfo.outputPartitioner(new 
RebalancePartitioner<>());
+                        requestInfo.withOutputPartitioner(new 
RebalancePartitioner<>());
                         requestInfos.add(requestInfo);
                     }
                 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java
index 2f063db1642..e2316a09e62 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java
@@ -19,9 +19,12 @@
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.streaming.api.graph.StreamGraphContext;
 import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph;
+import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
 import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
+import org.apache.flink.streaming.api.graph.util.StreamNodeUpdateRequestInfo;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 
 import org.junit.jupiter.api.BeforeEach;
@@ -77,6 +80,22 @@ class StreamGraphOptimizerTest {
                             List<StreamEdgeUpdateRequestInfo> requestInfos) {
                         return false;
                     }
+
+                    @Override
+                    public boolean modifyStreamNode(
+                            List<StreamNodeUpdateRequestInfo> requestInfos) {
+                        return false;
+                    }
+
+                    @Override
+                    public boolean 
areAllUpstreamNodesFinished(ImmutableStreamNode streamNode) {
+                        return false;
+                    }
+
+                    @Override
+                    public IntermediateDataSetID 
getConsumedIntermediateDataSetId(String edgeId) {
+                        return null;
+                    }
                 };
 
         optimizer.onOperatorsFinished(operatorsFinished, context);
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
index 0be8c246893..dcb1566ddd2 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
@@ -34,6 +34,7 @@ import org.junit.jupiter.api.Test;
 
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Map;
 
 import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@@ -53,7 +54,9 @@ class DefaultStreamGraphContextTest {
                         forwardGroupsByEndpointNodeIdCache,
                         frozenNodeToStartNodeMap,
                         opIntermediateOutputsCaches,
-                        new HashMap<>());
+                        new HashMap<>(),
+                        new HashSet<>(),
+                        Thread.currentThread().getContextClassLoader());
 
         StreamNode sourceNode =
                 
streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next());
@@ -74,7 +77,7 @@ class DefaultStreamGraphContextTest {
                                 targetEdge.getEdgeId(),
                                 targetEdge.getSourceId(),
                                 targetEdge.getTargetId())
-                        .outputPartitioner(new 
ForwardForUnspecifiedPartitioner<>());
+                        .withOutputPartitioner(new 
ForwardForUnspecifiedPartitioner<>());
 
         assertThat(
                         streamGraphContext.modifyStreamEdge(
@@ -138,7 +141,9 @@ class DefaultStreamGraphContextTest {
                         forwardGroupsByEndpointNodeIdCache,
                         frozenNodeToStartNodeMap,
                         opIntermediateOutputsCaches,
-                        new HashMap<>());
+                        new HashMap<>(),
+                        new HashSet<>(),
+                        Thread.currentThread().getContextClassLoader());
 
         StreamNode sourceNode =
                 
streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next());
@@ -158,7 +163,7 @@ class DefaultStreamGraphContextTest {
                                 targetEdge.getEdgeId(),
                                 targetEdge.getSourceId(),
                                 targetEdge.getTargetId())
-                        .outputPartitioner(new 
ForwardForUnspecifiedPartitioner<>());
+                        .withOutputPartitioner(new 
ForwardForUnspecifiedPartitioner<>());
 
         // Modify rescale partitioner to forward partitioner.
 
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamGraphTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamGraphTest.java
index 53777f2efdd..91960050bfd 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamGraphTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamGraphTest.java
@@ -36,7 +36,9 @@ class ImmutableStreamGraphTest {
         StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
         env.fromSequence(1L, 3L).map(value -> 
value).print().setParallelism(env.getParallelism());
         StreamGraph streamGraph = env.getStreamGraph();
-        ImmutableStreamGraph immutableStreamGraph = new 
ImmutableStreamGraph(streamGraph);
+        ImmutableStreamGraph immutableStreamGraph =
+                new ImmutableStreamGraph(
+                        streamGraph, 
Thread.currentThread().getContextClassLoader());
 
         for (StreamNode streamNode : streamGraph.getStreamNodes()) {
             isStreamNodeEquals(streamNode, 
immutableStreamGraph.getStreamNode(streamNode.getId()));
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
index 21b9a7645a9..c7f9b3c97cc 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
@@ -358,10 +358,10 @@ class AdaptiveBatchSchedulerITCase {
                                     outEdge.getSourceId(),
                                     outEdge.getTargetId());
                     if 
(convertToBroadcastEdgeIds.contains(outEdge.getEdgeId())) {
-                        requestInfo.outputPartitioner(new 
BroadcastPartitioner<>());
+                        requestInfo.withOutputPartitioner(new 
BroadcastPartitioner<>());
                         requestInfos.add(requestInfo);
                     } else if 
(convertToRescaleEdgeIds.contains(outEdge.getEdgeId())) {
-                        requestInfo.outputPartitioner(new 
RescalePartitioner<>());
+                        requestInfo.withOutputPartitioner(new 
RescalePartitioner<>());
                         requestInfos.add(requestInfo);
                     }
                 }


Reply via email to