This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new dc72f6ec02 KAFKA-10199: Handle task closure and recycling from state 
updater (#12466)
dc72f6ec02 is described below

commit dc72f6ec02c7a7fbda083cd8a5a9f0081c7e58fd
Author: Guozhang Wang <wangg...@gmail.com>
AuthorDate: Mon Aug 15 19:33:46 2022 -0700

    KAFKA-10199: Handle task closure and recycling from state updater (#12466)
    
    1. Within the tryCompleteRestore function of the thread, try to drain the 
removed tasks from state updater and handle accordingly: 1) for recycle, 2) for 
closure, 3) for update input partitions.
    2. Catch up on some unit test coverage from previous PRs.
    3. Some minor cleanups around exception handling.
    
    Reviewers: Bruno Cadonna <cado...@apache.org>
---
 .../processor/internals/ActiveTaskCreator.java     |   1 +
 .../streams/processor/internals/StandbyTask.java   |   2 +-
 .../processor/internals/StandbyTaskCreator.java    |   1 +
 .../streams/processor/internals/StreamTask.java    |   2 +-
 .../streams/processor/internals/TaskManager.java   | 311 +++++++++++++--------
 .../kafka/streams/processor/internals/Tasks.java   |  90 +++---
 .../internals/ProcessorTopologyFactories.java      |   1 -
 .../processor/internals/StandbyTaskTest.java       |   5 +-
 .../processor/internals/StreamTaskTest.java        |  12 +-
 .../processor/internals/TaskManagerTest.java       | 178 +++++++++++-
 .../streams/processor/internals/TasksTest.java     |  64 ++++-
 11 files changed, 471 insertions(+), 196 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index 46455111db..d28d0d4444 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -135,6 +135,7 @@ class ActiveTaskCreator {
         return threadProducer;
     }
 
