This is an automated email from the ASF dual-hosted git repository. pnowojski pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 691ce18e4843d29bc502483beb61d466ed9c7031 Author: Dmitriy Linevich <[email protected]> AuthorDate: Tue May 14 16:30:03 2024 +0700 [FLINK-35351][checkpoint] Fix fail during restore from unaligned checkpoint with custom partitioner Co-authored-by: Andrey Gaskov <[email protected]> --- .../checkpoint/StateAssignmentOperation.java | 28 +++++- .../checkpoint/StateAssignmentOperationTest.java | 39 +++++++++ .../runtime/checkpoint/StateHandleDummyUtil.java | 2 +- .../UnalignedCheckpointRescaleITCase.java | 99 +++++++++++++++++++++- .../checkpointing/UnalignedCheckpointTestBase.java | 2 +- 5 files changed, 165 insertions(+), 5 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 3517277c255..07088d901f6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -27,8 +27,10 @@ import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo; import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.IntermediateResult; +import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper; import org.apache.flink.runtime.jobgraph.IntermediateDataSet; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.OperatorInstanceID; import org.apache.flink.runtime.state.AbstractChannelStateHandle; @@ -421,7 +423,31 @@ public class StateAssignmentOperation { stateAssignment.oldState.get(stateAssignment.inputOperatorID); final List<List<InputChannelStateHandle>> inputOperatorState = splitBySubtasks(inputState, OperatorSubtaskState::getInputChannelState); - if (inputState.getParallelism() == executionJobVertex.getParallelism()) { + + boolean hasAnyFullMapper = + executionJobVertex.getJobVertex().getInputs().stream() + .map(JobEdge::getDownstreamSubtaskStateMapper) + .anyMatch(m -> m.equals(SubtaskStateMapper.FULL)); + boolean hasAnyPreviousOperatorChanged = + executionJobVertex.getInputs().stream() + .map(IntermediateResult::getProducer) + .map(vertexAssignments::get) + .anyMatch( + taskStateAssignment -> { + final int oldParallelism = + stateAssignment + .oldState + .get(stateAssignment.inputOperatorID) + .getParallelism(); + return oldParallelism + != taskStateAssignment.executionJobVertex + .getParallelism(); + }); + + // need rescale if any input operator parallelism was changed and have any input with FULL + // subtask state mapper + if (inputState.getParallelism() == executionJobVertex.getParallelism() + && !(hasAnyFullMapper && hasAnyPreviousOperatorChanged)) { stateAssignment.inputChannelStates.putAll( toInstanceMap(stateAssignment.inputOperatorID, inputOperatorState)); return; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java index 95dd2555da6..e1a3e399c82 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java @@ -82,6 +82,7 @@ import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNew import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewOperatorStateHandle; import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewResultSubpartitionStateHandle; import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ARBITRARY; +import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.FULL; import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.RANGE; import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN; import static org.apache.flink.util.Preconditions.checkArgument; @@ -561,6 +562,44 @@ class StateAssignmentOperationTest { oldIndices, rescaleMapping, ambiguousSubtaskIndexes, mappingType); } + @Test + void testChannelStateAssignmentTwoGatesPartiallyDownscaling() + throws JobException, JobExecutionException { + JobVertex upstream1 = createJobVertex(new OperatorID(), 2); + JobVertex upstream2 = createJobVertex(new OperatorID(), 2); + JobVertex downstream = createJobVertex(new OperatorID(), 3); + List<OperatorID> operatorIds = + Stream.of(upstream1, upstream2, downstream) + .map(v -> v.getOperatorIDs().get(0).getGeneratedOperatorID()) + .collect(Collectors.toList()); + Map<OperatorID, OperatorState> states = buildOperatorStates(operatorIds, 3); + + connectVertices(upstream1, downstream, ARBITRARY, FULL); + connectVertices(upstream2, downstream, ROUND_ROBIN, ROUND_ROBIN); + + Map<OperatorID, ExecutionJobVertex> vertices = + toExecutionVertices(upstream1, upstream2, downstream); + + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + .assignStates(); + + assertThat( + getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 0) + .getInputChannelState() + .size()) + .isEqualTo(6); + assertThat( + getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 1) + .getInputChannelState() + .size()) + .isEqualTo(6); + assertThat( + getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 2) + .getInputChannelState() + .size()) + .isEqualTo(6); + } + @Test void testChannelStateAssignmentDownscaling() throws JobException, JobExecutionException { List<OperatorID> operatorIds = buildOperatorIds(2); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java index 52a8bf032b6..60ec98e7e0f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java @@ -145,7 +145,7 @@ public class StateHandleDummyUtil { public static InputChannelStateHandle createNewInputChannelStateHandle( int numNamedStates, Random random) { return new InputChannelStateHandle( - new InputChannelInfo(random.nextInt(), random.nextInt()), + new InputChannelInfo(0, random.nextInt()), createStreamStateHandle(numNamedStates, random), genOffsets(numNamedStates, random)); } diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java index fa5b61a8910..bb217b5a8cd 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.accumulators.LongCounter; import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; @@ -44,6 +45,7 @@ import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; import org.apache.flink.streaming.api.functions.co.KeyedCoProcessFunction; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.util.Collector; import org.apache.commons.lang3.ArrayUtils; @@ -52,8 +54,10 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.BitSet; +import java.util.Collection; import java.util.Collections; import static org.apache.flink.api.common.eventtime.WatermarkStrategy.noWatermarks; @@ -329,6 +333,44 @@ public class UnalignedCheckpointRescaleITCase extends UnalignedCheckpointTestBas addFailingSink(joined, minCheckpoints, slotSharing); } + }, + CUSTOM_PARTITIONER { + final int sinkParallelism = 3; + final int numberElements = 1000; + + @Override + public void create( + StreamExecutionEnvironment environment, + int minCheckpoints, + boolean slotSharing, + int expectedFailuresUntilSourceFinishes, + long sourceSleepMs) { + int parallelism = environment.getParallelism(); + environment + .fromData(generateStrings(numberElements / parallelism, sinkParallelism)) + .name("source") + .setParallelism(parallelism) + .partitionCustom(new StringPartitioner(), str -> str.split(" ")[0]) + .addSink(new StringSink(numberElements / sinkParallelism)) + .name("sink") + .setParallelism(sinkParallelism); + } + + private Collection<String> generateStrings( + int producePerPartition, int partitionCount) { + Collection<String> list = new ArrayList<>(); + for (int i = 0; i < producePerPartition; i++) { + for (int partition = 0; partition < partitionCount; partition++) { + list.add(buildString(partition, i)); + } + } + return list; + } + + private String buildString(int partition, int index) { + String longStr = new String(new char[3713]).replace('\0', '\uFFFF'); + return partition + " " + index + " " + longStr; + } }; void addFailingSink( @@ -485,6 +527,7 @@ public class UnalignedCheckpointRescaleITCase extends UnalignedCheckpointTestBas // captured in-flight records, see FLINK-31963. Object[][] parameters = new Object[][] { + new Object[] {"downscale", Topology.CUSTOM_PARTITIONER, 3, 2, 0L}, new Object[] {"downscale", Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7, 0L}, new Object[] {"upscale", Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12, 0L}, new Object[] {"downscale", Topology.KEYED_DIFFERENT_PARALLELISM, 5, 3, 5L}, @@ -561,6 +604,7 @@ public class UnalignedCheckpointRescaleITCase extends UnalignedCheckpointTestBas @Test public void shouldRescaleUnalignedCheckpoint() throws Exception { + StringSink.failed = false; final UnalignedSettings prescaleSettings = new UnalignedSettings(topology) .setParallelism(oldParallelism) @@ -585,8 +629,12 @@ public class UnalignedCheckpointRescaleITCase extends UnalignedCheckpointTestBas "NUM_OUTPUTS = NUM_INPUTS", result.<Long>getAccumulatorResult(NUM_OUTPUTS), equalTo(result.getAccumulatorResult(NUM_INPUTS))); - collector.checkThat( - "NUM_DUPLICATES", result.<Long>getAccumulatorResult(NUM_DUPLICATES), equalTo(0L)); + if (!topology.equals(Topology.CUSTOM_PARTITIONER)) { + collector.checkThat( + "NUM_DUPLICATES", + result.<Long>getAccumulatorResult(NUM_DUPLICATES), + equalTo(0L)); + } } /** @@ -705,4 +753,51 @@ public class UnalignedCheckpointRescaleITCase extends UnalignedCheckpointTestBas return checkHeader(value); } } + + private static class StringPartitioner implements Partitioner<String> { + @Override + public int partition(String key, int numPartitions) { + return Integer.parseInt(key) % numPartitions; + } + } + + private static class StringSink implements SinkFunction<String>, CheckpointedFunction { + + static volatile boolean failed = false; + + int checkpointConsumed = 0; + + int recordsConsumed = 0; + + final int numberElements; + + public StringSink(int numberElements) { + this.numberElements = numberElements; + } + + @Override + public void invoke(String value, Context ctx) throws Exception { + if (!failed && checkpointConsumed > 1) { + failed = true; + throw new TestException("FAIL"); + } + recordsConsumed++; + if (!failed && recordsConsumed == (numberElements / 3)) { + Thread.sleep(1000); + } + if (recordsConsumed == (numberElements - 1)) { + Thread.sleep(1000); + } + } + + @Override + public void snapshotState(FunctionSnapshotContext context) { + checkpointConsumed++; + } + + @Override + public void initializeState(FunctionInitializationContext context) { + // do nothing + } + } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java index 188bc9fdbc3..a936f2171b7 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java @@ -1133,7 +1133,7 @@ public abstract class UnalignedCheckpointTestBase extends TestLogger { return value; } - private static class TestException extends Exception { + static class TestException extends Exception { public TestException(String s) { super(s); }
