This is an automated email from the ASF dual-hosted git repository.

dmvk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 03e4f8eb13d [FLINK-31370] Prevent more timers from being fired if the 
StreamTask has been canceled.
03e4f8eb13d is described below

commit 03e4f8eb13d8263a4a1f4947f5d1cb55a3d368e2
Author: David Moravek <[email protected]>
AuthorDate: Wed Mar 8 15:02:12 2023 +0100

    [FLINK-31370] Prevent more timers from being fired if the StreamTask has 
been canceled.
---
 .../api/operators/InternalTimeServiceManager.java  |   4 +-
 .../operators/InternalTimeServiceManagerImpl.java  |  19 ++--
 .../api/operators/InternalTimerServiceImpl.java    |  16 +++-
 .../operators/StreamTaskStateInitializerImpl.java  |  18 ++--
 .../BatchExecutionInternalTimeServiceManager.java  |   4 +-
 .../flink/streaming/runtime/tasks/StreamTask.java  |   3 +-
 .../tasks/StreamTaskCancellationContext.java       |  43 +++++++++
 .../operators/InternalTimerServiceImplTest.java    |   4 +-
 .../StateInitializationContextImplTest.java        |   7 +-
 .../StreamTaskStateInitializerImplTest.java        |   7 +-
 .../BatchExecutionInternalTimeServiceTest.java     |  22 +++--
 .../runtime/tasks/StreamTaskCancellationTest.java  | 100 +++++++++++++++++++++
 .../tasks/StreamTaskMailboxTestHarness.java        |   4 +
 .../util/AbstractStreamOperatorTestHarness.java    |  13 ++-
 .../restore/StreamOperatorSnapshotRestoreTest.java |   4 +-
 15 files changed, 233 insertions(+), 35 deletions(-)

diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
index e374fdb0fd8..439789c3709 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
@@ -25,6 +25,7 @@ import 
org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
 import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 
 import java.io.Serializable;
 
@@ -76,7 +77,8 @@ public interface InternalTimeServiceManager<K> {
                 ClassLoader userClassloader,
                 KeyContext keyContext,
                 ProcessingTimeService processingTimeService,
-                Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates)
+                Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates,
+                StreamTaskCancellationContext cancellationContext)
                 throws Exception;
     }
 }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerImpl.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerImpl.java
index 9cbadd1c2a9..51a280bdda2 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerImpl.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerImpl.java
@@ -32,6 +32,7 @@ import 
org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
 import org.apache.flink.runtime.state.PriorityQueueSetFactory;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.util.Preconditions;
 
 import org.slf4j.Logger;
@@ -67,9 +68,9 @@ public class InternalTimeServiceManagerImpl<K> implements 
InternalTimeServiceMan
 
     private final KeyGroupRange localKeyGroupRange;
     private final KeyContext keyContext;
-
     private final PriorityQueueSetFactory priorityQueueSetFactory;
     private final ProcessingTimeService processingTimeService;
+    private final StreamTaskCancellationContext cancellationContext;
 
     private final Map<String, InternalTimerServiceImpl<K, ?>> timerServices;
 
@@ -77,12 +78,14 @@ public class InternalTimeServiceManagerImpl<K> implements 
InternalTimeServiceMan
             KeyGroupRange localKeyGroupRange,
             KeyContext keyContext,
             PriorityQueueSetFactory priorityQueueSetFactory,
-            ProcessingTimeService processingTimeService) {
+            ProcessingTimeService processingTimeService,
+            StreamTaskCancellationContext cancellationContext) {
 
         this.localKeyGroupRange = 
Preconditions.checkNotNull(localKeyGroupRange);
         this.priorityQueueSetFactory = 
Preconditions.checkNotNull(priorityQueueSetFactory);
         this.keyContext = Preconditions.checkNotNull(keyContext);
         this.processingTimeService = 
Preconditions.checkNotNull(processingTimeService);
+        this.cancellationContext = cancellationContext;
 
         this.timerServices = new HashMap<>();
     }
