This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 354c0f455b9 [FLINK-31963][state] Fix rescaling bug in recovery from 
unaligned checkpoints. (#22584)
354c0f455b9 is described below

commit 354c0f455b92c083299d8028f161f0dd113ab614
Author: Stefan Richter <srich...@apache.org>
AuthorDate: Tue May 16 13:06:05 2023 +0200

    [FLINK-31963][state] Fix rescaling bug in recovery from unaligned 
checkpoints. (#22584)
    
    This commit fixes problems in StateAssignmentOperation for unaligned 
checkpoints with stateless operators that have upstream operators with output 
partition state or downstream operators with input channel state.
---
 .../checkpoint/StateAssignmentOperation.java       |  28 ++--
 .../runtime/checkpoint/TaskStateAssignment.java    |  19 ++-
 .../checkpoint/StateAssignmentOperationTest.java   | 178 ++++++++++++++++++++-
 .../checkpointing/UnalignedCheckpointITCase.java   |  18 ++-
 .../UnalignedCheckpointRescaleITCase.java          | 137 ++++++++++------
 .../checkpointing/UnalignedCheckpointTestBase.java |  32 +++-
 6 files changed, 335 insertions(+), 77 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 681e0b18df1..e476c6b65ec 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -136,19 +136,24 @@ public class StateAssignmentOperation {
 
         // repartition state
         for (TaskStateAssignment stateAssignment : vertexAssignments.values()) 
{
-            if (stateAssignment.hasNonFinishedState) {
+            if (stateAssignment.hasNonFinishedState
+                    // FLINK-31963: We need to run repartitioning for 
stateless operators that have
+                    // upstream output or downstream input states.
+                    || stateAssignment.hasUpstreamOutputStates()
+                    || stateAssignment.hasDownstreamInputStates()) {
                 assignAttemptState(stateAssignment);
             }
         }
 
         // actually assign the state
         for (TaskStateAssignment stateAssignment : vertexAssignments.values()) 
{
-            // If upstream has output states, even the empty task state should 
be assigned for the
-            // current task in order to notify this task that the old states 
will send to it which
-            // likely should be filtered.
+            // If upstream has output states or downstream has input states, 
even the empty task
+            // state should be assigned for the current task in order to 
notify this task that the
+            // old states will send to it which likely should be filtered.
             if (stateAssignment.hasNonFinishedState
                     || stateAssignment.isFullyFinished
-                    || stateAssignment.hasUpstreamOutputStates()) {
+                    || stateAssignment.hasUpstreamOutputStates()
+                    || stateAssignment.hasDownstreamInputStates()) {
                 assignTaskStateToExecutionJobVertices(stateAssignment);
             }
         }
@@ -345,9 +350,10 @@ public class StateAssignmentOperation {
                                         newParallelism)));
     }
 
-    public <I, T extends AbstractChannelStateHandle<I>> void 
reDistributeResultSubpartitionStates(
-            TaskStateAssignment assignment) {
-        if (!assignment.hasOutputState) {
+    public void reDistributeResultSubpartitionStates(TaskStateAssignment 
assignment) {
+        // FLINK-31963: We can skip this phase if there is no output state AND 
downstream has no
+        // input states
+        if (!assignment.hasOutputState && 
!assignment.hasDownstreamInputStates()) {
             return;
         }
 
@@ -394,7 +400,9 @@ public class StateAssignmentOperation {
     }
 
     public void reDistributeInputChannelStates(TaskStateAssignment 
stateAssignment) {
-        if (!stateAssignment.hasInputState) {
+        // FLINK-31963: We can skip this phase only if there is no input state 
AND upstream has no
+        // output states
+        if (!stateAssignment.hasInputState && 
!stateAssignment.hasUpstreamOutputStates()) {
             return;
         }
 
@@ -435,7 +443,7 @@ public class StateAssignmentOperation {
                             : getPartitionState(
                                     inputOperatorState, 
InputChannelInfo::getGateIdx, gateIndex);
             final MappingBasedRepartitioner<InputChannelStateHandle> 
repartitioner =
-                    new MappingBasedRepartitioner(mapping);
+                    new MappingBasedRepartitioner<>(mapping);
             final Map<OperatorInstanceID, List<InputChannelStateHandle>> 
repartitioned =
                     applyRepartitioner(
                             stateAssignment.inputOperatorID,
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
index 75ffc71d058..e9f9d11421e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
@@ -84,6 +84,8 @@ class TaskStateAssignment {
 
     @Nullable private TaskStateAssignment[] downstreamAssignments;
     @Nullable private TaskStateAssignment[] upstreamAssignments;
+    @Nullable private Boolean hasUpstreamOutputStates;
+    @Nullable private Boolean hasDownstreamInputStates;
 
     private final Map<IntermediateDataSetID, TaskStateAssignment> 
consumerAssignment;
     private final Map<ExecutionJobVertex, TaskStateAssignment> 
vertexAssignments;
@@ -202,8 +204,21 @@ class TaskStateAssignment {
     }
 
     public boolean hasUpstreamOutputStates() {
-        return Arrays.stream(getUpstreamAssignments())
-                .anyMatch(assignment -> assignment.hasOutputState);
+        if (hasUpstreamOutputStates == null) {
+            hasUpstreamOutputStates =
+                    Arrays.stream(getUpstreamAssignments())
+                            .anyMatch(assignment -> assignment.hasOutputState);
+        }
+        return hasUpstreamOutputStates;
+    }
+
+    public boolean hasDownstreamInputStates() {
+        if (hasDownstreamInputStates == null) {
+            hasDownstreamInputStates =
+                    Arrays.stream(getDownstreamAssignments())
+                            .anyMatch(assignment -> assignment.hasInputState);
+        }
+        return hasDownstreamInputStates;
     }
 
     private InflightDataGateOrPartitionRescalingDescriptor log(
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
index f9cb551c5ad..bffdd8686ae 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.OperatorIDPair;
 import 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
 import org.apache.flink.runtime.client.JobExecutionException;
+import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
@@ -51,6 +52,9 @@ import org.junit.Assert;
 import org.junit.ClassRule;
 import org.junit.Test;
 
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.EnumMap;
@@ -82,6 +86,7 @@ import static 
org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNew
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ARBITRARY;
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.RANGE;
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN;
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -785,6 +790,129 @@ public class StateAssignmentOperationTest extends 
TestLogger {
         }
     }
 
+    /** FLINK-31963: Tests rescaling for stateless operators and upstream 
result partition state. */
+    @Test
+    public void testOnlyUpstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState upstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setResultSubpartitionState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewResultSubpartitionStateHandle(10, random),
+                                                
createNewResultSubpartitionStateHandle(
+                                                        10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 
5, 7);
+    }
+
+    /** FLINK-31963: Tests rescaling for stateless operators and downstream 
input channel state. */
+    @Test
+    public void testOnlyDownstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState downstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setInputChannelState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewInputChannelStateHandle(10, random),
+                                                
createNewInputChannelStateHandle(10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 
5, 5);
+    }
+
+    private void testOnlyUpstreamOrDownstreamRescalingInternal(
+            @Nullable OperatorSubtaskState upstreamOpState,
+            @Nullable OperatorSubtaskState downstreamOpState,
+            int expectedUpstreamCount,
+            int expectedDownstreamCount)
+            throws JobException, JobExecutionException {
+
+        checkArgument(
+                upstreamOpState != downstreamOpState
+                        && (upstreamOpState == null || downstreamOpState == 
null),
+                "Either upstream or downstream state must exist, but not 
both");
+
+        // Start from parallelism 5 for both operators
+        int upstreamParallelism = 5;
+        int downstreamParallelism = 5;
+
+        // Build states
+        List<OperatorID> operatorIds = buildOperatorIds(2);
+        Map<OperatorID, OperatorState> states = new HashMap<>();
+        OperatorState upstreamState =
+                new OperatorState(operatorIds.get(0), upstreamParallelism, 
MAX_P);
+        OperatorState downstreamState =
+                new OperatorState(operatorIds.get(1), downstreamParallelism, 
MAX_P);
+
+        states.put(operatorIds.get(0), upstreamState);
+        states.put(operatorIds.get(1), downstreamState);
+
+        if (upstreamOpState != null) {
+            upstreamState.putState(0, upstreamOpState);
+            // rescale downstream 5 -> 3
+            downstreamParallelism = 3;
+        }
+
+        if (downstreamOpState != null) {
+            downstreamState.putState(0, downstreamOpState);
+            // rescale upstream 5 -> 3
+            upstreamParallelism = 3;
+        }
+
+        List<OperatorIdWithParallelism> opIdWithParallelism = new 
ArrayList<>(2);
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(0), 
upstreamParallelism));
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(1), 
downstreamParallelism));
+
+        Map<OperatorID, ExecutionJobVertex> vertices =
+                buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);
+
+        // Run state assignment
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+                .assignStates();
+
+        // Check results
+        ExecutionJobVertex upstreamExecutionJobVertex = 
vertices.get(operatorIds.get(0));
+        ExecutionJobVertex downstreamExecutionJobVertex = 
vertices.get(operatorIds.get(1));
+
+        List<TaskStateSnapshot> upstreamTaskStateSnapshots =
+                getTaskStateSnapshotFromVertex(upstreamExecutionJobVertex);
+        List<TaskStateSnapshot> downstreamTaskStateSnapshots =
+                getTaskStateSnapshotFromVertex(downstreamExecutionJobVertex);
+
+        checkMappings(
+                upstreamTaskStateSnapshots,
+                TaskStateSnapshot::getOutputRescalingDescriptor,
+                expectedUpstreamCount);
+
+        checkMappings(
+                downstreamTaskStateSnapshots,
+                TaskStateSnapshot::getInputRescalingDescriptor,
+                expectedDownstreamCount);
+    }
+
+    private void checkMappings(
+            List<TaskStateSnapshot> taskStateSnapshots,
+            Function<TaskStateSnapshot, InflightDataRescalingDescriptor> 
extractFun,
+            int expectedCount) {
+        Assert.assertEquals(
+                expectedCount,
+                taskStateSnapshots.stream()
+                        .map(extractFun)
+                        .mapToInt(
+                                x -> {
+                                    int len = x.getOldSubtaskIndexes(0).length;
+                                    // Assert that there is a mapping.
+                                    Assert.assertTrue(len > 0);
+                                    return len;
+                                })
+                        .sum());
+    }
+
     @Test
     public void testStateWithFullyFinishedOperators() throws JobException, 
JobExecutionException {
         List<OperatorID> operatorIds = buildOperatorIds(2);
@@ -949,15 +1077,50 @@ public class StateAssignmentOperationTest extends 
TestLogger {
                                 }));
     }
 
+    private static class OperatorIdWithParallelism {
+        private final OperatorID operatorID;
+        private final int parallelism;
+
+        public OperatorID getOperatorID() {
+            return operatorID;
+        }
+
+        public int getParallelism() {
+            return parallelism;
+        }
+
+        public OperatorIdWithParallelism(OperatorID operatorID, int 
parallelism) {
+            this.operatorID = operatorID;
+            this.parallelism = parallelism;
+        }
+    }
+
     private Map<OperatorID, ExecutionJobVertex> buildVertices(
             List<OperatorID> operatorIds,
-            int parallelism,
+            int parallelisms,
             SubtaskStateMapper downstreamRescaler,
             SubtaskStateMapper upstreamRescaler)
             throws JobException, JobExecutionException {
-        final JobVertex[] jobVertices =
+        List<OperatorIdWithParallelism> opIdsWithParallelism =
                 operatorIds.stream()
-                        .map(id -> createJobVertex(id, id, parallelism))
+                        .map(operatorID -> new 
OperatorIdWithParallelism(operatorID, parallelisms))
+                        .collect(Collectors.toList());
+        return buildVertices(opIdsWithParallelism, downstreamRescaler, 
upstreamRescaler);
+    }
+
+    private Map<OperatorID, ExecutionJobVertex> buildVertices(
+            List<OperatorIdWithParallelism> operatorIdsAndParallelism,
+            SubtaskStateMapper downstreamRescaler,
+            SubtaskStateMapper upstreamRescaler)
+            throws JobException, JobExecutionException {
+        final JobVertex[] jobVertices =
+                operatorIdsAndParallelism.stream()
+                        .map(
+                                idWithParallelism ->
+                                        createJobVertex(
+                                                
idWithParallelism.getOperatorID(),
+                                                
idWithParallelism.getOperatorID(),
+                                                
idWithParallelism.getParallelism()))
                         .toArray(JobVertex[]::new);
         for (int index = 1; index < jobVertices.length; index++) {
             connectVertices(
@@ -1029,6 +1192,15 @@ public class StateAssignmentOperationTest extends 
TestLogger {
         return jobVertex;
     }
 
+    private List<TaskStateSnapshot> getTaskStateSnapshotFromVertex(
+            ExecutionJobVertex executionJobVertex) {
+        return Arrays.stream(executionJobVertex.getTaskVertices())
+                .map(ExecutionVertex::getCurrentExecutionAttempt)
+                .map(Execution::getTaskRestore)
+                .map(JobManagerTaskRestore::getTaskStateSnapshot)
+                .collect(Collectors.toList());
+    }
+
     private OperatorSubtaskState getAssignedState(
             ExecutionJobVertex executionJobVertex, OperatorID operatorId, int 
subtaskIdx) {
         return executionJobVertex
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
index 5a6efd174e1..dc1b21f07b4 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
@@ -113,7 +113,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 final SingleOutputStreamOperator<Long> stream =
                         env.fromSource(
@@ -121,7 +122,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                                                 minCheckpoints,
                                                 parallelism,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source")
                                 .slotSharingGroup(slotSharing ? "default" : 
"source")
@@ -144,7 +146,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
                 for (int inputIndex = 0; inputIndex < NUM_SOURCES; 
inputIndex++) {
@@ -154,7 +157,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                                                     minCheckpoints,
                                                     parallelism,
                                                     expectedRestarts,
-                                                    
env.getCheckpointInterval()),
+                                                    
env.getCheckpointInterval(),
+                                                    sourceSleepMs),
                                             noWatermarks(),
                                             "source" + inputIndex)
                                     .slotSharingGroup(
@@ -182,7 +186,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 DataStream<Tuple2<Integer, Long>> combinedSource = null;
                 for (int inputIndex = 0; inputIndex < NUM_SOURCES; 
inputIndex++) {
@@ -193,7 +198,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                                                     minCheckpoints,
                                                     parallelism,
                                                     expectedRestarts,
-                                                    
env.getCheckpointInterval()),
+                                                    
env.getCheckpointInterval(),
+                                                    sourceSleepMs),
                                             noWatermarks(),
                                             "source" + inputIndex)
                                     .slotSharingGroup(
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
index 4216cc5469b..1f5d4c73afc 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
@@ -68,6 +68,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
     private final int oldParallelism;
     private final int newParallelism;
     private final int buffersPerChannel;
+    private final long sourceSleepMs;
 
     enum Topology implements DagCreator {
         PIPELINE {
@@ -76,7 +77,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMillis) {
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> source =
                         createSourcePipeline(
@@ -86,6 +88,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                 expectedRestarts,
                                 parallelism,
                                 0,
+                                sourceSleepMillis,
                                 val -> true);
                 addFailingSink(source, minCheckpoints, slotSharing);
             }
@@ -97,7 +100,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
@@ -111,6 +115,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                     expectedRestarts,
                                     parallelism,
                                     inputIndex,
+                                    sourceSleepMs,
                                     val -> withoutHeader(val) % NUM_SOURCES == 
finalInputIndex);
                     combinedSource =
                             combinedSource == null
@@ -134,10 +139,10 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
-                checkState(parallelism >= 4);
                 final DataStream<Long> source1 =
                         createSourcePipeline(
                                 env,
@@ -146,6 +151,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                 expectedRestarts,
                                 parallelism / 2,
                                 0,
+                                sourceSleepMs,
                                 val -> withoutHeader(val) % 2 == 0);
                 final DataStream<Long> source2 =
                         createSourcePipeline(
@@ -155,6 +161,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                 expectedRestarts,
                                 parallelism / 3,
                                 1,
+                                sourceSleepMs,
                                 val -> withoutHeader(val) % 2 == 1);
 
                 KeySelector<Long, Long> keySelector = i -> withoutHeader(i) % 
NUM_GROUPS;
@@ -174,7 +181,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
@@ -188,6 +196,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                     expectedRestarts,
                                     parallelism,
                                     inputIndex,
+                                    sourceSleepMs,
                                     val -> withoutHeader(val) % NUM_SOURCES == 
finalInputIndex);
                     combinedSource = combinedSource == null ? source : 
combinedSource.union(source);
                 }
@@ -202,7 +211,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> broadcastSide =
@@ -211,7 +221,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                         minCheckpoints,
                                         parallelism,
                                         expectedRestarts,
-                                        env.getCheckpointInterval()),
+                                        env.getCheckpointInterval(),
+                                        sourceSleepMs),
                                 noWatermarks(),
                                 "source");
                 final DataStream<Long> source =
