This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 415e07ebfe0695c871ef91618c17f910eb74e6a6
Author: Yun Gao <[email protected]>
AuthorDate: Tue Jun 15 17:34:47 2021 +0800

    [FLINK-21085][runtime][checkpoint] Allows triggering checkpoint for 
non-source tasks
    
    This closes #14820
---
 .../runtime/tasks/MultipleInputStreamTask.java     |  10 ++
 .../flink/streaming/runtime/tasks/StreamTask.java  |  51 ++++++++-
 ...tStreamTaskChainedSourcesCheckpointingTest.java |  55 ++++++++++
 .../runtime/tasks/MultipleInputStreamTaskTest.java | 114 ++++++++++++++++++++-
 .../streaming/runtime/tasks/StreamTaskTest.java    |  96 +++++++++++++++++
 5 files changed, 321 insertions(+), 5 deletions(-)

diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
index 1ec5906..179aedb 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
@@ -184,6 +184,16 @@ public class MultipleInputStreamTask<OUT>
     public Future<Boolean> triggerCheckpointAsync(
             CheckpointMetaData metadata, CheckpointOptions options) {
 
+        if (operatorChain.getSourceTaskInputs().size() == 0) {
+            return super.triggerCheckpointAsync(metadata, options);
+        }
+
+        // If there are chained sources, we would always only trigger the
+        // chained sources for checkpoint. This means that for the checkpoints
+        // during the upstream task finished and this task receives the
+        // EndOfPartitionEvent, we would not complement barriers for the
+        // unfinished network inputs, and the checkpoint would be triggered
+        // after received all the EndOfPartitionEvent.
         CompletableFuture<Boolean> resultFuture = new CompletableFuture<>();
         mainMailboxExecutor.execute(
                 () -> {
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 29ad693..ea42452 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -31,11 +31,13 @@ import 
org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import 
org.apache.flink.runtime.checkpoint.channel.SequentialChannelStateReader;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.writer.MultipleRecordWriters;
 import org.apache.flink.runtime.io.network.api.writer.NonRecordWriter;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
@@ -74,6 +76,7 @@ import 
org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
 import org.apache.flink.streaming.api.operators.StreamTaskStateInitializerImpl;
 import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.io.StreamInputProcessor;
+import 
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler;
 import 
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -98,7 +101,9 @@ import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
+import java.util.Optional;
 import java.util.OptionalLong;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
@@ -946,9 +951,19 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>> extends Ab
         mainMailboxExecutor.execute(
                 () -> {
                     try {
-                        result.complete(
-                                triggerCheckpointAsyncInMailbox(
-                                        checkpointMetaData, 
checkpointOptions));
+                        boolean noUnfinishedInputGates =
+                                
Arrays.stream(getEnvironment().getAllInputGates())
+                                        .allMatch(InputGate::isFinished);
+
+                        if (noUnfinishedInputGates) {
+                            result.complete(
+                                    triggerCheckpointAsyncInMailbox(
+                                            checkpointMetaData, 
checkpointOptions));
+                        } else {
+                            result.complete(
+                                    triggerUnfinishedChannelsCheckpoint(
+                                            checkpointMetaData, 
checkpointOptions));
+                        }
                     } catch (Exception ex) {
                         // Report the failure both via the Future result but 
also to the mailbox
                         result.completeExceptionally(ex);
@@ -1012,6 +1027,36 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>> extends Ab
         }
     }
 
+    private boolean triggerUnfinishedChannelsCheckpoint(
+            CheckpointMetaData checkpointMetaData, CheckpointOptions 
checkpointOptions)
+            throws Exception {
+        Optional<CheckpointBarrierHandler> checkpointBarrierHandler = 
getCheckpointBarrierHandler();
+        checkState(
+                checkpointBarrierHandler.isPresent(),
+                "CheckpointBarrier should exist for tasks with network 
inputs.");
+
+        CheckpointBarrier barrier =
+                new CheckpointBarrier(
+                        checkpointMetaData.getCheckpointId(),
+                        checkpointMetaData.getTimestamp(),
+                        checkpointOptions);
+
+        for (IndexedInputGate inputGate : getEnvironment().getAllInputGates()) 
{
+            if (!inputGate.isFinished()) {
+                for (InputChannelInfo channelInfo : 
inputGate.getUnfinishedChannels()) {
+                    checkpointBarrierHandler.get().processBarrier(barrier, 
channelInfo);
+                }
+            }
+        }
+
+        return true;
+    }
+
+    /**
+     * Acquires the optional {@link CheckpointBarrierHandler} associated with 
this stream task. The
+     * {@code CheckpointBarrierHandler} should exist if the task has data 
inputs and requires to
+     * align the barriers.
+     */
     protected Optional<CheckpointBarrierHandler> getCheckpointBarrierHandler() 
{
         return Optional.empty();
     }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
index 95ab281..c541e67 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
@@ -18,15 +18,18 @@
 
 package org.apache.flink.streaming.runtime.tasks;
 
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.eventtime.TimestampAssigner;
 import org.apache.flink.api.common.eventtime.WatermarkStrategy;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.connector.source.Boundedness;
 import org.apache.flink.api.connector.source.mocks.MockSource;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.SourceOperatorFactory;
@@ -38,10 +41,12 @@ import org.junit.Test;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Supplier;
 
 import static 
org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTaskTest.addSourceRecords;
 import static 
org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTaskTest.buildTestHarness;
+import static 
org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTaskTest.triggerCheckpoint;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -230,6 +235,56 @@ public class 
MultipleInputStreamTaskChainedSourcesCheckpointingTest {
         }
     }
 
+    @Test
+    public void testRpcTriggerCheckpointWithSourceChain() throws Exception {
+        AtomicReference<Future<?>> lastCheckpointTriggerFuture = new 
AtomicReference<>();
+
+        try (StreamTaskMailboxTestHarness<String> testHarness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                env ->
+                                        new MultipleInputStreamTaskTest
+                                                
.HoldingOnAfterInvokeMultipleInputStreamTask(
+                                                env, 
lastCheckpointTriggerFuture),
+                                BasicTypeInfo.STRING_TYPE_INFO)
+                        .modifyStreamConfig(config -> 
config.setCheckpointingEnabled(true))
+                        
.modifyExecutionConfig(ExecutionConfig::enableObjectReuse)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.STRING_TYPE_INFO)
+                        .addSourceInput(
+                                new SourceOperatorFactory<>(
+                                        new 
MultipleInputStreamTaskTest.LifeCycleTrackingMockSource(
+                                                Boundedness.BOUNDED, 1),
+                                        WatermarkStrategy.noWatermarks()))
+                        .addSourceInput(
+                                new SourceOperatorFactory<>(
+                                        new 
MultipleInputStreamTaskTest.LifeCycleTrackingMockSource(
+                                                Boundedness.BOUNDED, 1),
+                                        WatermarkStrategy.noWatermarks()))
+                        .setupOperatorChain(new 
MapToStringMultipleInputOperatorFactory(4))
+                        
.finishForSingletonOperatorChain(StringSerializer.INSTANCE)
+                        .build()) {
+
+            testHarness
+                    .getStreamTask()
+                    .getCheckpointCoordinator()
+                    .setEnableCheckpointAfterTasksFinished(true);
+
+            // TODO: Would add the test of part of channel finished after we 
are able to
+            // complement pending checkpoints when received 
EndOfPartitionEvent.
+
+            testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
+            testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 1, 0);
+            Future<Boolean> checkpointFuture = triggerCheckpoint(testHarness, 
4);
+            lastCheckpointTriggerFuture.set(checkpointFuture);
+
+            // The checkpoint 4 would be triggered successfully.
+            // TODO: Would also check the checkpoint succeed after we also 
waiting
+            // for the asynchronous step to finish on finish.
+            testHarness.finishProcessing();
+            assertTrue(checkpointFuture.isDone());
+        }
+    }
+
     private void addRecordsAndBarriers(
             StreamTaskMailboxTestHarness<String> testHarness, 
CheckpointBarrier checkpointBarrier)
             throws Exception {
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
index 352557c..4075ba6 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
@@ -24,6 +24,7 @@ import 
org.apache.flink.api.common.eventtime.WatermarkGenerator;
 import org.apache.flink.api.common.eventtime.WatermarkOutput;
 import org.apache.flink.api.common.eventtime.WatermarkStrategy;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.connector.source.Boundedness;
 import org.apache.flink.api.connector.source.SourceReader;
 import org.apache.flink.api.connector.source.SourceReaderContext;
@@ -31,12 +32,17 @@ import 
org.apache.flink.api.connector.source.mocks.MockSource;
 import org.apache.flink.api.connector.source.mocks.MockSourceReader;
 import org.apache.flink.api.connector.source.mocks.MockSourceSplit;
 import org.apache.flink.api.connector.source.mocks.MockSourceSplitSerializer;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.metrics.Metric;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.metrics.MetricNames;
 import org.apache.flink.runtime.metrics.NoOpMetricRegistry;
@@ -48,6 +54,7 @@ import 
org.apache.flink.runtime.metrics.util.InterceptingOperatorMetricGroup;
 import org.apache.flink.runtime.metrics.util.InterceptingTaskMetricGroup;
 import org.apache.flink.runtime.source.event.AddSplitEvent;
 import org.apache.flink.runtime.source.event.NoMoreSplitsEvent;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractInput;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
@@ -86,8 +93,11 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -840,6 +850,75 @@ public class MultipleInputStreamTaskTest {
         }
     }
 
+    @Test
+    public void testRpcTriggerCheckpointWithoutSourceChain() throws Exception {
+        AtomicReference<Future<?>> lastCheckpointTriggerFuture = new 
AtomicReference<>();
+
+        try (StreamTaskMailboxTestHarness<String> testHarness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                env ->
+                                        new 
HoldingOnAfterInvokeMultipleInputStreamTask(
+                                                env, 
lastCheckpointTriggerFuture),
+                                BasicTypeInfo.STRING_TYPE_INFO)
+                        .addInput(BasicTypeInfo.STRING_TYPE_INFO)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.DOUBLE_TYPE_INFO)
+                        .modifyStreamConfig(config -> 
config.setCheckpointingEnabled(true))
+                        .setupOperatorChain(new 
MapToStringMultipleInputOperatorFactory(3))
+                        
.finishForSingletonOperatorChain(StringSerializer.INSTANCE)
+                        .build()) {
+
+            testHarness
+                    .getStreamTask()
+                    .getCheckpointCoordinator()
+                    .setEnableCheckpointAfterTasksFinished(true);
+
+            // Tests triggering checkpoint when all the inputs are alive.
+            Future<Boolean> checkpointFuture = triggerCheckpoint(testHarness, 
2);
+            processMailTillCheckpointSuccess(testHarness, checkpointFuture);
+            assertEquals(2, 
testHarness.getTaskStateManager().getReportedCheckpointId());
+
+            // Tests trigger checkpoint after some inputs have received 
EndOfPartition
+            testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
+            checkpointFuture = triggerCheckpoint(testHarness, 4);
+            processMailTillCheckpointSuccess(testHarness, checkpointFuture);
+            assertEquals(4, 
testHarness.getTaskStateManager().getReportedCheckpointId());
+
+            // Tests trigger checkpoint after all the inputs have received 
EndOfPartition.
+            testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 1, 0);
+            testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 2, 0);
+            checkpointFuture = triggerCheckpoint(testHarness, 6);
+            lastCheckpointTriggerFuture.set(checkpointFuture);
+
+            // The checkpoint 6 would be triggered successfully.
+            // TODO: Would also check the checkpoint succeed after we also 
waiting
+            // for the asynchronous step to finish on finish.
+            testHarness.finishProcessing();
+            assertTrue(checkpointFuture.isDone());
+        }
+    }
+
+    static Future<Boolean> triggerCheckpoint(
+            StreamTaskMailboxTestHarness<String> testHarness, long 
checkpointId) {
+        testHarness.getTaskStateManager().setWaitForReportLatch(new 
OneShotLatch());
+        return testHarness
+                .getStreamTask()
+                .triggerCheckpointAsync(
+                        new CheckpointMetaData(checkpointId, checkpointId * 
1000),
+                        CheckpointOptions.alignedNoTimeout(
+                                CheckpointType.CHECKPOINT,
+                                
CheckpointStorageLocationReference.getDefault()));
+    }
+
+    static void processMailTillCheckpointSuccess(
+            StreamTaskMailboxTestHarness<String> testHarness, Future<Boolean> 
checkpointFuture)
+            throws Exception {
+        while (!checkpointFuture.isDone()) {
+            testHarness.processSingleStep();
+        }
+        testHarness.getTaskStateManager().getWaitForReportLatch().await();
+    }
+
     /** Test implementation of {@link MultipleInputStreamOperator}. */
     protected static class MapToStringMultipleInputOperator extends 