@@ -97,13 +100,18 @@ public class InternalTimeServiceManagerImpl<K> implements 
InternalTimeServiceMan
             ClassLoader userClassloader,
             KeyContext keyContext,
             ProcessingTimeService processingTimeService,
-            Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates)
+            Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates,
+            StreamTaskCancellationContext cancellationContext)
             throws Exception {
         final KeyGroupRange keyGroupRange = 
keyedStateBackend.getKeyGroupRange();
 
         final InternalTimeServiceManagerImpl<K> timeServiceManager =
                 new InternalTimeServiceManagerImpl<>(
-                        keyGroupRange, keyContext, keyedStateBackend, 
processingTimeService);
+                        keyGroupRange,
+                        keyContext,
+                        keyedStateBackend,
+                        processingTimeService,
+                        cancellationContext);
 
         // and then initialize the timer services
         for (KeyGroupStatePartitionStreamProvider streamProvider : 
rawKeyedStates) {
@@ -157,7 +165,8 @@ public class InternalTimeServiceManagerImpl<K> implements 
InternalTimeServiceMan
                             processingTimeService,
                             createTimerPriorityQueue(
                                     PROCESSING_TIMER_PREFIX + name, 
timerSerializer),
-                            createTimerPriorityQueue(EVENT_TIMER_PREFIX + 
name, timerSerializer));
+                            createTimerPriorityQueue(EVENT_TIMER_PREFIX + 
name, timerSerializer),
+                            cancellationContext);
 
             timerServices.put(name, timerService);
         }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImpl.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImpl.java