@@ -222,6 +233,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                         expectedRestarts,
                                         parallelism,
                                         0,
+                                        sourceSleepMs,
                                         val -> true)
                                 .map(i -> checkHeader(i))
                                 .name("map")
@@ -249,7 +261,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> broadcastSide1 =
@@ -258,7 +271,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-1")
                                 .setParallelism(1);
@@ -268,7 +282,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-2")
                                 .setParallelism(1);
@@ -278,7 +293,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-3")
                                 .setParallelism(1);
@@ -290,6 +306,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                         expectedRestarts,
                                         parallelism,
                                         0,
+                                        sourceSleepMs,
                                         val -> true)
                                 .map(i -> checkHeader(i))
                                 .name("map")
@@ -349,13 +366,15 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                 int expectedRestarts,
                 int parallelism,
                 int inputIndex,
+                long sourceSleepMs,
                 FilterFunction<Long> sourceFilter) {
             return env.fromSource(
                             new LongSource(
                                     minCheckpoints,
                                     parallelism,
                                     expectedRestarts,
-                                    env.getCheckpointInterval()),
+                                    env.getCheckpointInterval(),
+                                    sourceSleepMs),
                             noWatermarks(),
                             "source" + inputIndex)
                     .uid("source" + inputIndex)
@@ -459,46 +478,61 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
         }
     }
 
