This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new c48c50d3e80 KAFKA-19831: Improved error handling in 
DefaultStateUpdater. (#20767)
c48c50d3e80 is described below

commit c48c50d3e806f2d0dedbff7570ed11687799dd41
Author: Nikita Shupletsov <[email protected]>
AuthorDate: Fri Nov 21 15:02:28 2025 -0800

    KAFKA-19831: Improved error handling in DefaultStateUpdater. (#20767)
    
    - Improved error handling in DefaultStateUpdater to take potential
    failures in Task#maybeCheckpoint into account.
    - Improved TaskManager#shutdownStateUpdater to not hang indefinitely if
    the State Updater thread is dead.
    
    Reviewers: Matthias J. Sax <[email protected]>, Lucas Brutschy
     <[email protected]>
    
    ---------
    
    Co-authored-by: Matthias J. Sax <[email protected]>
---
 .../StateUpdaterFailureIntegrationTest.java        | 154 +++++++++++++++++++++
 .../processor/internals/DefaultStateUpdater.java   |  58 +++++---
 .../streams/processor/internals/TaskManager.java   |  24 +++-
 .../internals/DefaultStateUpdaterTest.java         | 134 ++++++++++++++++--
 .../processor/internals/TaskManagerTest.java       |  53 +++++--
 5 files changed, 377 insertions(+), 46 deletions(-)

diff --git 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java
 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java
new file mode 100644
index 00000000000..c5f4190931e
--- /dev/null
+++ 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.errors.ProcessorStateException;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.MockKeyValueStore;
+import org.apache.kafka.test.TestUtils;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class StateUpdaterFailureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 6;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private Properties streamsConfiguration;
+    private final MockTime mockTime = cluster.time;
+    private KafkaStreams streams;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        streamsConfiguration = new Properties();
+        final String safeTestName = safeUniqueTestName(testInfo);
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + 
safeTestName);
+        streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
cluster.bootstrapServers());
+        streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");
+        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory().getPath());
+        
streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+        streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
100L);
+        streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, 
Serdes.IntegerSerde.class);
+        
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class);
+        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams != null) {
+            streams.close(Duration.ofSeconds(30));
+        }
+    }
+
+    /**
+     * The conditions that we need to meet:
+     * <p><ul>
+     * <li>We have an unhandled task in {@link 
org.apache.kafka.streams.processor.internals.DefaultStateUpdater}</li>
+     * <li>StreamThread is not running, so {@link 
org.apache.kafka.streams.processor.internals.TaskManager#handleExceptionsFromStateUpdater}
 is not called anymore</li>
+     * <li>The task throws exception in {@link 
org.apache.kafka.streams.processor.internals.Task#maybeCheckpoint(boolean)} 
while being processed by {@code DefaultStateUpdater}</li>
+     * <li>{@link 
org.apache.kafka.streams.processor.internals.TaskManager#shutdownStateUpdater} 
tries to clean up all tasks that are left in the {@code 
DefaultStateUpdater}</li>
+     * </ul><p>
+     * If all conditions are met, {@code TaskManager} needs to be able to 
handle the failed task from the {@code DefaultStateUpdater} correctly and not 
hang.
+     */
+    @Test
+    public void correctlyHandleFlushErrorsDuringRebalance() throws Exception {
+        final AtomicInteger numberOfStoreInits = new AtomicInteger();
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+
+        final StoreBuilder<KeyValueStore<Object, Object>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Object, Object> build() {
+                return new MockKeyValueStore(name, false) {
+
+                    @Override
+                    public void init(final StateStoreContext 
stateStoreContext, final StateStore root) {
+                        super.init(stateStoreContext, root);
+                        numberOfStoreInits.incrementAndGet();
+                    }
+
+                    @Override
+                    public void flush() {
+                        // we want to throw the ProcessorStateException here 
only when the rebalance finished(we reassigned the 3 tasks from the removed 
thread to the existing thread)
+                        // we use waitForCondition to wait until the current 
state is PENDING_SHUTDOWN to make sure the Stream Thread will not handle the 
exception and we can get to in TaskManager#shutdownStateUpdater
+                        if (numberOfStoreInits.get() == 9) {
+                            try {
+                                pendingShutdownLatch.await();
+                            } catch (final InterruptedException e) {
+                                throw new RuntimeException(e);
+                            }
+                            throw new ProcessorStateException("flush");
+                        }
+                    }
+                };
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams = new KafkaStreams(topology, streamsConfiguration);
+        streams.setStateListener((newState, oldState) -> {
+            if (newState == KafkaStreams.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams.start();
+
+        TestUtils.waitForCondition(() -> streams.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams.removeStreamThread();
+
+        // Before shutting down, we want the tasks to be reassigned
+        TestUtils.waitForCondition(() -> numberOfStoreInits.get() == 9, 
"Streams never reinitialized the store enough times");
+
+        assertTrue(streams.close(Duration.ofSeconds(60)));
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index 880980a17c5..1bfe5eaceac 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -345,23 +345,26 @@ public class DefaultStateUpdater implements StateUpdater {
         // TODO: we can let the exception encode the actual corrupted 
changelog partitions and only
         //       mark those instead of marking all changelogs
         private void removeCheckpointForCorruptedTask(final Task task) {
-            task.markChangelogAsCorrupted(task.changelogPartitions());
+            try {
+                task.markChangelogAsCorrupted(task.changelogPartitions());
 
-            // we need to enforce a checkpoint that removes the corrupted 
partitions
-            measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+                // we need to enforce a checkpoint that removes the corrupted 
partitions
+                measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+            } catch (final StreamsException swallow) {
+                log.warn("Checkpoint failed for corrupted task {}", task.id(), 
swallow);
+            }
         }
 
         private void handleStreamsException(final StreamsException 
streamsException) {
             log.info("Encountered streams exception: ", streamsException);
             if (streamsException.taskId().isPresent()) {
-                handleStreamsExceptionWithTask(streamsException);
+                handleStreamsExceptionWithTask(streamsException, 
streamsException.taskId().get());
             } else {
                 handleStreamsExceptionWithoutTask(streamsException);
             }
         }
 
-        private void handleStreamsExceptionWithTask(final StreamsException 
streamsException) {
-            final TaskId failedTaskId = streamsException.taskId().get();
+        private void handleStreamsExceptionWithTask(final StreamsException 
streamsException, final TaskId failedTaskId) {
             if (updatingTasks.containsKey(failedTaskId)) {
                 addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(
                     new ExceptionAndTask(streamsException, 
updatingTasks.get(failedTaskId))
@@ -514,7 +517,7 @@ public class DefaultStateUpdater implements StateUpdater {
                         + " own this task.", taskId);
                 }
             } catch (final StreamsException streamsException) {
-                handleStreamsException(streamsException);
+                handleStreamsExceptionWithTask(streamsException, taskId);
                 future.completeExceptionally(streamsException);
             } catch (final RuntimeException runtimeException) {
                 handleRuntimeException(runtimeException);
@@ -606,14 +609,19 @@ public class DefaultStateUpdater implements StateUpdater {
         private void pauseTask(final Task task) {
             final TaskId taskId = task.id();
             // do not need to unregister changelog partitions for paused tasks
-            measureCheckpointLatency(() -> task.maybeCheckpoint(true));
-            pausedTasks.put(taskId, task);
-            updatingTasks.remove(taskId);
-            if (task.isActive()) {
-                transitToUpdateStandbysIfOnlyStandbysLeft();
+            try {
+                measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+                pausedTasks.put(taskId, task);
+                updatingTasks.remove(taskId);
+                if (task.isActive()) {
+                    transitToUpdateStandbysIfOnlyStandbysLeft();
+                }
+                log.info((task.isActive() ? "Active" : "Standby")
+                    + " task " + task.id() + " was paused from the updating 
tasks and added to the paused tasks.");
+
+            } catch (final StreamsException streamsException) {
+                handleStreamsExceptionWithTask(streamsException, taskId);
             }
-            log.info((task.isActive() ? "Active" : "Standby")
-                + " task " + task.id() + " was paused from the updating tasks 
and added to the paused tasks.");
         }
 
         private void resumeTask(final Task task) {
@@ -640,11 +648,15 @@ public class DefaultStateUpdater implements StateUpdater {
                                               final Set<TopicPartition> 
restoredChangelogs) {
             final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
             if (restoredChangelogs.containsAll(changelogPartitions)) {
-                measureCheckpointLatency(() -> task.maybeCheckpoint(true));
-                changelogReader.unregister(changelogPartitions);
-                addToRestoredTasks(task);
-                log.info("Stateful active task " + task.id() + " completed 
restoration");
-                transitToUpdateStandbysIfOnlyStandbysLeft();
+                try {
+                    measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+                    changelogReader.unregister(changelogPartitions);
+                    addToRestoredTasks(task);
+                    log.info("Stateful active task " + task.id() + " completed 
restoration");
+                    transitToUpdateStandbysIfOnlyStandbysLeft();
+                } catch (final StreamsException streamsException) {
+                    handleStreamsExceptionWithTask(streamsException, 
task.id());
+                }
             }
         }
 
@@ -676,8 +688,12 @@ public class DefaultStateUpdater implements StateUpdater {
 
                 measureCheckpointLatency(() -> {
                     for (final Task task : updatingTasks.values()) {
-                        // do not enforce checkpointing during restoration if 
its position has not advanced much
-                        task.maybeCheckpoint(false);
+                        try {
+                            // do not enforce checkpointing during restoration 
if its position has not advanced much
+                            task.maybeCheckpoint(false);
+                        } catch (final StreamsException streamsException) {
+                            handleStreamsExceptionWithTask(streamsException, 
task.id());
+                        }
                     }
                 });
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 232487e2ebc..c1b1c06379e 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -65,6 +65,7 @@ import java.util.Set;
 import java.util.TreeSet;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -772,7 +773,7 @@ public class TaskManager {
                                                          final 
CompletableFuture<StateUpdater.RemovedTaskResult> future) {
         final StateUpdater.RemovedTaskResult removedTaskResult;
         try {
-            removedTaskResult = future.get();
+            removedTaskResult = future.get(5, TimeUnit.MINUTES);
             if (removedTaskResult == null) {
                 throw new IllegalStateException("Task " + taskId + " was not 
found in the state updater. "
                     + BUG_ERROR_MESSAGE);
@@ -787,6 +788,10 @@ public class TaskManager {
             Thread.currentThread().interrupt();
             log.error(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen);
             throw new IllegalStateException(INTERRUPTED_ERROR_MESSAGE, 
shouldNotHappen);
+        } catch (final java.util.concurrent.TimeoutException timeoutException) 
{
+            log.warn("The state updater wasn't able to remove task {} in time. 
The state updater thread may be dead. "
+                    + BUG_ERROR_MESSAGE, taskId, timeoutException);
+            return null;
         }
     }
 
@@ -1576,6 +1581,12 @@ public class TaskManager {
 
     private void shutdownStateUpdater() {
         if (stateUpdater != null) {
+            // If there are failed tasks handling them first
+            for (final StateUpdater.ExceptionAndTask exceptionAndTask : 
stateUpdater.drainExceptionsAndFailedTasks()) {
+                final Task failedTask = exceptionAndTask.task();
+                closeTaskDirty(failedTask, false);
+            }
+
             final Map<TaskId, 
CompletableFuture<StateUpdater.RemovedTaskResult>> futures = new 
LinkedHashMap<>();
             for (final Task task : stateUpdater.tasks()) {
                 final CompletableFuture<StateUpdater.RemovedTaskResult> future 
= stateUpdater.remove(task.id());
@@ -1584,7 +1595,8 @@ public class TaskManager {
             final Set<Task> tasksToCloseClean = new HashSet<>();
             final Set<Task> tasksToCloseDirty = new HashSet<>();
             addToTasksToClose(futures, tasksToCloseClean, tasksToCloseDirty);
-            stateUpdater.shutdown(Duration.ofMillis(Long.MAX_VALUE));
+            // at this point we removed all tasks, so the shutdown should not 
take a lot of time
+            stateUpdater.shutdown(Duration.ofMinutes(1L));
 
             for (final Task task : tasksToCloseClean) {
                 tasks.addTask(task);
@@ -1592,16 +1604,22 @@ public class TaskManager {
             for (final Task task : tasksToCloseDirty) {
                 closeTaskDirty(task, false);
             }
+            // Handling all failures that occurred during the remove process
             for (final StateUpdater.ExceptionAndTask exceptionAndTask : 
stateUpdater.drainExceptionsAndFailedTasks()) {
                 final Task failedTask = exceptionAndTask.task();
                 closeTaskDirty(failedTask, false);
             }
+
+            // If there is anything left unhandled due to timeouts, handling 
now
+            for (final Task task : stateUpdater.tasks()) {
+                closeTaskDirty(task, false);
+            }
         }
     }
 
     private void shutdownSchedulingTaskManager() {
         if (schedulingTaskManager != null) {
-            schedulingTaskManager.shutdown(Duration.ofMillis(Long.MAX_VALUE));
+            schedulingTaskManager.shutdown(Duration.ofMinutes(5L));
         }
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index 8f767af93e0..4f9d1b3c0e6 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -23,6 +23,7 @@ import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.processor.TaskId;
@@ -73,6 +74,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyMap;
 import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.inOrder;
@@ -1717,6 +1719,114 @@ class DefaultStateUpdaterTest {
         }
     }
 
+    @Test
+    public void shouldNotFailTheThreadIfMaybeCheckpointFails() throws 
Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        
doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyExceptionsAndFailedTasks(new 
ExceptionAndTask(processorStateException, failedStatefulTask));
+        verifyUpdatingTasks(activeTask1);
+
+        stateUpdater.add(activeTask2);
+        verifyUpdatingTasks(activeTask1, activeTask2);
+    }
+
+    @Test
+    public void shouldNotFailTheThreadIfMaybeCheckpointFailsForCorruptedTask() 
throws Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        
doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+
+        final TaskCorruptedException taskCorruptedException = new 
TaskCorruptedException(Set.of(TASK_0_2));
+        when(changelogReader.restore(Map.of(
+                TASK_0_0, activeTask1,
+                TASK_0_2, failedStatefulTask))
+        ).thenThrow(taskCorruptedException);
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyExceptionsAndFailedTasks(new 
ExceptionAndTask(taskCorruptedException, failedStatefulTask));
+        verifyUpdatingTasks(activeTask1);
+
+        stateUpdater.add(activeTask2);
+        verifyUpdatingTasks(activeTask1, activeTask2);
+    }
+
+    @Test
+    public void 
shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRemoval() throws 
Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        final AtomicBoolean throwException = new AtomicBoolean(false);
+        doAnswer(invocation -> {
+            if (throwException.get()) {
+                throw processorStateException;
+            }
+            return null;
+        }).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(true);
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyUpdatingTasks(failedStatefulTask, activeTask1);
+
+        throwException.set(true);
+        final ExecutionException exception = 
assertThrows(ExecutionException.class, () -> 
stateUpdater.remove(TASK_0_2).get());
+        assertEquals(processorStateException, exception.getCause());
+
+        stateUpdater.add(activeTask2);
+        verifyUpdatingTasks(activeTask1, activeTask2);
+    }
+
+    @Test
+    public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskPause() 
throws Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        
doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+        
when(topologyMetadata.isPaused(null)).thenReturn(false).thenReturn(false).thenReturn(true);
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyExceptionsAndFailedTasks(new 
ExceptionAndTask(processorStateException, failedStatefulTask));
+        verifyPausedTasks(activeTask1);
+
+        stateUpdater.add(activeTask2);
+        verifyPausedTasks(activeTask1, activeTask2);
+    }
+
+    @Test
+    public void 
shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRestore() throws 
Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_B_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        
doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+        
when(changelogReader.completedChangelogs()).thenReturn(Set.of(TOPIC_PARTITION_B_0));
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyExceptionsAndFailedTasks(new 
ExceptionAndTask(processorStateException, failedStatefulTask));
+        verifyUpdatingTasks(activeTask1);
+
+        stateUpdater.add(activeTask2);
+        verifyUpdatingTasks(activeTask1, activeTask2);
+    }
+
     private static List<MetricName> getMetricNames(final String threadId) {
         final Map<String, String> tagMap = Map.of("thread-id", threadId);
         return List.of(
@@ -1779,7 +1889,8 @@ class DefaultStateUpdaterTest {
                         && restoredTasks.size() == 
expectedRestoredTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
-                "Did not get all restored active task within the given 
timeout!"
+                () -> "Did not get all restored active task within the given 
timeout! Expected: "
+                        + expectedRestoredTasks + ", actual: " + restoredTasks
             );
         }
     }
@@ -1794,7 +1905,8 @@ class DefaultStateUpdaterTest {
                     && restoredTasks.size() == expectedRestoredTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not get all restored active task within the given timeout!"
+            () -> "Did not get all restored active task within the given 
timeout! Expected: "
+                    + expectedRestoredTasks + ", actual: " + restoredTasks
         );
         
assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty());
     }
@@ -1816,7 +1928,8 @@ class DefaultStateUpdaterTest {
                         && updatingTasks.size() == 
expectedUpdatingTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
-                "Did not get all updating task within the given timeout!"
+                () -> "Did not get all updating task within the given timeout! 
Expected: "
+                        + expectedUpdatingTasks + ", actual: " + updatingTasks
             );
         }
     }
@@ -1831,7 +1944,8 @@ class DefaultStateUpdaterTest {
                     && standbyTasks.size() == expectedStandbyTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not see all standby task within the given timeout!"
+            () -> "Did not see all standby task within the given timeout! 
Expected: "
+                    + expectedStandbyTasks + ", actual: " + standbyTasks
         );
     }
 
@@ -1860,7 +1974,8 @@ class DefaultStateUpdaterTest {
                         && pausedTasks.size() == expectedPausedTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
-                "Did not get all paused task within the given timeout!"
+                () -> "Did not get all paused task within the given timeout! 
Expected: "
+                        + expectedPausedTasks + ", actual: " + pausedTasks
             );
         }
     }
@@ -1875,7 +1990,8 @@ class DefaultStateUpdaterTest {
                     && failedTasks.size() == expectedExceptionAndTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not get all exceptions and failed tasks within the given 
timeout!"
+            () -> "Did not get all exceptions and failed tasks within the 
given timeout! Expected: "
+                    + expectedExceptionAndTasks + ", actual: " + failedTasks
         );
     }
 
@@ -1893,7 +2009,8 @@ class DefaultStateUpdaterTest {
                     && failedTasks.size() == expectedFailedTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not get all exceptions and failed tasks within the given 
timeout!"
+            () -> "Did not get all exceptions and failed tasks within the 
given timeout! Expected: "
+                        + expectedFailedTasks + ", actual: " + failedTasks
         );
     }
 
@@ -1911,7 +2028,8 @@ class DefaultStateUpdaterTest {
                     && failedTasks.size() == expectedExceptionAndTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not get all exceptions and failed tasks within the given 
timeout!"
+            () -> "Did not get all exceptions and failed tasks within the 
given timeout! Expected: "
+                    + expectedExceptionAndTasks + ", actual: " + failedTasks
         );
         assertFalse(stateUpdater.hasExceptionsAndFailedTasks());
         assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty());
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 12a33558401..3e87eebe733 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -123,6 +123,7 @@ import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.lenient;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -3339,7 +3340,7 @@ public class TaskManagerTest {
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         verify(activeTaskCreator).close();
-        verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
     }
 
     @Test
@@ -3369,7 +3370,29 @@ public class TaskManagerTest {
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         verify(activeTaskCreator).close();
-        verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldCloseTasksIfStateUpdaterTimesOutOnRemove() throws 
Exception {
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, null, false);
+        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+                mkEntry(taskId00, taskId00Partitions)
+        );
+        final Task task00 = spy(new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager));
+
+        when(activeTaskCreator.createTasks(any(), 
eq(assignment))).thenReturn(singletonList(task00));
+        taskManager.handleAssignment(assignment, emptyMap());
+
+        when(stateUpdater.tasks()).thenReturn(singleton(task00));
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
mock(CompletableFuture.class);
+        when(stateUpdater.remove(eq(taskId00))).thenReturn(future);
+        when(future.get(anyLong(), any())).thenThrow(new 
java.util.concurrent.TimeoutException());
+
+        taskManager.shutdown(true);
+
+        verify(task00).closeDirty();
     }
 
     @Test
@@ -3513,7 +3536,7 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions)
             .build();
 
-        when(stateUpdater.tasks()).thenReturn(Set.of(standbyTask00));
+        
when(stateUpdater.tasks()).thenReturn(Set.of(standbyTask00)).thenReturn(Set.of());
         when(stateUpdater.standbyTasks()).thenReturn(Set.of(standbyTask00));
 
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForStandbyTask = new CompletableFuture<>();
@@ -3525,7 +3548,7 @@ public class TaskManagerTest {
 
         taskManager.shutdown(true);
 
-        verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
 
         verify(tasks).addTask(standbyTask00);
 
@@ -3550,13 +3573,14 @@ public class TaskManagerTest {
             .thenReturn(Arrays.asList(
                 new ExceptionAndTask(new RuntimeException(), 
failedStatefulTask),
                 new ExceptionAndTask(new RuntimeException(), 
failedStandbyTask))
-            );
+            )
+            .thenReturn(Collections.emptyList());
         final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         taskManager.shutdown(true);
 
         verify(activeTaskCreator).close();
-        verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
         verify(failedStatefulTask).prepareCommit(false);
         verify(failedStatefulTask).suspend();
         verify(failedStatefulTask).closeDirty();
@@ -3569,7 +3593,7 @@ public class TaskManagerTest {
 
         taskManager.shutdown(true);
 
-        
verify(schedulingTaskManager).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(schedulingTaskManager).shutdown(Duration.ofMinutes(5L));
     }
 
     @Test
@@ -3594,8 +3618,8 @@ public class TaskManagerTest {
                 removedFailedStatefulTask,
                 removedFailedStandbyTask,
                 removedFailedStatefulTaskDuringRemoval,
-                removedFailedStandbyTaskDuringRemoval
-            ));
+                removedFailedStandbyTaskDuringRemoval)
+            ).thenReturn(Collections.emptySet());
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedStatefulTask = new CompletableFuture<>();
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedStandbyTask = new CompletableFuture<>();
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedFailedStatefulTask = new CompletableFuture<>();
@@ -3610,10 +3634,11 @@ public class TaskManagerTest {
             .thenReturn(futureForRemovedFailedStatefulTaskDuringRemoval);
         when(stateUpdater.remove(removedFailedStandbyTaskDuringRemoval.id()))
             .thenReturn(futureForRemovedFailedStandbyTaskDuringRemoval);
-        
when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(Arrays.asList(
-            new ExceptionAndTask(new StreamsException("KABOOM!"), 
removedFailedStatefulTaskDuringRemoval),
-            new ExceptionAndTask(new StreamsException("KABOOM!"), 
removedFailedStandbyTaskDuringRemoval)
-        ));
+        when(stateUpdater.drainExceptionsAndFailedTasks())
+                .thenReturn(Arrays.asList(
+                    new ExceptionAndTask(new StreamsException("KABOOM!"), 
removedFailedStatefulTaskDuringRemoval),
+                    new ExceptionAndTask(new StreamsException("KABOOM!"), 
removedFailedStandbyTaskDuringRemoval))
+                ).thenReturn(Collections.emptyList());
         final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
         futureForRemovedStatefulTask.complete(new 
StateUpdater.RemovedTaskResult(removedStatefulTask));
         futureForRemovedStandbyTask.complete(new 
StateUpdater.RemovedTaskResult(removedStandbyTask));
@@ -3628,7 +3653,7 @@ public class TaskManagerTest {
 
         taskManager.shutdown(true);
 
-        verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
         verify(tasks).addTask(removedStatefulTask);
         verify(tasks).addTask(removedStandbyTask);
         verify(removedFailedStatefulTask).prepareCommit(false);


Reply via email to