index b2cf4c6b94a..e2c7e4139b2 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImpl.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImpl.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.state.InternalPriorityQueue;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
@@ -54,6 +55,9 @@ public class InternalTimerServiceImpl<K, N> implements 
InternalTimerService<N> {
     private final KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>>
             eventTimeTimersQueue;
 
+    /** Context that allows us to stop firing timers if the containing task 
has been cancelled. */
+    private final StreamTaskCancellationContext cancellationContext;
+
     /** Information concerning the local key-group range. */
     private final KeyGroupRange localKeyGroupRange;
 
@@ -93,13 +97,15 @@ public class InternalTimerServiceImpl<K, N> implements 
InternalTimerService<N> {
             KeyContext keyContext,
             ProcessingTimeService processingTimeService,
             KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> 
processingTimeTimersQueue,
-            KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> 
eventTimeTimersQueue) {
+            KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> 
eventTimeTimersQueue,
+            StreamTaskCancellationContext cancellationContext) {
 
         this.keyContext = checkNotNull(keyContext);
         this.processingTimeService = checkNotNull(processingTimeService);
         this.localKeyGroupRange = checkNotNull(localKeyGroupRange);
         this.processingTimeTimersQueue = 
checkNotNull(processingTimeTimersQueue);
         this.eventTimeTimersQueue = checkNotNull(eventTimeTimersQueue);
+        this.cancellationContext = cancellationContext;
 
         // find the starting index of the local key-group range
         int startIdx = Integer.MAX_VALUE;
@@ -278,7 +284,9 @@ public class InternalTimerServiceImpl<K, N> implements 
InternalTimerService<N> {
 
         InternalTimer<K, N> timer;
 
-        while ((timer = processingTimeTimersQueue.peek()) != null && 
timer.getTimestamp() <= time) {
+        while ((timer = processingTimeTimersQueue.peek()) != null
+                && timer.getTimestamp() <= time
+                && !cancellationContext.isCancelled()) {
             keyContext.setCurrentKey(timer.getKey());
             processingTimeTimersQueue.poll();
             triggerTarget.onProcessingTime(timer);
@@ -296,7 +304,9 @@ public class InternalTimerServiceImpl<K, N> implements 
InternalTimerService<N> {
 
         InternalTimer<K, N> timer;
 
-        while ((timer = eventTimeTimersQueue.peek()) != null && 
timer.getTimestamp() <= time) {
+        while ((timer = eventTimeTimersQueue.peek()) != null
+                && timer.getTimestamp() <= time
+                && !cancellationContext.isCancelled()) {
             keyContext.setCurrentKey(timer.getKey());
             eventTimeTimersQueue.poll();
             triggerTarget.onEventTime(timer);
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
index 18d91f828c0..552683856ef 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
@@ -46,12 +46,11 @@ import org.apache.flink.runtime.state.TaskStateManager;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.runtime.util.OperatorSubtaskDescriptionText;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.util.CloseableIterable;
 import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.io.IOUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
@@ -79,9 +78,6 @@ import static 
org.apache.flink.runtime.state.StateUtil.unexpectedStateHandleExce
  */
 public class StreamTaskStateInitializerImpl implements 
StreamTaskStateInitializer {
 
-    /** The logger for this class. */
-    private static final Logger LOG = 
LoggerFactory.getLogger(StreamTaskStateInitializerImpl.class);
-
     /**
      * The environment of the task. This is required as parameter to construct 
state backends via
      * their factory.
@@ -101,13 +97,16 @@ public class StreamTaskStateInitializerImpl implements 
StreamTaskStateInitialize
 
     private final InternalTimeServiceManager.Provider 
timeServiceManagerProvider;
 
+    private final StreamTaskCancellationContext cancellationContext;
+
     public StreamTaskStateInitializerImpl(Environment environment, 
StateBackend stateBackend) {
 
         this(
                 environment,
                 stateBackend,
                 TtlTimeProvider.DEFAULT,
-                InternalTimeServiceManagerImpl::create);
+                InternalTimeServiceManagerImpl::create,
+                StreamTaskCancellationContext.alwaysRunning());
     }
 
     @VisibleForTesting
@@ -115,13 +114,15 @@ public class StreamTaskStateInitializerImpl implements 
StreamTaskStateInitialize
             Environment environment,
             StateBackend stateBackend,
             TtlTimeProvider ttlTimeProvider,
-            InternalTimeServiceManager.Provider timeServiceManagerProvider) {
+            InternalTimeServiceManager.Provider timeServiceManagerProvider,
+            StreamTaskCancellationContext cancellationContext) {
 
         this.environment = environment;
         this.taskStateManager = 
Preconditions.checkNotNull(environment.getTaskStateManager());
         this.stateBackend = Preconditions.checkNotNull(stateBackend);
         this.ttlTimeProvider = ttlTimeProvider;
         this.timeServiceManagerProvider = 
Preconditions.checkNotNull(timeServiceManagerProvider);
+        this.cancellationContext = cancellationContext;
     }
 
     // 
-----------------------------------------------------------------------------------------------------------------
@@ -213,7 +214,8 @@ public class StreamTaskStateInitializerImpl implements 
StreamTaskStateInitialize
                                 
environment.getUserCodeClassLoader().asClassLoader(),
                                 keyContext,
                                 processingTimeService,
-                                restoredRawKeyedStateTimers);
+                                restoredRawKeyedStateTimers,
+                                cancellationContext);
             } else {
                 timeServiceManager = null;
             }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionInternalTimeServiceManager.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionInternalTimeServiceManager.java
index 8666f05abd9..c1152725e05 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionInternalTimeServiceManager.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionInternalTimeServiceManager.java
@@ -29,6 +29,7 @@ import org.apache.flink.streaming.api.operators.KeyContext;
 import org.apache.flink.streaming.api.operators.Triggerable;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.util.WrappingRuntimeException;
 
 import java.util.HashMap;
@@ -88,7 +89,8 @@ public class BatchExecutionInternalTimeServiceManager<K>
             ClassLoader userClassloader,
             KeyContext keyContext, // the operator
             ProcessingTimeService processingTimeService,
-            Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates) {
+            Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates,
+            StreamTaskCancellationContext cancellationContext) {
         checkState(
                 keyedStatedBackend instanceof BatchExecutionKeyedStateBackend,
                 "Batch execution specific time service can work only with 
BatchExecutionKeyedStateBackend");
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 8cc529720d8..f71794db58f 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -670,7 +670,8 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
                 TtlTimeProvider.DEFAULT,
                 timerServiceProvider != null
                         ? timerServiceProvider
-                        : InternalTimeServiceManagerImpl::create);
+                        : InternalTimeServiceManagerImpl::create,
+                () -> canceled);
     }
 
     protected Counter setupNumRecordsInCounter(StreamOperator streamOperator) {
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationContext.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationContext.java
new file mode 100644
index 00000000000..d47e0419107
--- /dev/null
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationContext.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.tasks;
+
+import org.apache.flink.annotation.Internal;
+
+/** Context on the {@link StreamTask} for figuring out whether it has been 
cancelled. */
+@FunctionalInterface
+@Internal
+public interface StreamTaskCancellationContext {
+
+    /**
+     * Factory for a context that always returns {@code false} when {@link 
#isCancelled()} is
+     * called.
+     *
+     * @return context
+     */
+    static StreamTaskCancellationContext alwaysRunning() {
+        return () -> false;
+    }
+
+    /**
+     * Find out whether the {@link StreamTask} this context belongs to has 
been cancelled.
+     *
+     * @return true if the {@code StreamTask} the has been cancelled
+     */
+    boolean isCancelled();
+}
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java
index 8b5b7afbd62..915b2732982 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java
@@ -30,6 +30,7 @@ import 
org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.runtime.state.PriorityQueueSetFactory;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
 
 import org.junit.Assert;
@@ -1100,7 +1101,8 @@ public class InternalTimerServiceImplTest {
                 processingTimeService,
                 createTimerQueue(
                         "__test_processing_timers", timerSerializer, 
priorityQueueSetFactory),
-                createTimerQueue("__test_event_timers", timerSerializer, 
priorityQueueSetFactory));
+                createTimerQueue("__test_event_timers", timerSerializer, 
priorityQueueSetFactory),
+                StreamTaskCancellationContext.alwaysRunning());
     }
 
     private static <K, N>
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
index 916caba04b1..f8e0b0ccfc2 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
@@ -58,6 +58,7 @@ import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.util.LongArrayList;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 
 import org.junit.Assert;
 import org.junit.Before;
@@ -191,7 +192,8 @@ public class StateInitializationContextImplTest {
                                     ClassLoader userClassloader,
                                     KeyContext keyContext,
                                     ProcessingTimeService 
processingTimeService,
-                                    
Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates)
+                                    
Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates,
+                                    StreamTaskCancellationContext 
cancellationContext)
                                     throws Exception {
                                 // We do not initialize a timer service 
manager here, because it
                                 // would already consume the raw keyed
@@ -200,7 +202,8 @@ public class StateInitializationContextImplTest {
                                 // stream.
                                 return null;
                             }
-                        });
+                        },
+                        StreamTaskCancellationContext.alwaysRunning());
 
         AbstractStreamOperator<?> mockOperator = 
mock(AbstractStreamOperator.class);
         when(mockOperator.getOperatorID()).thenReturn(operatorID);
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImplTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImplTest.java
index 58840b81212..e0a6498cf1b 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImplTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImplTest.java
@@ -52,6 +52,7 @@ import 
org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.runtime.taskmanager.TestCheckpointResponder;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
 import org.apache.flink.util.CloseableIterable;
 
@@ -317,11 +318,13 @@ public class StreamTaskStateInitializerImplTest {
                                 ClassLoader userClassloader,
                                 KeyContext keyContext,
                                 ProcessingTimeService processingTimeService,
-                                Iterable<KeyGroupStatePartitionStreamProvider> 
rawKeyedStates)
+                                Iterable<KeyGroupStatePartitionStreamProvider> 
rawKeyedStates,
+                                StreamTaskCancellationContext 
cancellationContext)
                                 throws Exception {
                             return null;
                         }
-                    });
+                    },
+                    StreamTaskCancellationContext.alwaysRunning());
         }
     }
 }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionInternalTimeServiceTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionInternalTimeServiceTest.java