-    @Parameterized.Parameters(name = "{0} {1} from {2} to {3}, 
buffersPerChannel = {4}")
+    @Parameterized.Parameters(
+            name = "{0} {1} from {2} to {3}, sourceSleepMs = {4}, 
buffersPerChannel = {5}")
     public static Object[][] getScaleFactors() {
+        // We use `sourceSleepMs` > 0 to test rescaling without backpressure 
and only very few
+        // captured in-flight records, see FLINK-31963.
         Object[][] parameters =
                 new Object[][] {
-                    new Object[] {"downscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7},
-                    new Object[] {"upscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12},
-                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 7, 2},
-                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 2, 7},
-                    new Object[] {"downscale", Topology.BROADCAST, 5, 2},
-                    new Object[] {"upscale", Topology.BROADCAST, 2, 5},
-                    new Object[] {"upscale", Topology.PIPELINE, 1, 2},
-                    new Object[] {"upscale", Topology.PIPELINE, 2, 3},
-                    new Object[] {"upscale", Topology.PIPELINE, 3, 7},
-                    new Object[] {"upscale", Topology.PIPELINE, 4, 8},
-                    new Object[] {"upscale", Topology.PIPELINE, 20, 21},
-                    new Object[] {"downscale", Topology.PIPELINE, 2, 1},
-                    new Object[] {"downscale", Topology.PIPELINE, 3, 2},
-                    new Object[] {"downscale", Topology.PIPELINE, 7, 3},
-                    new Object[] {"downscale", Topology.PIPELINE, 8, 4},
-                    new Object[] {"downscale", Topology.PIPELINE, 21, 20},
-                    new Object[] {"no scale", Topology.PIPELINE, 1, 1},
-                    new Object[] {"no scale", Topology.PIPELINE, 3, 3},
-                    new Object[] {"no scale", Topology.PIPELINE, 7, 7},
-                    new Object[] {"no scale", Topology.PIPELINE, 20, 20},
-                    new Object[] {"upscale", Topology.UNION, 1, 2},
-                    new Object[] {"upscale", Topology.UNION, 2, 3},
-                    new Object[] {"upscale", Topology.UNION, 3, 7},
-                    new Object[] {"downscale", Topology.UNION, 2, 1},
-                    new Object[] {"downscale", Topology.UNION, 3, 2},
-                    new Object[] {"downscale", Topology.UNION, 7, 3},
-                    new Object[] {"no scale", Topology.UNION, 1, 1},
-                    new Object[] {"no scale", Topology.UNION, 7, 7},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 1, 2},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 2, 3},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 7},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 2, 1},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 3, 2},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 7, 3},
-                    new Object[] {"no scale", Topology.MULTI_INPUT, 1, 1},
-                    new Object[] {"no scale", Topology.MULTI_INPUT, 7, 7},
+                    new Object[] {"downscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7, 0L},
+                    new Object[] {"upscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12, 0L},
+                    new Object[] {"downscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 5, 3, 5L},
+                    new Object[] {"upscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 7, 2, 
0L},
+                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 2, 7, 
0L},
+                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 5, 3, 
5L},
+                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 3, 5, 
5L},
+                    new Object[] {"downscale", Topology.BROADCAST, 5, 2, 0L},
+                    new Object[] {"upscale", Topology.BROADCAST, 2, 5, 0L},
+                    new Object[] {"downscale", Topology.BROADCAST, 5, 3, 5L},
+                    new Object[] {"upscale", Topology.BROADCAST, 3, 5, 5L},
+                    new Object[] {"upscale", Topology.PIPELINE, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 4, 8, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 20, 21, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.PIPELINE, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 8, 4, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 21, 20, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.PIPELINE, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 3, 3, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 7, 7, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 20, 20, 0L},
+                    new Object[] {"upscale", Topology.UNION, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.UNION, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.UNION, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.UNION, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.UNION, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.UNION, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.UNION, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.UNION, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.UNION, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.UNION, 7, 7, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.MULTI_INPUT, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.MULTI_INPUT, 7, 7, 0L},
                 };
         return Arrays.stream(parameters)
                 .map(
@@ -516,10 +550,12 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
             Topology topology,
             int oldParallelism,
             int newParallelism,
+            long sourceSleepMs,
             int buffersPerChannel) {
         this.topology = topology;
         this.oldParallelism = oldParallelism;
         this.newParallelism = newParallelism;
+        this.sourceSleepMs = sourceSleepMs;
         this.buffersPerChannel = buffersPerChannel;
     }
 
