This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 691ce18e4843d29bc502483beb61d466ed9c7031
Author: Dmitriy Linevich <[email protected]>
AuthorDate: Tue May 14 16:30:03 2024 +0700

    [FLINK-35351][checkpoint] Fix fail during restore from unaligned checkpoint 
with custom partitioner
    
    Co-authored-by:  Andrey Gaskov <[email protected]>
---
 .../checkpoint/StateAssignmentOperation.java       | 28 +++++-
 .../checkpoint/StateAssignmentOperationTest.java   | 39 +++++++++
 .../runtime/checkpoint/StateHandleDummyUtil.java   |  2 +-
 .../UnalignedCheckpointRescaleITCase.java          | 99 +++++++++++++++++++++-
 .../checkpointing/UnalignedCheckpointTestBase.java |  2 +-
 5 files changed, 165 insertions(+), 5 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 3517277c255..07088d901f6 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -27,8 +27,10 @@ import 
org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
+import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.JobEdge;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
 import org.apache.flink.runtime.state.AbstractChannelStateHandle;
@@ -421,7 +423,31 @@ public class StateAssignmentOperation {
                 stateAssignment.oldState.get(stateAssignment.inputOperatorID);
         final List<List<InputChannelStateHandle>> inputOperatorState =
                 splitBySubtasks(inputState, 
OperatorSubtaskState::getInputChannelState);
-        if (inputState.getParallelism() == 
executionJobVertex.getParallelism()) {
+
+        boolean hasAnyFullMapper =
+                executionJobVertex.getJobVertex().getInputs().stream()
+                        .map(JobEdge::getDownstreamSubtaskStateMapper)
+                        .anyMatch(m -> m.equals(SubtaskStateMapper.FULL));
+        boolean hasAnyPreviousOperatorChanged =
+                executionJobVertex.getInputs().stream()
+                        .map(IntermediateResult::getProducer)
+                        .map(vertexAssignments::get)
+                        .anyMatch(
+                                taskStateAssignment -> {
+                                    final int oldParallelism =
+                                            stateAssignment
+                                                    .oldState
+                                                    
.get(stateAssignment.inputOperatorID)
+                                                    .getParallelism();
+                                    return oldParallelism
+                                            != 
taskStateAssignment.executionJobVertex
+                                                    .getParallelism();
+                                });
+
+        // need rescale if any input operator parallelism was changed and have 
any input with FULL
+        // subtask state mapper
+        if (inputState.getParallelism() == executionJobVertex.getParallelism()
+                && !(hasAnyFullMapper && hasAnyPreviousOperatorChanged)) {
             stateAssignment.inputChannelStates.putAll(
                     toInstanceMap(stateAssignment.inputOperatorID, 
inputOperatorState));
             return;
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
index 95dd2555da6..e1a3e399c82 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
@@ -82,6 +82,7 @@ import static 
org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNew
 import static 
org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewOperatorStateHandle;
 import static 
org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewResultSubpartitionStateHandle;
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ARBITRARY;
+import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.FULL;
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.RANGE;
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN;
 import static org.apache.flink.util.Preconditions.checkArgument;
@@ -561,6 +562,44 @@ class StateAssignmentOperationTest {
                 oldIndices, rescaleMapping, ambiguousSubtaskIndexes, 
mappingType);
     }
 
+    @Test
+    void testChannelStateAssignmentTwoGatesPartiallyDownscaling()
+            throws JobException, JobExecutionException {
+        JobVertex upstream1 = createJobVertex(new OperatorID(), 2);
+        JobVertex upstream2 = createJobVertex(new OperatorID(), 2);
+        JobVertex downstream = createJobVertex(new OperatorID(), 3);
+        List<OperatorID> operatorIds =
+                Stream.of(upstream1, upstream2, downstream)
+                        .map(v -> 
v.getOperatorIDs().get(0).getGeneratedOperatorID())
+                        .collect(Collectors.toList());
+        Map<OperatorID, OperatorState> states = 
buildOperatorStates(operatorIds, 3);
+
+        connectVertices(upstream1, downstream, ARBITRARY, FULL);
+        connectVertices(upstream2, downstream, ROUND_ROBIN, ROUND_ROBIN);
+
+        Map<OperatorID, ExecutionJobVertex> vertices =
+                toExecutionVertices(upstream1, upstream2, downstream);
+
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+                .assignStates();
+
+        assertThat(
+                        getAssignedState(vertices.get(operatorIds.get(2)), 
operatorIds.get(2), 0)
+                                .getInputChannelState()
+                                .size())
+                .isEqualTo(6);
+        assertThat(
+                        getAssignedState(vertices.get(operatorIds.get(2)), 
operatorIds.get(2), 1)
+                                .getInputChannelState()
+                                .size())
+                .isEqualTo(6);
+        assertThat(
+                        getAssignedState(vertices.get(operatorIds.get(2)), 
operatorIds.get(2), 2)
+                                .getInputChannelState()
+                                .size())
+                .isEqualTo(6);
+    }
+
     @Test
     void testChannelStateAssignmentDownscaling() throws JobException, 
JobExecutionException {
         List<OperatorID> operatorIds = buildOperatorIds(2);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
index 52a8bf032b6..60ec98e7e0f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
@@ -145,7 +145,7 @@ public class StateHandleDummyUtil {
     public static InputChannelStateHandle createNewInputChannelStateHandle(
             int numNamedStates, Random random) {
         return new InputChannelStateHandle(
-                new InputChannelInfo(random.nextInt(), random.nextInt()),
+                new InputChannelInfo(0, random.nextInt()),
                 createStreamStateHandle(numNamedStates, random),
                 genOffsets(numNamedStates, random));
     }
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
index fa5b61a8910..bb217b5a8cd 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobExecutionResult;
 import org.apache.flink.api.common.accumulators.LongCounter;
 import org.apache.flink.api.common.functions.FilterFunction;
 import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.functions.Partitioner;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
@@ -44,6 +45,7 @@ import 
org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
 import org.apache.flink.streaming.api.functions.co.CoMapFunction;
 import 
org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
 import org.apache.flink.streaming.api.functions.co.KeyedCoProcessFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.util.Collector;
 
 import org.apache.commons.lang3.ArrayUtils;
@@ -52,8 +54,10 @@ import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
 import java.io.File;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.BitSet;
+import java.util.Collection;
 import java.util.Collections;
 
 import static 
org.apache.flink.api.common.eventtime.WatermarkStrategy.noWatermarks;
@@ -329,6 +333,44 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
 
                 addFailingSink(joined, minCheckpoints, slotSharing);
             }
+        },
+        CUSTOM_PARTITIONER {
+            final int sinkParallelism = 3;
+            final int numberElements = 1000;
+
+            @Override
+            public void create(
+                    StreamExecutionEnvironment environment,
+                    int minCheckpoints,
+                    boolean slotSharing,
+                    int expectedFailuresUntilSourceFinishes,
+                    long sourceSleepMs) {
+                int parallelism = environment.getParallelism();
+                environment
+                        .fromData(generateStrings(numberElements / 
parallelism, sinkParallelism))
+                        .name("source")
+                        .setParallelism(parallelism)
+                        .partitionCustom(new StringPartitioner(), str -> 
str.split(" ")[0])
+                        .addSink(new StringSink(numberElements / 
sinkParallelism))
+                        .name("sink")
+                        .setParallelism(sinkParallelism);
+            }
+
+            private Collection<String> generateStrings(
+                    int producePerPartition, int partitionCount) {
+                Collection<String> list = new ArrayList<>();
+                for (int i = 0; i < producePerPartition; i++) {
+                    for (int partition = 0; partition < partitionCount; 
partition++) {
+                        list.add(buildString(partition, i));
+                    }
+                }
+                return list;
+            }
+
+            private String buildString(int partition, int index) {
+                String longStr = new String(new char[3713]).replace('\0', 
'\uFFFF');
+                return partition + " " + index + " " + longStr;
+            }
         };
 
         void addFailingSink(
@@ -485,6 +527,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
         // captured in-flight records, see FLINK-31963.
         Object[][] parameters =
                 new Object[][] {
+                    new Object[] {"downscale", Topology.CUSTOM_PARTITIONER, 3, 
2, 0L},
                     new Object[] {"downscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7, 0L},
                     new Object[] {"upscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12, 0L},
                     new Object[] {"downscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 5, 3, 5L},
@@ -561,6 +604,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
 
     @Test
     public void shouldRescaleUnalignedCheckpoint() throws Exception {
+        StringSink.failed = false;
         final UnalignedSettings prescaleSettings =
                 new UnalignedSettings(topology)
                         .setParallelism(oldParallelism)
@@ -585,8 +629,12 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                 "NUM_OUTPUTS = NUM_INPUTS",
                 result.<Long>getAccumulatorResult(NUM_OUTPUTS),
                 equalTo(result.getAccumulatorResult(NUM_INPUTS)));
-        collector.checkThat(
-                "NUM_DUPLICATES", 
result.<Long>getAccumulatorResult(NUM_DUPLICATES), equalTo(0L));
+        if (!topology.equals(Topology.CUSTOM_PARTITIONER)) {
+            collector.checkThat(
+                    "NUM_DUPLICATES",
+                    result.<Long>getAccumulatorResult(NUM_DUPLICATES),
+                    equalTo(0L));
+        }
     }
 
     /**
@@ -705,4 +753,51 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
             return checkHeader(value);
         }
     }
+
+    private static class StringPartitioner implements Partitioner<String> {
+        @Override
+        public int partition(String key, int numPartitions) {
+            return Integer.parseInt(key) % numPartitions;
+        }
+    }
+
+    private static class StringSink implements SinkFunction<String>, 
CheckpointedFunction {
+
+        static volatile boolean failed = false;
+
+        int checkpointConsumed = 0;
+
+        int recordsConsumed = 0;
+
+        final int numberElements;
+
+        public StringSink(int numberElements) {
+            this.numberElements = numberElements;
+        }
+
+        @Override
+        public void invoke(String value, Context ctx) throws Exception {
+            if (!failed && checkpointConsumed > 1) {
+                failed = true;
+                throw new TestException("FAIL");
+            }
+            recordsConsumed++;
+            if (!failed && recordsConsumed == (numberElements / 3)) {
+                Thread.sleep(1000);
+            }
+            if (recordsConsumed == (numberElements - 1)) {
+                Thread.sleep(1000);
+            }
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext context) {
+            checkpointConsumed++;
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) {
+            // do  nothing
+        }
+    }
 }
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
index 188bc9fdbc3..a936f2171b7 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
@@ -1133,7 +1133,7 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
         return value;
     }
 
-    private static class TestException extends Exception {
+    static class TestException extends Exception {
         public TestException(String s) {
             super(s);
         }

Reply via email to