index a7703286bdf..35932523ba6 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionInternalTimeServiceTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionInternalTimeServiceTest.java
@@ -36,6 +36,7 @@ import 
org.apache.flink.streaming.api.operators.InternalTimerService;
 import org.apache.flink.streaming.api.operators.KeyContext;
 import org.apache.flink.streaming.api.operators.Triggerable;
 import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
 import org.apache.flink.util.TestLogger;
 
@@ -90,7 +91,8 @@ public class BatchExecutionInternalTimeServiceTest extends 
TestLogger {
                 this.getClass().getClassLoader(),
                 new DummyKeyContext(),
                 new TestProcessingTimeService(),
-                Collections.emptyList());
+                Collections.emptyList(),
+                StreamTaskCancellationContext.alwaysRunning());
     }
 
     @Test
@@ -134,7 +136,8 @@ public class BatchExecutionInternalTimeServiceTest extends 
TestLogger {
                         this.getClass().getClassLoader(),
                         new DummyKeyContext(),
                         new TestProcessingTimeService(),
-                        Collections.emptyList());
+                        Collections.emptyList(),
+                        StreamTaskCancellationContext.alwaysRunning());
 
         List<Long> timers = new ArrayList<>();
         InternalTimerService<VoidNamespace> timerService =
@@ -169,7 +172,8 @@ public class BatchExecutionInternalTimeServiceTest extends 
TestLogger {
                         this.getClass().getClassLoader(),
                         new DummyKeyContext(),
                         new TestProcessingTimeService(),
-                        Collections.emptyList());
+                        Collections.emptyList(),
+                        StreamTaskCancellationContext.alwaysRunning());
 
         List<Long> timers = new ArrayList<>();
         InternalTimerService<VoidNamespace> timerService =