AbstractStreamOperatorV2<String>
             implements MultipleInputStreamOperator<String> {
@@ -875,11 +954,12 @@ public class MultipleInputStreamTaskTest {
 
         @Override
         public List<Input> getInputs() {
-            checkArgument(numberOfInputs <= 3);
+            checkArgument(numberOfInputs <= 4);
             return Arrays.<Input>asList(
                             new MapToStringInput<String>(this, 1),
                             new MapToStringInput<Integer>(this, 2),
-                            new MapToStringInput<Double>(this, 3))
+                            new MapToStringInput<Double>(this, 3),
+                            new MapToStringInput<String>(this, 4))
                     .subList(0, numberOfInputs);
         }
 
@@ -1161,4 +1241,34 @@ public class MultipleInputStreamTaskTest {
         @Override
         public void onPeriodicEmit(WatermarkOutput output) {}
     }
+
+    /**
+     * Special stream task implementation that would waits till all 
checkpoints get triggered before
+     * actually finish.
+     *
+     * <p>TODO: It would be removed after we introduce the mechanism that make 
the upstream tasks
+     * wait for the downstream tasks to process all the records before 
finished.
+     */
+    static class HoldingOnAfterInvokeMultipleInputStreamTask
+            extends MultipleInputStreamTask<String> {
+
+        private final AtomicReference<Future<?>> lastCheckpointTriggerFuture;
+
+        public HoldingOnAfterInvokeMultipleInputStreamTask(
+                Environment env, AtomicReference<Future<?>> 
lastCheckpointTriggerFuture)
+                throws Exception {
+            super(env);
+            this.lastCheckpointTriggerFuture = 
checkNotNull(lastCheckpointTriggerFuture);
+        }
+
+        @Override
+        protected void afterInvoke() throws Exception {
+            while (!lastCheckpointTriggerFuture.get().isDone()) {
+                Thread.sleep(200);
+                mainMailboxExecutor.tryYield();
+            }
+
+            super.afterInvoke();
+        }
+    }
 }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 03864d1..90cdcd5 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -48,6 +48,7 @@ import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import 
