This is an automated email from the ASF dual-hosted git repository.

fanrui pushed a commit to branch release-1.20
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 76156cc9a1076754ee2db009a757881606cf3eab
Author: Rui Fan <1996fan...@gmail.com>
AuthorDate: Wed Aug 20 12:35:49 2025 +0200

    [FLINK-38267][checkpoint] Refactor hasInputState and hasOutputState related 
logic in TaskStateAssignment
---
 .../checkpoint/StateAssignmentOperation.java       |  4 +-
 .../runtime/checkpoint/TaskStateAssignment.java    | 81 +++++++++++++++++++---
 2 files changed, 73 insertions(+), 12 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 07088d901f6..54478c8c237 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -357,7 +357,7 @@ public class StateAssignmentOperation {
     public void reDistributeResultSubpartitionStates(TaskStateAssignment 
assignment) {
         // FLINK-31963: We can skip this phase if there is no output state AND 
downstream has no
         // input states
-        if (!assignment.hasOutputState && 
!assignment.hasDownstreamInputStates()) {
+        if (!assignment.hasOutputState() && 
!assignment.hasDownstreamInputStates()) {
             return;
         }
 
@@ -406,7 +406,7 @@ public class StateAssignmentOperation {
     public void reDistributeInputChannelStates(TaskStateAssignment 
stateAssignment) {
         // FLINK-31963: We can skip this phase only if there is no input state 
AND upstream has no
         // output states
-        if (!stateAssignment.hasInputState && 
!stateAssignment.hasUpstreamOutputStates()) {
+        if (!stateAssignment.hasInputState() && 
!stateAssignment.hasUpstreamOutputStates()) {
             return;
         }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
index c4e30204030..98f9082b064 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
@@ -20,6 +20,8 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.runtime.OperatorIDPair;
 import 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
 import 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
+import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
 import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
@@ -28,6 +30,8 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
 import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.MergedInputChannelStateHandle;
+import org.apache.flink.runtime.state.MergedResultSubpartitionStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
 import org.apache.flink.runtime.state.StateObject;
@@ -40,6 +44,7 @@ import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -48,6 +53,7 @@ import java.util.Optional;
 import java.util.Set;
 import java.util.function.BiFunction;
 import java.util.function.Function;
+import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import static java.util.Collections.emptySet;
@@ -65,12 +71,16 @@ class TaskStateAssignment {
     final Map<OperatorID, OperatorState> oldState;
     final boolean hasNonFinishedState;
     final boolean isFullyFinished;
-    final boolean hasInputState;
-    final boolean hasOutputState;
     final int newParallelism;
     final OperatorID inputOperatorID;
     final OperatorID outputOperatorID;
 
+    /** The InputGate set that containing input buffer state. */
+    private final Set<Integer> inputStateGates;
+
+    /** The ResultPartition set that containing input buffer state. */
+    private final Set<Integer> outputStatePartitions;
+
     final Map<OperatorInstanceID, List<OperatorStateHandle>> 
subManagedOperatorState;
     final Map<OperatorInstanceID, List<OperatorStateHandle>> 
subRawOperatorState;
     final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState;
@@ -127,12 +137,63 @@ class TaskStateAssignment {
         outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID();
         inputOperatorID = operatorIDs.get(operatorIDs.size() - 
1).getGeneratedOperatorID();
 
-        hasInputState =
-                oldState.get(inputOperatorID).getStates().stream()
-                        .anyMatch(subState -> 
!subState.getInputChannelState().isEmpty());
-        hasOutputState =
-                oldState.get(outputOperatorID).getStates().stream()
-                        .anyMatch(subState -> 
!subState.getResultSubpartitionState().isEmpty());
+        inputStateGates = 
extractInputStateGates(oldState.get(inputOperatorID));
+        outputStatePartitions = 
extractOutputStatePartitions(oldState.get(outputOperatorID));
+    }
+
+    private static Set<Integer> extractInputStateGates(OperatorState 
operatorState) {
+        return operatorState.getStates().stream()
+                .map(OperatorSubtaskState::getInputChannelState)
+                .flatMap(Collection::stream)
+                .flatMapToInt(
+                        handle -> {
+                            if (handle instanceof InputChannelStateHandle) {
+                                return IntStream.of(
+                                        ((InputChannelStateHandle) 
handle).getInfo().getGateIdx());
+                            } else if (handle instanceof 
MergedInputChannelStateHandle) {
+                                return ((MergedInputChannelStateHandle) handle)
+                                        
.getInfos().stream().mapToInt(InputChannelInfo::getGateIdx);
+                            } else {
+                                throw new IllegalStateException(
+                                        "Invalid input channel state : " + 
handle.getClass());
+                            }
+                        })
+                .distinct()
+                .boxed()
+                .collect(Collectors.toSet());
+    }
+
+    private static Set<Integer> extractOutputStatePartitions(OperatorState 
operatorState) {
+        return operatorState.getStates().stream()
+                .map(OperatorSubtaskState::getResultSubpartitionState)
+                .flatMap(Collection::stream)
+                .flatMapToInt(
+                        handle -> {
+                            if (handle instanceof 
ResultSubpartitionStateHandle) {
+                                return IntStream.of(
+                                        ((ResultSubpartitionStateHandle) 
handle)
+                                                .getInfo()
+                                                .getPartitionIdx());
+                            } else if (handle instanceof 
MergedResultSubpartitionStateHandle) {
+                                return ((MergedResultSubpartitionStateHandle) 
handle)
+                                        .getInfos().stream()
+                                                
.mapToInt(ResultSubpartitionInfo::getPartitionIdx);
+                            } else {
+                                throw new IllegalStateException(
+                                        "Invalid output channel state : " + 
handle.getClass());
+                            }
+                        })
+                .distinct()
+                .boxed()
+                .collect(Collectors.toSet());
+    }
+
+    public boolean hasInputState() {
+        return !inputStateGates.isEmpty();
+    }
+
+    public boolean hasOutputState() {
+        return !outputStatePartitions.isEmpty();
     }
 
     public TaskStateAssignment[] getDownstreamAssignments() {
@@ -210,7 +271,7 @@ class TaskStateAssignment {
         if (hasUpstreamOutputStates == null) {
             hasUpstreamOutputStates =
                     Arrays.stream(getUpstreamAssignments())
-                            .anyMatch(assignment -> assignment.hasOutputState);
+                            .anyMatch(TaskStateAssignment::hasOutputState);
         }
         return hasUpstreamOutputStates;
     }
@@ -219,7 +280,7 @@ class TaskStateAssignment {
         if (hasDownstreamInputStates == null) {
             hasDownstreamInputStates =
                     Arrays.stream(getDownstreamAssignments())
-                            .anyMatch(assignment -> assignment.hasInputState);
+                            .anyMatch(TaskStateAssignment::hasInputState);
         }
         return hasDownstreamInputStates;
     }

Reply via email to