@@ -197,7 +201,8 @@ public class BatchExecutionInternalTimeServiceTest extends 
TestLogger {
                         this.getClass().getClassLoader(),
                         new DummyKeyContext(),
                         new TestProcessingTimeService(),
-                        Collections.emptyList());
+                        Collections.emptyList(),
+                        StreamTaskCancellationContext.alwaysRunning());
 
         List<Long> timers = new ArrayList<>();
         TriggerWithTimerServiceAccess<Integer, VoidNamespace> eventTimeTrigger 
=
@@ -243,7 +248,8 @@ public class BatchExecutionInternalTimeServiceTest extends 
TestLogger {
                         this.getClass().getClassLoader(),
                         new DummyKeyContext(),
                         processingTimeService,
-                        Collections.emptyList());
+                        Collections.emptyList(),
+                        StreamTaskCancellationContext.alwaysRunning());
 
         List<Long> timers = new ArrayList<>();
         InternalTimerService<VoidNamespace> timerService =
@@ -277,7 +283,8 @@ public class BatchExecutionInternalTimeServiceTest extends 
TestLogger {
                         this.getClass().getClassLoader(),
                         new DummyKeyContext(),
                         processingTimeService,
-                        Collections.emptyList());
+                        Collections.emptyList(),
+                        StreamTaskCancellationContext.alwaysRunning());
 
         List<Long> timers = new ArrayList<>();
         TriggerWithTimerServiceAccess<Integer, VoidNamespace> trigger =
@@ -316,7 +323,8 @@ public class BatchExecutionInternalTimeServiceTest extends 
TestLogger {
                         this.getClass().getClassLoader(),
                         new DummyKeyContext(),
                         processingTimeService,
-                        Collections.emptyList());
+                        Collections.emptyList(),
+                        StreamTaskCancellationContext.alwaysRunning());
 
         List<Long> timers = new ArrayList<>();
         TriggerWithTimerServiceAccess<Integer, VoidNamespace> trigger =
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationTest.java
index 7589224b69a..8bf560c7605 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationTest.java
@@ -25,22 +25,33 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.InternalTimer;
+import org.apache.flink.streaming.api.operators.InternalTimerService;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxDefaultAction;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.TestLogger;
+import org.apache.flink.util.function.ThrowingConsumer;
 