+    // TODO: convert to StreamTask when we remove TaskManager#StateMachineTask 
with mocks
     public Collection<Task> createTasks(final Consumer<byte[], byte[]> 
consumer,
                                         final Map<TaskId, Set<TopicPartition>> 
tasksToBeCreated) {
         final List<Task> createdTasks = new ArrayList<>();
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index bb7aef1dcd..87f19c4b1f 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -234,7 +234,7 @@ public class StandbyTask extends AbstractTask implements 
Task {
         closeTaskSensor.record();
         transitionTo(State.CLOSED);
 
-        log.info("Closed and recycled state, and converted type to active");
+        log.info("Closed and recycled state");
     }
 
     private void close(final boolean clean) {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
index 2f48cdb67f..26a3a49af3 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
@@ -67,6 +67,7 @@ class StandbyTaskCreator {
         );
     }
 
+    // TODO: convert to StandbyTask when we remove 
TaskManager#StateMachineTask with mocks
     Collection<Task> createTasks(final Map<TaskId, Set<TopicPartition>> 
tasksToBeCreated) {
         final List<Task> createdTasks = new ArrayList<>();
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index f7bf8a5e74..8a30bcf6ca 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -568,7 +568,7 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
         closeTaskSensor.record();
         transitionTo(State.CLOSED);
 
-        log.info("Closed and recycled state, and converted type to standby");
+        log.info("Closed and recycled state");
     }
 
     /**
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index cfd20d2299..03c36b0daf 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -298,14 +298,16 @@ public class TaskManager {
             logPrefix
         );
 
-        final LinkedHashMap<TaskId, RuntimeException> taskCloseExceptions = 
new LinkedHashMap<>();
         final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = new 
HashMap<>(activeTasks);
         final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = new 
HashMap<>(standbyTasks);
         final Map<Task, Set<TopicPartition>> tasksToRecycle = new HashMap<>();
         final Set<Task> tasksToCloseClean = new 
TreeSet<>(Comparator.comparing(Task::id));
 
-        tasks.purgePendingTasks(activeTasks.keySet(), standbyTasks.keySet());
-
+        // first put aside those unrecognized tasks because of unknown 
named-topologies
+        tasks.clearPendingTasksToCreate();
+        
tasks.addPendingActiveTasksToCreate(pendingTasksToCreate(activeTasksToCreate));
+        
tasks.addPendingStandbyTasksToCreate(pendingTasksToCreate(standbyTasksToCreate));
+        
         // first rectify all existing tasks:
         // 1. for tasks that are already owned, just update input partitions / 
resume and skip re-creating them
         // 2. for tasks that have changed active/standby status, just recycle 
and skip re-creating them
@@ -316,20 +318,18 @@ public class TaskManager {
             classifyTasksWithStateUpdater(activeTasksToCreate, 
standbyTasksToCreate, tasksToRecycle, tasksToCloseClean);
         }
 
-        tasks.addPendingActiveTasks(pendingTasksToCreate(activeTasksToCreate));
-        
tasks.addPendingStandbyTasks(pendingTasksToCreate(standbyTasksToCreate));
+        final Map<TaskId, RuntimeException> taskCloseExceptions = 
closeAndRecycleTasks(tasksToRecycle, tasksToCloseClean);
 
-        // close and recycle those tasks
-        closeAndRecycleTasks(
-            tasksToRecycle,
-            tasksToCloseClean,
-            taskCloseExceptions
-        );
+        throwTaskExceptions(taskCloseExceptions);
 
-        if (!taskCloseExceptions.isEmpty()) {
-            log.error("Hit exceptions while closing / recycling tasks: {}", 
taskCloseExceptions);
+        createNewTasks(activeTasksToCreate, standbyTasksToCreate);
+    }
 
-            for (final Map.Entry<TaskId, RuntimeException> entry : 
taskCloseExceptions.entrySet()) {
+    private void throwTaskExceptions(final Map<TaskId, RuntimeException> 
taskExceptions) {
+        if (!taskExceptions.isEmpty()) {
+            log.error("Get exceptions for the following tasks: {}", 
taskExceptions);
+
+            for (final Map.Entry<TaskId, RuntimeException> entry : 
taskExceptions.entrySet()) {
                 if (!(entry.getValue() instanceof TaskMigratedException)) {
                     final TaskId taskId = entry.getKey();
                     final RuntimeException exception = entry.getValue();
@@ -340,8 +340,8 @@ public class TaskManager {
                         throw new StreamsException(exception, taskId);
                     } else {
                         throw new StreamsException(
-                            "Unexpected failure to close " + 
taskCloseExceptions.size() +
-                                " task(s) [" + taskCloseExceptions.keySet() + 
"]. " +
+                            "Unexpected failure to close " + 
taskExceptions.size() +
+                                " task(s) [" + taskExceptions.keySet() + "]. " 
+
                                 "First unexpected exception (for task " + 
taskId + ") follows.",
                             exception,
                             taskId
@@ -352,11 +352,9 @@ public class TaskManager {
 
             // If all exceptions are task-migrated, we would just throw the 
first one. No need to wrap with a
             // StreamsException since TaskMigrated is handled explicitly by 
the StreamThread
-            final Map.Entry<TaskId, RuntimeException> first = 
taskCloseExceptions.entrySet().iterator().next();
+            final Map.Entry<TaskId, RuntimeException> first = 
taskExceptions.entrySet().iterator().next();
             throw first.getValue();
         }
-
-        createNewTasks(activeTasksToCreate, standbyTasksToCreate);
     }
 
     private void createNewTasks(final Map<TaskId, Set<TopicPartition>> 
activeTasksToCreate,
@@ -368,8 +366,8 @@ public class TaskManager {
             tasks.addNewActiveTasks(newActiveTasks);
             tasks.addNewStandbyTasks(newStandbyTask);
         } else {
-            tasks.addPendingTaskToRestore(newActiveTasks);
-            tasks.addPendingTaskToRestore(newStandbyTask);
+            tasks.addPendingTaskToInit(newActiveTasks);
+            tasks.addPendingTaskToInit(newStandbyTask);
         }
     }
 
@@ -442,26 +440,27 @@ public class TaskManager {
         classifyRunningTasks(activeTasksToCreate, standbyTasksToCreate, 
tasksToRecycle, tasksToCloseClean);
         for (final Task task : stateUpdater.getTasks()) {
             final TaskId taskId = task.id();
+            final Set<TopicPartition> topicPartitions = 
activeTasksToCreate.get(taskId);
             if (activeTasksToCreate.containsKey(taskId)) {
                 if (task.isActive()) {
-                    final Set<TopicPartition> topicPartitions = 
activeTasksToCreate.get(taskId);
                     if (!task.inputPartitions().equals(topicPartitions)) {
-                        
tasks.addPendingTaskThatNeedsInputPartitionsUpdate(taskId);
+                        stateUpdater.remove(taskId);
+                        tasks.addPendingTaskToUpdateInputPartitions(taskId, 
topicPartitions);
                     }
                 } else {
                     stateUpdater.remove(taskId);
-                    tasks.addPendingStandbyTaskToRecycle(taskId);
+                    tasks.addPendingTaskToRecycle(taskId, topicPartitions);
                 }
                 activeTasksToCreate.remove(taskId);
             } else if (standbyTasksToCreate.containsKey(taskId)) {
                 if (!task.isActive()) {
-                    final Set<TopicPartition> topicPartitions = 
standbyTasksToCreate.get(taskId);
                     if (!task.inputPartitions().equals(topicPartitions)) {
-                        
tasks.addPendingTaskThatNeedsInputPartitionsUpdate(taskId);
+                        stateUpdater.remove(taskId);
+                        tasks.addPendingTaskToUpdateInputPartitions(taskId, 
topicPartitions);
                     }
                 } else {
                     stateUpdater.remove(taskId);
-                    tasks.addPendingActiveTaskToRecycle(taskId);
+                    tasks.addPendingTaskToRecycle(taskId, topicPartitions);
                 }
                 standbyTasksToCreate.remove(taskId);
             } else {
@@ -478,6 +477,8 @@ public class TaskManager {
             final Map.Entry<TaskId, Set<TopicPartition>> entry = iter.next();
             final TaskId taskId = entry.getKey();
             if (taskId.topologyName() != null && 
!topologyMetadata.namedTopologiesView().contains(taskId.topologyName())) {
+                log.info("Cannot create the assigned task {} since it's 
topology name cannot be recognized, will put it " +
+                        "aside as pending for now and create later when 
topology metadata gets refreshed", taskId);
                 pendingTasks.put(taskId, entry.getValue());
                 iter.remove();
             }
@@ -485,9 +486,9 @@ public class TaskManager {
         return pendingTasks;
     }
 
-    private void closeAndRecycleTasks(final Map<Task, Set<TopicPartition>> 
tasksToRecycle,
-                                      final Set<Task> tasksToCloseClean,
-                                      final LinkedHashMap<TaskId, 
RuntimeException> taskCloseExceptions) {
+    private Map<TaskId, RuntimeException> closeAndRecycleTasks(final Map<Task, 
Set<TopicPartition>> tasksToRecycle,
+                                                               final Set<Task> 
tasksToCloseClean) {
+        final Map<TaskId, RuntimeException> taskCloseExceptions = new 
LinkedHashMap<>();
         final Set<Task> tasksToCloseDirty = new 
TreeSet<>(Comparator.comparing(Task::id));
 
         // for all tasks to close or recycle, we should first write a 
checkpoint as in post-commit
@@ -530,28 +531,32 @@ public class TaskManager {
         tasksToCloseClean.removeAll(tasksToCloseDirty);
         for (final Task task : tasksToCloseClean) {
             try {
-                final RuntimeException removeTaskException = 
completeTaskCloseClean(task);
-                if (removeTaskException != null) {
-                    taskCloseExceptions.putIfAbsent(task.id(), 
removeTaskException);
-                }
+                closeTaskClean(task);
             } catch (final RuntimeException closeTaskException) {
                 final String uncleanMessage = String.format(
-                        "Failed to close task %s cleanly. Attempting to close 
remaining tasks before re-throwing:",
-                        task.id());
+                    "Failed to close task %s cleanly. Attempting to close 
remaining tasks before re-throwing:",
+                    task.id());
                 log.error(uncleanMessage, closeTaskException);
+
+                if (task.state() != State.CLOSED) {
+                    tasksToCloseDirty.add(task);
+                }
+
                 taskCloseExceptions.putIfAbsent(task.id(), closeTaskException);
-                tasksToCloseDirty.add(task);
             }
         }
 
         tasksToRecycle.keySet().removeAll(tasksToCloseDirty);
         for (final Map.Entry<Task, Set<TopicPartition>> entry : 
tasksToRecycle.entrySet()) {
             final Task oldTask = entry.getKey();
+            final Set<TopicPartition> inputPartitions = entry.getValue();
             try {
                 if (oldTask.isActive()) {
-                    convertActiveToStandby((StreamTask) oldTask, 
entry.getValue());
+                    final StandbyTask standbyTask = 
convertActiveToStandby((StreamTask) oldTask, inputPartitions);
+                    tasks.replaceActiveWithStandby(standbyTask);
                 } else {
-                    convertStandbyToActive((StandbyTask) oldTask, 
entry.getValue());
+                    final StreamTask activeTask = 
convertStandbyToActive((StandbyTask) oldTask, inputPartitions);
+                    tasks.replaceStandbyWithActive(activeTask);
                 }
             } catch (final RuntimeException e) {
                 final String uncleanMessage = String.format("Failed to recycle 
task %s cleanly. " +
@@ -566,19 +571,18 @@ public class TaskManager {
         for (final Task task : tasksToCloseDirty) {
             closeTaskDirty(task);
         }
+
+        return taskCloseExceptions;
     }
 
-    private void convertActiveToStandby(final StreamTask activeTask,
-                                        final Set<TopicPartition> partitions) {
+    private StandbyTask convertActiveToStandby(final StreamTask activeTask, 
final Set<TopicPartition> partitions) {
         final StandbyTask standbyTask = 
standbyTaskCreator.createStandbyTaskFromActive(activeTask, partitions);
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTask.id());
-        tasks.replaceActiveWithStandby(standbyTask);
+        return standbyTask;
     }
 
-    private void convertStandbyToActive(final StandbyTask standbyTask,
-                                        final Set<TopicPartition> partitions) {
-        final StreamTask activeTask = 
activeTaskCreator.createActiveTaskFromStandby(standbyTask, partitions, 
mainConsumer);
-        tasks.replaceStandbyWithActive(activeTask);
+    private StreamTask convertStandbyToActive(final StandbyTask standbyTask, 
final Set<TopicPartition> partitions) {
+        return activeTaskCreator.createActiveTaskFromStandby(standbyTask, 
partitions, mainConsumer);
     }
 
     /**
@@ -641,9 +645,9 @@ public class TaskManager {
                 }
             }
         } else {
-            for (final Task task : tasks.drainPendingTaskToRestore()) {
-                stateUpdater.add(task);
-            }
+            addTasksToStateUpdater();
+
+            handleRemovedTasksFromStateUpdater();
 
             // TODO: should add logic for checking and resuming when all 
active tasks have been restored
         }
@@ -656,6 +660,72 @@ public class TaskManager {
         return allRunning;
     }
 
+    private void addTasksToStateUpdater() {
+        for (final Task task : tasks.drainPendingTaskToInit()) {
+            task.initializeIfNeeded();
+            stateUpdater.add(task);
+        }
+    }
+
+    private void handleRemovedTasksFromStateUpdater() {
+        final Map<TaskId, RuntimeException> taskExceptions = new 
LinkedHashMap<>();
+        final Set<Task> tasksToCloseDirty = new 
TreeSet<>(Comparator.comparing(Task::id));
+
+        for (final Task task : stateUpdater.drainRemovedTasks()) {
+            final TaskId taskId = task.id();
+            Set<TopicPartition> inputPartitions;
+            if ((inputPartitions = 
tasks.removePendingTaskToRecycle(task.id())) != null) {
+                try {
+                    final Task newTask = task.isActive() ?
+                        convertActiveToStandby((StreamTask) task, 
inputPartitions) :
+                        convertStandbyToActive((StandbyTask) task, 
inputPartitions);
+                    newTask.initializeIfNeeded();
+                    stateUpdater.add(newTask);
+                } catch (final RuntimeException e) {
+                    final String uncleanMessage = String.format("Failed to 
recycle task %s cleanly. " +
+                        "Attempting to handle remaining tasks before 
re-throwing:", taskId);
+                    log.error(uncleanMessage, e);
+
+                    if (task.state() != State.CLOSED) {
+                        tasksToCloseDirty.add(task);
+                    }
+
+                    taskExceptions.putIfAbsent(taskId, e);
+                }
+            } else if (tasks.removePendingTaskToClose(task.id())) {
+                try {
+                    task.suspend();
+                    task.closeClean();
+                    if (task.isActive()) {
+                        
activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
+                    }
+                } catch (final RuntimeException e) {
+                    final String uncleanMessage = String.format("Failed to 
close task %s cleanly. " +
+                        "Attempting to handle remaining tasks before 
re-throwing:", task.id());
+                    log.error(uncleanMessage, e);
+
+                    if (task.state() != State.CLOSED) {
+                        tasksToCloseDirty.add(task);
+                    }
+
+                    taskExceptions.putIfAbsent(task.id(), e);
+                }
+            } else if ((inputPartitions = 
tasks.removePendingTaskToUpdateInputPartitions(task.id())) != null) {
+                task.updateInputPartitions(inputPartitions, 
topologyMetadata.nodeToSourceTopics(task.id()));
+                stateUpdater.add(task);
+            } else {
+                throw new IllegalStateException("Got a removed task " + 
task.id() + " from the state updater " +
+                    " that is not for recycle, closing, or updating input 
partitions; this should not happen");
+            }
+        }
+
+        for (final Task task : tasksToCloseDirty) {
+            closeTaskDirty(task);
+        }
+
+        throwTaskExceptions(taskExceptions);
+    }
+
     /**
      * Handle the revoked partitions and prepare for closing the associated 
tasks in {@link #handleAssignment(Map, Map)}
      * We should commit the revoking tasks first before suspending them as we 
will not officially own them anymore when
@@ -735,7 +805,7 @@ public class TaskManager {
                     task.postCommit(true);
                 } catch (final RuntimeException e) {
                     log.error("Exception caught while post-committing task " + 
task.id(), e);
-                    maybeWrapAndSetFirstException(firstException, e, 
task.id());
+                    maybeSetFirstException(false, maybeWrapTaskException(e, 
task.id()), firstException);
                 }
             }
         }
@@ -750,7 +820,7 @@ public class TaskManager {
                         task.postCommit(false);
                     } catch (final RuntimeException e) {
                         log.error("Exception caught while post-committing task 
" + task.id(), e);
-                        maybeWrapAndSetFirstException(firstException, e, 
task.id());
+                        maybeSetFirstException(false, 
maybeWrapTaskException(e, task.id()), firstException);
                     }
                 }
             }
@@ -761,7 +831,7 @@ public class TaskManager {
                 task.suspend();
             } catch (final RuntimeException e) {
                 log.error("Caught the following exception while trying to 
suspend revoked task " + task.id(), e);
-                maybeWrapAndSetFirstException(firstException, e, task.id());
+                maybeSetFirstException(false, maybeWrapTaskException(e, 
task.id()), firstException);
             }
         }
 
@@ -961,19 +1031,12 @@ public class TaskManager {
         }
     }
 
-    private RuntimeException completeTaskCloseClean(final Task task) {
+    private void closeTaskClean(final Task task) {
         task.closeClean();
-        try {
-            tasks.removeTask(task);
-
-            if (task.isActive()) {
-                
activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
-            }
-        } catch (final RuntimeException e) {
-            log.error("Error removing active task {}: {}", task.id(), 
e.getMessage());
-            return e;
+        tasks.removeTask(task);
+        if (task.isActive()) {
+            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
         }
-        return null;
     }
 
     void shutdown(final boolean clean) {
@@ -1010,7 +1073,7 @@ public class TaskManager {
 
         final RuntimeException fatalException = firstException.get();
         if (fatalException != null) {
-            throw new RuntimeException("Unexpected exception while closing 
task", fatalException);
+            throw fatalException;
         }
     }
 
@@ -1076,59 +1139,51 @@ public class TaskManager {
         } else {
             try {
                 
taskExecutor.commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
-
-                for (final Task task : activeTaskIterable()) {
-                    try {
-                        task.postCommit(true);
-                    } catch (final RuntimeException e) {
-                        log.error("Exception caught while post-committing task 
" + task.id(), e);
-                        maybeWrapAndSetFirstException(firstException, e, 
task.id());
-                        tasksToCloseDirty.add(task);
-                        tasksToCloseClean.remove(task);
-                    }
-                }
-            } catch (final TimeoutException timeoutException) {
-                firstException.compareAndSet(null, timeoutException);
-
-                tasksToCloseClean.removeAll(tasksToCommit);
-                tasksToCloseDirty.addAll(tasksToCommit);
-            } catch (final TaskCorruptedException taskCorruptedException) {
-                firstException.compareAndSet(null, taskCorruptedException);
-
-                final Set<TaskId> corruptedTaskIds = 
taskCorruptedException.corruptedTasks();
-                final Set<Task> corruptedTasks = tasksToCommit
+            } catch (final RuntimeException e) {
+                log.error("Exception caught while committing tasks " + 
consumedOffsetsAndMetadataPerTask.keySet(), e);
+                // TODO: should record the task ids when handling this 
exception
+                maybeSetFirstException(false, e, firstException);
+
+                if (e instanceof TaskCorruptedException) {
+                    final TaskCorruptedException taskCorruptedException = 
(TaskCorruptedException) e;
+                    final Set<TaskId> corruptedTaskIds = 
taskCorruptedException.corruptedTasks();
+                    final Set<Task> corruptedTasks = tasksToCommit
                         .stream()
                         .filter(task -> corruptedTaskIds.contains(task.id()))
                         .collect(Collectors.toSet());
+                    tasksToCloseClean.removeAll(corruptedTasks);
+                    tasksToCloseDirty.addAll(corruptedTasks);
+                } else {
+                    // If the commit fails, everyone who participated in it 
must be closed dirty
+                    tasksToCloseClean.removeAll(tasksToCommit);
+                    tasksToCloseDirty.addAll(tasksToCommit);
+                }
+            }
 
-                tasksToCloseClean.removeAll(corruptedTasks);
-                tasksToCloseDirty.addAll(corruptedTasks);
-            } catch (final RuntimeException e) {
-                log.error("Exception caught while committing tasks during 
shutdown", e);
-                firstException.compareAndSet(null, e);
-
-                // If the commit fails, everyone who participated in it must 
be closed dirty
-                tasksToCloseClean.removeAll(tasksToCommit);
-                tasksToCloseDirty.addAll(tasksToCommit);
+            for (final Task task : activeTaskIterable()) {
+                try {
+                    task.postCommit(true);
+                } catch (final RuntimeException e) {
+                    log.error("Exception caught while post-committing task " + 
task.id(), e);
+                    maybeSetFirstException(false, maybeWrapTaskException(e, 
task.id()), firstException);
+                    tasksToCloseDirty.add(task);
+                    tasksToCloseClean.remove(task);
+                }
             }
         }
 
         for (final Task task : tasksToCloseClean) {
             try {
                 task.suspend();
-                final RuntimeException exception = 
completeTaskCloseClean(task);
-                if (exception != null) {
-                    firstException.compareAndSet(null, exception);
-                }
-            } catch (final StreamsException e) {
-                log.error("Exception caught while clean-closing task " + 
task.id(), e);
-                e.setTaskId(task.id());
-                firstException.compareAndSet(null, e);
-                tasksToCloseDirty.add(task);
+                closeTaskClean(task);
             } catch (final RuntimeException e) {
-                log.error("Exception caught while clean-closing task " + 
task.id(), e);
-                firstException.compareAndSet(null, new StreamsException(e, 
task.id()));
-                tasksToCloseDirty.add(task);
+                log.error("Exception caught while clean-closing active task 
{}: {}", task.id(), e.getMessage());
+
+                if (task.state() != State.CLOSED) {
+                    tasksToCloseDirty.add(task);
+                }
+                // ignore task migrated exception as it doesn't matter during 
shutdown
+                maybeSetFirstException(true, maybeWrapTaskException(e, 
task.id()), firstException);
             }
         }
 
@@ -1150,16 +1205,15 @@ public class TaskManager {
                 task.prepareCommit();
                 task.postCommit(true);
                 task.suspend();
-                final RuntimeException exception = 
completeTaskCloseClean(task);
-                if (exception != null) {
-                    maybeWrapAndSetFirstException(firstException, exception, 
task.id());
-                }
-            } catch (final TaskMigratedException e) {
-                // just ignore the exception as it doesn't matter during 
shutdown
-                tasksToCloseDirty.add(task);
+                closeTaskClean(task);
             } catch (final RuntimeException e) {
-                maybeWrapAndSetFirstException(firstException, e, task.id());
-                tasksToCloseDirty.add(task);
+                log.error("Exception caught while clean-closing standby task 
{}: {}", task.id(), e.getMessage());
+
+                if (task.state() != State.CLOSED) {
+                    tasksToCloseDirty.add(task);
+                }
+                // ignore task migrated exception as it doesn't matter during 
shutdown
+                maybeSetFirstException(true, maybeWrapTaskException(e, 
task.id()), firstException);
             }
         }
         return tasksToCloseDirty;
@@ -1340,8 +1394,8 @@ public class TaskManager {
     }
 
     void createPendingTasks(final Set<String> currentNamedTopologies) {
-        final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = 
tasks.pendingActiveTasksForTopologies(currentNamedTopologies);
-        final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = 
tasks.pendingStandbyTasksForTopologies(currentNamedTopologies);
+        final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = 
tasks.drainPendingActiveTasksForTopologies(currentNamedTopologies);
+        final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = 
tasks.drainPendingStandbyTasksForTopologies(currentNamedTopologies);
 
         createNewTasks(activeTasksToCreate, standbyTasksToCreate);
     }
@@ -1432,14 +1486,21 @@ public class TaskManager {
         return Collections.unmodifiableSet(lockedTaskDirectories);
     }
 
-    private void maybeWrapAndSetFirstException(final 
AtomicReference<RuntimeException> firstException,
-                                               final RuntimeException 
exception,
-                                               final TaskId taskId) {
-        if (exception instanceof StreamsException) {
-            ((StreamsException) exception).setTaskId(taskId);
+    private void maybeSetFirstException(final boolean ignoreTaskMigrated,
+                                        final RuntimeException exception,
+                                        final 
AtomicReference<RuntimeException> firstException) {
+        if (!ignoreTaskMigrated || !(exception instanceof 
TaskMigratedException)) {
             firstException.compareAndSet(null, exception);
+        }
+    }
+
+    private StreamsException maybeWrapTaskException(final RuntimeException 
exception, final TaskId taskId) {
+        if (exception instanceof StreamsException) {
+            final StreamsException streamsException = (StreamsException) 
exception;
+            streamsException.setTaskId(taskId);
+            return streamsException;
         } else {
-            firstException.compareAndSet(null, new StreamsException(exception, 
taskId));
+            return new StreamsException(exception, taskId);
         }
     }
 
@@ -1479,4 +1540,8 @@ public class TaskManager {
     void addTask(final Task task) {
         tasks.addTask(task);
     }
+
+    Tasks tasks() {
+        return tasks;
+    }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
index e360556658..9628b42d92 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
@@ -43,83 +43,90 @@ import static org.apache.kafka.common.utils.Utils.union;
 class Tasks {
     private final Logger log;
 
-    // TODO: change type to `StreamTask`
+    // TODO: convert to Stream/StandbyTask when we remove 
TaskManager#StateMachineTask with mocks
     private final Map<TaskId, Task> activeTasksPerId = new TreeMap<>();
-    // TODO: change type to `StandbyTask`
     private final Map<TaskId, Task> standbyTasksPerId = new TreeMap<>();
 
     // Tasks may have been assigned for a NamedTopology that is not yet known 
by this host. When that occurs we stash
     // these unknown tasks until either the corresponding NamedTopology is 
added and we can create them at last, or
     // we receive a new assignment and they are revoked from the thread.
-
-    // Tasks may have been assigned but not yet created because:
-    // 1. They are for a NamedTopology that is yet known by this host.
-    // 2. They are to be recycled from an existing restoring task yet to be 
returned from the state updater.
-    //
-    // When that occurs we stash these pending tasks until either they are 
finally clear to be created,
-    // or they are revoked from a new assignment.
     private final Map<TaskId, Set<TopicPartition>> pendingActiveTasksToCreate 
= new HashMap<>();
     private final Map<TaskId, Set<TopicPartition>> pendingStandbyTasksToCreate 
= new HashMap<>();
-
-    private final Set<Task> pendingTasksToRestore = new HashSet<>();
-
-    private final Set<TaskId> pendingActiveTasksToRecycle = new HashSet<>();
-    private final Set<TaskId> pendingStandbyTasksToRecycle = new HashSet<>();
-    private final Set<TaskId> pendingTasksThatNeedInputPartitionUpdate = new 
HashSet<>();
+    private final Map<TaskId, Set<TopicPartition>> pendingTasksToRecycle = new 
HashMap<>();
+    private final Map<TaskId, Set<TopicPartition>> 
pendingTasksToUpdateInputPartitions = new HashMap<>();
+    private final Set<Task> pendingTasksToInit = new HashSet<>();
     private final Set<TaskId> pendingTasksToClose = new HashSet<>();
 
-    // TODO: change type to `StreamTask`
+    // TODO: convert to Stream/StandbyTask when we remove 
TaskManager#StateMachineTask with mocks
     private final Map<TopicPartition, Task> activeTasksPerPartition = new 
HashMap<>();
 
     Tasks(final LogContext logContext) {
         this.log = logContext.logger(getClass());
     }
 
-    void purgePendingTasks(final Set<TaskId> assignedActiveTasks, final 
Set<TaskId> assignedStandbyTasks) {
-        pendingActiveTasksToCreate.keySet().retainAll(assignedActiveTasks);
-        pendingStandbyTasksToCreate.keySet().retainAll(assignedStandbyTasks);
+    void clearPendingTasksToCreate() {
+        pendingActiveTasksToCreate.clear();
+        pendingStandbyTasksToCreate.clear();
+    }
+
+    Map<TaskId, Set<TopicPartition>> 
drainPendingActiveTasksForTopologies(final Set<String> currentTopologies) {
+        final Map<TaskId, Set<TopicPartition>> pendingActiveTasksForTopologies 
=
+            filterMap(pendingActiveTasksToCreate, t -> 
currentTopologies.contains(t.getKey().topologyName()));
+
+        
pendingActiveTasksToCreate.keySet().removeAll(pendingActiveTasksForTopologies.keySet());
+
+        return pendingActiveTasksForTopologies;
+    }
+
+    Map<TaskId, Set<TopicPartition>> 
drainPendingStandbyTasksForTopologies(final Set<String> currentTopologies) {
+        final Map<TaskId, Set<TopicPartition>> pendingActiveTasksForTopologies 
=
+            filterMap(pendingStandbyTasksToCreate, t -> 
currentTopologies.contains(t.getKey().topologyName()));
+
+        
pendingStandbyTasksToCreate.keySet().removeAll(pendingActiveTasksForTopologies.keySet());
+
+        return pendingActiveTasksForTopologies;
     }
 
-    void addPendingActiveTasks(final Map<TaskId, Set<TopicPartition>> 
pendingTasks) {
+    void addPendingActiveTasksToCreate(final Map<TaskId, Set<TopicPartition>> 
pendingTasks) {
         pendingActiveTasksToCreate.putAll(pendingTasks);
     }
 
-    void addPendingStandbyTasks(final Map<TaskId, Set<TopicPartition>> 
pendingTasks) {
+    void addPendingStandbyTasksToCreate(final Map<TaskId, Set<TopicPartition>> 
pendingTasks) {
         pendingStandbyTasksToCreate.putAll(pendingTasks);
     }
 
-    void addPendingActiveTaskToRecycle(final TaskId taskId) {
-        pendingActiveTasksToRecycle.add(taskId);
+    Set<TopicPartition> removePendingTaskToRecycle(final TaskId taskId) {
+        return pendingTasksToRecycle.remove(taskId);
     }
 
-    void addPendingStandbyTaskToRecycle(final TaskId taskId) {
-        pendingStandbyTasksToRecycle.add(taskId);
+    void addPendingTaskToRecycle(final TaskId taskId, final 
Set<TopicPartition> inputPartitions) {
+        pendingTasksToRecycle.put(taskId, inputPartitions);
     }
 
-    void addPendingTaskThatNeedsInputPartitionsUpdate(final TaskId taskId) {
-        pendingTasksThatNeedInputPartitionUpdate.add(taskId);
+    Set<TopicPartition> removePendingTaskToUpdateInputPartitions(final TaskId 
taskId) {
+        return pendingTasksToUpdateInputPartitions.remove(taskId);
     }
 
-    void addPendingTaskToClose(final TaskId taskId) {
-        pendingTasksToClose.add(taskId);
+    void addPendingTaskToUpdateInputPartitions(final TaskId taskId, final 
Set<TopicPartition> inputPartitions) {
+        pendingTasksToUpdateInputPartitions.put(taskId, inputPartitions);
     }
 
-    void addPendingTaskToRestore(final Collection<Task> tasks) {
-        pendingTasksToRestore.addAll(tasks);
+    boolean removePendingTaskToClose(final TaskId taskId) {
+        return pendingTasksToClose.remove(taskId);
     }
 
-    Set<Task> drainPendingTaskToRestore() {
-        final Set<Task> result = new HashSet<>(pendingTasksToRestore);
-        pendingTasksToRestore.clear();
-        return result;
+    void addPendingTaskToClose(final TaskId taskId) {
+        pendingTasksToClose.add(taskId);
     }
 
-    Map<TaskId, Set<TopicPartition>> pendingActiveTasksForTopologies(final 
Set<String> currentTopologies) {
-        return filterMap(pendingActiveTasksToCreate, t -> 
currentTopologies.contains(t.getKey().topologyName()));
+    Set<Task> drainPendingTaskToInit() {
+        final Set<Task> result = new HashSet<>(pendingTasksToInit);
+        pendingTasksToInit.clear();
+        return result;
     }
 
-    Map<TaskId, Set<TopicPartition>> pendingStandbyTasksForTopologies(final 
Set<String> currentTopologies) {
-        return filterMap(pendingStandbyTasksToCreate, t -> 
currentTopologies.contains(t.getKey().topologyName()));
+    void addPendingTaskToInit(final Collection<Task> tasks) {
+        pendingTasksToInit.addAll(tasks);
     }
 
     void addNewActiveTasks(final Collection<Task> newTasks) {
@@ -136,7 +143,6 @@ class Tasks {
                 }
 
                 activeTasksPerId.put(activeTask.id(), activeTask);
-                pendingActiveTasksToCreate.remove(activeTask.id());
                 for (final TopicPartition topicPartition : 
activeTask.inputPartitions()) {
                     activeTasksPerPartition.put(topicPartition, activeTask);
                 }
@@ -158,7 +164,6 @@ class Tasks {
                 }
 
                 standbyTasksPerId.put(standbyTask.id(), standbyTask);
-                pendingStandbyTasksToCreate.remove(standbyTask.id());
             }
         }
     }
@@ -175,12 +180,10 @@ class Tasks {
                 throw new IllegalArgumentException("Attempted to remove an 
active task that is not owned: " + taskId);
             }
             removePartitionsForActiveTask(taskId);
-            pendingActiveTasksToCreate.remove(taskId);
         } else {
             if (standbyTasksPerId.remove(taskId) == null) {
                 throw new IllegalArgumentException("Attempted to remove a 
standby task that is not owned: " + taskId);
             }
-            pendingStandbyTasksToCreate.remove(taskId);
         }
     }
 
@@ -269,7 +272,6 @@ class Tasks {
         return tasks;
     }
 
-    // TODO: change return type to `StreamTask`
     Collection<Task> activeTasks() {
         return Collections.unmodifiableCollection(activeTasksPerId.values());
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java
index 57e4490d73..ccc78a4873 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java
@@ -25,7 +25,6 @@ import java.util.Map;
 public final class ProcessorTopologyFactories {
     private ProcessorTopologyFactories() {}
 
-
     public static ProcessorTopology with(final List<ProcessorNode<?, ?, ?, ?>> 
processorNodes,
                                          final Map<String, SourceNode<?, ?>> 
sourcesByTopic,
                                          final List<StateStore> 
stateStoresByName,
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index 02d742d8ab..ba484d210c 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -69,6 +69,7 @@ import static 
org.apache.kafka.streams.processor.internals.Task.State.SUSPENDED;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.isA;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -565,9 +566,10 @@ public class StandbyTaskTest {
     }
 
     @Test
-    public void shouldRecycleTask() {
+    public void shouldPrepareRecycleSuspendedTask() {
         
EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap());
         stateManager.recycle();
+        EasyMock.expectLastCall().once();
         EasyMock.replay(stateManager);
 
         task = createStandbyTask();
@@ -578,6 +580,7 @@ public class StandbyTaskTest {
 
         task.suspend();
         task.prepareRecycle(); // SUSPENDED
+        assertThat(task.state(), is(Task.State.CLOSED));
 
         // Currently, there are no metrics registered for standby tasks.
         // This is a regression test so that, if we add some, we will be sure 
to deregister them.
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 68d2def110..61b8791af7 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -2178,8 +2178,10 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void shouldUnregisterMetricsInCloseCleanAndRecycleState() {
+    public void shouldUnregisterMetricsAndCloseInPrepareRecycle() {
         
EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes();
+        stateManager.recycle();
+        EasyMock.expectLastCall().once();
         
EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes();
         EasyMock.replay(stateManager, recordCollector);
 
@@ -2189,6 +2191,7 @@ public class StreamTaskTest {
         assertThat(getTaskMetrics(), not(empty()));
         task.prepareRecycle();
         assertThat(getTaskMetrics(), empty());
+        assertThat(task.state(), is(Task.State.CLOSED));
     }
 
     @Test
@@ -2270,10 +2273,12 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void shouldOnlyRecycleSuspendedTasks() {
+    public void shouldPrepareRecycleSuspendedTask() {
+        
EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.emptyMap()).anyTimes();
         stateManager.recycle();
+        EasyMock.expectLastCall().once();
         recordCollector.closeClean();
-        
EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.emptyMap()).anyTimes();
+        EasyMock.expectLastCall().once();
         EasyMock.replay(stateManager, recordCollector);
 
         task = createStatefulTask(createConfig("100"), true);
@@ -2287,6 +2292,7 @@ public class StreamTaskTest {
 
         task.suspend();
         task.prepareRecycle(); // SUSPENDED
+        assertThat(task.state(), is(Task.State.CLOSED));
 
         EasyMock.verify(stateManager, recordCollector);
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index c3233152c0..12ea6477e5 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -257,23 +257,171 @@ public class TaskManagerTest {
 
     @Test
     public void shouldAddTasksToStateUpdater() {
+        final StreamTask task00 = statefulTask(taskId00, taskId00Partitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RESTORING)
+            .build();
+        final StandbyTask task01 = standbyTask(taskId01, taskId01Partitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+        
expect(changeLogReader.completedChangelogs()).andReturn(emptySet()).anyTimes();
+        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
+        consumer.resume(anyObject());
+        expectLastCall().anyTimes();
+        expect(activeTaskCreator.createTasks(anyObject(), 
eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        
expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andStubReturn(singletonList(task01));
+        replay(activeTaskCreator, standbyTaskCreator, consumer, 
changeLogReader);
+
+        taskManager = 
setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, true);
+        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
+        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter 
-> { });
+
+        Mockito.verify(task00).initializeIfNeeded();
+        Mockito.verify(task01).initializeIfNeeded();
+        Mockito.verify(stateUpdater).add(task00);
+        Mockito.verify(stateUpdater).add(task01);
+    }
+
+    @Test
+    public void shouldHandleRemovedTasksToRecycleFromStateUpdater() {
+        final StreamTask task00 = statefulTask(taskId00, taskId00Partitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RESTORING)
+            .build();
+        final StandbyTask task01 = standbyTask(taskId01, taskId01Partitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+        final StandbyTask task00Converted = standbyTask(taskId00, 
taskId00Partitions)
+            .withInputPartitions(taskId00Partitions)
+            .build();
+        final StreamTask task01Converted = statefulTask(taskId01, 
taskId01Partitions)
+            .withInputPartitions(taskId01Partitions)
+            .build();
+        when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(task00, 
task01));
+
+        taskManager = 
setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, true);
+        expect(activeTaskCreator.createActiveTaskFromStandby(eq(task01), 
eq(taskId01Partitions), eq(consumer)))
+            .andStubReturn(task01Converted);
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
+        expectLastCall().once();
+        expect(standbyTaskCreator.createStandbyTaskFromActive(eq(task00), 
eq(taskId00Partitions)))
+            .andStubReturn(task00Converted);
+        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
+        consumer.resume(anyObject());
+        expectLastCall().anyTimes();
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, 
consumer);
+
+        taskManager.tasks().addPendingTaskToRecycle(taskId00, 
taskId00Partitions);
+        taskManager.tasks().addPendingTaskToRecycle(taskId01, 
taskId01Partitions);
+        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter 
-> { });
+
+        Mockito.verify(task00Converted).initializeIfNeeded();
+        Mockito.verify(task01Converted).initializeIfNeeded();
+        Mockito.verify(stateUpdater).add(task00Converted);
+        Mockito.verify(stateUpdater).add(task01Converted);
+    }
+
+    @Test
+    public void shouldHandleRemovedTasksToCloseFromStateUpdater() {
+        final StreamTask task00 = statefulTask(taskId00, taskId00Partitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RESTORING)
+            .build();
+        final StandbyTask task01 = standbyTask(taskId01, taskId01Partitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+        when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(task00, 
task01));
+
+        taskManager = 
setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, true);
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
+        expectLastCall().once();
+        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
+        consumer.resume(anyObject());
+        expectLastCall().anyTimes();
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, 
consumer);
+
+        taskManager.tasks().addPendingTaskToClose(taskId00);
+        taskManager.tasks().addPendingTaskToClose(taskId01);
+
+        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter 
-> { });
+
+        Mockito.verify(task00).suspend();
+        Mockito.verify(task00).closeClean();
+        Mockito.verify(task01).suspend();
+        Mockito.verify(task01).closeClean();
+    }
+
+    @Test
+    public void 
shouldHandleRemovedTasksToUpdateInputPartitionsFromStateUpdater() {
+        final StreamTask task00 = statefulTask(taskId00, taskId00Partitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RESTORING)
+            .build();
+        final StandbyTask task01 = standbyTask(taskId01, taskId01Partitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+        when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(task00, 
task01));
+
+        taskManager = 
setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, true);
+        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
+        consumer.resume(anyObject());
+        expectLastCall().anyTimes();
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, 
consumer);
+
+        taskManager.tasks().addPendingTaskToUpdateInputPartitions(taskId00, 
taskId02Partitions);
+        taskManager.tasks().addPendingTaskToUpdateInputPartitions(taskId01, 
taskId03Partitions);
+
+        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter 
-> { });
+
+        Mockito.verify(task00).updateInputPartitions(taskId02Partitions, 
emptyMap());
+        Mockito.verify(stateUpdater).add(task00);
+        Mockito.verify(task01).updateInputPartitions(taskId03Partitions, 
emptyMap());
+        Mockito.verify(stateUpdater).add(task01);
+    }
+
+    @Test
+    public void shouldHandleRemovedTasksFromStateUpdater() {
+        // tasks to recycle
         final StreamTask task00 = mock(StreamTask.class);
         final StandbyTask task01 = mock(StandbyTask.class);
+        final StandbyTask task00Converted = mock(StandbyTask.class);
+        final StreamTask task01Converted = mock(StreamTask.class);
+        // task to close
+        final StreamTask task02 = mock(StreamTask.class);
+        // task to update inputs
+        final StreamTask task03 = mock(StreamTask.class);
         when(task00.id()).thenReturn(taskId00);
         when(task01.id()).thenReturn(taskId01);
+        when(task02.id()).thenReturn(taskId02);
+        when(task03.id()).thenReturn(taskId03);
         when(task00.inputPartitions()).thenReturn(taskId00Partitions);
         when(task01.inputPartitions()).thenReturn(taskId01Partitions);
+        when(task02.inputPartitions()).thenReturn(taskId02Partitions);
+        when(task03.inputPartitions()).thenReturn(taskId03Partitions);
         when(task00.isActive()).thenReturn(true);
         when(task01.isActive()).thenReturn(false);
+        when(task02.isActive()).thenReturn(true);
+        when(task03.isActive()).thenReturn(true);
         when(task00.state()).thenReturn(State.RESTORING);
         when(task01.state()).thenReturn(State.RUNNING);
-        
expect(changeLogReader.completedChangelogs()).andReturn(emptySet()).anyTimes();
+        when(task02.state()).thenReturn(State.RESTORING);
+        when(task03.state()).thenReturn(State.RESTORING);
+        when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(task00, 
task01, task02, task03));
+
+        expect(activeTaskCreator.createActiveTaskFromStandby(eq(task01), 
eq(taskId01Partitions), eq(consumer)))
+            .andStubReturn(task01Converted);
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
+        expectLastCall().times(2);
+        expect(standbyTaskCreator.createStandbyTaskFromActive(eq(task00), 
eq(taskId00Partitions)))
+            .andStubReturn(task00Converted);
         expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
         consumer.resume(anyObject());
         expectLastCall().anyTimes();
-        expect(activeTaskCreator.createTasks(anyObject(), 
eq(taskId00Assignment))).andStubReturn(singletonList(task00));
-        
expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andStubReturn(singletonList(task01));
-        replay(activeTaskCreator, standbyTaskCreator, consumer, 
changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, 
consumer);
 
         taskManager = new TaskManager(
             time,
@@ -288,12 +436,20 @@ public class TaskManagerTest {
             stateUpdater
         );
         taskManager.setMainConsumer(consumer);
-        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
+        taskManager.tasks().addPendingTaskToClose(taskId02);
+        taskManager.tasks().addPendingTaskToRecycle(taskId00, 
taskId00Partitions);
+        taskManager.tasks().addPendingTaskToRecycle(taskId01, 
taskId01Partitions);
+        taskManager.tasks().addPendingTaskToUpdateInputPartitions(taskId03, 
taskId03Partitions);
 
         taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter 
-> { });
 
-        Mockito.verify(stateUpdater).add(task00);
-        Mockito.verify(stateUpdater).add(task01);
+        Mockito.verify(task00Converted).initializeIfNeeded();
+        Mockito.verify(task01Converted).initializeIfNeeded();
+        Mockito.verify(stateUpdater).add(task00Converted);
+        Mockito.verify(stateUpdater).add(task01Converted);
+        Mockito.verify(task02).closeClean();
+        Mockito.verify(task03).updateInputPartitions(taskId03Partitions, 
emptyMap());
+        Mockito.verify(stateUpdater).add(task03);
     }
 
     @Test
@@ -1825,9 +1981,7 @@ public class TaskManagerTest {
             RuntimeException.class,
             () -> taskManager.shutdown(true)
         );
-        assertThat(exception.getMessage(), equalTo("Unexpected exception while 
closing task"));
-        assertThat(exception.getCause().getMessage(), is("migrated; it means 
all tasks belonging to this thread should be migrated."));
-        assertThat(exception.getCause().getCause().getMessage(), is("cause"));
+        assertThat(exception.getCause().getMessage(), is("oops"));
 
         assertThat(closedDirtyTask01.get(), is(true));
         assertThat(closedDirtyTask02.get(), is(true));
@@ -1887,7 +2041,6 @@ public class TaskManagerTest {
         final RuntimeException exception = 
assertThrows(RuntimeException.class, () -> taskManager.shutdown(true));
 
         assertThat(task00.state(), is(Task.State.CLOSED));
-        assertThat(exception.getMessage(), is("Unexpected exception while 
closing task"));
         assertThat(exception.getCause().getMessage(), is("whatever"));
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
@@ -1938,8 +2091,7 @@ public class TaskManagerTest {
         final RuntimeException exception = 
assertThrows(RuntimeException.class, () -> taskManager.shutdown(true));
 
         assertThat(task00.state(), is(Task.State.CLOSED));
-        assertThat(exception.getMessage(), is("Unexpected exception while 
closing task"));
-        assertThat(exception.getCause().getMessage(), is("whatever"));
+        assertThat(exception.getMessage(), is("whatever"));
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         // the active task creator should also get closed (so that it closes 
the thread producer if applicable)
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
index 756aa53f86..6265fd4ca2 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
@@ -22,7 +22,10 @@ import org.apache.kafka.streams.processor.TaskId;
 import org.junit.jupiter.api.Test;
 
 import java.util.Collections;
+import java.util.HashSet;
 
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.standbyTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask;
@@ -34,6 +37,8 @@ public class TasksTest {
 
     private final static TopicPartition TOPIC_PARTITION_A_0 = new 
TopicPartition("topicA", 0);
     private final static TopicPartition TOPIC_PARTITION_A_1 = new 
TopicPartition("topicA", 1);
+    private final static TopicPartition TOPIC_PARTITION_B_0 = new 
TopicPartition("topicB", 0);
+    private final static TopicPartition TOPIC_PARTITION_B_1 = new 
TopicPartition("topicB", 1);
     private final static TaskId TASK_0_0 = new TaskId(0, 0);
     private final static TaskId TASK_0_1 = new TaskId(0, 1);
     private final static TaskId TASK_1_0 = new TaskId(1, 0);
@@ -41,7 +46,7 @@ public class TasksTest {
     private final LogContext logContext = new LogContext();
 
     @Test
-    public void shouldCreateTasks() {
+    public void shouldKeepAddedTasks() {
         final Tasks tasks = new Tasks(logContext);
         final StreamTask statefulTask = statefulTask(TASK_0_0, 
mkSet(TOPIC_PARTITION_A_0)).build();
         final StandbyTask standbyTask = standbyTask(TASK_0_1, 
mkSet(TOPIC_PARTITION_A_1)).build();
@@ -51,15 +56,56 @@ public class TasksTest {
         tasks.addNewStandbyTasks(Collections.singletonList(standbyTask));
 
         assertEquals(statefulTask, tasks.task(statefulTask.id()));
-        assertTrue(tasks.activeTasks().contains(statefulTask));
-        assertTrue(tasks.allTasks().contains(statefulTask));
-        
assertTrue(tasks.tasks(mkSet(statefulTask.id())).contains(statefulTask));
         assertEquals(statelessTask, tasks.task(statelessTask.id()));
-        assertTrue(tasks.activeTasks().contains(statelessTask));
-        assertTrue(tasks.allTasks().contains(statelessTask));
-        
assertTrue(tasks.tasks(mkSet(statelessTask.id())).contains(statelessTask));
         assertEquals(standbyTask, tasks.task(standbyTask.id()));
-        assertTrue(tasks.allTasks().contains(standbyTask));
-        assertTrue(tasks.tasks(mkSet(standbyTask.id())).contains(standbyTask));
+
+        assertEquals(mkSet(statefulTask, statelessTask), new 
HashSet<>(tasks.activeTasks()));
+        assertEquals(mkSet(statefulTask, statelessTask, standbyTask), 
tasks.allTasks());
+        assertEquals(mkSet(statefulTask, standbyTask), 
tasks.tasks(mkSet(statefulTask.id(), standbyTask.id())));
+        assertEquals(mkSet(statefulTask.id(), statelessTask.id(), 
standbyTask.id()), tasks.allTaskIds());
+        assertEquals(
+            mkMap(
+                mkEntry(statefulTask.id(), statefulTask),
+                mkEntry(statelessTask.id(), statelessTask),
+                mkEntry(standbyTask.id(), standbyTask)
+            ),
+            tasks.allTasksPerId());
+        assertTrue(tasks.owned(statefulTask.id()));
+        assertTrue(tasks.owned(statelessTask.id()));
+        assertTrue(tasks.owned(statefulTask.id()));
+    }
+
+    @Test
+    public void shouldDrainPendingTasksToCreate() {
+        final Tasks tasks = new Tasks(logContext);
+
+        tasks.addPendingActiveTasksToCreate(mkMap(
+            mkEntry(new TaskId(0, 0, "A"), mkSet(TOPIC_PARTITION_A_0)),
+            mkEntry(new TaskId(0, 1, "A"), mkSet(TOPIC_PARTITION_A_1)),
+            mkEntry(new TaskId(0, 0, "B"), mkSet(TOPIC_PARTITION_B_0)),
+            mkEntry(new TaskId(0, 1, "B"), mkSet(TOPIC_PARTITION_B_1))
+        ));
+
+        tasks.addPendingStandbyTasksToCreate(mkMap(
+            mkEntry(new TaskId(0, 0, "A"), mkSet(TOPIC_PARTITION_A_0)),
+            mkEntry(new TaskId(0, 1, "A"), mkSet(TOPIC_PARTITION_A_1)),
+            mkEntry(new TaskId(0, 0, "B"), mkSet(TOPIC_PARTITION_B_0)),
+            mkEntry(new TaskId(0, 1, "B"), mkSet(TOPIC_PARTITION_B_1))
+        ));
+
+        assertEquals(mkMap(
+            mkEntry(new TaskId(0, 0, "A"), mkSet(TOPIC_PARTITION_A_0)),
+            mkEntry(new TaskId(0, 1, "A"), mkSet(TOPIC_PARTITION_A_1))
+        ), tasks.drainPendingActiveTasksForTopologies(mkSet("A")));
+
+        assertEquals(mkMap(
+            mkEntry(new TaskId(0, 0, "A"), mkSet(TOPIC_PARTITION_A_0)),
+            mkEntry(new TaskId(0, 1, "A"), mkSet(TOPIC_PARTITION_A_1))
+        ), tasks.drainPendingStandbyTasksForTopologies(mkSet("A")));
+
+        tasks.clearPendingTasksToCreate();
+
+        assertEquals(Collections.emptyMap(), 
tasks.drainPendingActiveTasksForTopologies(mkSet("B")));
+        assertEquals(Collections.emptyMap(), 
tasks.drainPendingStandbyTasksForTopologies(mkSet("B")));
     }
 }
\ No newline at end of file

Reply via email to