This is an automated email from the ASF dual-hosted git repository. fanrui pushed a commit to branch release-1.20 in repository https://gitbox.apache.org/repos/asf/flink.git
commit 76156cc9a1076754ee2db009a757881606cf3eab Author: Rui Fan <1996fan...@gmail.com> AuthorDate: Wed Aug 20 12:35:49 2025 +0200 [FLINK-38267][checkpoint] Refactor hasInputState and hasOutputState related logic in TaskStateAssignment --- .../checkpoint/StateAssignmentOperation.java | 4 +- .../runtime/checkpoint/TaskStateAssignment.java | 81 +++++++++++++++++++--- 2 files changed, 73 insertions(+), 12 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 07088d901f6..54478c8c237 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -357,7 +357,7 @@ public class StateAssignmentOperation { public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) { // FLINK-31963: We can skip this phase if there is no output state AND downstream has no // input states - if (!assignment.hasOutputState && !assignment.hasDownstreamInputStates()) { + if (!assignment.hasOutputState() && !assignment.hasDownstreamInputStates()) { return; } @@ -406,7 +406,7 @@ public class StateAssignmentOperation { public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment) { // FLINK-31963: We can skip this phase only if there is no input state AND upstream has no // output states - if (!stateAssignment.hasInputState && !stateAssignment.hasUpstreamOutputStates()) { + if (!stateAssignment.hasInputState() && !stateAssignment.hasUpstreamOutputStates()) { return; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java index c4e30204030..98f9082b064 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java @@ -20,6 +20,8 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.runtime.OperatorIDPair; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper; @@ -28,6 +30,8 @@ import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.OperatorInstanceID; import org.apache.flink.runtime.state.InputChannelStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.MergedInputChannelStateHandle; +import org.apache.flink.runtime.state.MergedResultSubpartitionStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.runtime.state.StateObject; @@ -40,6 +44,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -48,6 +53,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.IntStream; import static java.util.Collections.emptySet; @@ -65,12 +71,16 @@ class TaskStateAssignment { final Map<OperatorID, OperatorState> oldState; final boolean hasNonFinishedState; final boolean isFullyFinished; - final boolean hasInputState; - final boolean hasOutputState; final int newParallelism; final OperatorID inputOperatorID; final OperatorID outputOperatorID; + /** The InputGate set that containing input buffer state. */ + private final Set<Integer> inputStateGates; + + /** The ResultPartition set that containing input buffer state. */ + private final Set<Integer> outputStatePartitions; + final Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState; final Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState; final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState; @@ -127,12 +137,63 @@ class TaskStateAssignment { outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID(); inputOperatorID = operatorIDs.get(operatorIDs.size() - 1).getGeneratedOperatorID(); - hasInputState = - oldState.get(inputOperatorID).getStates().stream() - .anyMatch(subState -> !subState.getInputChannelState().isEmpty()); - hasOutputState = - oldState.get(outputOperatorID).getStates().stream() - .anyMatch(subState -> !subState.getResultSubpartitionState().isEmpty()); + inputStateGates = extractInputStateGates(oldState.get(inputOperatorID)); + outputStatePartitions = extractOutputStatePartitions(oldState.get(outputOperatorID)); + } + + private static Set<Integer> extractInputStateGates(OperatorState operatorState) { + return operatorState.getStates().stream() + .map(OperatorSubtaskState::getInputChannelState) + .flatMap(Collection::stream) + .flatMapToInt( + handle -> { + if (handle instanceof InputChannelStateHandle) { + return IntStream.of( + ((InputChannelStateHandle) handle).getInfo().getGateIdx()); + } else if (handle instanceof MergedInputChannelStateHandle) { + return ((MergedInputChannelStateHandle) handle) + .getInfos().stream().mapToInt(InputChannelInfo::getGateIdx); + } else { + throw new IllegalStateException( + "Invalid input channel state : " + handle.getClass()); + } + }) + .distinct() + .boxed() + .collect(Collectors.toSet()); + } + + private static Set<Integer> extractOutputStatePartitions(OperatorState operatorState) { + return operatorState.getStates().stream() + .map(OperatorSubtaskState::getResultSubpartitionState) + .flatMap(Collection::stream) + .flatMapToInt( + handle -> { + if (handle instanceof ResultSubpartitionStateHandle) { + return IntStream.of( + ((ResultSubpartitionStateHandle) handle) + .getInfo() + .getPartitionIdx()); + } else if (handle instanceof MergedResultSubpartitionStateHandle) { + return ((MergedResultSubpartitionStateHandle) handle) + .getInfos().stream() + .mapToInt(ResultSubpartitionInfo::getPartitionIdx); + } else { + throw new IllegalStateException( + "Invalid output channel state : " + handle.getClass()); + } + }) + .distinct() + .boxed() + .collect(Collectors.toSet()); + } + + public boolean hasInputState() { + return !inputStateGates.isEmpty(); + } + + public boolean hasOutputState() { + return !outputStatePartitions.isEmpty(); } public TaskStateAssignment[] getDownstreamAssignments() { @@ -210,7 +271,7 @@ class TaskStateAssignment { if (hasUpstreamOutputStates == null) { hasUpstreamOutputStates = Arrays.stream(getUpstreamAssignments()) - .anyMatch(assignment -> assignment.hasOutputState); + .anyMatch(TaskStateAssignment::hasOutputState); } return hasUpstreamOutputStates; } @@ -219,7 +280,7 @@ class TaskStateAssignment { if (hasDownstreamInputStates == null) { hasDownstreamInputStates = Arrays.stream(getDownstreamAssignments()) - .anyMatch(assignment -> assignment.hasInputState); + .anyMatch(TaskStateAssignment::hasInputState); } return hasDownstreamInputStates; }