+import org.assertj.core.api.Assertions;
 import org.junit.ClassRule;
 import org.junit.Test;
 
 import java.io.Closeable;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static 
org.apache.flink.api.common.typeinfo.BasicTypeInfo.STRING_TYPE_INFO;
 import static 
org.apache.flink.streaming.runtime.tasks.StreamTaskTest.createTask;
@@ -257,4 +268,93 @@ public class StreamTaskCancellationTest extends TestLogger 
{
             throw new CancelTaskException();
         }
     }
+
+    @Test
+    public void 
testCancelTaskShouldPreventAdditionalEventTimeTimersFromBeingFired()
+            throws Exception {
+        testCancelTaskShouldPreventAdditionalTimersFromBeingFired(false);
+    }
+
+    @Test
+    public void 
testCancelTaskShouldPreventAdditionalProcessingTimeTimersFromBeingFired()
+            throws Exception {
+        testCancelTaskShouldPreventAdditionalTimersFromBeingFired(true);
+    }
+
+    private void 
testCancelTaskShouldPreventAdditionalTimersFromBeingFired(boolean 
processingTime)
+            throws Exception {
+        final int numKeyedTimersToRegister = 100;
+        final int numKeyedTimersToFire = 10;
+        final AtomicInteger numKeyedTimersFired = new AtomicInteger(0);
+        try (StreamTaskMailboxTestHarness<String> harness =
+                new 
StreamTaskMailboxTestHarnessBuilder<>(OneInputStreamTask::new, STRING_TYPE_INFO)
+                        .addInput(STRING_TYPE_INFO)
+                        .setKeyType(STRING_TYPE_INFO)
+                        .setupOutputForSingletonOperatorChain(
+                                new TaskWithPreRegisteredTimers(
+                                        numKeyedTimersToRegister, 
processingTime))
+                        .build()) {
+            TaskWithPreRegisteredTimers.setOnTimerListener(
+                    key -> {
+                        if (numKeyedTimersFired.incrementAndGet() >= 
numKeyedTimersToFire) {
+                            harness.cancel();
+                        }
+                    });
+            harness.processElement(new Watermark(Long.MAX_VALUE));
+        }
+        
Assertions.assertThat(numKeyedTimersFired).hasValue(numKeyedTimersToFire);
+    }
+
+    private static class TaskWithPreRegisteredTimers extends 
AbstractStreamOperator<String>
+            implements OneInputStreamOperator<String, String>, 
Triggerable<String, VoidNamespace> {
+
+        private static ThrowingConsumer<String, Exception> onTimerListener;
+
+        private final int numTimersToRegister;
+        private final boolean processingTime;
+
+        private TaskWithPreRegisteredTimers(int numTimersToRegister, boolean 
processingTime) {
+            this.numTimersToRegister = numTimersToRegister;
+            this.processingTime = processingTime;
+        }
+
+        @Override
+        public void open() throws Exception {
+            final InternalTimerService<VoidNamespace> timerService =
+                    getInternalTimerService("test-timers", 
VoidNamespaceSerializer.INSTANCE, this);
+            final KeyedStateBackend<String> keyedStateBackend = 
getKeyedStateBackend();
+            for (int keyIdx = 0; keyIdx < numTimersToRegister; keyIdx++) {
+                final String key = "key-" + keyIdx;
+                keyedStateBackend.setCurrentKey(key);
+                if (processingTime) {
+                    
timerService.registerProcessingTimeTimer(VoidNamespace.INSTANCE, 0);
+                } else {
+                    
timerService.registerEventTimeTimer(VoidNamespace.INSTANCE, 0);
+                }
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<String> element) throws 
Exception {
+            // No-op.
+        }
+
+        @Override
+        public void onEventTime(InternalTimer<String, VoidNamespace> timer) 
throws Exception {
+            Preconditions.checkState(!processingTime);
+            Preconditions.checkNotNull(onTimerListener).accept(timer.getKey());
+        }
+
+        @Override
+        public void onProcessingTime(InternalTimer<String, VoidNamespace> 
timer) throws Exception {
+            Preconditions.checkState(processingTime);
+            Preconditions.checkNotNull(onTimerListener).accept(timer.getKey());
+        }
+
+        private static void setOnTimerListener(
+                ThrowingConsumer<String, Exception> onTimerListener) {
+            TaskWithPreRegisteredTimers.onTimerListener =
+                    Preconditions.checkNotNull(onTimerListener);
+        }
+    }
 }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarness.java