@@ -529,7 +565,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                 new UnalignedSettings(topology)
                         .setParallelism(oldParallelism)
                         .setExpectedFailures(1)
-                        .setBuffersPerChannel(buffersPerChannel);
+                        .setBuffersPerChannel(buffersPerChannel)
+                        .setSourceSleepMs(sourceSleepMs);
         prescaleSettings.setGenerateCheckpoint(true);
         final File checkpointDir = super.execute(prescaleSettings);
 
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
index 3420cf6f090..7ae3ad63146 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
@@ -207,7 +207,8 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
                 setupEnv,
                 settings.minCheckpoints,
                 settings.channelType.slotSharing,
-                settings.expectedFailures - 
settings.failuresAfterSourceFinishes);
+                settings.expectedFailures - 
settings.failuresAfterSourceFinishes,
+                settings.sourceSleepMs);
 
         return setupEnv.getStreamGraph();
     }
@@ -221,16 +222,19 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
         private final int numSplits;
         private final int expectedRestarts;
         private final long checkpointingInterval;
+        private final long sourceSleepMs;
 
         protected LongSource(
                 int minCheckpoints,
                 int numSplits,
                 int expectedRestarts,
-                long checkpointingInterval) {
+                long checkpointingInterval,
+                long sourceSleepMs) {
             this.minCheckpoints = minCheckpoints;
             this.numSplits = numSplits;
             this.expectedRestarts = expectedRestarts;
             this.checkpointingInterval = checkpointingInterval;
+            this.sourceSleepMs = sourceSleepMs;
         }
 
         @Override
