This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch 4.0
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/4.0 by this push:
     new 37042cd7924  KAFKA-19831: Improved error handling in 
DefaultStateUpdater. (#20767)  (#21438)
37042cd7924 is described below

commit 37042cd79240a1bfc635c3f5558bfcc1d917fb8b
Author: Nikita Shupletsov <[email protected]>
AuthorDate: Mon Feb 9 20:14:27 2026 -0800

     KAFKA-19831: Improved error handling in DefaultStateUpdater. (#20767)  
(#21438)
    
    - Improved error handling in DefaultStateUpdater to take potential
    failures in Task#maybeCheckpoint into account.
    - Improved TaskManager#shutdownStateUpdater to not hang indefinitely if
    the State Updater thread is dead.
    
    Reviewers: Matthias J. Sax <[email protected]>, Lucas Brutschy
     <[email protected]>
---
 .../StateUpdaterFailureIntegrationTest.java        | 154 +++++++++++++++++++++
 .../processor/internals/DefaultStateUpdater.java   |  83 +++++------
 .../streams/processor/internals/TaskAndAction.java |   7 +-
 .../streams/processor/internals/TaskManager.java   |  24 +++-
 .../internals/DefaultStateUpdaterTest.java         | 134 ++++++++++++++++--
 .../processor/internals/TaskManagerTest.java       |  40 ++++--
 6 files changed, 368 insertions(+), 74 deletions(-)

diff --git 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java
 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java
new file mode 100644
index 00000000000..c5f4190931e
--- /dev/null
+++ 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.errors.ProcessorStateException;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.MockKeyValueStore;
+import org.apache.kafka.test.TestUtils;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class StateUpdaterFailureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 6;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private Properties streamsConfiguration;
+    private final MockTime mockTime = cluster.time;
+    private KafkaStreams streams;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        streamsConfiguration = new Properties();
+        final String safeTestName = safeUniqueTestName(testInfo);
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + 
safeTestName);
+        streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
cluster.bootstrapServers());
+        streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");
+        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory().getPath());
+        
streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+        streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
100L);
+        streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, 
Serdes.IntegerSerde.class);
+        
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class);
+        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams != null) {
+            streams.close(Duration.ofSeconds(30));
+        }
+    }
+
+    /**
+     * The conditions that we need to meet:
+     * <p><ul>
+     * <li>We have an unhandled task in {@link 
org.apache.kafka.streams.processor.internals.DefaultStateUpdater}</li>
+     * <li>StreamThread is not running, so {@link 
org.apache.kafka.streams.processor.internals.TaskManager#handleExceptionsFromStateUpdater}
 is not called anymore</li>
+     * <li>The task throws exception in {@link 
org.apache.kafka.streams.processor.internals.Task#maybeCheckpoint(boolean)} 
while being processed by {@code DefaultStateUpdater}</li>
+     * <li>{@link 
org.apache.kafka.streams.processor.internals.TaskManager#shutdownStateUpdater} 
tries to clean up all tasks that are left in the {@code 
DefaultStateUpdater}</li>
+     * </ul><p>
+     * If all conditions are met, {@code TaskManager} needs to be able to 
handle the failed task from the {@code DefaultStateUpdater} correctly and not 
hang.
+     */
+    @Test
+    public void correctlyHandleFlushErrorsDuringRebalance() throws Exception {
+        final AtomicInteger numberOfStoreInits = new AtomicInteger();
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+
+        final StoreBuilder<KeyValueStore<Object, Object>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Object, Object> build() {
+                return new MockKeyValueStore(name, false) {
+
+                    @Override
+                    public void init(final StateStoreContext 
stateStoreContext, final StateStore root) {
+                        super.init(stateStoreContext, root);
+                        numberOfStoreInits.incrementAndGet();
+                    }
+
+                    @Override
+                    public void flush() {
+                        // we want to throw the ProcessorStateException here 
only when the rebalance finished(we reassigned the 3 tasks from the removed 
thread to the existing thread)
+                        // we use waitForCondition to wait until the current 
state is PENDING_SHUTDOWN to make sure the Stream Thread will not handle the 
exception and we can get to in TaskManager#shutdownStateUpdater
+                        if (numberOfStoreInits.get() == 9) {
+                            try {
+                                pendingShutdownLatch.await();
+                            } catch (final InterruptedException e) {
+                                throw new RuntimeException(e);
+                            }
+                            throw new ProcessorStateException("flush");
+                        }
+                    }
+                };
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams = new KafkaStreams(topology, streamsConfiguration);
+        streams.setStateListener((newState, oldState) -> {
+            if (newState == KafkaStreams.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams.start();
+
+        TestUtils.waitForCondition(() -> streams.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams.removeStreamThread();
+
+        // Before shutting down, we want the tasks to be reassigned
+        TestUtils.waitForCondition(() -> numberOfStoreInits.get() == 9, 
"Streams never reinitialized the store enough times");
+
+        assertTrue(streams.close(Duration.ofSeconds(60)));
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index 48d42590c1e..cb983f02bce 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -208,11 +208,7 @@ public class DefaultStateUpdater implements StateUpdater {
                             addTask(taskAndAction.task());
                             break;
                         case REMOVE:
-                            if (taskAndAction.futureForRemove() == null) {
-                                removeTask(taskAndAction.taskId());
-                            } else {
-                                removeTask(taskAndAction.taskId(), 
taskAndAction.futureForRemove());
-                            }
+                            removeTask(taskAndAction.taskId(), 
taskAndAction.futureForRemove());
                             break;
                         default:
                             throw new IllegalStateException("Unknown action 
type " + action);
@@ -349,23 +345,26 @@ public class DefaultStateUpdater implements StateUpdater {
         // TODO: we can let the exception encode the actual corrupted 
changelog partitions and only
         //       mark those instead of marking all changelogs
         private void removeCheckpointForCorruptedTask(final Task task) {
-            task.markChangelogAsCorrupted(task.changelogPartitions());
+            try {
+                task.markChangelogAsCorrupted(task.changelogPartitions());
 
-            // we need to enforce a checkpoint that removes the corrupted 
partitions
-            measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+                // we need to enforce a checkpoint that removes the corrupted 
partitions
+                measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+            } catch (final StreamsException swallow) {
+                log.warn("Checkpoint failed for corrupted task {}", task.id(), 
swallow);
+            }
         }
 
         private void handleStreamsException(final StreamsException 
streamsException) {
             log.info("Encountered streams exception: ", streamsException);
             if (streamsException.taskId().isPresent()) {
-                handleStreamsExceptionWithTask(streamsException);
+                handleStreamsExceptionWithTask(streamsException, 
streamsException.taskId().get());
             } else {
                 handleStreamsExceptionWithoutTask(streamsException);
             }
         }
 
-        private void handleStreamsExceptionWithTask(final StreamsException 
streamsException) {
-            final TaskId failedTaskId = streamsException.taskId().get();
+        private void handleStreamsExceptionWithTask(final StreamsException 
streamsException, final TaskId failedTaskId) {
             if (updatingTasks.containsKey(failedTaskId)) {
                 addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(
                     new ExceptionAndTask(streamsException, 
updatingTasks.get(failedTaskId))
@@ -518,7 +517,7 @@ public class DefaultStateUpdater implements StateUpdater {
                         + " own this task.", taskId);
                 }
             } catch (final StreamsException streamsException) {
-                handleStreamsException(streamsException);
+                handleStreamsExceptionWithTask(streamsException, taskId);
                 future.completeExceptionally(streamsException);
             } catch (final RuntimeException runtimeException) {
                 handleRuntimeException(runtimeException);
@@ -607,44 +606,22 @@ public class DefaultStateUpdater implements StateUpdater {
             }
         }
 
-        private void removeTask(final TaskId taskId) {
-            final Task task;
-            if (updatingTasks.containsKey(taskId)) {
-                task = updatingTasks.get(taskId);
+        private void pauseTask(final Task task) {
+            final TaskId taskId = task.id();
+            // do not need to unregister changelog partitions for paused tasks
+            try {
                 measureCheckpointLatency(() -> task.maybeCheckpoint(true));
-                final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
-                changelogReader.unregister(changelogPartitions);
-                removedTasks.add(task);
+                pausedTasks.put(taskId, task);
                 updatingTasks.remove(taskId);
                 if (task.isActive()) {
                     transitToUpdateStandbysIfOnlyStandbysLeft();
                 }
                 log.info((task.isActive() ? "Active" : "Standby")
-                    + " task " + task.id() + " was removed from the updating 
tasks and added to the removed tasks.");
-            } else if (pausedTasks.containsKey(taskId)) {
-                task = pausedTasks.get(taskId);
-                final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
-                changelogReader.unregister(changelogPartitions);
-                removedTasks.add(task);
-                pausedTasks.remove(taskId);
-                log.info((task.isActive() ? "Active" : "Standby")
-                    + " task " + task.id() + " was removed from the paused 
tasks and added to the removed tasks.");
-            } else {
-                log.info("Task " + taskId + " was not removed since it is not 
updating or paused.");
-            }
-        }
+                    + " task " + task.id() + " was paused from the updating 
tasks and added to the paused tasks.");
 
-        private void pauseTask(final Task task) {
-            final TaskId taskId = task.id();
-            // do not need to unregister changelog partitions for paused tasks
-            measureCheckpointLatency(() -> task.maybeCheckpoint(true));
-            pausedTasks.put(taskId, task);
-            updatingTasks.remove(taskId);
-            if (task.isActive()) {
-                transitToUpdateStandbysIfOnlyStandbysLeft();
+            } catch (final StreamsException streamsException) {
+                handleStreamsExceptionWithTask(streamsException, taskId);
             }
-            log.info((task.isActive() ? "Active" : "Standby")
-                + " task " + task.id() + " was paused from the updating tasks 
and added to the paused tasks.");
         }
 
         private void resumeTask(final Task task) {
@@ -671,11 +648,15 @@ public class DefaultStateUpdater implements StateUpdater {
                                               final Set<TopicPartition> 
restoredChangelogs) {
             final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
             if (restoredChangelogs.containsAll(changelogPartitions)) {
-                measureCheckpointLatency(() -> task.maybeCheckpoint(true));
-                changelogReader.unregister(changelogPartitions);
-                addToRestoredTasks(task);
-                log.info("Stateful active task " + task.id() + " completed 
restoration");
-                transitToUpdateStandbysIfOnlyStandbysLeft();
+                try {
+                    measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+                    changelogReader.unregister(changelogPartitions);
+                    addToRestoredTasks(task);
+                    log.info("Stateful active task " + task.id() + " completed 
restoration");
+                    transitToUpdateStandbysIfOnlyStandbysLeft();
+                } catch (final StreamsException streamsException) {
+                    handleStreamsExceptionWithTask(streamsException, 
task.id());
+                }
             }
         }
 
@@ -707,8 +688,12 @@ public class DefaultStateUpdater implements StateUpdater {
 
                 measureCheckpointLatency(() -> {
                     for (final Task task : updatingTasks.values()) {
-                        // do not enforce checkpointing during restoration if 
its position has not advanced much
-                        task.maybeCheckpoint(false);
+                        try {
+                            // do not enforce checkpointing during restoration 
if its position has not advanced much
+                            task.maybeCheckpoint(false);
+                        } catch (final StreamsException streamsException) {
+                            handleStreamsExceptionWithTask(streamsException, 
task.id());
+                        }
                     }
                 });
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
index b9c07151cfa..ec6c6830bbd 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
@@ -55,11 +55,6 @@ public class TaskAndAction {
         return new TaskAndAction(null, taskId, Action.REMOVE, future);
     }
 
-    public static TaskAndAction createRemoveTask(final TaskId taskId) {
-        Objects.requireNonNull(taskId, "Task ID of task to remove is null!");
-        return new TaskAndAction(null, taskId, Action.REMOVE, null);
-    }
-
     public Task task() {
         if (action != Action.ADD) {
             throw new IllegalStateException("Action type " + action + " cannot 
have a task!");
@@ -84,4 +79,4 @@ public class TaskAndAction {
     public Action action() {
         return action;
     }
-}
\ No newline at end of file
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 0a36bdba67b..47bb52184ba 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -66,6 +66,7 @@ import java.util.Set;
 import java.util.TreeSet;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -701,7 +702,7 @@ public class TaskManager {
                                                          final 
CompletableFuture<StateUpdater.RemovedTaskResult> future) {
         final StateUpdater.RemovedTaskResult removedTaskResult;
         try {
-            removedTaskResult = future.get();
+            removedTaskResult = future.get(5, TimeUnit.MINUTES);
             if (removedTaskResult == null) {
                 throw new IllegalStateException("Task " + taskId + " was not 
found in the state updater. "
                     + BUG_ERROR_MESSAGE);
@@ -716,6 +717,10 @@ public class TaskManager {
             Thread.currentThread().interrupt();
             log.error(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen);
             throw new IllegalStateException(INTERRUPTED_ERROR_MESSAGE, 
shouldNotHappen);
+        } catch (final java.util.concurrent.TimeoutException timeoutException) 
{
+            log.warn("The state updater wasn't able to remove task {} in time. 
The state updater thread may be dead. "
+                    + BUG_ERROR_MESSAGE, taskId, timeoutException);
+            return null;
         }
     }
 
@@ -1517,6 +1522,12 @@ public class TaskManager {
 
     private void shutdownStateUpdater() {
         if (stateUpdater != null) {
+            // If there are failed tasks handling them first
+            for (final StateUpdater.ExceptionAndTask exceptionAndTask : 
stateUpdater.drainExceptionsAndFailedTasks()) {
+                final Task failedTask = exceptionAndTask.task();
+                closeTaskDirty(failedTask, false);
+            }
+
             final Map<TaskId, 
CompletableFuture<StateUpdater.RemovedTaskResult>> futures = new 
LinkedHashMap<>();
             for (final Task task : stateUpdater.tasks()) {
                 final CompletableFuture<StateUpdater.RemovedTaskResult> future 
= stateUpdater.remove(task.id());
@@ -1525,7 +1536,8 @@ public class TaskManager {
             final Set<Task> tasksToCloseClean = new HashSet<>();
             final Set<Task> tasksToCloseDirty = new HashSet<>();
             addToTasksToClose(futures, tasksToCloseClean, tasksToCloseDirty);
-            stateUpdater.shutdown(Duration.ofMillis(Long.MAX_VALUE));
+            // at this point we removed all tasks, so the shutdown should not 
take a lot of time
+            stateUpdater.shutdown(Duration.ofMinutes(1L));
 
             for (final Task task : tasksToCloseClean) {
                 tasks.addTask(task);
@@ -1533,16 +1545,22 @@ public class TaskManager {
             for (final Task task : tasksToCloseDirty) {
                 closeTaskDirty(task, false);
             }
+            // Handling all failures that occurred during the remove process
             for (final StateUpdater.ExceptionAndTask exceptionAndTask : 
stateUpdater.drainExceptionsAndFailedTasks()) {
                 final Task failedTask = exceptionAndTask.task();
                 closeTaskDirty(failedTask, false);
             }
+
+            // If there is anything left unhandled due to timeouts, handling 
now
+            for (final Task task : stateUpdater.tasks()) {
+                closeTaskDirty(task, false);
+            }
         }
     }
 
     private void shutdownSchedulingTaskManager() {
         if (schedulingTaskManager != null) {
-            schedulingTaskManager.shutdown(Duration.ofMillis(Long.MAX_VALUE));
+            schedulingTaskManager.shutdown(Duration.ofMinutes(5L));
         }
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index 4f540fcba64..95f273bfc90 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -23,6 +23,7 @@ import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.processor.TaskId;
@@ -73,6 +74,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyMap;
 import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.inOrder;
@@ -1709,6 +1711,114 @@ class DefaultStateUpdaterTest {
         }
     }
 
+    @Test
+    public void shouldNotFailTheThreadIfMaybeCheckpointFails() throws 
Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        
doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyExceptionsAndFailedTasks(new 
ExceptionAndTask(processorStateException, failedStatefulTask));
+        verifyUpdatingTasks(activeTask1);
+
+        stateUpdater.add(activeTask2);
+        verifyUpdatingTasks(activeTask1, activeTask2);
+    }
+
+    @Test
+    public void shouldNotFailTheThreadIfMaybeCheckpointFailsForCorruptedTask() 
throws Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        
doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+
+        final TaskCorruptedException taskCorruptedException = new 
TaskCorruptedException(Set.of(TASK_0_2));
+        when(changelogReader.restore(Map.of(
+                TASK_0_0, activeTask1,
+                TASK_0_2, failedStatefulTask))
+        ).thenThrow(taskCorruptedException);
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyExceptionsAndFailedTasks(new 
ExceptionAndTask(taskCorruptedException, failedStatefulTask));
+        verifyUpdatingTasks(activeTask1);
+
+        stateUpdater.add(activeTask2);
+        verifyUpdatingTasks(activeTask1, activeTask2);
+    }
+
+    @Test
+    public void 
shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRemoval() throws 
Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        final AtomicBoolean throwException = new AtomicBoolean(false);
+        doAnswer(invocation -> {
+            if (throwException.get()) {
+                throw processorStateException;
+            }
+            return null;
+        }).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(true);
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyUpdatingTasks(failedStatefulTask, activeTask1);
+
+        throwException.set(true);
+        final ExecutionException exception = 
assertThrows(ExecutionException.class, () -> 
stateUpdater.remove(TASK_0_2).get());
+        assertEquals(processorStateException, exception.getCause());
+
+        stateUpdater.add(activeTask2);
+        verifyUpdatingTasks(activeTask1, activeTask2);
+    }
+
+    @Test
+    public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskPause() 
throws Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        
doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+        
when(topologyMetadata.isPaused(null)).thenReturn(false).thenReturn(false).thenReturn(true);
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyExceptionsAndFailedTasks(new 
ExceptionAndTask(processorStateException, failedStatefulTask));
+        verifyPausedTasks(activeTask1);
+
+        stateUpdater.add(activeTask2);
+        verifyPausedTasks(activeTask1, activeTask2);
+    }
+
+    @Test
+    public void 
shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRestore() throws 
Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask failedStatefulTask = statefulTask(TASK_0_2, 
Set.of(TOPIC_PARTITION_B_0)).inState(State.RESTORING).build();
+        final ProcessorStateException processorStateException = new 
ProcessorStateException("flush");
+        
doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+        
when(changelogReader.completedChangelogs()).thenReturn(Set.of(TOPIC_PARTITION_B_0));
+
+        stateUpdater.add(failedStatefulTask);
+        stateUpdater.add(activeTask1);
+        stateUpdater.start();
+        verifyExceptionsAndFailedTasks(new 
ExceptionAndTask(processorStateException, failedStatefulTask));
+        verifyUpdatingTasks(activeTask1);
+
+        stateUpdater.add(activeTask2);
+        verifyUpdatingTasks(activeTask1, activeTask2);
+    }
+
     private static List<MetricName> getMetricNames(final String threadId) {
         final Map<String, String> tagMap = Map.of("thread-id", threadId);
         return List.of(
@@ -1771,7 +1881,8 @@ class DefaultStateUpdaterTest {
                         && restoredTasks.size() == 
expectedRestoredTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
-                "Did not get all restored active task within the given 
timeout!"
+                () -> "Did not get all restored active task within the given 
timeout! Expected: "
+                        + expectedRestoredTasks + ", actual: " + restoredTasks
             );
         }
     }
@@ -1786,7 +1897,8 @@ class DefaultStateUpdaterTest {
                     && restoredTasks.size() == expectedRestoredTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not get all restored active task within the given timeout!"
+            () -> "Did not get all restored active task within the given 
timeout! Expected: "
+                    + expectedRestoredTasks + ", actual: " + restoredTasks
         );
         
assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty());
     }
@@ -1808,7 +1920,8 @@ class DefaultStateUpdaterTest {
                         && updatingTasks.size() == 
expectedUpdatingTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
-                "Did not get all updating task within the given timeout!"
+                () -> "Did not get all updating task within the given timeout! 
Expected: "
+                        + expectedUpdatingTasks + ", actual: " + updatingTasks
             );
         }
     }
@@ -1823,7 +1936,8 @@ class DefaultStateUpdaterTest {
                     && standbyTasks.size() == expectedStandbyTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not see all standby task within the given timeout!"
+            () -> "Did not see all standby task within the given timeout! 
Expected: "
+                    + expectedStandbyTasks + ", actual: " + standbyTasks
         );
     }
 
@@ -1852,7 +1966,8 @@ class DefaultStateUpdaterTest {
                         && pausedTasks.size() == expectedPausedTasks.size();
                 },
                 VERIFICATION_TIMEOUT,
-                "Did not get all paused task within the given timeout!"
+                () -> "Did not get all paused task within the given timeout! 
Expected: "
+                        + expectedPausedTasks + ", actual: " + pausedTasks
             );
         }
     }
@@ -1867,7 +1982,8 @@ class DefaultStateUpdaterTest {
                     && failedTasks.size() == expectedExceptionAndTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not get all exceptions and failed tasks within the given 
timeout!"
+            () -> "Did not get all exceptions and failed tasks within the 
given timeout! Expected: "
+                    + expectedExceptionAndTasks + ", actual: " + failedTasks
         );
     }
 
@@ -1885,7 +2001,8 @@ class DefaultStateUpdaterTest {
                     && failedTasks.size() == expectedFailedTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not get all exceptions and failed tasks within the given 
timeout!"
+            () -> "Did not get all exceptions and failed tasks within the 
given timeout! Expected: "
+                        + expectedFailedTasks + ", actual: " + failedTasks
         );
     }
 
@@ -1903,7 +2020,8 @@ class DefaultStateUpdaterTest {
                     && failedTasks.size() == expectedExceptionAndTasks.size();
             },
             VERIFICATION_TIMEOUT,
-            "Did not get all exceptions and failed tasks within the given 
timeout!"
+            () -> "Did not get all exceptions and failed tasks within the 
given timeout! Expected: "
+                    + expectedExceptionAndTasks + ", actual: " + failedTasks
         );
         assertFalse(stateUpdater.hasExceptionsAndFailedTasks());
         assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty());
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 73d9b878970..80208775a32 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -126,6 +126,7 @@ import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.lenient;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -3431,6 +3432,28 @@ public class TaskManagerTest {
         verify(activeTaskCreator).close();
     }
 
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldCloseTasksIfStateUpdaterTimesOutOnRemove() throws 
Exception {
+        final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, mock(TasksRegistry.class), true, 
false);
+        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+                mkEntry(taskId00, taskId00Partitions)
+        );
+        final Task task00 = spy(new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager));
+
+        when(activeTaskCreator.createTasks(any(), 
eq(assignment))).thenReturn(singletonList(task00));
+        taskManager.handleAssignment(assignment, emptyMap());
+
+        when(stateUpdater.tasks()).thenReturn(singleton(task00));
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
mock(CompletableFuture.class);
+        when(stateUpdater.remove(eq(taskId00))).thenReturn(future);
+        when(future.get(anyLong(), any())).thenThrow(new 
java.util.concurrent.TimeoutException());
+
+        taskManager.shutdown(true);
+
+        verify(task00).closeDirty();
+    }
+
     @Test
     public void 
shouldOnlyCommitRevokedStandbyTaskAndPropagatePrepareCommitException() {
         setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2, false);
@@ -3598,13 +3621,14 @@ public class TaskManagerTest {
             .thenReturn(Arrays.asList(
                 new ExceptionAndTask(new RuntimeException(), 
failedStatefulTask),
                 new ExceptionAndTask(new RuntimeException(), 
failedStandbyTask))
-            );
+            )
+            .thenReturn(Collections.emptyList());
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
 
         taskManager.shutdown(true);
 
         verify(activeTaskCreator).close();
-        verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
         verify(failedStatefulTask).prepareCommit();
         verify(failedStatefulTask).suspend();
         verify(failedStatefulTask).closeDirty();
@@ -3617,7 +3641,7 @@ public class TaskManagerTest {
 
         taskManager.shutdown(true);
 
-        
verify(schedulingTaskManager).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(schedulingTaskManager).shutdown(Duration.ofMinutes(5L));
     }
 
     @Test
@@ -3642,8 +3666,8 @@ public class TaskManagerTest {
                 removedFailedStatefulTask,
                 removedFailedStandbyTask,
                 removedFailedStatefulTaskDuringRemoval,
-                removedFailedStandbyTaskDuringRemoval
-            ));
+                removedFailedStandbyTaskDuringRemoval)
+            ).thenReturn(Collections.emptySet());
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedStatefulTask = new CompletableFuture<>();
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedStandbyTask = new CompletableFuture<>();
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedFailedStatefulTask = new CompletableFuture<>();
@@ -3660,8 +3684,8 @@ public class TaskManagerTest {
             .thenReturn(futureForRemovedFailedStandbyTaskDuringRemoval);
         
when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(Arrays.asList(
             new ExceptionAndTask(new StreamsException("KABOOM!"), 
removedFailedStatefulTaskDuringRemoval),
-            new ExceptionAndTask(new StreamsException("KABOOM!"), 
removedFailedStandbyTaskDuringRemoval)
-        ));
+            new ExceptionAndTask(new StreamsException("KABOOM!"), 
removedFailedStandbyTaskDuringRemoval))
+        ).thenReturn(Collections.emptyList());
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
         futureForRemovedStatefulTask.complete(new 
StateUpdater.RemovedTaskResult(removedStatefulTask));
         futureForRemovedStandbyTask.complete(new 
StateUpdater.RemovedTaskResult(removedStandbyTask));
@@ -3676,7 +3700,7 @@ public class TaskManagerTest {
 
         taskManager.shutdown(true);
 
-        verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+        verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
         verify(tasks).addTask(removedStatefulTask);
         verify(tasks).addTask(removedStandbyTask);
         verify(removedFailedStatefulTask).prepareCommit();


Reply via email to