index f450af4c216..02398b4e747 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarness.java
@@ -167,6 +167,10 @@ public class StreamTaskMailboxTestHarness<OUT> implements 
AutoCloseable {
         streamTask.cleanUp(null);
     }
 
+    public void cancel() throws Exception {
+        streamTask.cancel();
+    }
+
     @Override
     public void close() throws Exception {
         if (streamTask.isRunning()) {
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
index 86f5c0ad02c..af4ac3d3986 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
@@ -76,6 +76,7 @@ import 
org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OperatorEventDispatcherImpl;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox;
 import org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailboxImpl;
@@ -149,7 +150,8 @@ public class AbstractStreamOperatorTestHarness<OUT> 
implements AutoCloseable {
                         ClassLoader userClassloader,
                         KeyContext keyContext,
                         ProcessingTimeService processingTimeService,
-                        Iterable<KeyGroupStatePartitionStreamProvider> 
rawKeyedStates)
+                        Iterable<KeyGroupStatePartitionStreamProvider> 
rawKeyedStates,
+                        StreamTaskCancellationContext cancellationContext)
                         throws Exception {
                     InternalTimeServiceManagerImpl<K> typedTimeServiceManager =
                             InternalTimeServiceManagerImpl.create(
@@ -157,7 +159,8 @@ public class AbstractStreamOperatorTestHarness<OUT> 
implements AutoCloseable {
                                     userClassloader,
                                     keyContext,
                                     processingTimeService,
-                                    rawKeyedStates);
+                                    rawKeyedStates,
+                                    cancellationContext);
                     timeServiceManager = typedTimeServiceManager;
                     return typedTimeServiceManager;
                 }
@@ -331,7 +334,11 @@ public class AbstractStreamOperatorTestHarness<OUT> 
implements AutoCloseable {
             TtlTimeProvider ttlTimeProvider,
             InternalTimeServiceManager.Provider timeServiceManagerProvider) {
         return new StreamTaskStateInitializerImpl(
-                env, stateBackend, ttlTimeProvider, 
timeServiceManagerProvider);
+                env,
+                stateBackend,
+                ttlTimeProvider,
+                timeServiceManagerProvider,
+                StreamTaskCancellationContext.alwaysRunning());
     }
 
     public void setStateBackend(StateBackend stateBackend) {
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/state/operator/restore/StreamOperatorSnapshotRestoreTest.java
 
b/flink-tests/src/test/java/org/apache/flink/test/state/operator/restore/StreamOperatorSnapshotRestoreTest.java
index 93e09bc3cd1..69b8ce40260 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/state/operator/restore/StreamOperatorSnapshotRestoreTest.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/state/operator/restore/StreamOperatorSnapshotRestoreTest.java
@@ -58,6 +58,7 @@ import 
org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
 import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.util.TernaryBoolean;
 import org.apache.flink.util.TestLogger;
@@ -236,7 +237,8 @@ public class StreamOperatorSnapshotRestoreTest extends 
TestLogger {
                             ClassLoader userClassloader,
                             KeyContext keyContext,
                             ProcessingTimeService processingTimeService,
-                            Iterable<KeyGroupStatePartitionStreamProvider> 
rawKeyedStates)
+                            Iterable<KeyGroupStatePartitionStreamProvider> 
rawKeyedStates,
+                            StreamTaskCancellationContext cancellationContext)
                             throws IOException {
                         return null;
                     }


Reply via email to