[ 
https://issues.apache.org/jira/browse/FLINK-35351?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Piotr Nowojski reassigned FLINK-35351:
--------------------------------------

    Assignee: Dmitriy Linevich

> Restore from unaligned checkpoints with a custom partitioner fails.
> -------------------------------------------------------------------
>
>                 Key: FLINK-35351
>                 URL: https://issues.apache.org/jira/browse/FLINK-35351
>             Project: Flink
>          Issue Type: Bug
>          Components: Runtime / Checkpointing
>            Reporter: Dmitriy Linevich
>            Assignee: Dmitriy Linevich
>            Priority: Major
>
> We encountered a problem when using a custom partitioner with unaligned 
> checkpoints. The bug reproduces under the following steps:
>  # Run a job with graph: Source[2]->Sink[3], the custom partitioner applied 
> after the Source task.
>  # Make a checkpoint.
>  # Restore from the checkpoint with a different source parallelism: 
> Source[1]->Sink[3].
>  # An exception is thrown.
> This issue does not occur when restoring with the same parallelism or when 
> changing the Sink parallelism. The exception only occurs when the parallelism 
> of the Source is changed while the Sink parallelism remains the same.
> See the exception below and the test code at the end. 
> {code:java}
> [db13789c52b80aad852c53a0afa26247] Task [Sink: sink (3/3)#0] WARN  Sink: sink 
> (3/3)#0 
> (be1d158c2e77fc9ed9e3e5d9a8431dc2_0a448493b4782967b150582570326227_2_0) 
> switched from RUNNING to FAILED with failure cause:
> java.io.IOException: Can't get next record for channel 
> InputChannelInfo{gateIdx=0, inputChannelIdx=0}
>     at 
> org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput.emitNext(AbstractStreamTaskNetworkInput.java:106)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.io.StreamOneInputProcessor.processInput(StreamOneInputProcessor.java:65)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.tasks.StreamTask.processInput(StreamTask.java:600)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor.runMailboxLoop(MailboxProcessor.java:231)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.tasks.StreamTask.runMailboxLoop(StreamTask.java:930)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.tasks.StreamTask.invoke(StreamTask.java:879)
>  ~[classes/:?]
>     at 
> org.apache.flink.runtime.taskmanager.Task.runWithSystemExitMonitoring(Task.java:960)
>  ~[classes/:?]
>     at 
> org.apache.flink.runtime.taskmanager.Task.restoreAndInvoke(Task.java:939) 
> [classes/:?]
>     at org.apache.flink.runtime.taskmanager.Task.doRun(Task.java:753) 
> [classes/:?]
>     at org.apache.flink.runtime.taskmanager.Task.run(Task.java:568) 
> [classes/:?]
>     at java.lang.Thread.run(Thread.java:745) [?:1.8.0_121]
> Caused by: java.io.IOException: Corrupt stream, found tag: -1
>     at 
> org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer.deserialize(StreamElementSerializer.java:222)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer.deserialize(StreamElementSerializer.java:44)
>  ~[classes/:?]
>     at 
> org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate.read(NonReusingDeserializationDelegate.java:53)
>  ~[classes/:?]
>     at 
> org.apache.flink.runtime.io.network.api.serialization.NonSpanningWrapper.readInto(NonSpanningWrapper.java:337)
>  ~[classes/:?]
>     at 
> org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer.readNonSpanningRecord(SpillingAdaptiveSpanningRecordDeserializer.java:128)
>  ~[classes/:?]
>     at 
> org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer.readNextRecord(SpillingAdaptiveSpanningRecordDeserializer.java:103)
>  ~[classes/:?]
>     at 
> org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer.getNextRecord(SpillingAdaptiveSpanningRecordDeserializer.java:93)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.io.recovery.DemultiplexingRecordDeserializer$VirtualChannel.getNextRecord(DemultiplexingRecordDeserializer.java:79)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.io.recovery.DemultiplexingRecordDeserializer.getNextRecord(DemultiplexingRecordDeserializer.java:154)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.io.recovery.DemultiplexingRecordDeserializer.getNextRecord(DemultiplexingRecordDeserializer.java:54)
>  ~[classes/:?]
>     at 
> org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput.emitNext(AbstractStreamTaskNetworkInput.java:103)
>  ~[classes/:?]
>     ... 10 more {code}
> We discovered that this issue occurs due to an optimization in the 
> [StateAssignmentOperation::reDistributeInputChannelStates|https://github.com/apache/flink/blame/master/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java#L424]
> {code:java}
>         if (inputState.getParallelism() == 
> executionJobVertex.getParallelism()) {
>             stateAssignment.inputChannelStates.putAll(
>                     toInstanceMap(stateAssignment.inputOperatorID, 
> inputOperatorState));
>             return;
>         }
>  {code}
> At the moment of checkpointing, some in-flight records could be partially 
> sent. The Sink sub-task's input has one half of a record, and the Source 
> sub-task's output has the second part of the record. When restoring in-flight 
> data, all the records are sent from all Source sub-tasks to all Sink 
> sub-tasks. However, due to the optimization mentioned above, some Sink 
> sub-tasks might not have information about the beginnings of in-flight 
> records.
> The proposed fix involves checking the situations when the optimization 
> should not be applied:
> {code:java}
> boolean noNeedRescale = stateAssignment.executionJobVertex
>         .getJobVertex()
>         .getInputs()
>         .stream()
>         .map(JobEdge::getDownstreamSubtaskStateMapper)
>         .anyMatch(m -> !m.equals(SubtaskStateMapper.FULL))
>         && stateAssignment.executionJobVertex
>         .getInputs()
>         .stream()
>         .map(IntermediateResult::getProducer)
>         .map(vertexAssignments::get)
>         .anyMatch(taskStateAssignment -> {
>             final int oldParallelism =
>                     
> stateAssignment.oldState.get(stateAssignment.inputOperatorID).getParallelism();
>             return oldParallelism == 
> taskStateAssignment.executionJobVertex.getParallelism();
>         });
> if (inputState.getParallelism() == executionJobVertex.getParallelism() && 
> noNeedRescale) {code}
>  
> Test for reproduce:
> {code:java}
> package org.apache.flink.test.checkpointing;
> import org.apache.flink.api.common.JobID;
> import org.apache.flink.api.common.JobStatus;
> import org.apache.flink.api.common.functions.Partitioner;
> import org.apache.flink.configuration.Configuration;
> import org.apache.flink.core.fs.Path;
> import org.apache.flink.runtime.jobgraph.JobGraph;
> import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
> import org.apache.flink.runtime.minicluster.MiniCluster;
> import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
> import org.apache.flink.streaming.api.environment.CheckpointConfig;
> import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
> import org.apache.flink.streaming.api.functions.sink.SinkFunction;
> import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
> import org.apache.flink.util.FileUtils;
> import org.junit.After;
> import org.junit.Rule;
> import org.junit.Test;
> import org.junit.rules.TemporaryFolder;
> import java.io.File;
> import static 
> org.apache.flink.streaming.api.environment.StreamExecutionEnvironment.createLocalEnvironmentWithWebUI;
> import static org.junit.Assert.fail;
> /** Integration test for performing rescale of unaligned checkpoint with 
> custom partitioner. */
> public class UnalignedCheckpointCustomRescaleITCase {
>     @Rule
>     public TemporaryFolder tempFolder = new TemporaryFolder();
>     private final static File CHECKPOINT_FILE = new 
> File("src/test/resources/custom-checkpoint");
>     @Test
>     public void createCheckpoint() {
>         runJob(2, 3, true);
>     }
>     @Test
>     public void restoreFromCheckpoint() {
>         runJob(1, 3, false);
>     }
>     @After
>     public void after() {
>         tempFolder.delete();
>     }
>     private void runJob(int sourceParallelism, int sinkParallelism, boolean 
> createCheckpoint) {
>         try (MiniCluster miniCluster = new 
> MiniCluster(buildMiniClusterConfig())) {
>             miniCluster.start();
>             Configuration configuration = new Configuration();
>             StreamExecutionEnvironment env = 
> createLocalEnvironmentWithWebUI(configuration);
>             CheckpointConfig checkpointConfig = env.getCheckpointConfig();
>             env.enableCheckpointing(Integer.MAX_VALUE);
>             checkpointConfig.setForceUnalignedCheckpoints(true);
>             checkpointConfig.enableUnalignedCheckpoints();
>             checkpointConfig.setMaxConcurrentCheckpoints(1);
>             
> checkpointConfig.setExternalizedCheckpointCleanup(CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
>             checkpointConfig.setCheckpointStorage("file://" + 
> tempFolder.newFolder() + "/checkPoints");
>             env
>                     .addSource(new StringsSource(createCheckpoint ? 10 : 0, 
> sinkParallelism))
>                     .name("source")
>                     .setParallelism(sourceParallelism)
>                     .partitionCustom(new StringPartitioner(), str -> 
> str.split(" ")[0])
>                     .addSink(new StringSink(createCheckpoint ? 16 : 100000))
>                     .name("sink")
>                     .setParallelism(sinkParallelism);
>             JobGraph job = env.getStreamGraph().getJobGraph();
>             if (!createCheckpoint) {
>                 
> job.setSavepointRestoreSettings(SavepointRestoreSettings.forPath("file://" + 
> CHECKPOINT_FILE.getAbsolutePath(), false));
>             }
>             JobID jobId = miniCluster.submitJob(job).get().getJobID();
>             if (createCheckpoint) {
>                 while 
> (!miniCluster.getJobStatus(jobId).get().equals(JobStatus.RUNNING)) {
>                     Thread.sleep(1000);
>                 }
>                 String savepointPath = 
> miniCluster.triggerCheckpoint(jobId).get();
>                 System.out.println("SAVE PATH " + savepointPath);
>                 Thread.sleep(1000);
>                 miniCluster.cancelJob(jobId);
>                 FileUtils.copy(new Path(savepointPath), 
> Path.fromLocalFile(CHECKPOINT_FILE), false);
>             } else {
>                 int count = 0;
>                 while 
> (!miniCluster.getJobStatus(jobId).get().equals(JobStatus.RUNNING)) {
>                     Thread.sleep(1000);
>                     count++;
>                     if (count > 10) {
>                         break;
>                     }
>                 }
>                 Thread.sleep(10000);
>                 boolean fail = 
> !miniCluster.getJobStatus(jobId).get().equals(JobStatus.RUNNING);
>                 miniCluster.cancelJob(jobId);
>                 if (fail) {
>                     fail("Job fails");
>                 }
>             }
>         } catch (Exception e) {
>             throw new RuntimeException(e);
>         }
>     }
>     private static MiniClusterConfiguration buildMiniClusterConfig() {
>         return new MiniClusterConfiguration.Builder()
>                 .setNumTaskManagers(2)
>                 .setNumSlotsPerTaskManager(4)
>                 .build();
>     }
>     private static class StringsSource implements 
> ParallelSourceFunction<String> {
>         volatile boolean isCanceled;
>         final int producePerPartition;
>         final int partitionCount;
>         public StringsSource(int producePerPartition, int partitionCount) {
>             this.producePerPartition = producePerPartition;
>             this.partitionCount = partitionCount;
>         }
>         private String buildString(int partition, int index) {
>             String longStr = new String(new char[3713]).replace('\0', 
> '\uFFFF');
>             return partition + " " + index + " " + longStr;
>         }
>         @Override
>         public void run(SourceContext<String> ctx) throws Exception {
>             for (int i = 0; i < producePerPartition; i++) {
>                 for (int partition = 0; partition < partitionCount; 
> partition++) {
>                     ctx.collect(buildString(partition, i));
>                 }
>             }
>             while (!isCanceled) { Thread.sleep(1000); }
>         }
>         @Override
>         public void cancel() { isCanceled = true; }
>     }
>     private static class StringSink implements SinkFunction<String> {
>         final int consumeBeforeCheckpoint;
>         int consumed = 0;
>         public StringSink(int consumeBeforeCheckpoint) {
>             this.consumeBeforeCheckpoint = consumeBeforeCheckpoint;
>         }
>         @Override
>         public void invoke(String value, Context ctx) throws 
> InterruptedException {
>             consumed++;
>             System.out.println("--- CONSUMED --- " + value.substring(0, 10));
>             if (consumed == consumeBeforeCheckpoint) {
>                 System.out.println("--- WAITING FOR CHECKPOINT START ---");
>                 Thread.sleep(4000);
>             }
>         }
>     }
>     public static class StringPartitioner implements Partitioner<String> {
>         @Override
>         public int partition(String key, int numPartitions) {
>             return Integer.parseInt(key) % numPartitions;
>         }
>     }
> }
>  {code}
>  



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

Reply via email to