org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
@@ -64,6 +65,7 @@ import 
org.apache.flink.runtime.shuffle.PartitionDescriptorBuilder;
 import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend;
 import org.apache.flink.runtime.state.DoneFuture;
@@ -152,11 +154,13 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 
 import static java.util.Arrays.asList;
@@ -168,6 +172,7 @@ import static 
org.apache.flink.runtime.checkpoint.StateObjectCollection.singleto
 import static 
org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
 import static 
org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox.MAX_PRIORITY;
 import static org.apache.flink.streaming.util.StreamTaskUtil.waitTaskIsRunning;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
@@ -1723,6 +1728,64 @@ public class StreamTaskTest extends TestLogger {
         assertTrue(OpenFailingOperator.wasClosed);
     }
 
+    @Test
+    public void testTriggeringCheckpointWithFinishedChannels() throws 
Exception {
+        AtomicReference<Future<?>> lastCheckpointTriggerFuture = new 
AtomicReference<>();
+
+        try (StreamTaskMailboxTestHarness<String> testHarness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                env ->
+                                        new HoldingOnAfterInvokeStreamTask(
+                                                env, 
lastCheckpointTriggerFuture),
+                                BasicTypeInfo.STRING_TYPE_INFO)
+                        .addInput(BasicTypeInfo.STRING_TYPE_INFO, 3)
+                        .setupOutputForSingletonOperatorChain(new 
EmptyOperator())
+                        .build()) {
+            // Tests triggering checkpoint when all the inputs are alive.
+            Future<Boolean> checkpointFuture = triggerCheckpoint(testHarness, 
2);
+            processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
+            assertEquals(2, 
testHarness.getTaskStateManager().getReportedCheckpointId());
+
+            // Tests trigger checkpoint after some inputs have received 
EndOfPartition
+            testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
+            checkpointFuture = triggerCheckpoint(testHarness, 4);
+            processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
+            assertEquals(4, 
testHarness.getTaskStateManager().getReportedCheckpointId());
+
+            testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 1);
+            testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 2);
+            checkpointFuture = triggerCheckpoint(testHarness, 6);
+            lastCheckpointTriggerFuture.set(checkpointFuture);
+
+            // The checkpoint 6 would be triggered successfully.
+            // TODO: Would also check the checkpoint succeed after we also 
waiting
+            // for the asynchronous step to finish on finish.
+            testHarness.finishProcessing();
+            assertTrue(checkpointFuture.isDone());
+        }
+    }
+
+    private static Future<Boolean> triggerCheckpoint(
+            StreamTaskMailboxTestHarness<String> testHarness, long 
checkpointId) {
+        testHarness.getTaskStateManager().setWaitForReportLatch(new 
OneShotLatch());
+        return testHarness
+                .getStreamTask()
+                .triggerCheckpointAsync(
+                        new CheckpointMetaData(checkpointId, checkpointId * 
1000),
+                        CheckpointOptions.alignedNoTimeout(
+                                CheckpointType.CHECKPOINT,
+                                
CheckpointStorageLocationReference.getDefault()));
+    }
+
+    private static void processMailTillCheckpointSucceeds(
+            StreamTaskMailboxTestHarness<String> testHarness, Future<Boolean> 
checkpointFuture)
+            throws Exception {
+        while (!checkpointFuture.isDone()) {
+            testHarness.processSingleStep();
+        }
+        testHarness.getTaskStateManager().getWaitForReportLatch().await();
+    }
+
     private MockEnvironment setupEnvironment(boolean... outputAvailabilities) {
         final Configuration configuration = new Configuration();
         new MockStreamConfig(configuration, outputAvailabilities.length);
@@ -2726,4 +2789,37 @@ public class StreamTaskTest extends TestLogger {
             };
         }
     }
+
+    /**
+     * Special stream task implementation that would waits till all 
checkpoints get triggered before
+     * actually finish.
+     */
+    private static class HoldingOnAfterInvokeStreamTask extends 
OneInputStreamTask<String, String> {
+
+        private final AtomicReference<Future<?>> lastCheckpointTriggerFuture;
+
+        public HoldingOnAfterInvokeStreamTask(
+                Environment env, AtomicReference<Future<?>> 
lastCheckpointTriggerFuture)
+                throws Exception {
+            super(env);
+            this.lastCheckpointTriggerFuture = 
checkNotNull(lastCheckpointTriggerFuture);
+        }
+
+        @Override
+        protected void afterInvoke() throws Exception {
+            while (!lastCheckpointTriggerFuture.get().isDone()) {
+                Thread.sleep(200);
+                mainMailboxExecutor.tryYield();
+            }
+
+            super.afterInvoke();
+        }
+    }
+
+    private static class EmptyOperator extends AbstractStreamOperator<String>
+            implements OneInputStreamOperator<String, String> {
+
+        @Override
+        public void processElement(StreamRecord<String> element) throws 
Exception {}
+    }
 }

Reply via email to