@@ -244,7 +248,8 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
                     readerContext.getIndexOfSubtask(),
                     minCheckpoints,
                     expectedRestarts,
-                    checkpointingInterval);
+                    checkpointingInterval,
+                    sourceSleepMs);
         }
 
         @Override
@@ -285,17 +290,20 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
             private int numCompletedCheckpoints;
             private boolean finishing;
             private boolean recovered;
+            private final long sourceSleepMs;
             @Nullable private Deadline pumpingUntil = null;
 
             public LongSourceReader(
                     int subtaskIndex,
                     int minCheckpoints,
                     int expectedRestarts,
-                    long checkpointingInterval) {
+                    long checkpointingInterval,
+                    long sourceSleepMs) {
                 this.subtaskIndex = subtaskIndex;
                 this.minCheckpoints = minCheckpoints;
                 this.expectedRestarts = expectedRestarts;
-                pumpInterval = Duration.ofMillis(checkpointingInterval);
+                this.pumpInterval = Duration.ofMillis(checkpointingInterval);
+                this.sourceSleepMs = sourceSleepMs;
             }
 
             @Override
@@ -304,6 +312,9 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
             @Override
             public InputStatus pollNext(ReaderOutput<Long> output) throws 
InterruptedException {
                 for (LongSplit split : splits) {
+                    if (sourceSleepMs > 0L) {
+                        Thread.sleep(sourceSleepMs);
+                    }
                     output.collect(withHeader(split.nextNumber), 
split.nextNumber);
                     split.nextNumber += split.increment;
                 }
@@ -627,7 +638,8 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
                 StreamExecutionEnvironment environment,
                 int minCheckpoints,
                 boolean slotSharing,
-                int expectedFailuresUntilSourceFinishes);
+                int expectedFailuresUntilSourceFinishes,
+                long sourceSleepMs);
     }
 
     /** Which channels are used to connect the tasks. */
@@ -664,6 +676,7 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
         private int failuresAfterSourceFinishes = 0;
         private ChannelType channelType = ChannelType.MIXED;
         private int buffersPerChannel = 1;
+        private long sourceSleepMs = 0;
 
         public UnalignedSettings(DagCreator dagCreator) {
             this.dagCreator = dagCreator;
@@ -719,6 +732,11 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
             return this;
         }
 
+        public UnalignedSettings setSourceSleepMs(long sourceSleepMs) {
+            this.sourceSleepMs = sourceSleepMs;
+            return this;
+        }
+
         public void configure(StreamExecutionEnvironment env) {
             env.enableCheckpointing(Math.max(100L, parallelism * 50L));
             
env.getCheckpointConfig().setAlignmentTimeout(Duration.ofMillis(alignmentTimeout));
@@ -791,6 +809,8 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
                     + failuresAfterSourceFinishes
                     + ", channelType="
                     + channelType
+                    + ", sourceSleepMs="
+                    + sourceSleepMs
                     + '}';
         }
     }


Reply via email to